3
3
//! An (example) use of the [dataflow analysis framework](super::dataflow).
4
4
5
5
pub mod value_handle;
6
- use std:: collections:: { HashMap , HashSet , VecDeque } ;
6
+ use std:: { collections:: HashMap , sync :: Arc } ;
7
7
use thiserror:: Error ;
8
8
9
9
use hugr_core:: {
10
- core:: HugrNode ,
11
10
hugr:: {
12
11
hugrmut:: HugrMut ,
13
12
views:: { DescendantsGraph , ExtractHugr , HierarchyView } ,
14
13
} ,
15
14
ops:: {
16
15
constant:: OpaqueValue , handle:: FuncID , Const , DataflowOpTrait , ExtensionOp , LoadConstant ,
17
- OpType , Value ,
16
+ Value ,
18
17
} ,
19
18
types:: { EdgeKind , TypeArg } ,
20
19
HugrView , IncomingPort , Node , NodeIndex , OutgoingPort , PortIndex , Wire ,
21
20
} ;
22
21
use value_handle:: ValueHandle ;
23
22
24
23
use crate :: dataflow:: {
25
- partial_from_const, AbstractValue , AnalysisResults , ConstLoader , ConstLocation , DFContext ,
26
- Machine , PartialValue , TailLoopTermination ,
24
+ partial_from_const, ConstLoader , ConstLocation , DFContext , Machine , PartialValue ,
25
+ TailLoopTermination ,
27
26
} ;
27
+ use crate :: dead_code:: PreserveNode ;
28
28
use crate :: validation:: { ValidatePassError , ValidationLevel } ;
29
+ use crate :: { find_main, DeadCodeElimPass } ;
29
30
30
31
#[ derive( Debug , Clone , Default ) ]
31
32
/// A configuration for the Constant Folding pass.
@@ -89,22 +90,15 @@ impl ConstantFoldPass {
89
90
} ) ;
90
91
91
92
let results = Machine :: new ( & hugr) . run ( ConstFoldContext ( hugr) , inputs) ;
92
- let keep_nodes = self . find_needed_nodes ( & results) ;
93
93
let mb_root_inp = hugr. get_io ( hugr. root ( ) ) . map ( |[ i, _] | i) ;
94
94
95
- let remove_nodes = hugr
95
+ let wires_to_break = hugr
96
96
. nodes ( )
97
- . filter ( |n| !keep_nodes. contains ( n) )
98
- . collect :: < HashSet < _ > > ( ) ;
99
- let wires_to_break = keep_nodes
100
- . into_iter ( )
101
97
. flat_map ( |n| hugr. node_inputs ( n) . map ( move |ip| ( n, ip) ) )
102
98
. filter ( |( n, ip) | {
103
99
* n != hugr. root ( )
104
100
&& matches ! ( hugr. get_optype( * n) . port_kind( * ip) , Some ( EdgeKind :: Value ( _) ) )
105
101
} )
106
- // Note we COULD filter out (avoid breaking) wires from other nodes that we are keeping.
107
- // This would insert fewer constants, but potentially expose less parallelism.
108
102
. filter_map ( |( n, ip) | {
109
103
let ( src, outp) = hugr. single_linked_output ( n, ip) . unwrap ( ) ;
110
104
// Avoid breaking edges from existing LoadConstant (we'd only add another)
@@ -119,20 +113,42 @@ impl ConstantFoldPass {
119
113
) )
120
114
} )
121
115
. collect :: < Vec < _ > > ( ) ;
116
+ // Sadly the results immutably borrow the hugr, so we must extract everything we need before mutation
117
+ let terminating_tail_loops = hugr
118
+ . nodes ( )
119
+ . filter ( |n| {
120
+ results. tail_loop_terminates ( * n) == Some ( TailLoopTermination :: NeverContinues )
121
+ } )
122
+ . collect :: < Vec < _ > > ( ) ;
122
123
123
- for ( n, import , v) in wires_to_break {
124
+ for ( n, inport , v) in wires_to_break {
124
125
let parent = hugr. get_parent ( n) . unwrap ( ) ;
125
126
let datatype = v. get_type ( ) ;
126
127
// We could try hash-consing identical Consts, but not ATM
127
128
let cst = hugr. add_node_with_parent ( parent, Const :: new ( v) ) ;
128
129
let lcst = hugr. add_node_with_parent ( parent, LoadConstant { datatype } ) ;
129
130
hugr. connect ( cst, OutgoingPort :: from ( 0 ) , lcst, IncomingPort :: from ( 0 ) ) ;
130
- hugr. disconnect ( n, import) ;
131
- hugr. connect ( lcst, OutgoingPort :: from ( 0 ) , n, import) ;
132
- }
133
- for n in remove_nodes {
134
- hugr. remove_node ( n) ;
131
+ hugr. disconnect ( n, inport) ;
132
+ hugr. connect ( lcst, OutgoingPort :: from ( 0 ) , n, inport) ;
135
133
}
134
+ // Dataflow analysis applies our inputs to the 'main' function if this is a Module, so do the same here
135
+ DeadCodeElimPass :: default ( )
136
+ . with_entry_points ( hugr. get_optype ( hugr. root ( ) ) . is_module ( ) . then (
137
+ // No main => remove everything, so not much use
138
+ || find_main ( hugr) . unwrap ( ) ,
139
+ ) )
140
+ . set_preserve_callback ( if self . allow_increase_termination {
141
+ Arc :: new ( |_, _| PreserveNode :: CanRemoveIgnoringChildren )
142
+ } else {
143
+ Arc :: new ( move |h, n| {
144
+ if terminating_tail_loops. contains ( & n) {
145
+ PreserveNode :: DeferToChildren
146
+ } else {
147
+ PreserveNode :: default_for ( h, n)
148
+ }
149
+ } )
150
+ } )
151
+ . run ( hugr) ?;
136
152
Ok ( ( ) )
137
153
}
138
154
@@ -141,97 +157,6 @@ impl ConstantFoldPass {
141
157
self . validation
142
158
. run_validated_pass ( hugr, |hugr : & mut H , _| self . run_no_validate ( hugr) )
143
159
}
144
-
145
- fn find_needed_nodes < H : HugrView > (
146
- & self ,
147
- results : & AnalysisResults < ValueHandle , H > ,
148
- ) -> HashSet < H :: Node > {
149
- let mut needed = HashSet :: new ( ) ;
150
- let h = results. hugr ( ) ;
151
- let mut q = VecDeque :: from_iter ( [ h. root ( ) ] ) ;
152
- while let Some ( n) = q. pop_front ( ) {
153
- if !needed. insert ( n) {
154
- continue ;
155
- } ;
156
- if h. get_optype ( n) . is_module ( ) {
157
- for ch in h. children ( n) {
158
- match h. get_optype ( ch) {
159
- OpType :: AliasDecl ( _) | OpType :: AliasDefn ( _) => {
160
- // Use of these is done via names, rather than following edges.
161
- // We could track these as well but for now be conservative.
162
- q. push_back ( ch) ;
163
- }
164
- OpType :: FuncDefn ( f) if f. name == "main" => {
165
- // Dataflow analysis will have applied any inputs the 'main' function, so assume reachable.
166
- q. push_back ( ch) ;
167
- }
168
- _ => ( ) ,
169
- }
170
- }
171
- } else if h. get_optype ( n) . is_cfg ( ) {
172
- for bb in h. children ( n) {
173
- //if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates
174
- q. push_back ( bb) ;
175
- }
176
- } else if let Some ( inout) = h. get_io ( n) {
177
- // Dataflow. Find minimal nodes necessary to compute output, including StateOrder edges.
178
- q. extend ( inout) ; // Input also necessary for legality even if unreachable
179
-
180
- if !self . allow_increase_termination {
181
- // Also add on anything that might not terminate (even if results not required -
182
- // if its results are required we'll add it by following dataflow, below.)
183
- for ch in h. children ( n) {
184
- if might_diverge ( results, ch) {
185
- q. push_back ( ch) ;
186
- }
187
- }
188
- }
189
- }
190
- // Also follow dataflow demand
191
- for ( src, op) in h. all_linked_outputs ( n) {
192
- let needs_predecessor = match h. get_optype ( src) . port_kind ( op) . unwrap ( ) {
193
- EdgeKind :: Value ( _) => {
194
- h. get_optype ( src) . is_load_constant ( )
195
- || results
196
- . try_read_wire_concrete :: < Value , _ , _ > ( Wire :: new ( src, op) )
197
- . is_err ( )
198
- }
199
- EdgeKind :: StateOrder | EdgeKind :: Const ( _) | EdgeKind :: Function ( _) => true ,
200
- EdgeKind :: ControlFlow => false , // we always include all children of a CFG above
201
- _ => true , // needed as EdgeKind non-exhaustive; not knowing what it is, assume the worst
202
- } ;
203
- if needs_predecessor {
204
- q. push_back ( src) ;
205
- }
206
- }
207
- }
208
- needed
209
- }
210
- }
211
-
212
- // "Diverge" aka "never-terminate"
213
- // TODO would be more efficient to compute this bottom-up and cache (dynamic programming)
214
- fn might_diverge < V : AbstractValue , N : HugrNode > (
215
- results : & AnalysisResults < V , impl HugrView < Node = N > > ,
216
- n : N ,
217
- ) -> bool {
218
- let op = results. hugr ( ) . get_optype ( n) ;
219
- if op. is_cfg ( ) {
220
- // TODO if the CFG has no cycles (that are possible given predicates)
221
- // then we could say it definitely terminates (i.e. return false)
222
- true
223
- } else if op. is_tail_loop ( )
224
- && results. tail_loop_terminates ( n) . unwrap ( ) != TailLoopTermination :: NeverContinues
225
- {
226
- // If we can even figure out the number of iterations is bounded that would allow returning false.
227
- true
228
- } else {
229
- // Node does not introduce non-termination, but still non-terminates if any of its children does
230
- results
231
- . hugr ( )
232
- . children ( n)
233
- . any ( |ch| might_diverge ( results, ch) )
234
- }
235
160
}
236
161
237
162
/// Exhaustively apply constant folding to a HUGR.
0 commit comments