Skip to content

Commit 583d21d

Browse files
authored
fix(py): allow conditional cases to be defined out of order (#1599)
Closes #1596 Was tempted to just change the type of `cases` since who would be using it anyway but I've tried to be good and deprecate instead
1 parent e04fcc5 commit 583d21d

File tree

2 files changed

+50
-15
lines changed

2 files changed

+50
-15
lines changed

hugr-py/src/hugr/build/cond_loop.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
from contextlib import AbstractContextManager
8-
from dataclasses import dataclass
8+
from dataclasses import dataclass, field
99
from typing import TYPE_CHECKING
1010

1111
from typing_extensions import Self
@@ -19,6 +19,7 @@
1919
if TYPE_CHECKING:
2020
from hugr.hugr.node_port import Node, ToNode, Wire
2121
from hugr.tys import TypeRow
22+
import warnings
2223

2324

2425
class Case(DfBase[ops.Case]):
@@ -104,8 +105,8 @@ class Conditional(ParentBuilder[ops.Conditional], AbstractContextManager):
104105
Conditional(sum_ty=Bool, other_inputs=[Qubit])
105106
"""
106107

107-
#: map from case index to node holding the :class:`Case <hugr.ops.Case>`
108-
cases: dict[int, Node | None]
108+
#: builders for each case and whether they have been built by the user yet
109+
_case_builders: list[tuple[Case, bool]] = field(default_factory=list)
109110

110111
def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None:
111112
root_op = ops.Conditional(sum_ty, other_inputs)
@@ -115,13 +116,40 @@ def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None:
115116
def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None:
116117
self.hugr = hugr
117118
self.parent_node = root
118-
self.cases = {i: None for i in range(n_cases)}
119+
self._case_builders = []
120+
121+
for case_id in range(n_cases):
122+
new_case = Case.new_nested(
123+
ops.Case(self.parent_op.nth_inputs(case_id)),
124+
self.hugr,
125+
self.parent_node,
126+
)
127+
new_case._parent_cond = self
128+
self._case_builders.append((new_case, False))
129+
130+
@property
131+
def cases(self) -> dict[int, Node | None]:
132+
"""Map from case index to node holding the :class:`Case <hugr.ops.Case>`.
133+
134+
DEPRECATED
135+
"""
136+
# TODO remove in 0.10
137+
warnings.warn(
138+
"The 'cases' property is deprecated and"
139+
" will be removed in a future version.",
140+
DeprecationWarning,
141+
stacklevel=2,
142+
)
143+
return {
144+
i: case.parent_node if b else None
145+
for i, (case, b) in enumerate(self._case_builders)
146+
}
119147

120148
def __enter__(self) -> Self:
121149
return self
122150

123151
def __exit__(self, *args) -> None:
124-
if any(c is None for c in self.cases.values()):
152+
if not all(built for _, built in self._case_builders):
125153
msg = "All cases must be added before exiting context."
126154
raise ConditionalError(msg)
127155
return None
@@ -185,18 +213,15 @@ def add_case(self, case_id: int) -> Case:
185213
>>> with cond.add_case(0) as case:\
186214
case.set_outputs(*case.inputs())
187215
"""
188-
if case_id not in self.cases:
216+
if case_id >= len(self._case_builders):
189217
msg = f"Case {case_id} out of possible range."
190218
raise ConditionalError(msg)
191-
input_types = self.parent_op.nth_inputs(case_id)
192-
new_case = Case.new_nested(
193-
ops.Case(input_types),
194-
self.hugr,
195-
self.parent_node,
196-
)
197-
new_case._parent_cond = self
198-
self.cases[case_id] = new_case.parent_node
199-
return new_case
219+
case, built = self._case_builders[case_id]
220+
if built:
221+
msg = f"Case {case_id} already built."
222+
raise ConditionalError(msg)
223+
self._case_builders[case_id] = (case, True)
224+
return case
200225

201226
# TODO insert_case
202227

hugr-py/tests/test_cond_loop.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,13 @@ def test_complex_tail_loop() -> None:
134134
h.set_outputs(*tl[:3])
135135

136136
validate(h.hugr)
137+
138+
139+
def test_conditional_bug() -> None:
140+
# bug with case ordering https://github.com/CQCL/hugr/issues/1596
141+
cond = Conditional(tys.Either([tys.USize()], [tys.Unit]), [])
142+
with cond.add_case(1) as case:
143+
case.set_outputs()
144+
with cond.add_case(0) as case:
145+
case.set_outputs()
146+
validate(cond.hugr)

0 commit comments

Comments
 (0)