5
5
from __future__ import annotations
6
6
7
7
from contextlib import AbstractContextManager
8
- from dataclasses import dataclass
8
+ from dataclasses import dataclass , field
9
9
from typing import TYPE_CHECKING
10
10
11
11
from typing_extensions import Self
19
19
if TYPE_CHECKING :
20
20
from hugr .hugr .node_port import Node , ToNode , Wire
21
21
from hugr .tys import TypeRow
22
+ import warnings
22
23
23
24
24
25
class Case (DfBase [ops .Case ]):
@@ -104,8 +105,8 @@ class Conditional(ParentBuilder[ops.Conditional], AbstractContextManager):
104
105
Conditional(sum_ty=Bool, other_inputs=[Qubit])
105
106
"""
106
107
107
- #: map from case index to node holding the :class:`Case <hugr.ops.Case>`
108
- cases : dict [ int , Node | None ]
108
+ #: builders for each case and whether they have been built by the user yet
109
+ _case_builders : list [ tuple [ Case , bool ]] = field ( default_factory = list )
109
110
110
111
def __init__ (self , sum_ty : Sum , other_inputs : TypeRow ) -> None :
111
112
root_op = ops .Conditional (sum_ty , other_inputs )
@@ -115,13 +116,40 @@ def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None:
115
116
def _init_impl (self : Conditional , hugr : Hugr , root : Node , n_cases : int ) -> None :
116
117
self .hugr = hugr
117
118
self .parent_node = root
118
- self .cases = {i : None for i in range (n_cases )}
119
+ self ._case_builders = []
120
+
121
+ for case_id in range (n_cases ):
122
+ new_case = Case .new_nested (
123
+ ops .Case (self .parent_op .nth_inputs (case_id )),
124
+ self .hugr ,
125
+ self .parent_node ,
126
+ )
127
+ new_case ._parent_cond = self
128
+ self ._case_builders .append ((new_case , False ))
129
+
130
+ @property
131
+ def cases (self ) -> dict [int , Node | None ]:
132
+ """Map from case index to node holding the :class:`Case <hugr.ops.Case>`.
133
+
134
+ DEPRECATED
135
+ """
136
+ # TODO remove in 0.10
137
+ warnings .warn (
138
+ "The 'cases' property is deprecated and"
139
+ " will be removed in a future version." ,
140
+ DeprecationWarning ,
141
+ stacklevel = 2 ,
142
+ )
143
+ return {
144
+ i : case .parent_node if b else None
145
+ for i , (case , b ) in enumerate (self ._case_builders )
146
+ }
119
147
120
148
def __enter__ (self ) -> Self :
121
149
return self
122
150
123
151
def __exit__ (self , * args ) -> None :
124
- if any ( c is None for c in self .cases . values () ):
152
+ if not all ( built for _ , built in self ._case_builders ):
125
153
msg = "All cases must be added before exiting context."
126
154
raise ConditionalError (msg )
127
155
return None
@@ -185,18 +213,15 @@ def add_case(self, case_id: int) -> Case:
185
213
>>> with cond.add_case(0) as case:\
186
214
case.set_outputs(*case.inputs())
187
215
"""
188
- if case_id not in self .cases :
216
+ if case_id >= len ( self ._case_builders ) :
189
217
msg = f"Case { case_id } out of possible range."
190
218
raise ConditionalError (msg )
191
- input_types = self .parent_op .nth_inputs (case_id )
192
- new_case = Case .new_nested (
193
- ops .Case (input_types ),
194
- self .hugr ,
195
- self .parent_node ,
196
- )
197
- new_case ._parent_cond = self
198
- self .cases [case_id ] = new_case .parent_node
199
- return new_case
219
+ case , built = self ._case_builders [case_id ]
220
+ if built :
221
+ msg = f"Case { case_id } already built."
222
+ raise ConditionalError (msg )
223
+ self ._case_builders [case_id ] = (case , True )
224
+ return case
200
225
201
226
# TODO insert_case
202
227
0 commit comments