Skip to content

Commit e4fd5d5

Browse files
[3.12] gh-115539: Allow enum.Flag to have None members (GH-115636) (GH-115694)
gh-115539: Allow enum.Flag to have None members (GH-115636) (cherry picked from commit c2cb31b) Co-authored-by: Jason Zhang <[email protected]>
1 parent eb74573 commit e4fd5d5

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

Lib/enum.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,10 @@ def __set_name__(self, enum_class, member_name):
283283
enum_member._sort_order_ = len(enum_class._member_names_)
284284

285285
if Flag is not None and issubclass(enum_class, Flag):
286-
enum_class._flag_mask_ |= value
287-
if _is_single_bit(value):
288-
enum_class._singles_mask_ |= value
286+
if isinstance(value, int):
287+
enum_class._flag_mask_ |= value
288+
if _is_single_bit(value):
289+
enum_class._singles_mask_ |= value
289290
enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1
290291

291292
# If another member with the same value was already defined, the
@@ -313,6 +314,7 @@ def __set_name__(self, enum_class, member_name):
313314
elif (
314315
Flag is not None
315316
and issubclass(enum_class, Flag)
317+
and isinstance(value, int)
316318
and _is_single_bit(value)
317319
):
318320
# no other instances found, record this member in _member_names_
@@ -1534,37 +1536,50 @@ def __str__(self):
15341536
def __bool__(self):
15351537
return bool(self._value_)
15361538

1539+
def _get_value(self, flag):
1540+
if isinstance(flag, self.__class__):
1541+
return flag._value_
1542+
elif self._member_type_ is not object and isinstance(flag, self._member_type_):
1543+
return flag
1544+
return NotImplemented
1545+
15371546
def __or__(self, other):
1538-
if isinstance(other, self.__class__):
1539-
other = other._value_
1540-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1541-
other = other
1542-
else:
1547+
other_value = self._get_value(other)
1548+
if other_value is NotImplemented:
15431549
return NotImplemented
1550+
1551+
for flag in self, other:
1552+
if self._get_value(flag) is None:
1553+
raise TypeError(f"'{flag}' cannot be combined with other flags with |")
15441554
value = self._value_
1545-
return self.__class__(value | other)
1555+
return self.__class__(value | other_value)
15461556

15471557
def __and__(self, other):
1548-
if isinstance(other, self.__class__):
1549-
other = other._value_
1550-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1551-
other = other
1552-
else:
1558+
other_value = self._get_value(other)
1559+
if other_value is NotImplemented:
15531560
return NotImplemented
1561+
1562+
for flag in self, other:
1563+
if self._get_value(flag) is None:
1564+
raise TypeError(f"'{flag}' cannot be combined with other flags with &")
15541565
value = self._value_
1555-
return self.__class__(value & other)
1566+
return self.__class__(value & other_value)
15561567

15571568
def __xor__(self, other):
1558-
if isinstance(other, self.__class__):
1559-
other = other._value_
1560-
elif self._member_type_ is not object and isinstance(other, self._member_type_):
1561-
other = other
1562-
else:
1569+
other_value = self._get_value(other)
1570+
if other_value is NotImplemented:
15631571
return NotImplemented
1572+
1573+
for flag in self, other:
1574+
if self._get_value(flag) is None:
1575+
raise TypeError(f"'{flag}' cannot be combined with other flags with ^")
15641576
value = self._value_
1565-
return self.__class__(value ^ other)
1577+
return self.__class__(value ^ other_value)
15661578

15671579
def __invert__(self):
1580+
if self._get_value(self) is None:
1581+
raise TypeError(f"'{self}' cannot be inverted")
1582+
15681583
if self._inverted_ is None:
15691584
if self._boundary_ in (EJECT, KEEP):
15701585
self._inverted_ = self.__class__(~self._value_)

Lib/test/test_enum.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,22 @@ class TestPlainEnumFunction(_EnumTests, _PlainOutputTests, unittest.TestCase):
10071007
class TestPlainFlagClass(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):
10081008
enum_type = Flag
10091009

1010+
def test_none_member(self):
1011+
class FlagWithNoneMember(Flag):
1012+
A = 1
1013+
E = None
1014+
1015+
self.assertEqual(FlagWithNoneMember.A.value, 1)
1016+
self.assertIs(FlagWithNoneMember.E.value, None)
1017+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with |"):
1018+
FlagWithNoneMember.A | FlagWithNoneMember.E
1019+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with &"):
1020+
FlagWithNoneMember.E & FlagWithNoneMember.A
1021+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be combined with other flags with \^"):
1022+
FlagWithNoneMember.A ^ FlagWithNoneMember.E
1023+
with self.assertRaisesRegex(TypeError, r"'FlagWithNoneMember.E' cannot be inverted"):
1024+
~FlagWithNoneMember.E
1025+
10101026

10111027
class TestPlainFlagFunction(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase):
10121028
enum_type = Flag

0 commit comments

Comments
 (0)