Skip to content

Commit 2f76368

Browse files
authored
[3.10] bpo-43897: ast validation for pattern matching nodes (GH-27074)
(cherry picked from commit 8dcb7d9) Co-authored-by: Batuhan Taskaya <[email protected]>
1 parent 662ace1 commit 2f76368

File tree

2 files changed

+265
-32
lines changed

2 files changed

+265
-32
lines changed

Lib/test/test_ast.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def test_constant_as_name(self):
695695
for constant in "True", "False", "None":
696696
expr = ast.Expression(ast.Name(constant, ast.Load()))
697697
ast.fix_missing_locations(expr)
698-
with self.assertRaisesRegex(ValueError, f"Name node can't be used with '{constant}' constant"):
698+
with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"):
699699
compile(expr, "<test>", "eval")
700700

701701

@@ -1476,6 +1476,147 @@ def test_stdlib_validates(self):
14761476
mod = ast.parse(source, fn)
14771477
compile(mod, fn, "exec")
14781478

1479+
constant_1 = ast.Constant(1)
1480+
pattern_1 = ast.MatchValue(constant_1)
1481+
1482+
constant_x = ast.Constant('x')
1483+
pattern_x = ast.MatchValue(constant_x)
1484+
1485+
constant_true = ast.Constant(True)
1486+
pattern_true = ast.MatchSingleton(True)
1487+
1488+
name_carter = ast.Name('carter', ast.Load())
1489+
1490+
_MATCH_PATTERNS = [
1491+
ast.MatchValue(
1492+
ast.Attribute(
1493+
ast.Attribute(
1494+
ast.Name('x', ast.Store()),
1495+
'y', ast.Load()
1496+
),
1497+
'z', ast.Load()
1498+
)
1499+
),
1500+
ast.MatchValue(
1501+
ast.Attribute(
1502+
ast.Attribute(
1503+
ast.Name('x', ast.Load()),
1504+
'y', ast.Store()
1505+
),
1506+
'z', ast.Load()
1507+
)
1508+
),
1509+
ast.MatchValue(
1510+
ast.Constant(...)
1511+
),
1512+
ast.MatchValue(
1513+
ast.Constant(True)
1514+
),
1515+
ast.MatchValue(
1516+
ast.Constant((1,2,3))
1517+
),
1518+
ast.MatchSingleton('string'),
1519+
ast.MatchSequence([
1520+
ast.MatchSingleton('string')
1521+
]),
1522+
ast.MatchSequence(
1523+
[
1524+
ast.MatchSequence(
1525+
[
1526+
ast.MatchSingleton('string')
1527+
]
1528+
)
1529+
]
1530+
),
1531+
ast.MatchMapping(
1532+
[constant_1, constant_true],
1533+
[pattern_x]
1534+
),
1535+
ast.MatchMapping(
1536+
[constant_true, constant_1],
1537+
[pattern_x, pattern_1],
1538+
rest='True'
1539+
),
1540+
ast.MatchMapping(
1541+
[constant_true, ast.Starred(ast.Name('lol', ast.Load()), ast.Load())],
1542+
[pattern_x, pattern_1],
1543+
rest='legit'
1544+
),
1545+
ast.MatchClass(
1546+
ast.Attribute(
1547+
ast.Attribute(
1548+
constant_x,
1549+
'y', ast.Load()),
1550+
'z', ast.Load()),
1551+
patterns=[], kwd_attrs=[], kwd_patterns=[]
1552+
),
1553+
ast.MatchClass(
1554+
name_carter,
1555+
patterns=[],
1556+
kwd_attrs=['True'],
1557+
kwd_patterns=[pattern_1]
1558+
),
1559+
ast.MatchClass(
1560+
name_carter,
1561+
patterns=[],
1562+
kwd_attrs=[],
1563+
kwd_patterns=[pattern_1]
1564+
),
1565+
ast.MatchClass(
1566+
name_carter,
1567+
patterns=[ast.MatchSingleton('string')],
1568+
kwd_attrs=[],
1569+
kwd_patterns=[]
1570+
),
1571+
ast.MatchClass(
1572+
name_carter,
1573+
patterns=[ast.MatchStar()],
1574+
kwd_attrs=[],
1575+
kwd_patterns=[]
1576+
),
1577+
ast.MatchClass(
1578+
name_carter,
1579+
patterns=[],
1580+
kwd_attrs=[],
1581+
kwd_patterns=[ast.MatchStar()]
1582+
),
1583+
ast.MatchSequence(
1584+
[
1585+
ast.MatchStar("True")
1586+
]
1587+
),
1588+
ast.MatchAs(
1589+
name='False'
1590+
),
1591+
ast.MatchOr(
1592+
[]
1593+
),
1594+
ast.MatchOr(
1595+
[pattern_1]
1596+
),
1597+
ast.MatchOr(
1598+
[pattern_1, pattern_x, ast.MatchSingleton('xxx')]
1599+
)
1600+
]
1601+
1602+
def test_match_validation_pattern(self):
1603+
name_x = ast.Name('x', ast.Load())
1604+
for pattern in self._MATCH_PATTERNS:
1605+
with self.subTest(ast.dump(pattern, indent=4)):
1606+
node = ast.Match(
1607+
subject=name_x,
1608+
cases = [
1609+
ast.match_case(
1610+
pattern=pattern,
1611+
body = [ast.Pass()]
1612+
)
1613+
]
1614+
)
1615+
node = ast.fix_missing_locations(node)
1616+
module = ast.Module([node], [])
1617+
with self.assertRaises(ValueError):
1618+
compile(module, "<test>", "exec")
1619+
14791620

