Skip to content

Commit 2ea2540

Browse files
JelleZijlstradavidism
authored andcommitted
Improve types for AST handling
Adapting to changes in python/typeshed#11880. This mostly adds more precise types for individual pieces of AST.
1 parent c9f7a6b commit 2ea2540

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

src/werkzeug/routing/rules.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,16 @@ def get_rules(self, map: Map) -> t.Iterator[Rule]:
294294
)
295295

296296

297-
def _prefix_names(src: str) -> ast.stmt:
297+
_ASTT = t.TypeVar("_ASTT", bound=ast.AST)
298+
299+
300+
def _prefix_names(src: str, expected_type: type[_ASTT]) -> _ASTT:
298301
"""ast parse and prefix names with `.` to avoid collision with user vars"""
299-
tree = ast.parse(src).body[0]
302+
tree: ast.AST = ast.parse(src).body[0]
300303
if isinstance(tree, ast.Expr):
301-
tree = tree.value # type: ignore
304+
tree = tree.value
305+
if not isinstance(tree, expected_type):
306+
raise TypeError(f"AST node is of type {type(tree).__name__}, not {expected_type.__name__}")
302307
for node in ast.walk(tree):
303308
if isinstance(node, ast.Name):
304309
node.id = f".{node.id}"
@@ -313,8 +318,8 @@ def _prefix_names(src: str) -> ast.stmt:
313318
else:
314319
q = params = ""
315320
"""
316-
_IF_KWARGS_URL_ENCODE_AST = _prefix_names(_IF_KWARGS_URL_ENCODE_CODE)
317-
_URL_ENCODE_AST_NAMES = (_prefix_names("q"), _prefix_names("params"))
321+
_IF_KWARGS_URL_ENCODE_AST = _prefix_names(_IF_KWARGS_URL_ENCODE_CODE, ast.If)
322+
_URL_ENCODE_AST_NAMES = (_prefix_names("q", ast.Name), _prefix_names("params", ast.Name))
318323

319324

320325
class Rule(RuleFactory):
@@ -751,13 +756,13 @@ def _compile_builder(
751756
else:
752757
opl.append((True, data))
753758

754-
def _convert(elem: str) -> ast.stmt:
755-
ret = _prefix_names(_CALL_CONVERTER_CODE_FMT.format(elem=elem))
756-
ret.args = [ast.Name(str(elem), ast.Load())] # type: ignore # str for py2
759+
def _convert(elem: str) -> ast.Call:
760+
ret = _prefix_names(_CALL_CONVERTER_CODE_FMT.format(elem=elem), ast.Call)
761+
ret.args = [ast.Name(elem, ast.Load())]
757762
return ret
758763

759-
def _parts(ops: list[tuple[bool, str]]) -> list[ast.AST]:
760-
parts = [
764+
def _parts(ops: list[tuple[bool, str]]) -> list[ast.expr]:
765+
parts: list[ast.expr] = [
761766
_convert(elem) if is_dynamic else ast.Constant(elem)
762767
for is_dynamic, elem in ops
763768
]
@@ -773,13 +778,14 @@ def _parts(ops: list[tuple[bool, str]]) -> list[ast.AST]:
773778

774779
dom_parts = _parts(dom_ops)
775780
url_parts = _parts(url_ops)
781+
body: list[ast.stmt]
776782
if not append_unknown:
777783
body = []
778784
else:
779785
body = [_IF_KWARGS_URL_ENCODE_AST]
780786
url_parts.extend(_URL_ENCODE_AST_NAMES)
781787

782-
def _join(parts: list[ast.AST]) -> ast.AST:
788+
def _join(parts: list[ast.expr]) -> ast.expr:
783789
if len(parts) == 1: # shortcut
784790
return parts[0]
785791
return ast.JoinedStr(parts)
@@ -795,7 +801,7 @@ def _join(parts: list[ast.AST]) -> ast.AST:
795801
]
796802
kargs = [str(k) for k in defaults]
797803

798-
func_ast: ast.FunctionDef = _prefix_names("def _(): pass") # type: ignore
804+
func_ast = _prefix_names("def _(): pass", ast.FunctionDef)
799805
func_ast.name = f"<builder:{self.rule!r}>"
800806
func_ast.args.args.append(ast.arg(".self", None))
801807
for arg in pargs + kargs:
@@ -815,13 +821,13 @@ def _join(parts: list[ast.AST]) -> ast.AST:
815821
# bad line numbers cause an assert to fail in debug builds
816822
for node in ast.walk(module):
817823
if "lineno" in node._attributes:
818-
node.lineno = 1
824+
node.lineno = 1 # type: ignore[attr-defined]
819825
if "end_lineno" in node._attributes:
820-
node.end_lineno = node.lineno
826+
node.end_lineno = node.lineno # type: ignore[attr-defined]
821827
if "col_offset" in node._attributes:
822-
node.col_offset = 0
828+
node.col_offset = 0 # type: ignore[attr-defined]
823829
if "end_col_offset" in node._attributes:
824-
node.end_col_offset = node.col_offset
830+
node.end_col_offset = node.col_offset # type: ignore[attr-defined]
825831

826832
code = compile(module, "<werkzeug routing>", "exec")
827833
return self._get_func_code(code, func_ast.name)

0 commit comments

Comments
 (0)