Skip to content

Commit 28e96f7

Browse files
aivanoufacebook-github-bot
authored andcommitted
Make docstring optional (#259)
Summary: Pull Request resolved: #259 * Refactor docstring functions: combines two functions that retrieve docstring into one * Make docstring optional * Remove docstring validator Differential Revision: D31671125 fbshipit-source-id: 2fb71f6b98e212700479003ca4d15a01cec0e571
1 parent 95ea9f5 commit 28e96f7

File tree

6 files changed

+92
-193
lines changed

6 files changed

+92
-193
lines changed

torchx/specs/api.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
)
2929

3030
import yaml
31-
from pyre_extensions import none_throws
32-
from torchx.specs.file_linter import parse_fn_docstring
31+
from torchx.specs.file_linter import get_fn_docstring
3332
from torchx.util.types import decode_from_string, decode_optional, is_bool, is_primitive
3433

3534

@@ -748,22 +747,21 @@ def get_argparse_param_type(parameter: inspect.Parameter) -> Callable[[str], obj
748747
return str
749748

750749

751-
def _create_args_parser(
752-
fn_name: str,
753-
parameters: Mapping[str, inspect.Parameter],
754-
function_desc: str,
755-
args_desc: Dict[str, str],
756-
) -> argparse.ArgumentParser:
750+
def _create_args_parser(app_fn: Callable[..., AppDef]) -> argparse.ArgumentParser:
751+
parameters = inspect.signature(app_fn).parameters
752+
function_desc, args_desc = get_fn_docstring(app_fn)
757753
script_parser = argparse.ArgumentParser(
758-
prog=f"torchx run ...torchx_params... {fn_name} ",
754+
prog=f"torchx run <<torchx_params>> {app_fn.__name__} ",
759755
description=f"App spec: {function_desc}",
756+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
760757
)
761758

762759
remainder_arg = []
763760

764761
for param_name, parameter in parameters.items():
762+
param_desc = args_desc[parameter.name]
765763
args: Dict[str, Any] = {
766-
"help": args_desc[param_name],
764+
"help": param_desc,
767765
"type": get_argparse_param_type(parameter),
768766
}
769767
if parameter.default != inspect.Parameter.empty:
@@ -788,20 +786,15 @@ def _create_args_parser(
788786
def _get_function_args(
789787
app_fn: Callable[..., AppDef], app_args: List[str]
790788
) -> Tuple[List[object], List[str], Dict[str, object]]:
791-
docstring = none_throws(inspect.getdoc(app_fn))
792-
function_desc, args_desc = parse_fn_docstring(docstring)
793-
794-
parameters = inspect.signature(app_fn).parameters
795-
script_parser = _create_args_parser(
796-
app_fn.__name__, parameters, function_desc, args_desc
797-
)
789+
script_parser = _create_args_parser(app_fn)
798790

799791
parsed_args = script_parser.parse_args(app_args)
800792

801793
function_args = []
802794
var_arg = []
803795
kwargs = {}
804796

797+
parameters = inspect.signature(app_fn).parameters
805798
for param_name, parameter in parameters.items():
806799
arg_value = getattr(parsed_args, param_name)
807800
parameter_type = parameter.annotation

torchx/specs/file_linter.py

Lines changed: 16 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
import abc
99
import ast
10+
import inspect
1011
from dataclasses import dataclass
11-
from typing import Dict, List, Optional, Tuple, cast
12+
from typing import Dict, List, Optional, Tuple, cast, Callable
1213

1314
from docstring_parser import parse
1415
from pyre_extensions import none_throws
@@ -30,41 +31,29 @@ def get_arg_names(app_specs_func_def: ast.FunctionDef) -> List[str]:
3031
return arg_names
3132

3233

33-
def parse_fn_docstring(func_description: str) -> Tuple[str, Dict[str, str]]:
34+
def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str, str]:
35+
parameters = inspect.signature(fn).parameters
36+
args_decs = {}
37+
for parameter_name in parameters.keys():
38+
args_decs[parameter_name] = parameter_name
39+
return args_decs
40+
41+
42+
def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
3443
"""
3544
Given a docstring in a google-style format, returns the function description and
3645
description of all arguments.
3746
See: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
3847
"""
39-
args_description = {}
48+
args_description = _get_default_arguments_descriptions(fn)
49+
func_description = inspect.getdoc(fn)
50+
if not func_description:
51+
return fn.__name__, args_description
4052
docstring = parse(func_description)
4153
for param in docstring.params:
4254
args_description[param.arg_name] = param.description
4355
short_func_description = docstring.short_description
44-
return (short_func_description or "", args_description)
45-
46-
47-
def _get_fn_docstring(
48-
source: str, function_name: str
49-
) -> Optional[Tuple[str, Dict[str, str]]]:
50-
module = ast.parse(source)
51-
for expr in module.body:
52-
if type(expr) == ast.FunctionDef:
53-
func_def = cast(ast.FunctionDef, expr)
54-
if func_def.name == function_name:
55-
docstring = ast.get_docstring(func_def)
56-
if not docstring:
57-
return None
58-
return parse_fn_docstring(docstring)
59-
return None
60-
61-
62-
def get_short_fn_description(path: str, function_name: str) -> Optional[str]:
63-
source = read_conf_file(path)
64-
docstring = _get_fn_docstring(source, function_name)
65-
if not docstring:
66-
return None
67-
return docstring[0]
56+
return (short_func_description or fn.__name__, args_description)
6857

6958

7059
@dataclass
@@ -91,38 +80,6 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
9180
)
9281

