Skip to content

Commit 11cc21c

Browse files
authored
Add basic support for enum literals (#6668)
This pull request adds in basic support for enum literals. Basically, you can construct them and use them as types, but not much else. Some notes: 1. I plan on submitting incremental and fine-grained-mode tests in a future PR. 2. I wanted to add support for aliases to enums, but ran into some difficulty doing so. See #6667 for some analysis on the root cause.
1 parent aa13adf commit 11cc21c

File tree

6 files changed

+259
-14
lines changed

6 files changed

+259
-14
lines changed

mypy/message_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
from typing_extensions import Final
1212

1313

14+
# Invalid types
15+
16+
INVALID_TYPE_RAW_ENUM_VALUE = "Invalid type: try using Literal[{}.{}] instead?" # type: Final
17+
18+
1419
# Type checker error message constants --
1520

1621
NO_RETURN_VALUE_EXPECTED = 'No return value expected' # type: Final

mypy/messages.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,11 @@ def format_bare(self, typ: Type, verbosity: int = 0) -> str:
233233
s = 'TypedDict({{{}}})'.format(', '.join(items))
234234
return s
235235
elif isinstance(typ, LiteralType):
236-
return str(typ)
236+
if typ.is_enum_literal():
237+
underlying_type = self.format_bare(typ.fallback, verbosity=verbosity)
238+
return 'Literal[{}.{}]'.format(underlying_type, typ.value)
239+
else:
240+
return str(typ)
237241
elif isinstance(typ, UnionType):
238242
# Only print Unions as Optionals if the Optional wouldn't have to contain another Union
239243
print_as_optional = (len(typ.items) -

mypy/newsemanal/typeanal.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,15 +151,15 @@ def __init__(self,
151151
# Names of type aliases encountered while analysing a type will be collected here.
152152
self.aliases_used = set() # type: Set[str]
153153

154-
def visit_unbound_type(self, t: UnboundType) -> Type:
155-
typ = self.visit_unbound_type_nonoptional(t)
154+
def visit_unbound_type(self, t: UnboundType, defining_literal: bool = False) -> Type:
155+
typ = self.visit_unbound_type_nonoptional(t, defining_literal)
156156
if t.optional:
157157
# We don't need to worry about double-wrapping Optionals or
158158
# wrapping Anys: Union simplification will take care of that.
159159
return make_optional_type(typ)
160160
return typ
161161

162-
def visit_unbound_type_nonoptional(self, t: UnboundType) -> Type:
162+
def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) -> Type:
163163
sym = self.lookup_qualified(t.name, t, suppress_errors=self.third_pass)
164164
if sym is not None:
165165
node = sym.node
@@ -217,7 +217,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType) -> Type:
217217
elif isinstance(node, TypeInfo):
218218
return self.analyze_type_with_type_info(node, t.args, t)
219219
else:
220-
return self.analyze_unbound_type_without_type_info(t, sym)
220+
return self.analyze_unbound_type_without_type_info(t, sym, defining_literal)
221221
else: # sym is None
222222
if self.third_pass:
223223
self.fail('Invalid type "{}"'.format(t.name), t)
@@ -348,7 +348,8 @@ def analyze_type_with_type_info(self, info: TypeInfo, args: List[Type], ctx: Con
348348
fallback=instance)
349349
return instance
350350

351-
def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTableNode) -> Type:
351+
def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTableNode,
352+
defining_literal: bool) -> Type:
352353
"""Figure out what an unbound type that doesn't refer to a TypeInfo node means.
353354
354355
This is something unusual. We try our best to find out what it is.
@@ -373,6 +374,30 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl
373374
if self.allow_unbound_tvars and unbound_tvar and not self.third_pass:
374375
return t
375376

377+
# Option 3:
378+
# Enum value. Note: we only want to return a LiteralType when
379+
# we're using this enum value specifically within context of
380+
# a "Literal[...]" type. So, if `defining_literal` is not set,
381+
# we bail out early with an error.
382+
#
383+
# If, in the distant future, we decide to permit things like
384+
# `def foo(x: Color.RED) -> None: ...`, we can remove that
385+
# check entirely.
386+
if isinstance(sym.node, Var) and sym.node.info and sym.node.info.is_enum:
387+
value = sym.node.name()
388+
base_enum_short_name = sym.node.info.name()
389+
if not defining_literal:
390+
msg = message_registry.INVALID_TYPE_RAW_ENUM_VALUE.format(
391+
base_enum_short_name, value)
392+
self.fail(msg, t)
393+
return AnyType(TypeOfAny.from_error)
394+
return LiteralType(
395+
value=value,
396+
fallback=Instance(sym.node.info, [], line=t.line, column=t.column),
397+
line=t.line,
398+
column=t.column,
399+
)
400+
376401
# None of the above options worked, we give up.
377402
self.fail('Invalid type "{}"'.format(name), t)
378403

@@ -631,7 +656,11 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
631656
# If arg is an UnboundType that was *not* originally defined as
632657
# a string, try expanding it in case it's a type alias or something.
633658
if isinstance(arg, UnboundType):
634-
arg = self.anal_type(arg)
659+
self.nesting_level += 1
660+
try:
661+
arg = self.visit_unbound_type(arg, defining_literal=True)
662+
finally:
663+
self.nesting_level -= 1
635664

636665
# Literal[...] cannot contain Any. Give up and add an error message
637666
# (if we haven't already).

mypy/typeanal.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,15 @@ def __init__(self,
187187
# Names of type aliases encountered while analysing a type will be collected here.
188188
self.aliases_used = set() # type: Set[str]
189189

190-
def visit_unbound_type(self, t: UnboundType) -> Type:
191-
typ = self.visit_unbound_type_nonoptional(t)
190+
def visit_unbound_type(self, t: UnboundType, defining_literal: bool = False) -> Type:
191+
typ = self.visit_unbound_type_nonoptional(t, defining_literal)
192192
if t.optional:
193193
# We don't need to worry about double-wrapping Optionals or
194194
# wrapping Anys: Union simplification will take care of that.
195195
return make_optional_type(typ)
196196
return typ
197197

198-
def visit_unbound_type_nonoptional(self, t: UnboundType) -> Type:
198+
def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) -> Type:
199199
sym = self.lookup(t.name, t, suppress_errors=self.third_pass)
200200
if '.' in t.name:
201201
# Handle indirect references to imported names.
@@ -249,7 +249,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType) -> Type:
249249
elif isinstance(node, TypeInfo):
250250
return self.analyze_unbound_type_with_type_info(t, node)
251251
else:
252-
return self.analyze_unbound_type_without_type_info(t, sym)
252+
return self.analyze_unbound_type_without_type_info(t, sym, defining_literal)
253253
else: # sym is None
254254
if self.third_pass:
255255
self.fail('Invalid type "{}"'.format(t.name), t)
@@ -368,7 +368,8 @@ def analyze_unbound_type_with_type_info(self, t: UnboundType, info: TypeInfo) ->
368368
fallback=instance)
369369
return instance
370370

371-
def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTableNode) -> Type:
371+
def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTableNode,
372+
defining_literal: bool) -> Type:
372373
"""Figure out what an unbound type that doesn't refer to a TypeInfo node means.
373374
374375
This is something unusual. We try our best to find out what it is.
@@ -377,6 +378,7 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl
377378
if name is None:
378379
assert sym.node is not None
379380
name = sym.node.name()
381+
380382
# Option 1:
381383
# Something with an Any type -- make it an alias for Any in a type
382384
# context. This is slightly problematic as it allows using the type 'Any'
@@ -385,14 +387,40 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl
385387
if isinstance(sym.node, Var) and isinstance(sym.node.type, AnyType):
386388
return AnyType(TypeOfAny.from_unimported_type,
387389
missing_import_name=sym.node.type.missing_import_name)
390+
388391
# Option 2:
389392
# Unbound type variable. Currently these may be still valid,
390393
# for example when defining a generic type alias.
391394
unbound_tvar = (isinstance(sym.node, TypeVarExpr) and
392395
(not self.tvar_scope or self.tvar_scope.get_binding(sym) is None))
393396
if self.allow_unbound_tvars and unbound_tvar and not self.third_pass:
394397
return t
398+
395399
# Option 3:
400+
# Enum value. Note: we only want to return a LiteralType when
401+
# we're using this enum value specifically within context of
402+
# a "Literal[...]" type. So, if `defining_literal` is not set,
403+
# we bail out early with an error.
404+
#
405+
# If, in the distant future, we decide to permit things like
406+
# `def foo(x: Color.RED) -> None: ...`, we can remove that
407+
# check entirely.
408+
if isinstance(sym.node, Var) and not t.args and sym.node.info and sym.node.info.is_enum:
409+
value = sym.node.name()
410+
base_enum_short_name = sym.node.info.name()
411+
if not defining_literal:
412+
msg = message_registry.INVALID_TYPE_RAW_ENUM_VALUE.format(
413+
base_enum_short_name, value)
414+
self.fail(msg, t)
415+
return AnyType(TypeOfAny.from_error)
416+
return LiteralType(
417+
value=value,
418+
fallback=Instance(sym.node.info, [], line=t.line, column=t.column),
419+
line=t.line,
420+
column=t.column,
421+
)
422+
423+
# Option 4:
396424
# If it is not something clearly bad (like a known function, variable,
397425
# type variable, or module), and it is still not too late, we try deferring
398426
# this type using a forward reference wrapper. It will be revisited in
@@ -410,6 +438,7 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl
410438
self.fail('Unsupported forward reference to "{}"'.format(t.name), t)
411439
return AnyType(TypeOfAny.from_error)
412440
return ForwardRef(t)
441+
413442
# None of the above options worked, we give up.
414443
self.fail('Invalid type "{}"'.format(name), t)
415444
if self.third_pass and isinstance(sym.node, TypeVarExpr):
@@ -657,7 +686,11 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
657686
# If arg is an UnboundType that was *not* originally defined as
658687
# a string, try expanding it in case it's a type alias or something.
659688
if isinstance(arg, UnboundType):
660-
arg = self.anal_type(arg)
689+
self.nesting_level += 1
690+
try:
691+
arg = self.visit_unbound_type(arg, defining_literal=True)
692+
finally:
693+
self.nesting_level -= 1
661694