14801621
class ConstantTests(unittest.TestCase):
14811622
"""Tests on the ast.Constant node type."""

Python/ast.c

Lines changed: 123 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ struct validator {
1515
};
1616

1717
static int validate_stmts(struct validator *, asdl_stmt_seq *);
18-
static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, int);
18+
static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
19+
static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
1920
static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
2021
static int validate_stmt(struct validator *, stmt_ty);
2122
static int validate_expr(struct validator *, expr_ty, expr_context_ty);
@@ -33,7 +34,7 @@ validate_name(PyObject *name)
3334
};
3435
for (int i = 0; forbidden[i] != NULL; i++) {
3536
if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
36-
PyErr_Format(PyExc_ValueError, "Name node can't be used with '%s' constant", forbidden[i]);
37+
PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
3738
return 0;
3839
}
3940
}
@@ -448,6 +449,21 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
448449
switch (exp->kind)
449450
{
450451
case Constant_kind:
452+
/* Ellipsis and immutable sequences are not allowed.
453+
For True, False and None, MatchSingleton() should
454+
be used */
455+
if (!validate_expr(state, exp, Load)) {
456+
return 0;
457+
}
458+
PyObject *literal = exp->v.Constant.value;
459+
if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
460+
PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
461+
PyUnicode_CheckExact(literal)) {
462+
return 1;
463+
}
464+
PyErr_SetString(PyExc_ValueError,
465+
"unexpected constant inside of a literal pattern");
466+
return 0;
451467
case Attribute_kind:
452468
// Constants and attribute lookups are always permitted
453469
return 1;
@@ -465,11 +481,14 @@ validate_pattern_match_value(struct validator *state, expr_ty exp)
465481
return 1;
466482
}
467483
break;
484+
case JoinedStr_kind:
485+
// Handled in the later stages
486+
return 1;
468487
default:
469488
break;
470489
}
471-
PyErr_SetString(PyExc_SyntaxError,
472-
"patterns may only match literals and attribute lookups");
490+
PyErr_SetString(PyExc_ValueError,
491+
"patterns may only match literals and attribute lookups");
473492
return 0;
474493
}
475494

