|
13 | 13 | from tokenize import ENCODING as tk_ENCODING |
14 | 14 | from tokenize import NAME as tk_NAME |
15 | 15 | from tokenize import tokenize as generate_tokens |
| 16 | +from string import Formatter |
16 | 17 |
|
17 | 18 | builtins = __builtins__ |
18 | 19 | if not isinstance(builtins, dict): |
|
33 | 34 | except ImportError: |
34 | 35 | pass |
35 | 36 |
|
| 37 | +# This is a necessary API but it's undocumented and moved around |
| 38 | +# between Python releases |
| 39 | +try: |
| 40 | + from _string import formatter_field_name_split |
| 41 | +except ImportError: |
| 42 | + formatter_field_name_split = lambda \ |
| 43 | + x: x._formatter_field_name_split() |
| 44 | + |
36 | 45 |
|
37 | 46 |
|
38 | 47 | MAX_EXPONENT = 10000 |
|
59 | 68 | '__getattribute__', '__subclasshook__', '__new__', |
60 | 69 | '__init__', 'func_globals', 'func_code', 'func_closure', |
61 | 70 | 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', |
62 | | - 'f_locals', '__asteval__') |
| 71 | + 'f_locals', '__asteval__','mro') |
63 | 72 |
|
64 | 73 | # unsafe attributes for particular objects, by type |
65 | 74 | UNSAFE_ATTRS_DTYPES = {str: ('format', 'format_map')} |
@@ -266,6 +275,45 @@ def safe_lshift(arg1, arg2): |
266 | 275 | ast.UAdd: lambda a: +a, |
267 | 276 | ast.USub: lambda a: -a} |
268 | 277 |
|
| 278 | +# Safe version of getattr |
| 279 | + |
| 280 | +def safe_getattr(obj, attr, raise_exc, node): |
| 281 | + """safe version of getattr""" |
| 282 | + unsafe = (attr in UNSAFE_ATTRS or |
| 283 | + (attr.startswith('__') and attr.endswith('__'))) |
| 284 | + if not unsafe: |
| 285 | + for dtype, attrlist in UNSAFE_ATTRS_DTYPES.items(): |
| 286 | + unsafe = (isinstance(obj, dtype) or obj is dtype) and attr in attrlist |
| 287 | + if unsafe: |
| 288 | + break |
| 289 | + if unsafe: |
| 290 | + msg = f"no safe attribute '{attr}' for {repr(obj)}" |
| 291 | + raise_exc(node, exc=AttributeError, msg=msg) |
| 292 | + else: |
| 293 | + try: |
| 294 | + return getattr(obj, attr) |
| 295 | + except AttributeError: |
| 296 | + pass |
| 297 | + |
| 298 | +class SafeFormatter(Formatter): |
| 299 | + def __init__(self, raise_exc, node): |
| 300 | + self.raise_exc = raise_exc |
| 301 | + self.node = node |
| 302 | + super().__init__() |
| 303 | + |
| 304 | + def get_field(self, field_name, args, kwargs): |
| 305 | + first, rest = formatter_field_name_split(field_name) |
| 306 | + obj = self.get_value(first, args, kwargs) |
| 307 | + for is_attr, i in rest: |
| 308 | + if is_attr: |
| 309 | + obj = safe_getattr(obj, i, self.raise_exc, self.node) |
| 310 | + else: |
| 311 | + obj = obj[i] |
| 312 | + return obj, first |
| 313 | + |
| 314 | +def safe_format(_string, raise_exc, node, *args, **kwargs): |
| 315 | + formatter = SafeFormatter(raise_exc, node) |
| 316 | + return formatter.vformat(_string, args, kwargs) |
269 | 317 |
|
270 | 318 | def valid_symbol_name(name): |
271 | 319 | """Determine whether the input symbol name is a valid name. |
|
0 commit comments