Skip to content

Commit 018a49f

Browse files
Overhaul (deep)copy and pickle support for temporal and spatial types (#1002)
* Add support to pickling DateTime * Fix Duration copy and deepcopy dropping dynamic attributes (in `__dict__`) * Adding more unit tests --------- Co-authored-by: Robsdedude <[email protected]>
1 parent 3b7c790 commit 018a49f

File tree

9 files changed

+432
-88
lines changed

9 files changed

+432
-88
lines changed

src/neo4j/time/__init__.py

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -574,12 +574,17 @@ def __str__(self) -> str:
574574
""""""
575575
return self.iso_format()
576576

577-
def __copy__(self) -> Duration:
578-
return self.__new__(self.__class__, months=self[0], days=self[1],
579-
seconds=self[2], nanoseconds=self[3])
577+
def __reduce__(self):
578+
return (
579+
type(self)._restore, (tuple(self), self.__dict__)
580+
)
580581

581-
def __deepcopy__(self, memo) -> Duration:
582-
return self.__copy__()
582+
@classmethod
583+
def _restore(cls, elements, dict_):
584+
instance = tuple.__new__(cls, elements)
585+
if dict_:
586+
instance.__dict__.update(dict_)
587+
return instance
583588

584589
@classmethod
585590
def from_iso_format(cls, s: str) -> Duration:
@@ -763,6 +768,13 @@ class Date(date_base_class, metaclass=DateType):
763768
# CONSTRUCTOR #
764769

765770
def __new__(cls, year: int, month: int, day: int) -> Date:
771+
# TODO: 6.0 - remove the __new__ magic and ZeroDate being a singleton.
772+
# It's fine to remain as constant. Instead, simply use
773+
# __init__ and simplify pickle/copy (remove __reduce__).
774+
# N.B. this is a breaking change and must be treated as
775+
# such. Also consider introducing __slots__. Potentially
776+
# apply similar treatment to other temporal types as well
777+
# as spatial types.
766778
if year == month == day == 0:
767779
return ZeroDate
768780
year, month, day = _normalize_day(year, month, day)
@@ -1218,11 +1230,17 @@ def __sub__(self, other):
12181230
except TypeError:
12191231
return NotImplemented
12201232

1221-
def __copy__(self) -> Date:
1222-
return self.__new(self.__ordinal, self.__year, self.__month, self.__day)
1233+
def __reduce__(self):
1234+
if self is ZeroDate:
1235+
return "ZeroDate"
1236+
return type(self)._restore, (self.__dict__,)
12231237

1224-
def __deepcopy__(self, *args, **kwargs) -> Date:
1225-
return self.__copy__()
1238+
@classmethod
1239+
def _restore(cls, dict_) -> Date:
1240+
instance = object.__new__(cls)
1241+
if dict_:
1242+
instance.__dict__.update(dict_)
1243+
return instance
12261244

12271245
# INSTANCE METHODS #
12281246

@@ -1396,34 +1414,53 @@ class Time(time_base_class, metaclass=TimeType):
13961414

13971415
# CONSTRUCTOR #
13981416

1399-
def __new__(
1400-
cls,
1417+
def __init__(
1418+
self,
14011419
hour: int = 0,
14021420
minute: int = 0,
14031421
second: int = 0,
14041422
nanosecond: int = 0,
14051423
tzinfo: t.Optional[_tzinfo] = None
1406-
) -> Time:
1407-
hour, minute, second, nanosecond = cls.__normalize_nanosecond(
1424+
) -> None:
1425+
hour, minute, second, nanosecond = self.__normalize_nanosecond(
14081426
hour, minute, second, nanosecond
14091427
)
14101428
ticks = (3600000000000 * hour
14111429
+ 60000000000 * minute
14121430
+ 1000000000 * second
14131431
+ nanosecond)
1414-
return cls.__new(ticks, hour, minute, second, nanosecond, tzinfo)
1432+
self.__unchecked_init(ticks, hour, minute, second, nanosecond, tzinfo)
14151433

14161434
@classmethod
1417-
def __new(cls, ticks, hour, minute, second, nanosecond, tzinfo):
1418-
instance = object.__new__(cls)
1419-
instance.__ticks = int(ticks)
1420-
instance.__hour = int(hour)
1421-
instance.__minute = int(minute)
1422-
instance.__second = int(second)
1423-
instance.__nanosecond = int(nanosecond)
1424-
instance.__tzinfo = tzinfo
1435+
def __unchecked_new(
1436+
cls,
1437+
ticks: int,
1438+
hour: int,
1439+
minutes: int,
1440+
second: int,
1441+
nano: int,
1442+
tz: t.Optional[_tzinfo]
1443+
) -> Time:
1444+
instance = object.__new__(Time)
1445+
instance.__unchecked_init(ticks, hour, minutes, second, nano, tz)
14251446
return instance
14261447

1448+
def __unchecked_init(
1449+
self,
1450+
ticks: int,
1451+
hour: int,
1452+
minutes: int,
1453+
second: int,
1454+
nano: int,
1455+
tz: t.Optional[_tzinfo]
1456+
) -> None:
1457+
self.__ticks = ticks
1458+
self.__hour = hour
1459+
self.__minute = minutes
1460+
self.__second = second
1461+
self.__nanosecond = nano
1462+
self.__tzinfo = tz
1463+
14271464
# CLASS METHODS #
14281465

14291466
@classmethod
@@ -1521,7 +1558,8 @@ def from_ticks(cls, ticks: int, tz: t.Optional[_tzinfo] = None) -> Time:
15211558
second, nanosecond = divmod(ticks, NANO_SECONDS)
15221559
minute, second = divmod(second, 60)
15231560
hour, minute = divmod(minute, 60)
1524-
return cls.__new(ticks, hour, minute, second, nanosecond, tz)
1561+
return cls.__unchecked_new(ticks, hour, minute, second, nanosecond,
1562+
tz)
15251563
raise ValueError("Ticks out of range (0..86400000000000)")
15261564

15271565
@classmethod
@@ -1619,7 +1657,7 @@ def utc_now(cls) -> Time:
16191657

16201658
__nanosecond = 0
16211659

1622-
__tzinfo = None
1660+
__tzinfo: t.Optional[_tzinfo] = None
16231661

16241662
@property
16251663
def ticks(self) -> int:
@@ -1751,13 +1789,6 @@ def __gt__(self, other: t.Union[Time, time]) -> bool:
17511789
return NotImplemented
17521790
return self_ticks > other_ticks
17531791

1754-
def __copy__(self) -> Time:
1755-
return self.__new(self.__ticks, self.__hour, self.__minute,
1756-
self.__second, self.__nanosecond, self.__tzinfo)
1757-
1758-
def __deepcopy__(self, *args, **kwargs) -> Time:
1759-
return self.__copy__()
1760-
17611792
# INSTANCE METHODS #
17621793

17631794
if t.TYPE_CHECKING:
@@ -2126,6 +2157,10 @@ def combine( # type: ignore[override]
21262157
"""
21272158
assert isinstance(date, Date)
21282159
assert isinstance(time, Time)
2160+
return cls._combine(date, time)
2161+
2162+
@classmethod
2163+
def _combine(cls, date: Date, time: Time) -> DateTime:
21292164
instance = object.__new__(cls)
21302165
instance.__date = date
21312166
instance.__time = time
@@ -2491,11 +2526,15 @@ def __sub__(self, other):
24912526
return self.__add__(-other)
24922527
return NotImplemented
24932528

2494-
def __copy__(self) -> DateTime:
2495-
return self.combine(self.__date, self.__time)
2529+
def __reduce__(self):
2530+
return type(self)._restore, (self.__dict__,)
24962531

2497-
def __deepcopy__(self, memo) -> DateTime:
2498-
return self.__copy__()
2532+
@classmethod
2533+
def _restore(cls, dict_):
2534+
instance = object.__new__(cls)
2535+
if dict_:
2536+
instance.__dict__.update(dict_)
2537+
return instance
24992538

25002539
# INSTANCE METHODS #
25012540

tests/unit/common/spatial/test_cartesian_point.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,23 @@
1616

1717
from __future__ import annotations
1818

19+
import copy
20+
import pickle
21+
1922
import pytest
2023

2124
from neo4j.spatial import CartesianPoint
2225

2326

27+
def make_reduce_points():
28+
return (
29+
CartesianPoint((1, 2)),
30+
CartesianPoint((3.2, 4.0)),
31+
CartesianPoint((3, 4, -1)),
32+
CartesianPoint((3.2, 4.0, -1.2)),
33+
)
34+
35+
2436
class TestCartesianPoint:
2537

2638
def test_alias_3d(self) -> None:
@@ -42,3 +54,30 @@ def test_alias_2d(self) -> None:
4254
assert p.y == y
4355
with pytest.raises(AttributeError):
4456
_ = p.z
57+
58+
@pytest.mark.parametrize("p", make_reduce_points())
59+
def test_copy(self, p):
60+
p.foo = [1, 2]
61+
p2 = copy.copy(p)
62+
assert p == p2
63+
assert p is not p2
64+
assert p.foo is p2.foo
65+
66+
@pytest.mark.parametrize("p", make_reduce_points())
67+
def test_deep_copy(self, p):
68+
p.foo = [1, [2]]
69+
p2 = copy.deepcopy(p)
70+
assert p == p2
71+
assert p is not p2
72+
assert p.foo == p2.foo
73+
assert p.foo is not p2.foo
74+
assert p.foo[1] is not p2.foo[1]
75+
76+
@pytest.mark.parametrize("expected", make_reduce_points())
77+
def test_pickle(self, expected):
78+
expected.foo = [1, [2]]
79+
actual = pickle.loads(pickle.dumps(expected))
80+
assert expected == actual
81+
assert expected is not actual
82+
assert expected.foo == actual.foo
83+
assert expected.foo is not actual.foo

tests/unit/common/spatial/test_point.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from __future__ import annotations
1818

19+
import copy
20+
import pickle
1921
import typing as t
2022

2123
import pytest
@@ -26,6 +28,17 @@
2628
)
2729

2830

31+
def make_reduce_points():
32+
return (
33+
Point((42,)),
34+
Point((69.420,)),
35+
Point((1, 2)),
36+
Point((1.2, 2.3)),
37+
Point((1, 3, 3, 7)),
38+
Point((1.0, 3.0, 3.0, 7.0)),
39+
)
40+
41+
2942
class TestPoint:
3043

3144
@pytest.mark.parametrize("argument", (
@@ -59,3 +72,30 @@ def test_immutable_coordinates(self) -> None:
5972
p[1] = 2.0 # type: ignore[index]
6073
with pytest.raises(TypeError):
6174
p[2] = 2.0 # type: ignore[index]
75+
76+
@pytest.mark.parametrize("p", make_reduce_points())
77+
def test_copy(self, p):
78+
p.foo = [1, 2]
79+
p2 = copy.copy(p)
80+
assert p == p2
81+
assert p is not p2
82+
assert p.foo is p2.foo
83+
84+
@pytest.mark.parametrize("p", make_reduce_points())
85+
def test_deep_copy(self, p):
86+
p.foo = [1, [2]]
87+
p2 = copy.deepcopy(p)
88+
assert p == p2
89+
assert p is not p2
90+
assert p.foo == p2.foo
91+
assert p.foo is not p2.foo
92+
assert p.foo[1] is not p2.foo[1]
93+
94+
@pytest.mark.parametrize("expected", make_reduce_points())
95+
def test_pickle(self, expected):
96+
expected.foo = [1, [2]]
97+
actual = pickle.loads(pickle.dumps(expected))
98+
assert expected == actual
99+
assert expected is not actual
100+
assert expected.foo == actual.foo
101+
assert expected.foo is not actual.foo

tests/unit/common/spatial/test_wgs84_point.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,23 @@
1616

1717
from __future__ import annotations
1818

19+
import copy
20+
import pickle
21+
1922
import pytest
2023

2124
from neo4j.spatial import WGS84Point
2225

2326

27+
def make_reduce_points():
28+
return (
29+
WGS84Point((1, 2)),
30+
WGS84Point((3.2, 4.0)),
31+
WGS84Point((3, 4, -1)),
32+
WGS84Point((3.2, 4.0, -1.2)),
33+
)
34+
35+
2436
class TestWGS84Point:
2537

2638
def test_alias_3d(self) -> None:
@@ -60,3 +72,30 @@ def test_alias_2d(self) -> None:
6072
p.height
6173
with pytest.raises(AttributeError):
6274
p.z
75+
76+
@pytest.mark.parametrize("p", make_reduce_points())
77+
def test_copy(self, p):
78+
p.foo = [1, 2]
79+
p2 = copy.copy(p)
80+
assert p == p2
81+
assert p is not p2
82+
assert p.foo is p2.foo
83+
84+
@pytest.mark.parametrize("p", make_reduce_points())
85+
def test_deep_copy(self, p):
86+
p.foo = [1, [2]]
87+
p2 = copy.deepcopy(p)
88+
assert p == p2
89+
assert p is not p2
90+
assert p.foo == p2.foo
91+
assert p.foo is not p2.foo
92+
assert p.foo[1] is not p2.foo[1]
93+
94+
@pytest.mark.parametrize("expected", make_reduce_points())
95+
def test_pickle(self, expected):
96+
expected.foo = [1, [2]]
97+
actual = pickle.loads(pickle.dumps(expected))
98+
assert expected == actual
99+
assert expected is not actual
100+
assert expected.foo == actual.foo
101+
assert expected.foo is not actual.foo

0 commit comments

Comments
 (0)