Skip to content

Commit 9784454

Browse files
Kiuk Chungfacebook-github-bot
authored andcommitted
(torchx/cli) allow --workspace to be configured through .torchxconfig (#397)
Summary: Pull Request resolved: #397 For `torchx run`, allows `--workspace` to be configured through `.torchxconfig`. Adds a custom `argparse.Action` that can be used to read and default other CLI options through `.torchxconfig`. Only hooked it up to `workspace` for now since there is some cleanup to be done to hook it up for `component`, `component args` (we get this from a different section not `[cli:run]`)) Reviewed By: divchenko Differential Revision: D34317215 fbshipit-source-id: 878f9d75002ab352f4d08d15802cfa405995f544
1 parent 4b09940 commit 9784454

File tree

6 files changed

+312
-38
lines changed

6 files changed

+312
-38
lines changed

torchx/cli/argparse_util.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from argparse import Action, ArgumentParser, Namespace
8+
from pathlib import Path
9+
from typing import Any, Dict, Optional, Sequence, Text
10+
11+
from torchx.runner import config
12+
13+
14+
CONFIG_DIRS = [str(Path.home()), str(Path.cwd())]
15+
16+
17+
class _torchxconfig(Action):
18+
"""
19+
Custom argparse action that loads default torchx CLI options
20+
from .torchxconfig file.
21+
22+
"""
23+
24+
# since this action is used for each argparse argument
25+
# load the config section for the subcmd once
26+
_subcmd_configs: Dict[str, Dict[str, str]] = {}
27+
28+
def __init__(
29+
self,
30+
subcmd: str,
31+
dest: str,
32+
option_strings: Sequence[Text],
33+
required: bool = False,
34+
# pyre-ignore[2] declared as Any in superclass Action
35+
default: Any = None,
36+
**kwargs: Any,
37+
) -> None:
38+
cfg = self._subcmd_configs.setdefault(
39+
subcmd,
40+
config.get_configs(
41+
prefix="cli",
42+
name=subcmd,
43+
dirs=CONFIG_DIRS,
44+
),
45+
)
46+
47+
# if found in .torchxconfig make it the default for this argument
48+
# otherwise use the default defined from add_argument(...)
49+
default = cfg.get(dest, default)
50+
51+
# ``required`` means that it NEEDS to be present in the CLI args
52+
# if we found it in .torchxconfig then we don't "require" it to be
53+
# in the CLI args so set it to False
54+
if default:
55+
required = False
56+
57+
super().__init__(
58+
dest=dest,
59+
default=default,
60+
option_strings=option_strings,
61+
required=required,
62+
**kwargs,
63+
)
64+
65+
def __call__(
66+
self,
67+
parser: ArgumentParser,
68+
namespace: Namespace,
69+
values: Any, # pyre-ignore[2] declared as Any in superclass Action
70+
option_string: Optional[str] = None,
71+
) -> None:
72+
setattr(namespace, self.dest, values)
73+
74+
75+
# argparse takes the action as a Type[Action] so we can't have custom constructors
76+
# hence for each subcommand we need to subclass the base _torchxconfig Action
77+
# this is also how store_true and store_false builtin actions are implemented in argparse
78+
class torchxconfig_run(_torchxconfig):
79+
"""
80+
Custom action that gets the default argument from .torchxconfig.
81+
"""
82+
83+
def __init__(
84+
self,
85+
dest: str,
86+
option_strings: Sequence[Text],
87+
required: bool = False,
88+
# pyre-ignore[2] declared as Any in superclass Action
89+
default: Any = None,
90+
**kwargs: Any,
91+
) -> None:
92+
super().__init__(
93+
"run",
94+
dest=dest,
95+
default=default,
96+
required=required,
97+
option_strings=option_strings,
98+
**kwargs,
99+
)

torchx/cli/cmd_run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torchx.specs as specs
1818
from pyre_extensions import none_throws
19+
from torchx.cli.argparse_util import CONFIG_DIRS, torchxconfig_run
1920
from torchx.cli.cmd_base import SubCommand
2021
from torchx.cli.cmd_log import get_logs
2122
from torchx.runner import Runner, config, get_runner
@@ -36,7 +37,6 @@
3637
"missing component name, either provide it from the CLI or in .torchxconfig"
3738
)
3839

39-
CONFIG_DIRS = [str(Path.home()), str(Path.cwd())]
4040

4141
logger: logging.Logger = logging.getLogger(__name__)
4242

