Skip to content

Commit 78e0176

Browse files
authored
feat: add separate DCE pass (#1902)
Still non-breaking, so we leave Constant Folding as assuming inputs apply to `main`: the DCE pass allows explicitly specifying entry points, constant folding specifies `main` if appropriate. (I could add a flag to ConstantFoldPass ~=~ preserve all FuncDefns/FuncDecls?) The callback mechanism generalizes the previous `might_diverge` mechanism but is significantly more general. (Too general??) closes #1807
1 parent b760ef8 commit 78e0176

File tree

4 files changed

+372
-116
lines changed

4 files changed

+372
-116
lines changed

hugr-passes/src/const_fold.rs

Lines changed: 35 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,30 @@
33
//! An (example) use of the [dataflow analysis framework](super::dataflow).
44
55
pub mod value_handle;
6-
use std::collections::{HashMap, HashSet, VecDeque};
6+
use std::{collections::HashMap, sync::Arc};
77
use thiserror::Error;
88

99
use hugr_core::{
10-
core::HugrNode,
1110
hugr::{
1211
hugrmut::HugrMut,
1312
views::{DescendantsGraph, ExtractHugr, HierarchyView},
1413
},
1514
ops::{
1615
constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant,
17-
OpType, Value,
16+
Value,
1817
},
1918
types::{EdgeKind, TypeArg},
2019
HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire,
2120
};
2221
use value_handle::ValueHandle;
2322

2423
use crate::dataflow::{
25-
partial_from_const, AbstractValue, AnalysisResults, ConstLoader, ConstLocation, DFContext,
26-
Machine, PartialValue, TailLoopTermination,
24+
partial_from_const, ConstLoader, ConstLocation, DFContext, Machine, PartialValue,
25+
TailLoopTermination,
2726
};
27+
use crate::dead_code::PreserveNode;
2828
use crate::validation::{ValidatePassError, ValidationLevel};
29+
use crate::{find_main, DeadCodeElimPass};
2930

3031
#[derive(Debug, Clone, Default)]
3132
/// A configuration for the Constant Folding pass.
@@ -89,22 +90,15 @@ impl ConstantFoldPass {
8990
});
9091

9192
let results = Machine::new(&hugr).run(ConstFoldContext(hugr), inputs);
92-
let keep_nodes = self.find_needed_nodes(&results);
9393
let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i);
9494

95-
let remove_nodes = hugr
95+
let wires_to_break = hugr
9696
.nodes()
97-
.filter(|n| !keep_nodes.contains(n))
98-
.collect::<HashSet<_>>();
99-
let wires_to_break = keep_nodes
100-
.into_iter()
10197
.flat_map(|n| hugr.node_inputs(n).map(move |ip| (n, ip)))
10298
.filter(|(n, ip)| {
10399
*n != hugr.root()
104100
&& matches!(hugr.get_optype(*n).port_kind(*ip), Some(EdgeKind::Value(_)))
105101
})
106-
// Note we COULD filter out (avoid breaking) wires from other nodes that we are keeping.
107-
// This would insert fewer constants, but potentially expose less parallelism.
108102
.filter_map(|(n, ip)| {
109103
let (src, outp) = hugr.single_linked_output(n, ip).unwrap();
110104
// Avoid breaking edges from existing LoadConstant (we'd only add another)
@@ -119,20 +113,42 @@ impl ConstantFoldPass {
119113
))
120114
})
121115
.collect::<Vec<_>>();
116+
// Sadly the results immutably borrow the hugr, so we must extract everything we need before mutation
117+
let terminating_tail_loops = hugr
118+
.nodes()
119+
.filter(|n| {
120+
results.tail_loop_terminates(*n) == Some(TailLoopTermination::NeverContinues)
121+
})
122+
.collect::<Vec<_>>();
122123

123-
for (n, import, v) in wires_to_break {
124+
for (n, inport, v) in wires_to_break {
124125
let parent = hugr.get_parent(n).unwrap();
125126
let datatype = v.get_type();
126127
// We could try hash-consing identical Consts, but not ATM
127128
let cst = hugr.add_node_with_parent(parent, Const::new(v));
128129
let lcst = hugr.add_node_with_parent(parent, LoadConstant { datatype });
129130
hugr.connect(cst, OutgoingPort::from(0), lcst, IncomingPort::from(0));
130-
hugr.disconnect(n, import);
131-
hugr.connect(lcst, OutgoingPort::from(0), n, import);
132-
}
133-
for n in remove_nodes {
134-
hugr.remove_node(n);
131+
hugr.disconnect(n, inport);
132+
hugr.connect(lcst, OutgoingPort::from(0), n, inport);
135133
}
134+
// Dataflow analysis applies our inputs to the 'main' function if this is a Module, so do the same here
135+
DeadCodeElimPass::default()
136+
.with_entry_points(hugr.get_optype(hugr.root()).is_module().then(
137+
// No main => remove everything, so not much use
138+
|| find_main(hugr).unwrap(),
139+
))
140+
.set_preserve_callback(if self.allow_increase_termination {
141+
Arc::new(|_, _| PreserveNode::CanRemoveIgnoringChildren)
142+
} else {
143+
Arc::new(move |h, n| {
144+
if terminating_tail_loops.contains(&n) {
145+
PreserveNode::DeferToChildren
146+
} else {
147+
PreserveNode::default_for(h, n)
148+
}
149+
})
150+
})
151+
.run(hugr)?;
136152
Ok(())
137153
}
138154

