Skip to content

Commit cf860f3

Browse files
authored
feat(hugr-py): Add StaticArray to standard extensions (#1985)
Closes #1984
1 parent d91dc8a commit cf860f3

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Statically sized immutable array type and its operations."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
7+
from hugr import tys, val
8+
from hugr.std import _load_extension
9+
from hugr.utils import comma_sep_str
10+
11+
EXTENSION = _load_extension("collections.static_array")
12+
13+
14+
@dataclass(eq=False)
15+
class StaticArray(tys.ExtType):
16+
"""Fixed size immutable array of `ty` elements."""
17+
18+
def __init__(self, ty: tys.Type) -> None:
19+
self.type_def = EXTENSION.types["static_array"]
20+
if (
21+
tys.TypeBound.join(ty.type_bound(), tys.TypeBound.Copyable)
22+
!= tys.TypeBound.Copyable
23+
):
24+
msg = "Static array elements must be copyable"
25+
raise ValueError(msg)
26+
self.args = [tys.TypeTypeArg(ty)]
27+
28+
@property
29+
def ty(self) -> tys.Type:
30+
assert isinstance(
31+
self.args[0], tys.TypeTypeArg
32+
), "Array elements must have a valid type"
33+
return self.args[0].ty
34+
35+
def type_bound(self) -> tys.TypeBound:
36+
return self.ty.type_bound()
37+
38+
39+
@dataclass
40+
class StaticArrayVal(val.ExtensionValue):
41+
"""Constant value for a statically sized immutable array of elements."""
42+
43+
v: list[val.Value]
44+
ty: tys.Type
45+
name: str
46+
47+
def __init__(self, v: list[val.Value], elem_ty: tys.Type, name: str) -> None:
48+
self.v = v
49+
self.ty = StaticArray(elem_ty)
50+
self.name = name
51+
52+
def to_value(self) -> val.Extension:
53+
# The value list must be serialized at this point, otherwise the
54+
# `Extension` value would not be serializable.
55+
vs = [v._to_serial_root() for v in self.v]
56+
serial_val = {"values": vs, "name": self.name}
57+
return val.Extension(
58+
"StaticArrayValue", typ=self.ty, val=serial_val, extensions=[EXTENSION.name]
59+
)
60+
61+
def __str__(self) -> str:
62+
return f"static_array({comma_sep_str(self.v)})"

hugr-py/tests/test_tys.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from hugr import val
66
from hugr.std.collections.array import Array, ArrayVal
77
from hugr.std.collections.list import List, ListVal
8+
from hugr.std.collections.static_array import StaticArray, StaticArrayVal
89
from hugr.std.float import FLOAT_T
910
from hugr.std.int import INT_T, _int_tv
1011
from hugr.tys import (
@@ -123,6 +124,7 @@ def test_args_str(arg: TypeArg, string: str):
123124
("ty", "string"),
124125
[
125126
(Array(Bool, 3), "array<3, Type(Bool)>"),
127+
(StaticArray(Bool), "static_array<Type(Bool)>"),
126128
(Variable(2, TypeBound.Any), "$2"),
127129
(RowVariable(4, TypeBound.Copyable), "$4"),
128130
(USize(), "USize"),
@@ -174,3 +176,21 @@ def test_array():
174176
ar_val = ArrayVal([val.TRUE, val.FALSE], Bool)
175177
assert ar_val.v == [val.TRUE, val.FALSE]
176178
assert ar_val.ty == Array(Bool, 2)
179+
180+
181+
def test_static_array():
182+
ty_var = Variable(0, TypeBound.Copyable)
183+
184+
ls = StaticArray(Bool)
185+
assert ls.ty == Bool
186+
187+
ls = StaticArray(ty_var)
188+
assert ls.ty == ty_var
189+
190+
name = "array_name"
191+
ar_val = StaticArrayVal([val.TRUE, val.FALSE], Bool, name)
192+
assert ar_val.v == [val.TRUE, val.FALSE]
193+
assert ar_val.ty == StaticArray(Bool)
194+
195+
with pytest.raises(ValueError, match="Static array elements must be copyable"):
196+
StaticArray(Qubit)

0 commit comments

Comments
 (0)