Skip to content

Commit 911cd3c

Browse files
committed
feat: support storing tuples in state
1 parent 2b20087 commit 911cd3c

File tree

66 files changed

+12841
-806
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+12841
-806
lines changed

src/_algopy_testing/arc4.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,12 +742,18 @@ def __repr__(self) -> str:
742742

743743

744744
class _DynamicArrayTypeInfo(_TypeInfo):
745-
def __init__(self, item_type: _TypeInfo):
745+
_subclass_type: Callable[[], type] | None
746+
747+
def __init__(self, item_type: _TypeInfo, subclass_type: Callable[[], type] | None = None):
748+
self._subclass_type = subclass_type
746749
self.item_type = item_type
747750

748751
@property
749752
def typ(self) -> type:
750-
return _parameterize_type(DynamicArray, self.item_type.typ)
753+
if self._subclass_type is not None:
754+
return self._subclass_type()
755+
else:
756+
return _parameterize_type(DynamicArray, self.item_type.typ)
751757

752758
@property
753759
def arc4_name(self) -> str:
@@ -891,6 +897,10 @@ def __repr__(self) -> str:
891897
class DynamicBytes(DynamicArray[Byte]):
892898
"""A variable sized array of bytes."""
893899

900+
_type_info: _DynamicArrayTypeInfo = _DynamicArrayTypeInfo(
901+
Byte._type_info, lambda: DynamicBytes
902+
)
903+
894904
@typing.overload
895905
def __init__(self, *values: Byte | UInt8 | int): ...
896906

@@ -996,6 +1006,12 @@ def __init__(self, _items: tuple[typing.Unpack[_TTuple]] = (), /): # type: igno
9961006
)
9971007
self._value = _encode(items)
9981008

1009+
def __bool__(self) -> bool:
1010+
try:
1011+
return bool(self.native)
1012+
except ValueError:
1013+
return False
1014+
9991015
def __len__(self) -> int:
10001016
return len(self.native)
10011017

@@ -1103,6 +1119,8 @@ def _update_backing_value(self) -> None:
11031119
def from_bytes(cls, value: algopy.Bytes | bytes, /) -> typing.Self:
11041120
tuple_type = _tuple_type_from_struct(cls)
11051121
tuple_value = tuple_type.from_bytes(value)
1122+
if not tuple_value:
1123+
return typing.cast(typing.Self, tuple_value)
11061124
return cls(*tuple_value.native)
11071125

11081126
@property

src/_algopy_testing/models/contract.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4+
import inspect
45
import typing
56
from dataclasses import dataclass
67

@@ -201,12 +202,12 @@ def _get_state_totals(contract: Contract, cls_state_totals: StateTotals) -> _Sta
201202

202203
global_bytes = global_uints = local_bytes = local_uints = 0
203204
for type_ in get_global_states(contract).values():
204-
if issubclass(type_, UInt64 | UInt64Backed | bool):
205+
if inspect.isclass(type_) and issubclass(type_, UInt64 | UInt64Backed | bool):
205206
global_uints += 1
206207
else:
207208
global_bytes += 1
208209
for type_ in get_local_states(contract).values():
209-
if issubclass(type_, UInt64 | UInt64Backed | bool):
210+
if inspect.isclass(type_) and issubclass(type_, UInt64 | UInt64Backed | bool):
210211
local_uints += 1
211212
else:
212213
local_bytes += 1

src/_algopy_testing/serialize.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,16 @@ def native_to_arc4(value: object) -> "_ABIEncoded":
136136
return arc4_value
137137

138138

139+
def compare_type(value_type: type, typ: type) -> bool:
140+
if typing.NamedTuple in getattr(typ, "__orig_bases__", []):
141+
tuple_fields: Sequence[type] = list(inspect.get_annotations(typ).values())
142+
typ = tuple[*tuple_fields] # type: ignore[valid-type]
143+
return value_type == typ
144+
145+
139146
def deserialize_from_bytes(typ: type[_T], bites: bytes) -> _T:
140147
serializer = get_native_to_arc4_serializer(typ)
141148
arc4_value = serializer.arc4_type.from_bytes(bites)
142149
native_value = serializer.arc4_to_native(arc4_value)
143-
assert isinstance(native_value, typ)
144-
return native_value
150+
assert compare_type(type_of(native_value), typ) or isinstance(native_value, typ)
151+
return native_value # type: ignore[no-any-return]