@@ -141,97 +157,6 @@ impl ConstantFoldPass {
141157
self.validation
142158
.run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr))
143159
}
144-
145-
fn find_needed_nodes<H: HugrView>(
146-
&self,
147-
results: &AnalysisResults<ValueHandle, H>,
148-
) -> HashSet<H::Node> {
149-
let mut needed = HashSet::new();
150-
let h = results.hugr();
151-
let mut q = VecDeque::from_iter([h.root()]);
152-
while let Some(n) = q.pop_front() {
153-
if !needed.insert(n) {
154-
continue;
155-
};
156-
if h.get_optype(n).is_module() {
157-
for ch in h.children(n) {
158-
match h.get_optype(ch) {
159-
OpType::AliasDecl(_) | OpType::AliasDefn(_) => {
160-
// Use of these is done via names, rather than following edges.
161-
// We could track these as well but for now be conservative.
162-
q.push_back(ch);
163-
}
164-
OpType::FuncDefn(f) if f.name == "main" => {
165-
// Dataflow analysis will have applied any inputs the 'main' function, so assume reachable.
166-
q.push_back(ch);
167-
}
168-
_ => (),
169-
}
170-
}
171-
} else if h.get_optype(n).is_cfg() {
172-
for bb in h.children(n) {
173-
//if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates
174-
q.push_back(bb);
175-
}
176-
} else if let Some(inout) = h.get_io(n) {
177-
// Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges.
178-
q.extend(inout); // Input also necessary for legality even if unreachable
179-
180-
if !self.allow_increase_termination {
181-
// Also add on anything that might not terminate (even if results not required -
182-
// if its results are required we'll add it by following dataflow, below.)
183-
for ch in h.children(n) {
184-
if might_diverge(results, ch) {
185-
q.push_back(ch);
186-
}
187-
}
188-
}
189-
}
190-
// Also follow dataflow demand
191-
for (src, op) in h.all_linked_outputs(n) {
192-
let needs_predecessor = match h.get_optype(src).port_kind(op).unwrap() {
193-
EdgeKind::Value(_) => {
194-
h.get_optype(src).is_load_constant()
195-
|| results
196-
.try_read_wire_concrete::<Value, _, _>(Wire::new(src, op))
197-
.is_err()
198-
}
199-
EdgeKind::StateOrder | EdgeKind::Const(_) | EdgeKind::Function(_) => true,
200-
EdgeKind::ControlFlow => false, // we always include all children of a CFG above
201-
_ => true, // needed as EdgeKind non-exhaustive; not knowing what it is, assume the worst
202-
};
203-
if needs_predecessor {
204-
q.push_back(src);
205-
}
206-
}
207-
}
208-
needed
209-
}
210-
}
211-
212-
// "Diverge" aka "never-terminate"
213-
// TODO would be more efficient to compute this bottom-up and cache (dynamic programming)
214-
fn might_diverge<V: AbstractValue, N: HugrNode>(
215-
results: &AnalysisResults<V, impl HugrView<Node = N>>,
216-
n: N,
217-
) -> bool {
218-
let op = results.hugr().get_optype(n);
219-
if op.is_cfg() {
220-
// TODO if the CFG has no cycles (that are possible given predicates)
221-
// then we could say it definitely terminates (i.e. return false)
222-
true
223-
} else if op.is_tail_loop()
224-
&& results.tail_loop_terminates(n).unwrap() != TailLoopTermination::NeverContinues
225-
{
226-
// If we can even figure out the number of iterations is bounded that would allow returning false.
227-
true
228-
} else {
229-
// Node does not introduce non-termination, but still non-terminates if any of its children does
230-
results
231-
.hugr()
232-
.children(n)
233-
.any(|ch| might_diverge(results, ch))
234-
}
235160
}
236161

237162
/// Exhaustively apply constant folding to a HUGR.

hugr-passes/src/dataflow/datalog.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ use hugr_core::extension::prelude::{MakeTuple, UnpackTuple};
1010
use hugr_core::ops::{OpTrait, OpType, TailLoop};
1111
use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire};
1212

13+
use crate::find_main;
14+
1315
use super::value_row::ValueRow;
1416
use super::{
1517
partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext,
@@ -83,12 +85,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
8385
// we must find the corresponding Input node.
8486
let input_node_parent = match self.0.get_optype(root) {
8587
OpType::Module(_) => {
86-
let main = self.0.children(root).find(|n| {
87-
self.0
88-
.get_optype(*n)
89-
.as_func_defn()
90-
.is_some_and(|f| f.name == "main")
91-
});
88+
let main = find_main(&self.0);
9289
if main.is_none() && in_values.next().is_some() {
9390
panic!("Cannot give inputs to module with no 'main'");
9491
}

0 commit comments

Comments
 (0)