@@ -185,6 +185,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
185185
"--workspace",
186186
"--buck-target",
187187
default=f"file://{Path.cwd()}",
188+
action=torchxconfig_run,
188189
help="local workspace to build/patch (buck-target of main binary if using buck)",
189190
)
190191
subparser.add_argument(

torchx/cli/conf_helpers.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

torchx/cli/test/argparse_util_test.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import shutil
8+
import tempfile
9+
import unittest
10+
from argparse import ArgumentParser
11+
from pathlib import Path
12+
from unittest import mock
13+
14+
from torchx.cli import argparse_util
15+
from torchx.cli.argparse_util import torchxconfig_run
16+
17+
18+
CONFIG_DIRS = "torchx.cli.argparse_util.CONFIG_DIRS"
19+
20+
21+
class ArgparseUtilTest(unittest.TestCase):
22+
def _write(self, filename: str, content: str) -> Path:
23+
f = Path(self.test_dir) / filename
24+
f.parent.mkdir(parents=True, exist_ok=True)
25+
with open(f, "w") as fp:
26+
fp.write(content)
27+
return f
28+
29+
def setUp(self) -> None:
30+
self.test_dir = tempfile.mkdtemp(prefix="torchx_argparse_util_test")
31+
argparse_util._torchxconfig._subcmd_configs.clear()
32+
33+
def tearDown(self) -> None:
34+
shutil.rmtree(self.test_dir)
35+
36+
def test_torchxconfig_action(self) -> None:
37+
with mock.patch(CONFIG_DIRS, [self.test_dir]):
38+
self._write(
39+
".torchxconfig",
40+
"""
41+
[cli:run]
42+
workspace = baz
43+
""",
44+
)
45+
46+
parser = ArgumentParser()
47+
48+
subparsers = parser.add_subparsers()
49+
run_parser = subparsers.add_parser("run")
50+
51+
run_parser.add_argument(
52+
"--workspace",
53+
default="foo",
54+
type=str,
55+
action=torchxconfig_run,
56+
)
57+
58+
# arguments specified in CLI should take outmost precedence
59+
args = parser.parse_args(["run", "--workspace", "bar"])
60+
self.assertEqual("bar", args.workspace)
61+
62+
# if not specified in CLI, then grab it from .torchxconfig
63+
args = parser.parse_args(["run"])
64+
self.assertEqual("baz", args.workspace)
65+
66+
def test_torchxconfig_action_argparse_default(self) -> None:
67+
with mock.patch(CONFIG_DIRS, [self.test_dir]):
68+
self._write(
69+
".torchxconfig",
70+
"""
71+
[cli:run]
72+
""",
73+
)
74+
75+
parser = ArgumentParser()
76+
77+
subparsers = parser.add_subparsers()
78+
run_parser = subparsers.add_parser("run")
79+
80+
run_parser.add_argument(
81+
"--workspace",
82+
default="foo",
83+
type=str,
84+
action=torchxconfig_run,
85+
)
86+
87+
# if not found in .torchxconfig should use argparse default
88+
args = parser.parse_args(["run"])
89+
self.assertEqual("foo", args.workspace)
90+
91+
def test_torchxconfig_action_required(self) -> None:
92+
with mock.patch(CONFIG_DIRS, [self.test_dir]):
93+
self._write(
94+
".torchxconfig",
95+
"""
96+
[cli:run]
97+
workspace = bazz
98+
""",
99+
)
100+
101+
parser = ArgumentParser()
102+
103+
subparsers = parser.add_subparsers()
104+
run_parser = subparsers.add_parser("run")
105+
106+
run_parser.add_argument(
107+
"--workspace",
108+
required=True,
109+
type=str,
110+
action=torchxconfig_run,
111+
)
112+
113+
# arguments specified in CLI should take outmost precedence
114+
args = parser.parse_args(["run", "--workspace", "bar"])
115+
self.assertEqual("bar", args.workspace)
116+
117+
# if not specified in CLI, then grab it from .torchxconfig
118+
args = parser.parse_args(["run"])
119+
self.assertEqual("bazz", args.workspace)
120+
121+
def test_torchxconfig_action_aliases(self) -> None:
122+
# for aliases, the config file needs to declare the original arg
123+
with mock.patch(CONFIG_DIRS, [self.test_dir]):
124+
self._write(
125+
".torchxconfig",
126+
"""
127+
[cli:run]
128+
workspace = baz
129+
""",
130+
)
131+
132+
parser = ArgumentParser()
133+
134+
subparsers = parser.add_subparsers()
135+
run_parser = subparsers.add_parser("run")
136+
137+
run_parser.add_argument(
138+
"--workspace",
139+
"--buck-target",
140+
type=str,
141+
required=True,
142+
action=torchxconfig_run,
143+
)
144+
145+
args = parser.parse_args(["run"])
146+
self.assertEqual("baz", args.workspace)

torchx/runner/config.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,31 @@ def strip_prefix(section_name: str) -> Optional[str]:
326326
return sections
327327

328328

329+
def get_configs(
330+
prefix: str,
331+
name: str,
332+
dirs: Optional[List[str]],
333+
) -> Dict[str, str]:
334+
"""
335+
Gets all the config values in the section ``["{prefix}:{name}"]``.
336+
Or an empty map if the section does not exist.
337+
338+
Example:
339+
340+
::
341+
342+
# for config file:
343+
# [foo:bar]
344+
# baz = 1
345+
346+
get_configs(prefix="foo", name="bar") # returns {"baz": "1"}
347+
get_config(prefix="foo", name="barr") # returns {}
348+
349+
"""
350+
sections = load_sections(prefix, dirs)
351+
return sections.get(name, {})
352+
353+
329354
def get_config(
330355
prefix: str,
331356
name: str,
@@ -350,9 +375,7 @@ def get_config(
350375
get_config(prefix="fooo", name="bar", key="baz") == None
351376
352377
"""
353-
sections = load_sections(prefix, dirs)
354-
section = sections.get(name, {})
355-
return section.get(key, None)
378+
return get_configs(prefix, name, dirs).get(key, None)
356379

357380

358381
def find_configs(dirs: Optional[Iterable[str]] = None) -> List[str]:

0 commit comments

Comments
 (0)