Skip to content

Commit eb24868

Browse files
authored
🐛 fix(arrayish): fix length 0 arrays (#121)
Signed-off-by: Nathaniel Starkman <[email protected]>
1 parent f70fa86 commit eb24868

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

src/quaxed/experimental/_arrayish/container.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,14 @@ class LaxLenMixin:
4545
>>> len(x)
4646
3
4747
48+
>>> x = MyArray(jnp.array(1))
49+
>>> len(x)
50+
0
51+
4852
""" # noqa: E501
4953

5054
def __len__(self: HasShape) -> int:
51-
return self.shape[0]
55+
return self.shape[0] if self.shape else 0
5256

5357

5458
class NumpyLenMixin:
@@ -70,10 +74,15 @@ class NumpyLenMixin:
7074
>>> len(x)
7175
3
7276
77+
>>> x = MyArray(jnp.array(1))
78+
>>> len(x)
79+
0
80+
7381
""" # noqa: E501
7482

7583
def __len__(self) -> int:
76-
return qnp.shape(self)[0]
84+
shape = qnp.shape(self)
85+
return shape[0] if shape else 0
7786

7887

7988
# -----------------------------------------------
@@ -99,10 +108,14 @@ class LaxLengthHintMixin:
99108
>>> x.__length_hint__()
100109
3
101110
111+
>>> x = MyArray(jnp.array(0))
112+
>>> x.__length_hint__()
113+
0
114+
102115
""" # noqa: E501
103116

104117
def __length_hint__(self: HasShape) -> int:
105-
return self.shape[0]
118+
return self.shape[0] if self.shape else 0
106119

107120

108121
class NumpyLengthHintMixin:
@@ -124,7 +137,12 @@ class NumpyLengthHintMixin:
124137
>>> x.__length_hint__()
125138
3
126139
140+
>>> x = MyArray(jnp.array(1))
141+
>>> x.__length_hint__()
142+
0
143+
127144
""" # noqa: E501
128145

129146
def __length_hint__(self) -> int:
130-
return qnp.shape(self)[0]
147+
shape = qnp.shape(self)
148+
return shape[0] if shape else 0

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)