9382

94-
class TorchxDocstringValidator(TorchxFunctionValidator):
95-
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
96-
"""
97-
Validates the docstring of the `get_app_spec` function. Criteria:
98-
* There mast be google-style docstring
99-
* If there are more than zero arguments, there mast be a `Args:` section defined
100-
with all arguments included.
101-
"""
102-
docsting = ast.get_docstring(app_specs_func_def)
103-
lineno = app_specs_func_def.lineno
104-
if not docsting:
105-
desc = (
106-
f"`{app_specs_func_def.name}` is missing a Google Style docstring, please add one. "
107-
"For more information on the docstring format see: "
108-
"https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html"
109-
)
110-
return [self._gen_linter_message(desc, lineno)]
111-
112-
arg_names = get_arg_names(app_specs_func_def)
113-
_, docstring_arg_defs = parse_fn_docstring(docsting)
114-
missing_args = [
115-
arg_name for arg_name in arg_names if arg_name not in docstring_arg_defs
116-
]
117-
if len(missing_args) > 0:
118-
desc = (
119-
f"`{app_specs_func_def.name}` not all function arguments are present"
120-
f" in the docstring. Missing args: {missing_args}"
121-
)
122-
return [self._gen_linter_message(desc, lineno)]
123-
return []
124-
125-
12683
class TorchxFunctionArgsValidator(TorchxFunctionValidator):
12784
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
12885
linter_errors = []
@@ -149,7 +106,6 @@ def _validate_arg_def(
149106
)
150107
]
151108
if isinstance(arg_def.annotation, ast.Name):
152-
# TODO(aivanou): add support for primitive type check
153109
return []
154110
complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
155111
if complex_type_def.value.id == "Optional":
@@ -239,12 +195,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):
239195
Visitor that finds the component_function and runs registered validators on it.
240196
Current registered validators:
241197
242-
* TorchxDocstringValidator - validates the docstring of the function.
243-
Criteria:
244-
* There format should be google-python
245-
* If there are more than zero arguments defined, there
246-
should be obligatory `Args:` section that describes each argument on a new line.
247-
248198
* TorchxFunctionArgsValidator - validates arguments of the function.
249199
Criteria:
250200
* Each argument should be annotated with the type
@@ -260,7 +210,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):
260210

261211
def __init__(self, component_function_name: str) -> None:
262212
self.validators = [
263-
TorchxDocstringValidator(),
264213
TorchxFunctionArgsValidator(),
265214
TorchxReturnValidator(),
266215
]

torchx/specs/finder.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pyre_extensions import none_throws
1919
from torchx.specs import AppDef
20-
from torchx.specs.file_linter import get_short_fn_description, validate
20+
from torchx.specs.file_linter import get_fn_docstring, validate
2121
from torchx.util import entrypoints
2222
from torchx.util.io import read_conf_file
2323