662695
# Literal[...] cannot contain Any. Give up and add an error message
663696
# (if we haven't already).

mypy/types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,9 @@ class LiteralType(Type):
14431443
14441444
For example, 'Literal[42]' is represented as
14451445
'LiteralType(value=42, fallback=instance_of_int)'
1446+
1447+
As another example, `Literal[Color.RED]` (where Color is an enum) is
1448+
represented as `LiteralType(value="RED", fallback=instance_of_color)'.
14461449
"""
14471450
__slots__ = ('value', 'fallback')
14481451

@@ -1464,15 +1467,23 @@ def __eq__(self, other: object) -> bool:
14641467
else:
14651468
return NotImplemented
14661469

1470+
def is_enum_literal(self) -> bool:
1471+
return self.fallback.type.is_enum
1472+
14671473
def value_repr(self) -> str:
14681474
"""Returns the string representation of the underlying type.
14691475
14701476
This function is almost equivalent to running `repr(self.value)`,
14711477
except it includes some additional logic to correctly handle cases
1472-
where the value is a string, byte string, or a unicode string.
1478+
where the value is a string, byte string, a unicode string, or an enum.
14731479
"""
14741480
raw = repr(self.value)
14751481
fallback_name = self.fallback.type.fullname()
1482+
1483+
# If this is backed by an enum,
1484+
if self.is_enum_literal():
1485+
return '{}.{}'.format(fallback_name, self.value)
1486+
14761487
if fallback_name == 'builtins.bytes':
14771488
# Note: 'builtins.bytes' only appears in Python 3, so we want to
14781489
# explicitly prefix with a "b"

0 commit comments

Comments
 (0)