Skip to content

Commit ef8ea5e

Browse files
authored
feat!: Entrypoints in hugr-py (#2148)
Adds entrypoint definitions to `hugr-py`. Has similar caveats to #2147. The main issue here was builders that initialized a hugr for a dataflow op without knowing its output types. The automatic machinery that wraps the op in a function definition needs to connect the outputs to the new function's output once we know the types, so I had to add some extra machinery to `DataflowOp` to signal back to the hugr that it should connect things. This feels a bit hacky, perhaps there's a cleaner way to do it... Blocked by #2147. BREAKING CHANGE: Hugrs now define an `entrypoint` in addition to a module root.
1 parent 125b341 commit ef8ea5e

File tree

18 files changed

+1777
-439
lines changed

18 files changed

+1777
-439
lines changed

hugr-py/src/hugr/_serialization/serial_hugr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class SerialHugr(ConfiguredBaseModel):
3535
encoder: str | None = Field(
3636
default=None, description="The name of the encoder used to generate the Hugr."
3737
)
38+
entrypoint: NodeIdx | None = None
3839

3940
def to_json(self) -> str:
4041
"""Return a JSON representation of the Hugr."""

hugr-py/src/hugr/build/cfg.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ class Block(DfBase[ops.DataflowBlock]):
2323
"""Builder class for a basic block in a HUGR control flow graph."""
2424

2525
def set_outputs(self, *outputs: Wire) -> None:
26-
super().set_outputs(*outputs)
27-
2826
assert len(outputs) > 0
2927
branching = outputs[0]
3028
branch_type = self.hugr.port_type(branching.out_port())
3129
assert isinstance(branch_type, tys.Sum)
3230
self._set_parent_output_count(len(branch_type.variant_rows))
3331

32+
super().set_outputs(*outputs)
33+
3434
def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
3535
self.set_outputs(branching, *other_outputs)
3636

@@ -50,7 +50,7 @@ def _wire_up_port(self, node: Node, offset: PortOffset, p: Wire) -> Type:
5050
# it does not check for valid dominance between basic blocks
5151
# that is deferred to full HUGR validation.
5252
while cfg_node != src_parent:
53-
if src_parent is None or src_parent == self.hugr.root:
53+
if src_parent is None or src_parent == self.hugr.module_root:
5454
raise NotInSameCfg(src.node.idx, node.idx) from e
5555
src_parent = self.hugr[src_parent].parent
5656

@@ -60,7 +60,7 @@ def _wire_up_port(self, node: Node, offset: PortOffset, p: Wire) -> Type:
6060

6161
@dataclass
6262
class Cfg(ParentBuilder[ops.CFG], AbstractContextManager):
63-
"""Builder class for a HUGR control flow graph, with the HUGR root node
63+
"""Builder class for a HUGR control flow graph, with the HUGR entrypoint node
6464
being a :class:`CFG <hugr.ops.CFG>`.
6565
6666
Args:
@@ -85,13 +85,17 @@ def __init__(self, *input_types: Type) -> None:
8585
input_typs = list(input_types)
8686
root_op = ops.CFG(inputs=input_typs)
8787
hugr = Hugr(root_op)
88-
self._init_impl(hugr, hugr.root, input_typs)
88+
self._init_impl(hugr, hugr.entrypoint, input_typs)
8989

90-
def _init_impl(self: Cfg, hugr: Hugr, root: Node, input_types: TypeRow) -> None:
90+
def _init_impl(
91+
self: Cfg, hugr: Hugr, entrypoint: Node, input_types: TypeRow
92+
) -> None:
9193
self.hugr = hugr
92-
self.parent_node = root
94+
self.parent_node = entrypoint
9395
# to ensure entry is first child, add a dummy entry at the start
94-
self._entry_block = Block.new_nested(ops.DataflowBlock(input_types), hugr, root)
96+
self._entry_block = Block.new_nested(
97+
ops.DataflowBlock(input_types), hugr, entrypoint
98+
)
9599

96100
self.exit = self.hugr.add_node(ops.ExitBlock(), self.parent_node)
97101

@@ -128,7 +132,7 @@ def new_nested(
128132
new = cls.__new__(cls)
129133
root = hugr.add_node(
130134
ops.CFG(inputs=input_types),
131-
parent or hugr.root,
135+
parent or hugr.entrypoint,
132136
)
133137
new._init_impl(hugr, root, input_types)
134138
return new
@@ -261,3 +265,8 @@ def branch_exit(self, src: Wire) -> None:
261265
self.parent_node = self.hugr._update_node_outs(
262266
self.parent_node, len(out_types)
263267
)
268+
if (
269+
self.parent_op._entrypoint_requires_wiring
270+
and self.hugr.entrypoint == self.parent_node
271+
):
272+
self.hugr._connect_df_entrypoint_outputs()

hugr-py/src/hugr/build/cond_loop.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,13 @@ class Conditional(ParentBuilder[ops.Conditional], AbstractContextManager):
112112
def __init__(self, sum_ty: Sum, other_inputs: TypeRow) -> None:
113113
root_op = ops.Conditional(sum_ty, other_inputs)
114114
hugr = Hugr(root_op)
115-
self._init_impl(hugr, hugr.root, len(sum_ty.variant_rows))
115+
self._init_impl(hugr, hugr.entrypoint, len(sum_ty.variant_rows))
116116

117-
def _init_impl(self: Conditional, hugr: Hugr, root: Node, n_cases: int) -> None:
117+
def _init_impl(
118+
self: Conditional, hugr: Hugr, entrypoint: Node, n_cases: int
119+
) -> None:
118120
self.hugr = hugr
119-
self.parent_node = root
121+
self.parent_node = entrypoint
120122
self._case_builders = []
121123

122124
for case_id in range(n_cases):
@@ -163,18 +165,18 @@ def new_nested(
163165
other_inputs: The inputs for the conditional that aren't included in the
164166
sum variants. These are passed to all cases.
165167
hugr: The HUGR instance this Conditional is part of.
166-
parent: The parent node for the Conditional: defaults to the root of
168+
parent: The parent node for the Conditional: defaults to the entrypoint of
167169
the HUGR instance.
168170
169171
Returns:
170172
The new Conditional builder.
171173
"""
172174
new = cls.__new__(cls)
173-
root = hugr.add_node(
175+
entrypoint = hugr.add_node(
174176
ops.Conditional(sum_ty, other_inputs),
175-
parent or hugr.root,
177+
parent or hugr.entrypoint,
176178
)
177-
new._init_impl(hugr, root, len(sum_ty.variant_rows))
179+
new._init_impl(hugr, entrypoint, len(sum_ty.variant_rows))
178180
return new
179181

180182
def _update_outputs(self, outputs: TypeRow) -> None:
@@ -183,6 +185,11 @@ def _update_outputs(self, outputs: TypeRow) -> None:
183185
self.parent_node = self.hugr._update_node_outs(
184186
self.parent_node, len(outputs)
185187
)
188+
if (
189+
self.parent_op._entrypoint_requires_wiring
190+
and self.hugr.entrypoint == self.parent_node
191+
):
192+
self.hugr._connect_df_entrypoint_outputs()
186193
else:
187194
if outputs != self.parent_op._outputs:
188195
msg = "Mismatched case outputs."
@@ -239,15 +246,15 @@ def __init__(self, just_inputs: TypeRow, rest: TypeRow) -> None:
239246
super().__init__(root_op)
240247

241248
def set_outputs(self, *outputs: Wire) -> None:
242-
super().set_outputs(*outputs)
243-
244249
assert len(outputs) > 0
245250
sum_wire = outputs[0]
246251
sum_type = self.hugr.port_type(sum_wire.out_port())
247252
assert isinstance(sum_type, Sum)
248253
assert len(sum_type.variant_rows) == 2
249254
self._set_parent_output_count(len(sum_type.variant_rows[1]) + len(outputs) - 1)
250255

256+
super().set_outputs(*outputs)
257+
251258
def set_loop_outputs(self, sum_wire: Wire, *rest: Wire) -> None:
252259
"""Set the outputs of the loop body. The first wire must be the sum type
253260
that controls loop termination.

hugr-py/src/hugr/build/dfg.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
OpVar = TypeVar("OpVar", bound=ops.Op)
3131

3232

33+
class DataflowError(Exception):
34+
"""Error building a :class:`DfBase` dataflow graph."""
35+
36+
3337
@dataclass()
3438
class DefinitionBuilder(Generic[OpVar]):
3539
"""Base class for builders that can define functions, constants, and aliases.
@@ -55,12 +59,12 @@ def define_function(
5559
output_types: The output types for the function.
5660
If not provided, it will be inferred after the function is built.
5761
type_params: The type parameters for the function, if polymorphic.
58-
parent: The parent node of the constant. Defaults to the root node.
62+
parent: The parent node of the constant. Defaults to the entrypoint node.
5963
6064
Returns:
6165
The new function builder.
6266
"""
63-
parent_node = parent or self.hugr.root
67+
parent_node = parent or self.hugr.entrypoint
6468
parent_op = ops.FuncDefn(name, input_types, type_params or [])
6569
func = Function.new_nested(parent_op, self.hugr, parent_node)
6670
if output_types is not None:
@@ -72,7 +76,7 @@ def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node:
7276
7377
Args:
7478
value: The constant value to add.
75-
parent: The parent node of the constant. Defaults to the root node.
79+
parent: The parent node of the constant. Defaults to the entrypoint node.
7680
7781
Returns:
7882
The node holding the :class:`Const <hugr.ops.Const>` operation.
@@ -83,12 +87,12 @@ def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node:
8387
>>> dfg.hugr[const_n].op
8488
Const(TRUE)
8589
"""
86-
parent_node = parent or self.hugr.root
90+
parent_node = parent or self.hugr.entrypoint
8791
return self.hugr.add_node(ops.Const(value), parent_node)
8892

8993
def add_alias_defn(self, name: str, ty: Type, parent: ToNode | None = None) -> Node:
9094
"""Add a type alias definition."""
91-
parent_node = parent or self.hugr.root
95+
parent_node = parent or self.hugr.entrypoint
9296
return self.hugr.add_node(ops.AliasDefn(name, ty), parent_node)
9397

9498

@@ -114,7 +118,7 @@ class DfBase(ParentBuilder[DP], DefinitionBuilder, AbstractContextManager):
114118

115119
def __init__(self, parent_op: DP) -> None:
116120
self.hugr = Hugr(parent_op)
117-
self.parent_node = self.hugr.root
121+
self.parent_node = self.hugr.entrypoint
118122
self._init_io_nodes(parent_op)
119123

120124
def _init_io_nodes(self, parent_op: DP):
@@ -141,7 +145,7 @@ def new_nested(
141145
parent_op: The parent operation of the new dataflow graph.
142146
hugr: The host HUGR instance to build the dataflow graph in.
143147
parent: Parent of new dataflow graph's root node: defaults to the
144-
host HUGR root.
148+
host HUGR entrypoint.
145149
146150
Example:
147151
>>> hugr = Hugr()
@@ -152,10 +156,44 @@ def new_nested(
152156
new = cls.__new__(cls)
153157

154158
new.hugr = hugr
155-
new.parent_node = hugr.add_node(parent_op, parent or hugr.root)
159+
new.parent_node = hugr.add_node(parent_op, parent or hugr.entrypoint)
156160
new._init_io_nodes(parent_op)
157161
return new
158162

163+
@classmethod
164+
def _new_existing(cls, hugr: Hugr, root: ToNode | None = None) -> Self:
165+
"""Start a dataflow graph builder for an existing node.
166+
167+
Args:
168+
hugr: The host HUGR instance to build the dataflow graph in.
169+
root: The dataflow graph's root node.
170+
Defaults to the host HUGR's entrypoint.
171+
172+
Example:
173+
>>> hugr = Hugr(ops.DFG([]))
174+
>>> dfg = Dfg._new_existing(hugr)
175+
>>> dfg.parent_node
176+
Node(4)
177+
178+
Raises:
179+
:class:`DataflowError` if the `root` operation is not a dataflow
180+
parent.
181+
"""
182+
root = root or hugr.entrypoint
183+
184+
if not ops.is_df_parent_op(hugr[root].op):
185+
msg = f"{hugr[root].op} is not a dataflow parent"
186+
raise DataflowError(msg)
187+
188+
new = cls.__new__(cls)
189+
new.hugr = hugr
190+
new.parent_node = root.to_node()
191+
[inp, out] = hugr.children(root)[:2]
192+
new.input_node = inp
193+
new.output_node = out
194+
195+
return new
196+
159197
def _input_op(self) -> ops.Input:
160198
return self.hugr._get_typed_op(self.input_node, ops.Input)
161199

@@ -256,7 +294,7 @@ def insert_nested(self, dfg: Dfg, *args: Wire) -> Node:
256294
args: The input wires to the graph.
257295
258296
Returns:
259-
The root node of the inserted graph.
297+
The entrypoint node of the inserted graph.
260298
261299
Example:
262300
>>> dfg = Dfg(tys.Bool)
@@ -482,7 +520,14 @@ def set_outputs(self, *args: Wire) -> None:
482520
>>> dfg.set_outputs(dfg.inputs()[0]) # connect input to output
483521
"""
484522
self._wire_up(self.output_node, args)
485-
self.parent_op._set_out_types(self._output_op().types)
523+
out_types = self._output_op().types
524+
self.parent_op._set_out_types(out_types)
525+
if (
526+
isinstance(self.parent_op, ops.DataflowOp)
527+
and self.parent_op._entrypoint_requires_wiring
528+
and self.hugr.entrypoint == self.parent_node
529+
):
530+
self.hugr._connect_df_entrypoint_outputs()
486531

487532
def _set_parent_output_count(self, count: int) -> None:
488533
"""Set the final number of output ports on the parent operation.
@@ -644,7 +689,7 @@ def _wire_up_port(self, node: Node, offset: PortOffset, p: Wire) -> tys.Type:
644689

645690

646691
class Dfg(DfBase[ops.DFG]):
647-
"""Builder for a simple nested Dataflow graph, with root node of type
692+
"""Builder for a simple nested Dataflow graph, with entrypoint node of type
648693
:class:`DFG <hugr.ops.DFG>`.
649694
650695
Args:
@@ -662,8 +707,8 @@ def __init__(self, *input_types: tys.Type) -> None:
662707
super().__init__(parent_op)
663708

664709
def set_outputs(self, *outputs: Wire) -> None:
665-
super().set_outputs(*outputs)
666710
self._set_parent_output_count(len(outputs))
711+
super().set_outputs(*outputs)
667712

668713

669714
def _ancestral_sibling(h: Hugr, src: Node, tgt: Node) -> Node | None:

hugr-py/src/hugr/build/function.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Module(DefinitionBuilder[ops.Module]):
2222
2323
Examples:
2424
>>> m = Module()
25-
>>> m.hugr.root_op()
25+
>>> m.hugr.entrypoint_op()
2626
Module()
2727
"""
2828

@@ -52,13 +52,13 @@ def declare_function(self, name: str, signature: PolyFuncType) -> Node:
5252
>>> m.declare_function("f", sig)
5353
Node(1)
5454
"""
55-
return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.root)
55+
return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.entrypoint)
5656

5757
def add_alias_decl(self, name: str, bound: TypeBound) -> Node:
5858
"""Add a type alias declaration."""
59-
return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.root)
59+
return self.hugr.add_node(ops.AliasDecl(name, bound), self.hugr.entrypoint)
6060

6161
@property
6262
def metadata(self) -> dict[str, object]:
6363
"""Metadata associated with this module."""
64-
return self.hugr.root.metadata
64+
return self.hugr.entrypoint.metadata

0 commit comments

Comments
 (0)