@@ -47,7 +47,7 @@ class _Component:
4747
"""
4848

4949
name: str
50-
description: Optional[str]
50+
description: str
5151
fn_name: str
5252
fn: Callable[..., AppDef]
5353
validation_errors: List[str]
@@ -119,9 +119,10 @@ def _get_components_from_dir(
119119
search_pattern = os.path.join(search_dir, "**", "*.py")
120120
component_defs = []
121121
for filepath in glob.glob(search_pattern, recursive=True):
122-
module = self._try_load_module(
123-
self._get_module_name(filepath, search_dir, base_module)
124-
)
122+
module_name = self._get_module_name(filepath, search_dir, base_module)
123+
if module_name.startswith("torchx.components.base"):
124+
continue
125+
module = self._try_load_module(module_name)
125126
defs = self._get_components_from_module(base_module, module)
126127
component_defs += defs
127128
return component_defs
@@ -146,7 +147,7 @@ def _get_components_from_module(
146147
module_path = os.path.abspath(module.__file__)
147148
for function_name, function in functions:
148149
linter_errors = validate(module_path, function_name)
149-
component_desc = get_short_fn_description(module_path, function_name)
150+
component_desc, _ = get_fn_docstring(function)
150151
component_def = _Component(
151152
name=self._get_component_name(
152153
base_module, module.__name__, function_name
@@ -193,7 +194,6 @@ def find(self) -> List[_Component]:
193194
validation_errors = self._get_validation_errors(
194195
self._filepath, self._function_name
195196
)
196-
fn_desc = get_short_fn_description(self._filepath, self._function_name)
197197

198198
file_source = read_conf_file(self._filepath)
199199
namespace = globals()
@@ -203,6 +203,7 @@ def find(self) -> List[_Component]:
203203
f"Function {self._function_name} does not exist in file {self._filepath}"
204204
)
205205
app_fn = namespace[self._function_name]
206+
fn_desc, _ = get_fn_docstring(app_fn)
206207
return [
207208
_Component(
208209
name=f"{self._filepath}:{self._function_name}",

torchx/specs/test/api_test.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import argparse
89
import sys
910
import unittest
1011
from dataclasses import asdict
@@ -33,6 +34,7 @@
3334
make_app_handle,
3435
parse_app_handle,
3536
runopts,
37+
_create_args_parser,
3638
)
3739

3840

@@ -463,11 +465,6 @@ def _test_complex_fn(
463465
app_name: AppDef name
464466
containers: List of containers
465467
roles_scripts: Dict role_name -> role_script
466-
num_cpus: List of cpus per role
467-
num_gpus: Dict role_name -> gpus used for role
468-
nnodes: Num replicas per role
469-
first_arg: First argument to the user script
470-
roles_args: Roles args
471468
"""
472469
num_roles = len(roles_scripts)
473470
if not num_cpus:
@@ -710,3 +707,28 @@ def test_varargs_only_arg_first(self) -> None:
710707
_TEST_VAR_ARGS_FIRST,
711708
(("fooval", "--foo", "barval", "arg1", "arg2"), "asdf"),
712709
)
710+
711+
def _get_argument_help(
712+
self, parser: argparse.ArgumentParser, name: str
713+
) -> Optional[str]:
714+
actions = parser._actions
715+
for action in actions:
716+
if action.dest == name:
717+
return action.help
718+
return None
719+
720+
def test_argparster_complex_fn_partial(self) -> None:
721+
parser = _create_args_parser(_test_complex_fn)
722+
self.assertEqual("AppDef name", self._get_argument_help(parser, "app_name"))
723+
self.assertEqual(
724+
"List of containers", self._get_argument_help(parser, "containers")
725+
)
726+
self.assertEqual(
727+
"Dict role_name -> role_script",
728+
self._get_argument_help(parser, "roles_scripts"),
729+
)
730+
self.assertEqual("num_cpus", self._get_argument_help(parser, "num_cpus"))
731+
self.assertEqual("num_gpus", self._get_argument_help(parser, "num_gpus"))
732+
self.assertEqual("nnodes", self._get_argument_help(parser, "nnodes"))
733+
self.assertEqual("first_arg", self._get_argument_help(parser, "first_arg"))
734+
self.assertEqual("roles_args", self._get_argument_help(parser, "roles_args"))

0 commit comments

Comments
 (0)