src/_algopy_testing/state/box.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,9 @@ def get(self, key: _TKey, *, default: _TValue) -> _TValue:
322322
def maybe(self, key: _TKey) -> tuple[_TValue, bool]:
323323
key_bytes = self._full_key(key)
324324
box_exists = lazy_context.ledger.box_exists(self.app_id, key_bytes)
325-
if not box_exists:
326-
return self._value_type(), False
327-
box_content_bytes = lazy_context.ledger.get_box(self.app_id, key_bytes)
325+
box_content_bytes = (
326+
b"" if not box_exists else lazy_context.ledger.get_box(self.app_id, key_bytes)
327+
)
328328
box_content = cast_from_bytes(self._value_type, box_content_bytes)
329329
return box_content, box_exists
330330

src/_algopy_testing/state/global_state.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from _algopy_testing.context_helpers import lazy_context
77
from _algopy_testing.mutable import set_attr_on_mutate
88
from _algopy_testing.primitives import Bytes, String
9+
from _algopy_testing.serialize import type_of
910
from _algopy_testing.state.utils import deserialize, serialize
1011

1112
if typing.TYPE_CHECKING:
@@ -49,10 +50,10 @@ def __init__(
4950
self._key: Bytes | None = None
5051
self._pending_value: _T | None = None
5152

52-
if isinstance(type_or_value, type):
53-
self.type_: type[_T] = type_or_value
53+
if isinstance(type_or_value, type) or isinstance(typing.get_origin(type_or_value), type):
54+
self.type_: type[_T] = typing.cast(type[_T], type_or_value)
5455
else:
55-
self.type_ = type(type_or_value)
56+
self.type_ = type_of(type_or_value)
5657
self._pending_value = type_or_value
5758

5859
self.set_key(key)
@@ -123,9 +124,7 @@ def get(self, default: _T | None = None) -> _T:
123124
try:
124125
return self.value
125126
except ValueError:
126-
if default is not None:
127-
return default
128-
return self.type_()
127+
return typing.cast(_T, default)
129128

130129
def maybe(self) -> tuple[_T | None, bool]:
131130
try:

src/_algopy_testing/state/local_state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@ def get(self, key: algopy.Account | algopy.UInt64 | int, default: _T | None = No
7575
try:
7676
return self[account]
7777
except KeyError:
78-
return default if default is not None else self.type_()
78+
return typing.cast(_T, default)
7979

8080
def maybe(self, key: algopy.Account | algopy.UInt64 | int) -> tuple[_T, bool]:
8181
account = _get_account(key)
8282
try:
8383
return self[account], True
8484
except KeyError:
85-
return self.type_(), False
85+
return typing.cast(_T, None), False
8686

8787

8888
# TODO: make a util function along with one used by ops

src/_algopy_testing/state/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from __future__ import annotations
22

3+
import inspect
34
import typing
45

56
from _algopy_testing.primitives.bytes import Bytes
67
from _algopy_testing.primitives.uint64 import UInt64
78
from _algopy_testing.protocols import BytesBacked, Serializable, UInt64Backed
9+
from _algopy_testing.serialize import (
10+
deserialize_from_bytes,
11+
serialize_to_bytes,
12+
)
813

914
_TValue = typing.TypeVar("_TValue")
1015
SerializableValue = int | bytes
@@ -21,12 +26,16 @@ def serialize(value: _TValue) -> SerializableValue:
2126
return value.bytes.value
2227
elif isinstance(value, Serializable):
2328
return value.serialize()
29+
elif isinstance(value, tuple):
30+
return serialize_to_bytes(value)
2431
else:
2532
raise TypeError(f"Unsupported type: {type(value)}")
2633

2734

2835
def deserialize(typ: type[_TValue], value: SerializableValue) -> _TValue:
29-
if issubclass(typ, bool):
36+
if (typing.get_origin(typ) is tuple or issubclass(typ, tuple)) and isinstance(value, bytes):
37+
return () if not value else deserialize_from_bytes(typ, value) # type: ignore[return-value]
38+
elif issubclass(typ, bool):
3039
return value != 0 # type: ignore[return-value]
3140
elif issubclass(typ, UInt64 | Bytes):
3241
return typ(value) # type: ignore[arg-type, return-value]
@@ -55,7 +64,7 @@ def cast_from_bytes(typ: type[_TValue], value: bytes) -> _TValue:
5564
"""
5665
from _algopy_testing.utils import as_int64
5766

58-
if issubclass(typ, bool | UInt64Backed | UInt64):
67+
if inspect.isclass(typ) and issubclass(typ, bool | UInt64Backed | UInt64):
5968
if len(value) > 8:
6069
raise ValueError("uint64 value too big")
6170
serialized: SerializableValue = int.from_bytes(value)

0 commit comments

Comments
 (0)