Skip to content

Commit 3b65acf

Browse files
committed
Improve zero argument support for super() in dataclasses
1 parent 10d504a commit 3b65acf

File tree

3 files changed

+94
-13
lines changed

3 files changed

+94
-13
lines changed

Lib/dataclasses.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,11 +1222,6 @@ def _get_slots(cls):
12221222

12231223

12241224
def _update_func_cell_for__class__(f, oldcls, newcls):
1225-
# Returns True if we update a cell, else False.
1226-
if f is None:
1227-
# f will be None in the case of a property where not all of
1228-
# fget, fset, and fdel are used. Nothing to do in that case.
1229-
return False
12301225
try:
12311226
idx = f.__code__.co_freevars.index("__class__")
12321227
except ValueError:
@@ -1235,13 +1230,36 @@ def _update_func_cell_for__class__(f, oldcls, newcls):
12351230
# Fix the cell to point to the new class, if it's already pointing
12361231
# at the old class. I'm not convinced that the "is oldcls" test
12371232
# is needed, but other than performance can't hurt.
1238-
closure = f.__closure__[idx]
1239-
if closure.cell_contents is oldcls:
1240-
closure.cell_contents = newcls
1233+
cell = f.__closure__[idx]
1234+
if cell.cell_contents is oldcls:
1235+
cell.cell_contents = newcls
12411236
return True
12421237
return False
12431238

12441239

1240+
def _find_inner_functions(obj, _seen=None, _depth=0):
1241+
if _seen is None:
1242+
_seen = set()
1243+
if id(obj) in _seen:
1244+
return None
1245+
_seen.add(id(obj))
1246+
1247+
_depth += 1
1248+
if _depth > 2:
1249+
return None
1250+
1251+
obj = inspect.unwrap(obj)
1252+
1253+
for attr in dir(obj):
1254+
value = getattr(obj, attr, None)
1255+
if value is None:
1256+
continue
1257+
if isinstance(obj, types.FunctionType):
1258+
yield obj
1259+
return
1260+
yield from _find_inner_functions(value, _seen, _depth)
1261+
1262+
12451263
def _create_slots(defined_fields, inherited_slots, field_names, weakref_slot):
12461264
# The slots for our class. Remove slots from our base classes. Add
12471265
# '__weakref__' if weakref_slot was given, unless it is already present.
@@ -1317,7 +1335,10 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
13171335
# (the newly created one, which we're returning) and not the
13181336
# original class. We can break out of this loop as soon as we
13191337
# make an update, since all closures for a class will share a
1320-
# given cell.
1338+
# given cell. First we try to find a pure function/properties,
1339+
# and then fallback to inspecting custom descriptors.
1340+
1341+
custom_descriptors_to_check = []
13211342
for member in newcls.__dict__.values():
13221343
# If this is a wrapped function, unwrap it.
13231344
member = inspect.unwrap(member)
@@ -1326,10 +1347,27 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
13261347
if _update_func_cell_for__class__(member, cls, newcls):
13271348
break
13281349
elif isinstance(member, property):
1329-
if (_update_func_cell_for__class__(member.fget, cls, newcls)
1330-
or _update_func_cell_for__class__(member.fset, cls, newcls)
1331-
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
1332-
break
1350+
for f in member.fget, member.fset, member.fdel:
1351+
if f is None:
1352+
continue
1353+
# unwrap once more in case function
1354+
# was wrapped before it became property
1355+
f = inspect.unwrap(f)
1356+
if _update_func_cell_for__class__(f, cls, newcls):
1357+
break
1358+
elif hasattr(member, "__get__") and not inspect.ismemberdescriptor(
1359+
member
1360+
):
1361+
# we don't want to inspect custom descriptors just yet
1362+
# there's still a chance we'll encounter a pure function
1363+
# or a property
1364+
custom_descriptors_to_check.append(member)
1365+
else:
1366+
# now let's ensure custom descriptors won't be left out
1367+
for descriptor in custom_descriptors_to_check:
1368+
for f in _find_inner_functions(descriptor):
1369+
if _update_func_cell_for__class__(f, cls, newcls):
1370+
break
13331371

13341372
return newcls
13351373

Lib/test/test_dataclasses/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5031,6 +5031,47 @@ def foo(self):
50315031

50325032
A().foo()
50335033

5034+
def test_wrapped_property(self):
5035+
def mydecorator(f):
5036+
@wraps(f)
5037+
def wrapper(*args, **kwargs):
5038+
return f(*args, **kwargs)
5039+
return wrapper
5040+
5041+
class B:
5042+
@property
5043+
def foo(self):
5044+
return "bar"
5045+
5046+
@dataclass(slots=True)
5047+
class A(B):
5048+
@property
5049+
@mydecorator
5050+
def foo(self):
5051+
return super().foo
5052+
5053+
self.assertEqual(A().foo, "bar")
5054+
5055+
def test_custom_descriptor(self):
5056+
class CustomDescriptor:
5057+
def __init__(self, f):
5058+
self._f = f
5059+
5060+
def __get__(self, instance, owner):
5061+
return self._f(instance)
5062+
5063+
class B:
5064+
def foo(self):
5065+
return "bar"
5066+
5067+
@dataclass(slots=True)
5068+
class A(B):
5069+
@CustomDescriptor
5070+
def foo(cls):
5071+
return super().foo()
5072+
5073+
self.assertEqual(A().foo, "bar")
5074+
50345075
def test_remembered_class(self):
50355076
# Apply the dataclass decorator manually (not when the class
50365077
# is created), so that we can keep a reference to the
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Modify dataclasses to enable zero argument support for ``super()`` when ``slots=True`` is
2+
specified and custom descriptor is used or `property` function is wrapped.

0 commit comments

Comments
 (0)