@@ -489,51 +508,101 @@ validate_pattern(struct validator *state, pattern_ty p)
489508
ret = validate_pattern_match_value(state, p->v.MatchValue.value);
490509
break;
491510
case MatchSingleton_kind:
492-
// TODO: Check constant is specifically None, True, or False
493-
ret = validate_constant(state, p->v.MatchSingleton.value);
511+
ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
512+
if (!ret) {
513+
PyErr_SetString(PyExc_ValueError,
514+
"MatchSingleton can only contain True, False and None");
515+
}
494516
break;
495517
case MatchSequence_kind:
496-
// TODO: Validate all subpatterns
497-
// return validate_patterns(state, p->v.MatchSequence.patterns);
498-
ret = 1;
518+
ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
499519
break;
500520
case MatchMapping_kind:
501-
// TODO: check "rest" target name is valid
502521
if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
503522
PyErr_SetString(PyExc_ValueError,
504523
"MatchMapping doesn't have the same number of keys as patterns");
505-
return 0;
524+
ret = 0;
525+
break;
506526
}
507-
// null_ok=0 for key expressions, as rest-of-mapping is captured in "rest"
508-
// TODO: replace with more restrictive expression validator, as per MatchValue above
509-
if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) {
510-
return 0;
527+
528+
if (p->v.MatchMapping.rest && !validate_name(p->v.MatchMapping.rest)) {
529+
ret = 0;
530+
break;
511531
}
512-
// TODO: Validate all subpatterns
513-
// ret = validate_patterns(state, p->v.MatchMapping.patterns);
514-
ret = 1;
532+
533+
asdl_expr_seq *keys = p->v.MatchMapping.keys;
534+
for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
535+
expr_ty key = asdl_seq_GET(keys, i);
536+
if (key->kind == Constant_kind) {
537+
PyObject *literal = key->v.Constant.value;
538+
if (literal == Py_None || PyBool_Check(literal)) {
539+
/* validate_pattern_match_value will ensure the key
540+
doesn't contain True, False and None but it is
541+
syntactically valid, so we will pass those on in
542+
a special case. */
543+
continue;
544+
}
545+
}
546+
if (!validate_pattern_match_value(state, key)) {
547+
ret = 0;
548+
break;
549+
}
550+
}
551+
552+
ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
515553
break;
516554
case MatchClass_kind:
517555
if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
518556
PyErr_SetString(PyExc_ValueError,
519557
"MatchClass doesn't have the same number of keyword attributes as patterns");
520-
return 0;
558+
ret = 0;
559+
break;
521560
}
522-
// TODO: Restrict cls lookup to being a name or attribute
523561
if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
524-
return 0;
562+
ret = 0;
563+
break;
525564
}
526-
// TODO: Validate all subpatterns
527-
// return validate_patterns(state, p->v.MatchClass.patterns) &&
528-
// validate_patterns(state, p->v.MatchClass.kwd_patterns);
529-
ret = 1;
565+
566+
expr_ty cls = p->v.MatchClass.cls;
567+
while (1) {
568+
if (cls->kind == Name_kind) {
569+
break;
570+
}
571+
else if (cls->kind == Attribute_kind) {
572+
cls = cls->v.Attribute.value;
573+
continue;
574+
}
575+
else {
576+
PyErr_SetString(PyExc_ValueError,
577+
"MatchClass cls field can only contain Name or Attribute nodes.");
578+
state->recursion_depth--;
579+
return 0;
580+
}
581+
}
582+
583+
for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
584+
PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
585+
if (!validate_name(identifier)) {
586+
state->recursion_depth--;
587+
return 0;
588+
}
589+
}
590+
591+
if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
592+
ret = 0;
593+
break;
594+
}
595+
596+
ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
530597
break;
531598
case MatchStar_kind:
532-
// TODO: check target name is valid
533-
ret = 1;
599+
ret = p->v.MatchStar.name == NULL || validate_name(p->v.MatchStar.name);
534600
break;
535601
case MatchAs_kind:
536-
// TODO: check target name is valid
602+
if (p->v.MatchAs.name && !validate_name(p->v.MatchAs.name)) {
603+
ret = 0;
604+
break;
605+
}
537606
if (p->v.MatchAs.pattern == NULL) {
538607
ret = 1;
539608
}
@@ -547,9 +616,13 @@ validate_pattern(struct validator *state, pattern_ty p)
547616
}
548617
break;
549618
case MatchOr_kind:
550-
// TODO: Validate all subpatterns
551-
// return validate_patterns(state, p->v.MatchOr.patterns);
552-
ret = 1;
619+
if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
620+
PyErr_SetString(PyExc_ValueError,
621+
"MatchOr requires at least 2 patterns");
622+
ret = 0;
623+
break;
624+
}
625+
ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
553626
break;
554627
// No default case, so the compiler will emit a warning if new pattern
555628
// kinds are added without being handled here
@@ -815,6 +888,25 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct
815888
return 1;
816889
}
817890

891+
static int
892+
validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
893+
{
894+
Py_ssize_t i;
895+
for (i = 0; i < asdl_seq_LEN(patterns); i++) {
896+
pattern_ty pattern = asdl_seq_GET(patterns, i);
897+
if (pattern->kind == MatchStar_kind && !star_ok) {
898+
PyErr_SetString(PyExc_ValueError,
899+
"Can't use MatchStar within this sequence of patterns");
900+
return 0;
901+
}
902+
if (!validate_pattern(state, pattern)) {
903+
return 0;
904+
}
905+
}
906+
return 1;
907+
}
908+
909+
818910
/* See comments in symtable.c. */
819911
#define COMPILER_STACK_FRAME_SCALE 3
820912

0 commit comments

Comments
 (0)