From c7a9d89d85bc5bc7c9c826013515b532cc6ba949 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 13:24:01 +0100 Subject: [PATCH 001/203] Just const_fold2 + inside that partial_value (taken from hugr_core) --- hugr-passes/src/const_fold2.rs | 2 + hugr-passes/src/const_fold2/datalog.rs | 254 ++++++++++ .../src/const_fold2/datalog/context.rs | 67 +++ hugr-passes/src/const_fold2/datalog/test.rs | 232 +++++++++ hugr-passes/src/const_fold2/datalog/utils.rs | 390 +++++++++++++++ hugr-passes/src/const_fold2/partial_value.rs | 454 ++++++++++++++++++ .../src/const_fold2/partial_value/test.rs | 346 +++++++++++++ .../const_fold2/partial_value/value_handle.rs | 245 ++++++++++ hugr-passes/src/lib.rs | 1 + 9 files changed, 1991 insertions(+) create mode 100644 hugr-passes/src/const_fold2.rs create mode 100644 hugr-passes/src/const_fold2/datalog.rs create mode 100644 hugr-passes/src/const_fold2/datalog/context.rs create mode 100644 hugr-passes/src/const_fold2/datalog/test.rs create mode 100644 hugr-passes/src/const_fold2/datalog/utils.rs create mode 100644 hugr-passes/src/const_fold2/partial_value.rs create mode 100644 hugr-passes/src/const_fold2/partial_value/test.rs create mode 100644 hugr-passes/src/const_fold2/partial_value/value_handle.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs new file mode 100644 index 000000000..dbe4464fd --- /dev/null +++ b/hugr-passes/src/const_fold2.rs @@ -0,0 +1,2 @@ +mod datalog; +pub mod partial_value; \ No newline at end of file diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs new file mode 100644 index 000000000..d7df9c1e6 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -0,0 +1,254 @@ +use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; +use delegate::delegate; +use itertools::{zip_eq, Itertools}; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; + +use either::Either; +use hugr_core::ops::{OpTag, OpTrait, Value}; +use hugr_core::partial_value::{PartialValue, ValueHandle, ValueKey}; +use hugr_core::types::{EdgeKind, FunctionType, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; + +mod context; +mod utils; + +use context::DataflowContext; +pub use utils::{TailLoopTermination, ValueRow, IO, PV}; + +pub trait DFContext: AsRef + Clone + Eq + Hash + std::ops::Deref {} + +ascent::ascent! { + // The trait-indirection layer here means we can just write 'C' but in practice ATM + // DataflowContext (for H: HugrView) would be sufficient, there's really no + // point in using anything else yet. However DFContext will be useful when we + // move interpretation of nodes out into a trait method. + struct AscentProgram; + relation context(C); + relation out_wire_value_proto(Node, OutgoingPort, PV); + + relation node(C, Node); + relation in_wire(C, Node, IncomingPort); + relation out_wire(C, Node, OutgoingPort); + relation parent_of_node(C, Node, Node); + relation io_node(C, Node, Node, IO); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); + + node(c, n) <-- context(c), for n in c.nodes(); + + in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c, *n); + + out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c, *n); + + parent_of_node(c, parent, child) <-- + node(c, child), if let Some(parent) = c.get_parent(*child); + + io_node(c, parent, child, io) <-- node(c, parent), + if let Some([i,o]) = c.get_io(*parent), + for (child,io) in [(i,IO::Input),(o,IO::Output)]; + // We support prepopulating out_wire_value via out_wire_value_proto. + // + // out wires that do not have prepopulation values are initialised to bottom. + out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); + out_wire_value(c, n, p, v) <-- out_wire(c,n,p) , out_wire_value_proto(n, p, v); + + in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), + if let Some((m,op)) = c.single_linked_output(*n, *ip), + out_wire_value(c, m, op, v); + + + node_in_value_row(c, n, utils::bottom_row(c, *n)) <-- node(c, n); + node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); + + + // Per node-type rules + // TODO do all leaf ops with a rule + // define `fn propagate_leaf_op(Context, Node, ValueRow) -> ValueRow + + // LoadConstant + relation load_constant_node(C, Node); + load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); + + out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c, *n)) <-- + load_constant_node(c, n); + + + // MakeTuple + relation make_tuple_node(C, Node); + make_tuple_node(c, n) <-- node(c, n), if c.get_optype(*n).is_make_tuple(); + + out_wire_value(c, n, 0.into(), utils::partial_value_tuple_from_value_row(vs.clone())) <-- + make_tuple_node(c, n), node_in_value_row(c, n, vs); + + + // UnpackTuple + relation unpack_tuple_node(C, Node); + unpack_tuple_node(c,n) <-- node(c, n), if c.get_optype(*n).is_unpack_tuple(); + + out_wire_value(c, n, p, v.tuple_field_value(p.index())) <-- + unpack_tuple_node(c, n), + in_wire_value(c, n, IncomingPort::from(0), v), + out_wire(c, n, p); + + + // DFG + relation dfg_node(C, Node); + dfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_dfg(); + + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), + io_node(c, dfg, i, IO::Input), in_wire_value(c, dfg, p, v); + + out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), + io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v); + + + // TailLoop + relation tail_loop_node(C, Node); + tail_loop_node(c,n) <-- node(c, n), if c.get_optype(*n).is_tail_loop(); + + // inputs of tail loop propagate to Input node of child region + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- tail_loop_node(c, tl), + io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); + + // Output node of child region propagate to Input node of child region + out_wire_value(c, in_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,in_n, IO::Input), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0 + if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), + let variant_len = tailloop.just_inputs.len(), + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v) + ); + + // Output node of child region propagate to outputs of tail loop + out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), + io_node(c,tl_n,out_n, IO::Output), + node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node + if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 + if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), + let variant_len = tailloop.just_outputs.len(), + for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) + ); + + lattice tail_loop_termination(C,Node,TailLoopTermination); + tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <-- + tail_loop_node(c,tl_n); + tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <-- + tail_loop_node(c,tl_n), + io_node(c,tl,out_n, IO::Output), + in_wire_value(c, out_n, IncomingPort::from(0), v); + + + // Conditional + relation conditional_node(C, Node); + relation case_node(C,Node,usize, Node); + + conditional_node (c,n)<-- node(c, n), if c.get_optype(*n).is_conditional(); + case_node(c,cond,i, case) <-- conditional_node(c,cond), + for (i, case) in c.children(*cond).enumerate(), + if c.get_optype(case).is_case(); + + // inputs of conditional propagate into case nodes + out_wire_value(c, i_node, i_p, v) <-- + case_node(c, cond, case_index, case), + io_node(c, case, i_node, IO::Input), + in_wire_value(c, cond, cond_in_p, cond_in_v), + if let Some(conditional) = c.get_optype(*cond).as_conditional(), + let variant_len = conditional.sum_rows[*case_index].len(), + for (i_p, v) in utils::outputs_for_variant(*cond_in_p, *case_index, variant_len, cond_in_v); + + // outputs of case nodes propagate to outputs of conditional + out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- + case_node(c, cond, _, case), + io_node(c, case, o, IO::Output), + in_wire_value(c, o, o_p, v); + + lattice case_reachable(C, Node, Node, bool); + case_reachable(c, cond, case, reachable) <-- case_node(c,cond,i,case), + in_wire_value(c, cond, IncomingPort::from(0), v), + let reachable = v.supports_tag(*i); + +} + +// TODO This should probably be called 'Analyser' or something +struct Machine( + AscentProgram>, + Option>, +); + +/// Usage: +/// 1. [Self::new()] +/// 2. Zero or more [Self::propolutate_out_wires] with initial values +/// 3. Exactly one [Self::run_hugr] to do the analysis +/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] +impl Machine { + pub fn new() -> Self { + Self(Default::default(), None) + } + + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + assert!(self.1.is_none()); + self.0.out_wire_value_proto.extend( + wires + .into_iter() + .map(|(w, v)| (w.node(), w.source(), v.into())), + ); + } + + pub fn run_hugr(&mut self, hugr: H) { + assert!(self.1.is_none()); + self.0.context.push((DataflowContext::new(hugr),)); + self.0.run(); + self.1 = Some( + self.0 + .out_wire_value + .iter() + .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone().into())) + .collect(), + ) + } + + pub fn read_out_wire_partial_value(&self, w: Wire) -> Option { + self.1.as_ref().unwrap().get(&w).cloned() + } + + pub fn read_out_wire_value(&self, hugr: H, w: Wire) -> Option { + // dbg!(&w); + let pv = self.read_out_wire_partial_value(w)?; + // dbg!(&pv); + let (_, typ) = hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + pv.try_into_value(&typ).ok() + } + + pub fn tail_loop_terminates(&self, hugr: H, node: Node) -> TailLoopTermination { + assert!(hugr.get_optype(node).is_tail_loop()); + self.0 + .tail_loop_termination + .iter() + .find_map(|(_, n, v)| (n == &node).then_some(*v)) + .unwrap() + } + + pub fn case_reachable(&self, hugr: H, case: Node) -> bool { + assert!(hugr.get_optype(case).is_case()); + let cond = hugr.get_parent(case).unwrap(); + assert!(hugr.get_optype(cond).is_conditional()); + self.0 + .case_reachable + .iter() + .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) + .unwrap() + } +} + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs new file mode 100644 index 000000000..92c0c3285 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -0,0 +1,67 @@ +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::sync::atomic::AtomicUsize; +use std::sync::{Arc, Mutex}; + +use hugr_core::hugr::internal::HugrInternals; +use hugr_core::ops::Value; +use hugr_core::partial_value::{ValueHandle, ValueKey}; +use hugr_core::{Hugr, HugrView, Node}; + +use super::DFContext; + +#[derive(Debug)] +pub(super) struct DataflowContext(Arc); + +impl DataflowContext { + pub fn new(hugr: H) -> Self { + Self(Arc::new(hugr)) + } +} + +// Deriving Clone requires H:HugrView to implement Clone, +// but we don't need that as we only clone the Arc. +impl Clone for DataflowContext { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Hash for DataflowContext { + fn hash(&self, state: &mut I) {} +} + +impl PartialEq for DataflowContext { + fn eq(&self, other: &Self) -> bool { + // Any AscentProgram should have only one DataflowContext + assert_eq!(self as *const _, other as *const _); + true + } +} + +impl Eq for DataflowContext {} + +impl PartialOrd for DataflowContext { + fn partial_cmp(&self, other: &Self) -> Option { + // Any AscentProgram should have only one DataflowContext + assert_eq!(self as *const _, other as *const _); + Some(std::cmp::Ordering::Equal) + } +} + +impl Deref for DataflowContext { + type Target = Hugr; + + fn deref(&self) -> &Self::Target { + self.0.base_hugr() + } +} + +impl AsRef for DataflowContext { + fn as_ref(&self) -> &Hugr { + self.base_hugr() + } +} + +impl DFContext for DataflowContext {} diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs new file mode 100644 index 000000000..4e086c4b7 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -0,0 +1,232 @@ +use hugr_core::{ + builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, + extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, + ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, + type_row, + types::{FunctionType, SumType}, + Extension, +}; + +use hugr_core::partial_value::PartialValue; + +use super::*; + +#[test] +fn test_make_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + machine.run_hugr(&hugr); + + let x = machine.read_out_wire_value(&hugr, v3).unwrap(); + assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); +} + +#[test] +fn test_unpack_tuple() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::false_val()); + let v2 = builder.add_load_value(Value::true_val()); + let v3 = builder.make_tuple([v1, v2]).unwrap(); + let [o1, o2] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v3]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + machine.run_hugr(&hugr); + + let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); + assert_eq!(o1_r, Value::false_val()); + let o2_r = machine.read_out_wire_value(&hugr, o2).unwrap(); + assert_eq!(o2_r, Value::true_val()); +} + +#[test] +fn test_unpack_const() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); + let [o] = builder + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) + .unwrap() + .outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + + let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); + assert_eq!(o_r, Value::true_val()); +} + +#[test] +fn test_tail_loop_never_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_v = Value::unit_sum(3, 6).unwrap(); + let r_w = builder.add_load_value( + Value::sum( + 1, + [r_v.clone()], + SumType::new([type_row![], r_v.get_type().into()]), + ) + .unwrap(), + ); + let tlb = builder + .tail_loop_builder([], [], vec![r_v.get_type()].into()) + .unwrap(); + let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); + let [tl_o] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + let o_r = machine.read_out_wire_value(&hugr, tl_o).unwrap(); + assert_eq!(o_r, r_v); + assert_eq!( + TailLoopTermination::ExactlyZeroContinues, + machine.tail_loop_terminates(&hugr, tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_always_iterates() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let r_w = builder + .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let true_w = builder.add_load_value(Value::true_val()); + + let tlb = builder + .tail_loop_builder([], [(BOOL_T, true_w)], vec![BOOL_T].into()) + .unwrap(); + + // r_w has tag 0, so we always continue; + // we put true in our "other_output", but we should not propagate this to + // output because r_w never supports 1. + let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap(); + + let [tl_o1, tl_o2] = tail_loop.outputs_arr(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + machine.run_hugr(&hugr); + + let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); + assert_eq!(o_r1, PartialValue::bottom()); + let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); + assert_eq!(o_r2, PartialValue::bottom()); + assert_eq!( + TailLoopTermination::bottom(), + machine.tail_loop_terminates(&hugr, tail_loop.node()) + ) +} + +#[test] +fn test_tail_loop_iterates_twice() { + let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + // let r_w = builder + // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let tlb = builder + .tail_loop_builder([], [(BOOL_T, false_w), (BOOL_T, true_w)], vec![].into()) + .unwrap(); + assert_eq!( + tlb.loop_signature().unwrap().dataflow_signature().unwrap(), + FunctionType::new_endo(type_row![BOOL_T, BOOL_T]) + ); + let [in_w1, in_w2] = tlb.input_wires_arr(); + let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); + + // let optype = builder.hugr().get_optype(tail_loop.node()); + // for p in builder.hugr().node_outputs(tail_loop.node()) { + // use hugr_core::ops::OpType; + // println!("{:?}, {:?}", p, optype.port_kind(p)); + + // } + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + // TODO once we can do conditionals put these wires inside `just_outputs` and + // we should be able to propagate their values + let [o_w1, o_w2, _] = tail_loop.outputs_arr(); + + let mut machine = Machine::new(); + let c = machine.run_hugr(&hugr); + // dbg!(&machine.tail_loop_io_node); + // dbg!(&machine.out_wire_value); + + // TODO these hould be the propagated values for now they will bt join(true,false) + let o_r1 = machine.read_out_wire_partial_value(o_w1).unwrap(); + // assert_eq!(o_r1, PartialValue::top()); + let o_r2 = machine.read_out_wire_partial_value(o_w2).unwrap(); + // assert_eq!(o_r2, Value::true_val()); + assert_eq!( + TailLoopTermination::Top, + machine.tail_loop_terminates(&hugr, tail_loop.node()) + ) +} + +#[test] +fn conditional() { + let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; + let cond_t = Type::new_sum(variants.clone()); + let mut builder = DFGBuilder::new(FunctionType::new( + Into::::into(cond_t), + type_row![], + )) + .unwrap(); + let [arg_w] = builder.input_wires_arr(); + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let mut cond_builder = builder + .conditional_builder( + (variants, arg_w), + [(BOOL_T, true_w)], + type_row!(BOOL_T, BOOL_T), + ExtensionSet::default(), + ) + .unwrap(); + // will be unreachable + let case1_b = cond_builder.case_builder(0).unwrap(); + let case1 = case1_b.finish_with_outputs([false_w, false_w]).unwrap(); + + let case2_b = cond_builder.case_builder(1).unwrap(); + let [c2a] = case2_b.input_wires_arr(); + let case2 = case2_b.finish_with_outputs([false_w, c2a]).unwrap(); + + let case3_b = cond_builder.case_builder(2).unwrap(); + let [c3_1, c3_2] = case3_b.input_wires_arr(); + let case3 = case3_b.finish_with_outputs([c3_1, false_w]).unwrap(); + + let cond = cond_builder.finish_sub_container().unwrap(); + + let [cond_o1, cond_o2] = cond.outputs_arr(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let mut machine = Machine::new(); + let arg_pv = + PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); + machine.propolutate_out_wires([(arg_w, arg_pv)]); + machine.run_hugr(&hugr); + + let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); + assert_eq!(cond_r1, Value::false_val()); + assert!(machine.read_out_wire_value(&hugr, cond_o2).is_none()); + + assert!(!machine.case_reachable(&hugr, case1.node())); + assert!(machine.case_reachable(&hugr, case2.node())); + assert!(machine.case_reachable(&hugr, case3.node())); +} diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs new file mode 100644 index 000000000..9c2e46ae3 --- /dev/null +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -0,0 +1,390 @@ +// proptest-derive generates many of these warnings. +// https://github.com/rust-lang/rust/issues/120363 +// https://github.com/proptest-rs/proptest/issues/447 +#![cfg_attr(test, allow(non_local_definitions))] + +use std::{cmp::Ordering, ops::Index, sync::Arc}; + +use ascent::lattice::{BoundedLattice, Lattice}; +use either::Either; +use hugr_core::{ + ops::OpTrait as _, + partial_value::{PartialValue, ValueHandle}, + types::{EdgeKind, TypeRow}, + HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, +}; +use itertools::zip_eq; + +#[cfg(test)] +use proptest_derive::Arbitrary; + +#[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] +pub struct PV(PartialValue); + +impl From for PV { + fn from(inner: PartialValue) -> Self { + Self(inner) + } +} + +impl PV { + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0, idx) + } + + /// TODO the arguments here are not pretty, two usizes, better not mix them + /// up!!! + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + self.0.variant_field_value(variant, idx).into() + } + + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.supports_tag(tag) + } +} + +impl From for PartialValue { + fn from(value: PV) -> Self { + value.0 + } +} + +impl From for PV { + fn from(inner: ValueHandle) -> Self { + Self(inner.into()) + } +} + +impl Lattice for PV { + fn meet(self, other: Self) -> Self { + self.0.meet(other.0).into() + } + + fn meet_mut(&mut self, other: Self) -> bool { + self.0.meet_mut(other.0) + } + + fn join(self, other: Self) -> Self { + self.0.join(other.0).into() + } + + fn join_mut(&mut self, other: Self) -> bool { + self.0.join_mut(other.0) + } +} + +impl BoundedLattice for PV { + fn bottom() -> Self { + PartialValue::bottom().into() + } + + fn top() -> Self { + PartialValue::top().into() + } +} + +#[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] +pub struct ValueRow(Vec); + +impl ValueRow { + fn new(len: usize) -> Self { + Self(vec![PV::bottom(); len]) + } + + fn singleton(len: usize, idx: usize, v: PV) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { + Self::singleton(r.len(), idx, v) + } + + fn bottom_from_row(r: &TypeRow) -> Self { + Self::new(r.len()) + } + + pub fn iter<'b>( + &'b self, + h: &'b impl HugrView, + n: Node, + ) -> impl Iterator + 'b { + zip_eq(value_inputs(h, n), self.0.iter()) + } + + // fn initialised(&self) -> bool { + // self.0.iter().all(|x| x != &PV::top()) + // } +} + +impl Lattice for ValueRow { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PV; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec: Index, +{ + type Output = as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { + if let Some(sig) = h.signature(n) { + ValueRow::new(sig.input_count()) + } else { + ValueRow::new(0) + } +} + +pub(super) fn singleton_in_row(h: &impl HugrView, n: &Node, ip: &IncomingPort, v: PV) -> ValueRow { + let Some(sig) = h.signature(*n) else { + panic!("dougrulz"); + }; + if sig.input_count() <= ip.index() { + panic!( + "bad port index: {} >= {}: {}", + ip.index(), + sig.input_count(), + h.get_optype(*n).description() + ); + } + ValueRow::singleton_from_row(&h.signature(*n).unwrap().input, ip.index(), v) +} + +pub(super) fn partial_value_from_load_constant(h: &impl HugrView, node: Node) -> PV { + let load_op = h.get_optype(node).as_load_constant().unwrap(); + let const_node = h + .single_linked_output(node, load_op.constant_port()) + .unwrap() + .0; + let const_op = h.get_optype(const_node).as_const().unwrap(); + ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())).into() +} + +pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { + PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum IO { + Input, + Output, +} + +pub(super) fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { + h.in_value_types(n).map(|x| x.0) +} + +pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { + h.out_value_types(n).map(|x| x.0) +} + +// We have several cases where sum types propagate to different places depending +// on their variant tag: +// - From the input of a conditional to the inputs of it's case nodes +// - From the input of the output node of a tail loop to the output of the input node of the tail loop +// - From the input of the output node of a tail loop to the output of tail loop node +// - From the input of a the output node of a dataflow block to the output of the input node of a dataflow block +// - From the input of a the output node of a dataflow block to the output of the cfg +// +// For a value `v` on an incoming porg `output_p`, compute the (out port,value) +// pairs that should be propagated for a given variant tag. We must also supply +// the length of this variant because it cannot always be deduced from the other +// inputs. +// +// If `v` does not support `variant_tag`, then all propagated values will be bottom.` +// +// If `output_p.index()` is 0 then the result is the contents of the variant. +// Otherwise, it is the single "other_output". +// +// TODO doctests +pub(super) fn outputs_for_variant<'a>( + output_p: IncomingPort, + variant_tag: usize, + variant_len: usize, + v: &'a PV, +) -> impl Iterator + 'a { + if output_p.index() == 0 { + Either::Left( + (0..variant_len).map(move |i| (i.into(), v.variant_field_value(variant_tag, i))), + ) + } else { + let v = if v.supports_tag(variant_tag) { + v.clone() + } else { + PV::bottom() + }; + Either::Right(std::iter::once(( + (variant_len + output_p.index() - 1).into(), + v, + ))) + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +#[cfg_attr(test, derive(Arbitrary))] +pub enum TailLoopTermination { + Bottom, + ExactlyZeroContinues, + Top, +} + +impl TailLoopTermination { + pub fn from_control_value(v: &PV) -> Self { + let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); + if may_break && !may_continue { + Self::ExactlyZeroContinues + } else if may_break && may_continue { + Self::top() + } else { + Self::bottom() + } + } +} + +impl PartialOrd for TailLoopTermination { + fn partial_cmp(&self, other: &Self) -> Option { + if self == other { + return Some(std::cmp::Ordering::Equal); + }; + match (self, other) { + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + _ => None, + } + } +} + +impl Lattice for TailLoopTermination { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn meet_mut(&mut self, other: Self) -> bool { + // let new_self = &mut self; + match (*self).partial_cmp(&other) { + Some(Ordering::Greater) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Bottom; + true + } + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match (*self).partial_cmp(&other) { + Some(Ordering::Less) => { + *self = other; + true + } + Some(_) => false, + _ => { + *self = Self::Top; + true + } + } + } +} + +impl BoundedLattice for TailLoopTermination { + fn bottom() -> Self { + Self::Bottom + } + + fn top() -> Self { + Self::Top + } +} + +#[cfg(test)] +#[cfg_attr(test, allow(non_local_definitions))] +mod test { + use super::*; + use proptest::prelude::*; + + proptest! { + #[test] + fn bounded_lattice(v: TailLoopTermination) { + prop_assert!(v <= TailLoopTermination::top()); + prop_assert!(v >= TailLoopTermination::bottom()); + } + + #[test] + fn meet_join_self_noop(v1: TailLoopTermination) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice(v1: TailLoopTermination, v2: TailLoopTermination) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } + } +} diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs new file mode 100644 index 000000000..0442aa4c9 --- /dev/null +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -0,0 +1,454 @@ +#![allow(missing_docs)] +use std::cmp::Ordering; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; + +use itertools::{zip_eq, Itertools as _}; + +use crate::ops::Value; +use crate::types::{Type, TypeEnum}; + +mod value_handle; + +pub use value_handle::{ValueHandle, ValueKey}; + +// TODO ALAN inline into PartialValue +#[derive(PartialEq, Clone, Eq)] +struct PartialSum(HashMap>); + +impl PartialSum { + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + Self([(tag, values.into_iter().collect())].into_iter().collect()) + } + + pub fn num_variants(&self) -> usize { + self.0.len() + } + + fn assert_variants(&self) { + assert_ne!(self.num_variants(), 0); + for pv in self.0.values().flat_map(|x| x.iter()) { + pv.assert_invariants(); + } + } + + pub fn variant_field_value(&self, variant: usize, idx: usize) -> PartialValue { + if let Some(row) = self.0.get(&variant) { + assert!(row.len() > idx); + row[idx].clone() + } else { + PartialValue::bottom() + } + } + + pub fn try_into_value(self, typ: &Type) -> Result { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; + + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self); + } + match zip_eq(v.into_iter(), r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + { + Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), + Err(_) => Err(self), + } + } + + // unsafe because we panic if any common rows have different lengths + fn join_mut_unsafe(&mut self, other: Self) -> bool { + let mut changed = false; + + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.join_mut(rhs); + } + } else { + self.0.insert(k, v); + changed = true; + } + } + changed + } + + // unsafe because we panic if any common rows have different lengths + fn meet_mut_unsafe(&mut self, other: Self) -> bool { + let mut changed = false; + let mut keys_to_remove = vec![]; + for k in self.0.keys() { + if !other.0.contains_key(k) { + keys_to_remove.push(*k); + } + } + for (k, v) in other.0 { + if let Some(row) = self.0.get_mut(&k) { + for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { + changed |= lhs.meet_mut(rhs); + } + } else { + keys_to_remove.push(k); + } + } + for k in keys_to_remove { + self.0.remove(&k); + changed = true; + } + changed + } + + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } +} + +impl PartialOrd for PartialSum { + fn partial_cmp(&self, other: &Self) -> Option { + let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); + let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); + for k in self.0.keys() { + keys1[*k] = 1; + } + + for k in other.0.keys() { + keys2[*k] = 1; + } + + if let Some(ord) = keys1.partial_cmp(&keys2) { + if ord != Ordering::Equal { + return Some(ord); + } + } else { + return None; + } + for (k, lhs) in &self.0 { + let Some(rhs) = other.0.get(&k) else { + unreachable!() + }; + match lhs.partial_cmp(rhs) { + Some(Ordering::Equal) => continue, + x => { + return x; + } + } + } + Some(Ordering::Equal) + } +} + +impl std::fmt::Debug for PartialSum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl Hash for PartialSum { + fn hash(&self, state: &mut H) { + for (k, v) in &self.0 { + k.hash(state); + v.hash(state); + } + } +} + +impl TryFrom for PartialSum { + type Error = ValueHandle; + + fn try_from(value: ValueHandle) -> Result { + match value.value() { + Value::Tuple { vs } => { + let vec = (0..vs.len()) + .map(|i| PartialValue::from(value.index(i)).into()) + .collect(); + return Ok(Self([(0, vec)].into_iter().collect())); + } + Value::Sum { tag, values, .. } => { + let vec = (0..values.len()) + .map(|i| PartialValue::from(value.index(i)).into()) + .collect(); + return Ok(Self([(*tag, vec)].into_iter().collect())); + } + _ => (), + }; + Err(value) + } +} + +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PartialValue { + Bottom, + Value(ValueHandle), + PartialSum(PartialSum), + Top, +} + +impl From for PartialValue { + fn from(v: ValueHandle) -> Self { + TryInto::::try_into(v).map_or_else(Self::Value, Self::PartialSum) + } +} + +impl From for PartialValue { + fn from(v: PartialSum) -> Self { + Self::PartialSum(v) + } +} + +impl PartialValue { + // const BOTTOM: Self = Self::Bottom; + // const BOTTOM_REF: &'static Self = &Self::BOTTOM; + + // fn initialised(&self) -> bool { + // !self.is_top() + // } + + // fn is_top(&self) -> bool { + // self == &PartialValue::Top + // } + + fn assert_invariants(&self) { + match self { + Self::PartialSum(ps) => { + ps.assert_variants(); + } + Self::Value(v) => { + assert!(matches!(v.clone().into(), Self::Value(_))) + } + _ => {} + } + } + + pub fn try_into_value(self, typ: &Type) -> Result { + let r = match self { + Self::Value(v) => Ok(v.value().clone()), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + }?; + assert_eq!(typ, &r.get_type()); + Ok(r) + } + + fn join_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + self.assert_invariants(); + match &*self { + Self::Top => return false, + Self::Value(v) if v == &vh => return false, + Self::Value(v) => { + *self = Self::Top; + } + Self::PartialSum(_) => match vh.into() { + Self::Value(_) => { + *self = Self::Top; + } + other => return self.join_mut(other), + }, + Self::Bottom => { + *self = vh.into(); + } + }; + true + } + + fn meet_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + self.assert_invariants(); + match &*self { + Self::Bottom => false, + Self::Value(v) => { + if v == &vh { + false + } else { + *self = Self::Bottom; + true + } + } + Self::PartialSum(_) => match vh.into() { + Self::Value(_) => { + *self = Self::Bottom; + true + } + other => self.meet_mut(other), + }, + Self::Top => { + *self = vh.into(); + true + } + } + } + + pub fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + pub fn join_mut(&mut self, other: Self) -> bool { + // println!("join {self:?}\n{:?}", &other); + let changed = match (&*self, other) { + (Self::Top, _) => false, + (_, other @ Self::Top) => { + *self = other; + true + } + (_, Self::Bottom) => false, + (Self::Bottom, other) => { + *self = other; + true + } + (Self::Value(h1), Self::Value(h2)) => { + if h1 == &h2 { + false + } else { + *self = Self::Top; + true + } + } + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { + unreachable!() + }; + ps1.join_mut_unsafe(ps2) + } + (Self::Value(_), mut other) => { + std::mem::swap(self, &mut other); + let Self::Value(old_self) = other else { + unreachable!() + }; + self.join_mut_value_handle(old_self) + } + (_, Self::Value(h)) => self.join_mut_value_handle(h), + // (new_self, _) => { + // **new_self = Self::Top; + // false + // } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed + } + + pub fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + pub fn meet_mut(&mut self, other: Self) -> bool { + let changed = match (&*self, other) { + (Self::Bottom, _) => false, + (_, other @ Self::Bottom) => { + *self = other; + true + } + (_, Self::Top) => false, + (Self::Top, other) => { + *self = other; + true + } + (Self::Value(h1), Self::Value(h2)) => { + if h1 == &h2 { + false + } else { + *self = Self::Bottom; + true + } + } + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { + unreachable!() + }; + ps1.meet_mut_unsafe(ps2) + } + (Self::Value(_), mut other @ Self::PartialSum(_)) => { + std::mem::swap(self, &mut other); + let Self::Value(old_self) = other else { + unreachable!() + }; + self.meet_mut_value_handle(old_self) + } + (Self::PartialSum(_), Self::Value(h)) => self.meet_mut_value_handle(h), + // (new_self, _) => { + // **new_self = Self::Bottom; + // false + // } + }; + // if changed { + // println!("join new self: {:?}", s); + // } + changed + } + + pub fn top() -> Self { + Self::Top + } + + pub fn bottom() -> Self { + Self::Bottom + } + + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::variant(tag, values).into() + } + + pub fn unit() -> Self { + Self::variant(0, []) + } + + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom => false, + PartialValue::Value(v) => v.tag() == tag, // can never be a sum or tuple + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// TODO docs + /// just delegate to variant_field_value + pub fn tuple_field_value(&self, idx: usize) -> Self { + self.variant_field_value(0, idx) + } + + /// TODO docs + pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { + match self { + Self::Bottom => Self::Bottom, + Self::PartialSum(ps) => ps.variant_field_value(variant, idx), + Self::Value(v) => { + if v.tag() == variant { + Self::Value(v.index(idx)) + } else { + Self::Bottom + } + } + Self::Top => Self::Top, + } + } +} + +impl PartialOrd for PartialValue { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + match (self, other) { + (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), + (Self::Top, Self::Top) => Some(Ordering::Equal), + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), + _ => None, + } + } +} + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs new file mode 100644 index 000000000..35fbf5373 --- /dev/null +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -0,0 +1,346 @@ +use std::sync::Arc; + +use itertools::{zip_eq, Either, Itertools as _}; +use lazy_static::lazy_static; +use proptest::prelude::*; + +use crate::{ + ops::Value, + std_extensions::arithmetic::int_types::{ + self, get_log_width, ConstInt, INT_TYPES, LOG_WIDTH_BOUND, + }, + types::{CustomType, Type, TypeEnum}, +}; + +use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; +impl Arbitrary for ValueHandle { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + // prop_oneof![ + + // ] + todo!() + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +enum TestSumLeafType { + Int(Type), + Unit, +} + +impl TestSumLeafType { + fn assert_invariants(&self) { + match self { + Self::Int(t) => { + if let TypeEnum::Extension(ct) = t.as_type_enum() { + assert_eq!("int", ct.name()); + assert_eq!(&int_types::EXTENSION_ID, ct.extension()); + } else { + panic!("Expected int type, got {:#?}", t); + } + } + _ => (), + } + } + + fn get_type(&self) -> Type { + match self { + Self::Int(t) => t.clone(), + Self::Unit => Type::UNIT, + } + } + + fn type_check(&self, ps: &PartialSum) -> bool { + match self { + Self::Int(_) => false, + Self::Unit => { + if let Ok((0, v)) = ps.0.iter().exactly_one() { + v.is_empty() + } else { + false + } + } + } + } + + fn partial_value_strategy(self) -> impl Strategy { + match self { + Self::Int(t) => { + let TypeEnum::Extension(ct) = t.as_type_enum() else { + unreachable!() + }; + let lw = get_log_width(&ct.args()[0]).unwrap(); + (0u64..(1 << (2u64.pow(lw as u32) - 1))) + .prop_map(move |x| { + let ki = ConstInt::new_u(lw, x).unwrap(); + ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() + }) + .boxed() + } + Self::Unit => Just(PartialSum::unit().into()).boxed(), + } + } +} + +impl Arbitrary for TestSumLeafType { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + let int_strat = (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); + prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +enum TestSumType { + Branch(usize, Vec>>), + Leaf(TestSumLeafType), +} + +impl TestSumType { + const UNIT: TestSumLeafType = TestSumLeafType::Unit; + + fn leaf(v: Type) -> Self { + TestSumType::Leaf(TestSumLeafType::Int(v)) + } + + fn branch(vs: impl IntoIterator>>) -> Self { + let vec = vs.into_iter().collect_vec(); + let depth: usize = vec + .iter() + .flat_map(|x| x.iter()) + .map(|x| x.depth() + 1) + .max() + .unwrap_or(0); + Self::Branch(depth, vec) + } + + fn depth(&self) -> usize { + match self { + TestSumType::Branch(x, _) => *x, + TestSumType::Leaf(_) => 0, + } + } + + fn is_leaf(&self) -> bool { + self.depth() == 0 + } + + fn assert_invariants(&self) { + match self { + TestSumType::Branch(d, sop) => { + assert!(!sop.is_empty(), "No variants"); + for v in sop.iter().flat_map(|x| x.iter()) { + assert!(v.depth() < *d); + v.assert_invariants(); + } + } + TestSumType::Leaf(l) => { + l.assert_invariants(); + } + _ => (), + } + } + + fn select(self) -> impl Strategy>)>> { + match self { + TestSumType::Branch(_, sop) => any::() + .prop_map(move |i| { + let index = i.index(sop.len()); + Either::Right((index, sop[index].clone())) + }) + .boxed(), + TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), + } + } + + fn get_type(&self) -> Type { + match self { + TestSumType::Branch(_, sop) => Type::new_sum( + sop.iter() + .map(|row| row.iter().map(|x| x.get_type()).collect_vec().into()), + ), + TestSumType::Leaf(l) => l.get_type(), + } + } + + fn type_check(&self, pv: &PartialValue) -> bool { + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, + (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), + (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { + for (k, v) in &ps.0 { + if *k >= sop.len() { + return false; + } + let prod = &sop[*k]; + if prod.len() != v.len() { + return false; + } + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { + return false; + } + } + true + } + (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), + } + } +} + +impl From for TestSumType { + fn from(value: TestSumLeafType) -> Self { + Self::Leaf(value) + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +struct UnarySumTypeParams { + depth: usize, + branch_width: usize, +} + +impl UnarySumTypeParams { + pub fn descend(mut self, d: usize) -> Self { + assert!(d < self.depth); + self.depth = d; + self + } +} + +impl Default for UnarySumTypeParams { + fn default() -> Self { + Self { + depth: 3, + branch_width: 3, + } + } +} + +impl Arbitrary for TestSumType { + type Parameters = UnarySumTypeParams; + type Strategy = BoxedStrategy; + fn arbitrary_with( + params @ UnarySumTypeParams { + depth, + branch_width, + }: Self::Parameters, + ) -> Self::Strategy { + if depth == 0 { + any::().prop_map_into().boxed() + } else { + (0..depth) + .prop_flat_map(move |d| { + prop::collection::vec( + prop::collection::vec( + any_with::(params.clone().descend(d)).prop_map_into(), + 0..branch_width, + ), + 1..=branch_width, + ) + .prop_map(TestSumType::branch) + }) + .boxed() + } + } +} + +proptest! { + #[test] + fn unary_sum_type_valid(ust: TestSumType) { + ust.assert_invariants(); + } +} + +fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy { + ust.select().prop_flat_map(|x| match x { + Either::Left(l) => l.partial_value_strategy().boxed(), + Either::Right((index, usts)) => { + let pvs = usts + .into_iter() + .map(|x| { + any_partial_value_of_type( + Arc::::try_unwrap(x).unwrap_or_else(|x| x.as_ref().clone()), + ) + }) + .collect_vec(); + pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) + .boxed() + } + }) +} + +fn any_partial_value_with( + params: ::Parameters, +) -> impl Strategy { + any_with::(params).prop_flat_map(any_partial_value_of_type) +} + +fn any_partial_value() -> impl Strategy { + any_partial_value_with(Default::default()) +} + +fn any_partial_values() -> impl Strategy { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(ust.clone())) + .collect_vec(), + ) + .unwrap() + }) +} + +fn any_typed_partial_value() -> impl Strategy { + any::() + .prop_flat_map(|t| any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v))) +} + +proptest! { + #[test] + fn partial_value_type((tst, pv) in any_typed_partial_value()) { + prop_assert!(tst.type_check(&pv)) + } + + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); + } + + #[test] + fn meet_join_self_noop(v1 in any_partial_value()) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } +} diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs new file mode 100644 index 000000000..dfb019872 --- /dev/null +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -0,0 +1,245 @@ +use std::any::Any; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; + +use downcast_rs::Downcast; +use itertools::Either; + +use crate::ops::Value; +use crate::std_extensions::arithmetic::int_types::ConstInt; +use crate::Node; + +pub trait ValueName: std::fmt::Debug + Downcast + Any { + fn hash(&self) -> u64; + fn eq(&self, other: &dyn ValueName) -> bool; +} + +fn hash_hash(x: &impl Hash) -> u64 { + let mut hasher = DefaultHasher::new(); + x.hash(&mut hasher); + hasher.finish() +} + +fn value_name_eq(x: &T, other: &dyn ValueName) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + x == other + } else { + false + } +} + +impl ValueName for String { + fn hash(&self) -> u64 { + hash_hash(self) + } + + fn eq(&self, other: &dyn ValueName) -> bool { + value_name_eq(self, other) + } +} + +impl ValueName for ConstInt { + fn hash(&self) -> u64 { + hash_hash(self) + } + + fn eq(&self, other: &dyn ValueName) -> bool { + value_name_eq(self, other) + } +} + +#[derive(Clone, Debug)] +pub struct ValueKey(Vec, Either>); + +impl PartialEq for ValueKey { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + && match (&self.1, &other.1) { + (Either::Left(ref n1), Either::Left(ref n2)) => n1 == n2, + (Either::Right(ref v1), Either::Right(ref v2)) => v1.eq(v2.as_ref()), + _ => false, + } + } +} + +impl Eq for ValueKey {} + +impl Hash for ValueKey { + fn hash(&self, state: &mut H) { + self.0.hash(state); + match &self.1 { + Either::Left(n) => (0, n).hash(state), + Either::Right(v) => (1, v.hash()).hash(state), + } + } +} + +impl From for ValueKey { + fn from(n: Node) -> Self { + Self(vec![], Either::Left(n)) + } +} + +impl ValueKey { + pub fn new(k: impl ValueName) -> Self { + Self(vec![], Either::Right(Arc::new(k))) + } + + pub fn index(self, i: usize) -> Self { + let mut is = self.0; + is.push(i); + Self(is, self.1) + } +} + +#[derive(Clone, Debug)] +pub struct ValueHandle(ValueKey, Arc); + +impl ValueHandle { + pub fn new(key: ValueKey, value: Arc) -> Self { + Self(key, value) + } + + pub fn value(&self) -> &Value { + self.1.as_ref() + } + + pub fn is_compound(&self) -> bool { + match self.value() { + Value::Sum { .. } | Value::Tuple { .. } => true, + _ => false, + } + } + + pub fn num_fields(&self) -> usize { + assert!( + self.is_compound(), + "ValueHandle::num_fields called on non-Sum, non-Tuple value: {:#?}", + self + ); + match self.value() { + Value::Sum { values, .. } => values.len(), + Value::Tuple { vs } => vs.len(), + _ => unreachable!(), + } + } + + pub fn tag(&self) -> usize { + assert!( + self.is_compound(), + "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", + self + ); + match self.value() { + Value::Sum { tag, .. } => *tag, + Value::Tuple { .. } => 0, + _ => unreachable!(), + } + } + + pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { + assert!( + i < self.num_fields(), + "ValueHandle::index called with out-of-bounds index {}: {:#?}", + i, + &self + ); + let vs = match self.value() { + Value::Sum { values, .. } => values, + Value::Tuple { vs, .. } => vs, + _ => unreachable!(), + }; + let v = vs[i].clone().into(); + Self(self.0.clone().index(i), v) + } +} + +impl PartialEq for ValueHandle { + fn eq(&self, other: &Self) -> bool { + // If the keys are equal, we return true since the values must have the + // same provenance, and so be equal. If the keys are different but the + // values are equal, we could return true if we didn't impl Eq, but + // since we do impl Eq, the Hash contract prohibits us from having equal + // values with different hashes. + let r = self.0 == other.0; + if r { + debug_assert_eq!(self.get_type(), other.get_type()); + } + r + } +} + +impl Eq for ValueHandle {} + +impl Hash for ValueHandle { + fn hash(&self, state: &mut I) { + self.0.hash(state); + } +} + +/// TODO this is perhaps dodgy +/// we do not hash or compare the value, just the key +/// this means two handles with different keys, but with the same value, will +/// not compare equal. +impl Deref for ValueHandle { + type Target = Value; + + fn deref(&self) -> &Self::Target { + self.value() + } +} + +#[cfg(test)] +mod test { + use crate::{ops::constant::CustomConst as _, types::SumType}; + + use super::*; + + #[test] + fn value_key_eq() { + let k1 = ValueKey::new("foo".to_string()); + let k2 = ValueKey::new("foo".to_string()); + let k3 = ValueKey::new("bar".to_string()); + + assert_eq!(k1, k2); + assert_ne!(k1, k3); + + let k4: ValueKey = From::::from(portgraph::NodeIndex::new(1).into()); + let k5 = From::::from(portgraph::NodeIndex::new(1).into()); + let k6 = From::::from(portgraph::NodeIndex::new(2).into()); + + assert_eq!(&k4, &k5); + assert_ne!(&k4, &k6); + + let k7 = k5.clone().index(3); + let k4 = k4.index(3); + + assert_eq!(&k4, &k7); + + let k5 = k5.index(2); + + assert_ne!(&k5, &k7); + } + + #[test] + fn value_handle_eq() { + let k_i = ConstInt::new_u(4, 2).unwrap(); + let subject_val = Arc::new( + Value::sum( + 0, + [k_i.clone().into()], + SumType::new([vec![k_i.get_type()], vec![]]), + ) + .unwrap(), + ); + + let k1 = ValueKey::new("foo".to_string()); + let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); + let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); + + // we do not compare the value, just the key + assert_ne!(v1.index(0), v2); + assert_eq!(v1.index(0).value(), v2.value()); + } +} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 13dd47776..8949d8bd4 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod const_fold2; pub mod force_order; mod half_node; pub mod lower; From ac45e53ed2346eec5604ceb79f15ca4cff180a2f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 16:16:29 +0100 Subject: [PATCH 002/203] merge/update+fmt (ValueName for ConstInt non-compiling as ConstInt not Hash) --- hugr-passes/Cargo.toml | 5 ++++ hugr-passes/src/const_fold2.rs | 2 +- hugr-passes/src/const_fold2/datalog.rs | 6 ++-- .../src/const_fold2/datalog/context.rs | 1 - hugr-passes/src/const_fold2/datalog/test.rs | 27 ++++++++--------- hugr-passes/src/const_fold2/datalog/utils.rs | 6 ++-- hugr-passes/src/const_fold2/partial_value.rs | 19 ++++++------ .../src/const_fold2/partial_value/test.rs | 17 ++++++----- .../const_fold2/partial_value/value_handle.rs | 29 ++++++++----------- 9 files changed, 54 insertions(+), 58 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index f0b09516d..a6ed580c3 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -14,6 +14,9 @@ categories = ["compilers"] [dependencies] hugr-core = { path = "../hugr-core", version = "0.9.1" } +portgraph = { workspace = true } +ascent = "0.6.0" +downcast-rs = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } @@ -25,3 +28,5 @@ extension_inference = ["hugr-core/extension_inference"] [dev-dependencies] rstest = { workspace = true } +proptest = { workspace = true } +proptest-derive = { workspace = true } diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index dbe4464fd..96af004e1 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,2 +1,2 @@ mod datalog; -pub mod partial_value; \ No newline at end of file +pub mod partial_value; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index d7df9c1e6..0aca8e9b8 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,14 +1,12 @@ use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; -use delegate::delegate; use itertools::{zip_eq, Itertools}; use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::{Arc, Mutex}; -use either::Either; +use super::partial_value::{PartialValue, ValueHandle, ValueKey}; use hugr_core::ops::{OpTag, OpTrait, Value}; -use hugr_core::partial_value::{PartialValue, ValueHandle, ValueKey}; -use hugr_core::types::{EdgeKind, FunctionType, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::types::{EdgeKind, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 92c0c3285..9117cc429 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -6,7 +6,6 @@ use std::sync::{Arc, Mutex}; use hugr_core::hugr::internal::HugrInternals; use hugr_core::ops::Value; -use hugr_core::partial_value::{ValueHandle, ValueKey}; use hugr_core::{Hugr, HugrView, Node}; use super::DFContext; diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 4e086c4b7..5e70bf8b4 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -3,17 +3,17 @@ use hugr_core::{ extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, type_row, - types::{FunctionType, SumType}, + types::{Signature, SumType}, Extension, }; -use hugr_core::partial_value::PartialValue; +use crate::const_fold2::partial_value::PartialValue; use super::*; #[test] fn test_make_tuple() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); @@ -28,7 +28,7 @@ fn test_make_tuple() { #[test] fn test_unpack_tuple() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); @@ -49,7 +49,7 @@ fn test_unpack_tuple() { #[test] fn test_unpack_const() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); let [o] = builder .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) @@ -66,7 +66,7 @@ fn test_unpack_const() { #[test] fn test_tail_loop_never_iterates() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let r_v = Value::unit_sum(3, 6).unwrap(); let r_w = builder.add_load_value( Value::sum( @@ -98,7 +98,7 @@ fn test_tail_loop_never_iterates() { #[test] fn test_tail_loop_always_iterates() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let r_w = builder .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); let true_w = builder.add_load_value(Value::true_val()); @@ -130,7 +130,7 @@ fn test_tail_loop_always_iterates() { #[test] fn test_tail_loop_iterates_twice() { - let mut builder = DFGBuilder::new(FunctionType::new_endo(&[])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); let true_w = builder.add_load_value(Value::true_val()); @@ -143,7 +143,7 @@ fn test_tail_loop_iterates_twice() { .unwrap(); assert_eq!( tlb.loop_signature().unwrap().dataflow_signature().unwrap(), - FunctionType::new_endo(type_row![BOOL_T, BOOL_T]) + Signature::new_endo(type_row![BOOL_T, BOOL_T]) ); let [in_w1, in_w2] = tlb.input_wires_arr(); let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); @@ -180,18 +180,15 @@ fn test_tail_loop_iterates_twice() { fn conditional() { let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; let cond_t = Type::new_sum(variants.clone()); - let mut builder = DFGBuilder::new(FunctionType::new( - Into::::into(cond_t), - type_row![], - )) - .unwrap(); + let mut builder = + DFGBuilder::new(Signature::new(Into::::into(cond_t), type_row![])).unwrap(); let [arg_w] = builder.input_wires_arr(); let true_w = builder.add_load_value(Value::true_val()); let false_w = builder.add_load_value(Value::false_val()); let mut cond_builder = builder - .conditional_builder( + .conditional_builder_exts( (variants, arg_w), [(BOOL_T, true_w)], type_row!(BOOL_T, BOOL_T), diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 9c2e46ae3..31162a718 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -6,14 +6,14 @@ use std::{cmp::Ordering, ops::Index, sync::Arc}; use ascent::lattice::{BoundedLattice, Lattice}; -use either::Either; +use itertools::{zip_eq, Either}; + +use crate::const_fold2::partial_value::{PartialValue, ValueHandle}; use hugr_core::{ ops::OpTrait as _, - partial_value::{PartialValue, ValueHandle}, types::{EdgeKind, TypeRow}, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; -use itertools::zip_eq; #[cfg(test)] use proptest_derive::Arbitrary; diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 0442aa4c9..dafc48fce 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -3,10 +3,11 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; +use hugr_core::ops::constant::Sum; use itertools::{zip_eq, Itertools as _}; -use crate::ops::Value; -use crate::types::{Type, TypeEnum}; +use hugr_core::ops::Value; +use hugr_core::types::{Type, TypeEnum, TypeRow}; mod value_handle; @@ -17,6 +18,9 @@ pub use value_handle::{ValueHandle, ValueKey}; struct PartialSum(HashMap>); impl PartialSum { + pub fn unit() -> Self { + Self::variant(0, []) + } pub fn variant(tag: usize, values: impl IntoIterator) -> Self { Self([(tag, values.into_iter().collect())].into_iter().collect()) } @@ -52,6 +56,9 @@ impl PartialSum { let Some(r) = st.get_variant(*k) else { Err(self)? }; + let Ok(r): Result = r.clone().try_into() else { + Err(self)? + }; if v.len() != r.len() { return Err(self); } @@ -165,13 +172,7 @@ impl TryFrom for PartialSum { fn try_from(value: ValueHandle) -> Result { match value.value() { - Value::Tuple { vs } => { - let vec = (0..vs.len()) - .map(|i| PartialValue::from(value.index(i)).into()) - .collect(); - return Ok(Self([(0, vec)].into_iter().collect())); - } - Value::Sum { tag, values, .. } => { + Value::Sum(Sum { tag, values, .. }) => { let vec = (0..values.len()) .map(|i| PartialValue::from(value.index(i)).into()) .collect(); diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs index 35fbf5373..227d7aff7 100644 --- a/hugr-passes/src/const_fold2/partial_value/test.rs +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -4,12 +4,10 @@ use itertools::{zip_eq, Either, Itertools as _}; use lazy_static::lazy_static; use proptest::prelude::*; -use crate::{ +use hugr_core::{ ops::Value, - std_extensions::arithmetic::int_types::{ - self, get_log_width, ConstInt, INT_TYPES, LOG_WIDTH_BOUND, - }, - types::{CustomType, Type, TypeEnum}, + std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, + types::{CustomType, Type, TypeArg, TypeEnum}, }; use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; @@ -71,10 +69,13 @@ impl TestSumLeafType { let TypeEnum::Extension(ct) = t.as_type_enum() else { unreachable!() }; - let lw = get_log_width(&ct.args()[0]).unwrap(); + // TODO this should be get_log_width, but that's not pub + let TypeArg::BoundedNat { n: lw } = ct.args()[0] else { + panic!() + }; (0u64..(1 << (2u64.pow(lw as u32) - 1))) .prop_map(move |x| { - let ki = ConstInt::new_u(lw, x).unwrap(); + let ki = ConstInt::new_u(lw as u8, x).unwrap(); ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() }) .boxed() @@ -160,7 +161,7 @@ impl TestSumType { match self { TestSumType::Branch(_, sop) => Type::new_sum( sop.iter() - .map(|row| row.iter().map(|x| x.get_type()).collect_vec().into()), + .map(|row| row.iter().map(|x| x.get_type()).collect_vec()), ), TestSumType::Leaf(l) => l.get_type(), } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index dfb019872..6a91d513a 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -4,11 +4,12 @@ use std::ops::Deref; use std::sync::Arc; use downcast_rs::Downcast; +use hugr_core::ops::constant::Sum; use itertools::Either; -use crate::ops::Value; -use crate::std_extensions::arithmetic::int_types::ConstInt; -use crate::Node; +use hugr_core::ops::Value; +use hugr_core::std_extensions::arithmetic::int_types::ConstInt; +use hugr_core::Node; pub trait ValueName: std::fmt::Debug + Downcast + Any { fn hash(&self) -> u64; @@ -106,10 +107,7 @@ impl ValueHandle { } pub fn is_compound(&self) -> bool { - match self.value() { - Value::Sum { .. } | Value::Tuple { .. } => true, - _ => false, - } + matches!(self.value(), Value::Sum(_)) } pub fn num_fields(&self) -> usize { @@ -119,8 +117,7 @@ impl ValueHandle { self ); match self.value() { - Value::Sum { values, .. } => values.len(), - Value::Tuple { vs } => vs.len(), + Value::Sum(Sum { values, .. }) => values.len(), _ => unreachable!(), } } @@ -132,8 +129,7 @@ impl ValueHandle { self ); match self.value() { - Value::Sum { tag, .. } => *tag, - Value::Tuple { .. } => 0, + Value::Sum(Sum { tag, .. }) => *tag, _ => unreachable!(), } } @@ -146,8 +142,7 @@ impl ValueHandle { &self ); let vs = match self.value() { - Value::Sum { values, .. } => values, - Value::Tuple { vs, .. } => vs, + Value::Sum(Sum { values, .. }) => values, _ => unreachable!(), }; let v = vs[i].clone().into(); @@ -192,7 +187,7 @@ impl Deref for ValueHandle { #[cfg(test)] mod test { - use crate::{ops::constant::CustomConst as _, types::SumType}; + use hugr_core::{ops::constant::CustomConst as _, types::SumType}; use super::*; @@ -205,9 +200,9 @@ mod test { assert_eq!(k1, k2); assert_ne!(k1, k3); - let k4: ValueKey = From::::from(portgraph::NodeIndex::new(1).into()); - let k5 = From::::from(portgraph::NodeIndex::new(1).into()); - let k6 = From::::from(portgraph::NodeIndex::new(2).into()); + let k4: ValueKey = Node::from(portgraph::NodeIndex::new(1)).into(); + let k5: ValueKey = Node::from(portgraph::NodeIndex::new(1)).into(); + let k6: ValueKey = Node::from(portgraph::NodeIndex::new(2)).into(); assert_eq!(&k4, &k5); assert_ne!(&k4, &k6); From 8adaa6e5ae809c958b65cedf0775efcdb1e15c66 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 14:58:55 +0100 Subject: [PATCH 003/203] Missing imports / lints. Now running, but failing w/StackOverflow --- hugr-passes/src/const_fold2/datalog.rs | 11 ++++------- hugr-passes/src/const_fold2/datalog/context.rs | 7 ++----- hugr-passes/src/const_fold2/datalog/test.rs | 7 +++---- hugr-passes/src/const_fold2/datalog/utils.rs | 4 +--- hugr-passes/src/const_fold2/partial_value/test.rs | 5 +---- 5 files changed, 11 insertions(+), 23 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 0aca8e9b8..7e30b29e6 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,12 +1,9 @@ -use ascent::lattice::{ord_lattice::OrdLattice, BoundedLattice, Dual, Lattice}; -use itertools::{zip_eq, Itertools}; +use ascent::lattice::BoundedLattice; use std::collections::HashMap; -use std::hash::{Hash, Hasher}; -use std::sync::{Arc, Mutex}; +use std::hash::Hash; -use super::partial_value::{PartialValue, ValueHandle, ValueKey}; -use hugr_core::ops::{OpTag, OpTrait, Value}; -use hugr_core::types::{EdgeKind, SumType, Type, TypeEnum, TypeRow}; +use super::partial_value::PartialValue; +use hugr_core::ops::Value; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 9117cc429..81e3709c4 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -1,12 +1,9 @@ -use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::ops::Deref; -use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use hugr_core::hugr::internal::HugrInternals; -use hugr_core::ops::Value; -use hugr_core::{Hugr, HugrView, Node}; +use hugr_core::{Hugr, HugrView}; use super::DFContext; diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 5e70bf8b4..bea8db857 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -3,8 +3,7 @@ use hugr_core::{ extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, type_row, - types::{Signature, SumType}, - Extension, + types::{Signature, SumType, Type, TypeRow}, }; use crate::const_fold2::partial_value::PartialValue; @@ -58,7 +57,7 @@ fn test_unpack_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - let c = machine.run_hugr(&hugr); + machine.run_hugr(&hugr); let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); assert_eq!(o_r, Value::true_val()); @@ -161,7 +160,7 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2, _] = tail_loop.outputs_arr(); let mut machine = Machine::new(); - let c = machine.run_hugr(&hugr); + machine.run_hugr(&hugr); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 31162a718..5c2b12730 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -10,9 +10,7 @@ use itertools::{zip_eq, Either}; use crate::const_fold2::partial_value::{PartialValue, ValueHandle}; use hugr_core::{ - ops::OpTrait as _, - types::{EdgeKind, TypeRow}, - HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, + ops::OpTrait as _, types::TypeRow, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; #[cfg(test)] diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs index 227d7aff7..5e3b861e3 100644 --- a/hugr-passes/src/const_fold2/partial_value/test.rs +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -1,13 +1,11 @@ use std::sync::Arc; use itertools::{zip_eq, Either, Itertools as _}; -use lazy_static::lazy_static; use proptest::prelude::*; use hugr_core::{ - ops::Value, std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, - types::{CustomType, Type, TypeArg, TypeEnum}, + types::{Type, TypeArg, TypeEnum}, }; use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; @@ -141,7 +139,6 @@ impl TestSumType { TestSumType::Leaf(l) => { l.assert_invariants(); } - _ => (), } } From 098c7350c58fcf9f74b794db5cec2ffdce62f60e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 17:43:51 +0100 Subject: [PATCH 004/203] Fix tests... * DFContext reinstate fn hugr(), drop AsRef requirement (fixes StackOverflow) * test_tail_loop_iterates_twice: use tail_loop_builder_exts, fix from #1332(?) * Fix only-one-DataflowContext asserts using Arc::ptr_eq --- hugr-passes/src/const_fold2/datalog.rs | 18 ++++++++++-------- .../src/const_fold2/datalog/context.rs | 19 ++++++++----------- hugr-passes/src/const_fold2/datalog/test.rs | 9 +++++++-- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 7e30b29e6..96c4dd50c 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -12,7 +12,9 @@ mod utils; use context::DataflowContext; pub use utils::{TailLoopTermination, ValueRow, IO, PV}; -pub trait DFContext: AsRef + Clone + Eq + Hash + std::ops::Deref {} +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + fn hugr(&self) -> &impl HugrView; +} ascent::ascent! { // The trait-indirection layer here means we can just write 'C' but in practice ATM @@ -34,9 +36,9 @@ ascent::ascent! { node(c, n) <-- context(c), for n in c.nodes(); - in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c, *n); + in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c.hugr(), *n); - out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c, *n); + out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c.hugr(), *n); parent_of_node(c, parent, child) <-- node(c, child), if let Some(parent) = c.get_parent(*child); @@ -55,8 +57,8 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, utils::bottom_row(c, *n)) <-- node(c, n); - node_in_value_row(c, n, utils::singleton_in_row(c, n, p, v.clone())) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, utils::bottom_row(c.hugr(), *n)) <-- node(c, n); + node_in_value_row(c, n, utils::singleton_in_row(c.hugr(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); // Per node-type rules @@ -67,7 +69,7 @@ ascent::ascent! { relation load_constant_node(C, Node); load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); - out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c, *n)) <-- + out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c.hugr(), *n)) <-- load_constant_node(c, n); @@ -116,7 +118,7 @@ ascent::ascent! { if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_inputs.len(), - for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map( |(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v) ); @@ -127,7 +129,7 @@ ascent::ascent! { if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_outputs.len(), - for (out_p, v) in out_in_row.iter(c, *out_n).flat_map( + for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map( |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) ); diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 81e3709c4..1d77e39eb 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -2,7 +2,6 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::hugr::internal::HugrInternals; use hugr_core::{Hugr, HugrView}; use super::DFContext; @@ -25,13 +24,13 @@ impl Clone for DataflowContext { } impl Hash for DataflowContext { - fn hash(&self, state: &mut I) {} + fn hash(&self, _state: &mut I) {} } impl PartialEq for DataflowContext { fn eq(&self, other: &Self) -> bool { - // Any AscentProgram should have only one DataflowContext - assert_eq!(self as *const _, other as *const _); + // Any AscentProgram should have only one DataflowContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); true } } @@ -40,8 +39,8 @@ impl Eq for DataflowContext {} impl PartialOrd for DataflowContext { fn partial_cmp(&self, other: &Self) -> Option { - // Any AscentProgram should have only one DataflowContext - assert_eq!(self as *const _, other as *const _); + // Any AscentProgram should have only one DataflowContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); Some(std::cmp::Ordering::Equal) } } @@ -54,10 +53,8 @@ impl Deref for DataflowContext { } } -impl AsRef for DataflowContext { - fn as_ref(&self) -> &Hugr { - self.base_hugr() +impl DFContext for DataflowContext { + fn hugr(&self) -> &impl HugrView { + self.0.as_ref() } } - -impl DFContext for DataflowContext {} diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index bea8db857..783171525 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -138,7 +138,12 @@ fn test_tail_loop_iterates_twice() { // let r_w = builder // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); let tlb = builder - .tail_loop_builder([], [(BOOL_T, false_w), (BOOL_T, true_w)], vec![].into()) + .tail_loop_builder_exts( + [], + [(BOOL_T, false_w), (BOOL_T, true_w)], + vec![].into(), + ExtensionSet::new(), + ) .unwrap(); assert_eq!( tlb.loop_signature().unwrap().dataflow_signature().unwrap(), @@ -157,7 +162,7 @@ fn test_tail_loop_iterates_twice() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); // TODO once we can do conditionals put these wires inside `just_outputs` and // we should be able to propagate their values - let [o_w1, o_w2, _] = tail_loop.outputs_arr(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::new(); machine.run_hugr(&hugr); From 706c89208bea60c423cf7fa8960e5073499ccde6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 18:00:04 +0100 Subject: [PATCH 005/203] ValueKey using MaybeHash --- .../src/const_fold2/partial_value/test.rs | 3 +- .../const_fold2/partial_value/value_handle.rs | 153 +++++++++--------- 2 files changed, 83 insertions(+), 73 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs index 5e3b861e3..6621f0a69 100644 --- a/hugr-passes/src/const_fold2/partial_value/test.rs +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -74,7 +74,8 @@ impl TestSumLeafType { (0u64..(1 << (2u64.pow(lw as u32) - 1))) .prop_map(move |x| { let ki = ConstInt::new_u(lw as u8, x).unwrap(); - ValueHandle::new(ValueKey::new(ki.clone()), Arc::new(ki.into())).into() + let k = ValueKey::try_new(ki.clone()).unwrap(); + ValueHandle::new(k, Arc::new(ki.into())).into() }) .boxed() } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 6a91d513a..5ffe5af21 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -1,96 +1,68 @@ -use std::any::Any; use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use downcast_rs::Downcast; -use hugr_core::ops::constant::Sum; -use itertools::Either; +use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::Value; -use hugr_core::std_extensions::arithmetic::int_types::ConstInt; use hugr_core::Node; -pub trait ValueName: std::fmt::Debug + Downcast + Any { - fn hash(&self) -> u64; - fn eq(&self, other: &dyn ValueName) -> bool; -} - -fn hash_hash(x: &impl Hash) -> u64 { - let mut hasher = DefaultHasher::new(); - x.hash(&mut hasher); - hasher.finish() +#[derive(Clone, Debug)] +pub struct HashedConst { + hash: u64, + val: Arc, } -fn value_name_eq(x: &T, other: &dyn ValueName) -> bool { - if let Some(other) = other.as_any().downcast_ref::() { - x == other - } else { - false +impl PartialEq for HashedConst { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash && self.val.equal_consts(other.val.as_ref()) } } -impl ValueName for String { - fn hash(&self) -> u64 { - hash_hash(self) - } +impl Eq for HashedConst {} - fn eq(&self, other: &dyn ValueName) -> bool { - value_name_eq(self, other) +impl Hash for HashedConst { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); } } -impl ValueName for ConstInt { - fn hash(&self) -> u64 { - hash_hash(self) - } - - fn eq(&self, other: &dyn ValueName) -> bool { - value_name_eq(self, other) - } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum ValueKey { + Select(usize, Box), + Const(HashedConst), + Node(Node), } -#[derive(Clone, Debug)] -pub struct ValueKey(Vec, Either>); - -impl PartialEq for ValueKey { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - && match (&self.1, &other.1) { - (Either::Left(ref n1), Either::Left(ref n2)) => n1 == n2, - (Either::Right(ref v1), Either::Right(ref v2)) => v1.eq(v2.as_ref()), - _ => false, - } +impl From for ValueKey { + fn from(n: Node) -> Self { + Self::Node(n) } } -impl Eq for ValueKey {} - -impl Hash for ValueKey { - fn hash(&self, state: &mut H) { - self.0.hash(state); - match &self.1 { - Either::Left(n) => (0, n).hash(state), - Either::Right(v) => (1, v.hash()).hash(state), - } +impl From for ValueKey { + fn from(value: HashedConst) -> Self { + Self::Const(value) } } -impl From for ValueKey { - fn from(n: Node) -> Self { - Self(vec![], Either::Left(n)) +impl ValueKey { + pub fn new(n: Node, k: impl CustomConst) -> Self { + Self::try_new(k).unwrap_or(Self::Node(n)) } -} -impl ValueKey { - pub fn new(k: impl ValueName) -> Self { - Self(vec![], Either::Right(Arc::new(k))) + pub fn try_new(cst: impl CustomConst) -> Option { + let mut hasher = DefaultHasher::new(); + cst.maybe_hash(&mut hasher).then(|| { + Self::Const(HashedConst { + hash: hasher.finish(), + val: Arc::new(cst), + }) + }) } pub fn index(self, i: usize) -> Self { - let mut is = self.0; - is.push(i); - Self(is, self.1) + Self::Select(i, Box::new(self)) } } @@ -187,22 +159,40 @@ impl Deref for ValueHandle { #[cfg(test)] mod test { - use hugr_core::{ops::constant::CustomConst as _, types::SumType}; + use hugr_core::{ + extension::prelude::ConstString, + ops::constant::CustomConst as _, + std_extensions::{ + arithmetic::{ + float_types::{ConstF64, FLOAT64_TYPE}, + int_types::{ConstInt, INT_TYPES}, + }, + collections::ListValue, + }, + types::SumType, + }; use super::*; #[test] fn value_key_eq() { - let k1 = ValueKey::new("foo".to_string()); - let k2 = ValueKey::new("foo".to_string()); - let k3 = ValueKey::new("bar".to_string()); + let n = Node::from(portgraph::NodeIndex::new(0)); + let n2: Node = portgraph::NodeIndex::new(1).into(); + let k1 = ValueKey::new(n, ConstString::new("foo".to_string())); + let k2 = ValueKey::new(n2, ConstString::new("foo".to_string())); + let k3 = ValueKey::new(n, ConstString::new("bar".to_string())); - assert_eq!(k1, k2); + assert_eq!(k1, k2); // Node ignored assert_ne!(k1, k3); - let k4: ValueKey = Node::from(portgraph::NodeIndex::new(1)).into(); - let k5: ValueKey = Node::from(portgraph::NodeIndex::new(1)).into(); - let k6: ValueKey = Node::from(portgraph::NodeIndex::new(2)).into(); + assert_eq!(ValueKey::from(n), ValueKey::from(n)); + let f = ConstF64::new(3.141); + assert_eq!(ValueKey::new(n, f.clone()), ValueKey::from(n)); + + assert_ne!(ValueKey::new(n, f.clone()), ValueKey::new(n2, f)); // Node taken into account + let k4 = ValueKey::from(n); + let k5 = ValueKey::from(n); + let k6: ValueKey = ValueKey::from(n2); assert_eq!(&k4, &k5); assert_ne!(&k4, &k6); @@ -217,6 +207,25 @@ mod test { assert_ne!(&k5, &k7); } + #[test] + fn value_key_list() { + let v1 = ConstInt::new_u(3, 3).unwrap(); + let v2 = ConstInt::new_u(4, 3).unwrap(); + let v3 = ConstF64::new(3.141); + + let n = Node::from(portgraph::NodeIndex::new(0)); + let n2: Node = portgraph::NodeIndex::new(1).into(); + + let lst = ListValue::new(INT_TYPES[0].clone(), [v1.into(), v2.into()]); + assert_eq!(ValueKey::new(n, lst.clone()), ValueKey::new(n2, lst)); + + let lst = ListValue::new(FLOAT64_TYPE, [v3.into()]); + assert_ne!( + ValueKey::new(n, lst.clone()), + ValueKey::new(n2, lst.clone()) + ); + } + #[test] fn value_handle_eq() { let k_i = ConstInt::new_u(4, 2).unwrap(); @@ -229,7 +238,7 @@ mod test { .unwrap(), ); - let k1 = ValueKey::new("foo".to_string()); + let k1 = ValueKey::try_new(ConstString::new("foo".to_string())).unwrap(); let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); From 63bc944c86dcf98d98414fdfbf117d27817bd6dc Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 20:43:19 +0100 Subject: [PATCH 006/203] tag() does not refer to self.is_compound --- hugr-passes/src/const_fold2/partial_value.rs | 1 - hugr-passes/src/const_fold2/partial_value/value_handle.rs | 8 ++------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index dafc48fce..b5018ce38 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -412,7 +412,6 @@ impl PartialValue { } /// TODO docs - /// just delegate to variant_field_value pub fn tuple_field_value(&self, idx: usize) -> Self { self.variant_field_value(0, idx) } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 5ffe5af21..ff3a1fa16 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -95,14 +95,10 @@ impl ValueHandle { } pub fn tag(&self) -> usize { - assert!( - self.is_compound(), - "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", - self - ); match self.value() { Value::Sum(Sum { tag, .. }) => *tag, - _ => unreachable!(), + _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", + self), } } From 5fa7edbcd6816df80462f478c6f4e5d1c692e2ad Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 21:13:04 +0100 Subject: [PATCH 007/203] ValueHandle::{is_compound,num_fields,index} => {variant_values, as_sum} --- hugr-passes/src/const_fold2/partial_value.rs | 30 ++++-------- .../const_fold2/partial_value/value_handle.rs | 47 +++++++------------ 2 files changed, 27 insertions(+), 50 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index b5018ce38..2e01108e5 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -1,11 +1,9 @@ #![allow(missing_docs)] +use itertools::{zip_eq, Itertools as _}; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -use hugr_core::ops::constant::Sum; -use itertools::{zip_eq, Itertools as _}; - use hugr_core::ops::Value; use hugr_core::types::{Type, TypeEnum, TypeRow}; @@ -22,7 +20,7 @@ impl PartialSum { Self::variant(0, []) } pub fn variant(tag: usize, values: impl IntoIterator) -> Self { - Self([(tag, values.into_iter().collect())].into_iter().collect()) + Self(HashMap::from([(tag, Vec::from_iter(values))])) } pub fn num_variants(&self) -> usize { @@ -171,16 +169,10 @@ impl TryFrom for PartialSum { type Error = ValueHandle; fn try_from(value: ValueHandle) -> Result { - match value.value() { - Value::Sum(Sum { tag, values, .. }) => { - let vec = (0..values.len()) - .map(|i| PartialValue::from(value.index(i)).into()) - .collect(); - return Ok(Self([(*tag, vec)].into_iter().collect())); - } - _ => (), - }; - Err(value) + value + .as_sum() + .map(|(tag, values)| Self::variant(tag, values.into_iter().map(PartialValue::from))) + .ok_or(value) } } @@ -421,13 +413,9 @@ impl PartialValue { match self { Self::Bottom => Self::Bottom, Self::PartialSum(ps) => ps.variant_field_value(variant, idx), - Self::Value(v) => { - if v.tag() == variant { - Self::Value(v.index(idx)) - } else { - Self::Bottom - } - } + Self::Value(v) => v + .variant_values(variant) + .map_or(Self::Bottom, |vals| Self::Value(vals[idx].clone())), Self::Top => Self::Top, } } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index ff3a1fa16..ae5facbc5 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -78,44 +78,32 @@ impl ValueHandle { self.1.as_ref() } - pub fn is_compound(&self) -> bool { - matches!(self.value(), Value::Sum(_)) + pub fn variant_values(&self, variant: usize) -> Option> { + self.as_sum() + .and_then(|(tag, vals)| (tag == variant).then_some(vals)) } - pub fn num_fields(&self) -> usize { - assert!( - self.is_compound(), - "ValueHandle::num_fields called on non-Sum, non-Tuple value: {:#?}", - self - ); + pub fn as_sum(&self) -> Option<(usize, Vec)> { match self.value() { - Value::Sum(Sum { values, .. }) => values.len(), - _ => unreachable!(), + Value::Sum(Sum { tag, values, .. }) => { + let vals = values.iter().cloned().map(Arc::new); + let keys = (0..).map(|i| self.0.clone().index(i)); + let vec = keys.zip(vals).map(|(i, v)| Self(i, v)).collect(); + Some((*tag, vec)) + } + _ => None, } } pub fn tag(&self) -> usize { match self.value() { Value::Sum(Sum { tag, .. }) => *tag, - _ => panic!("ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", - self), + _ => panic!( + "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", + self + ), } } - - pub fn index(self: &ValueHandle, i: usize) -> ValueHandle { - assert!( - i < self.num_fields(), - "ValueHandle::index called with out-of-bounds index {}: {:#?}", - i, - &self - ); - let vs = match self.value() { - Value::Sum(Sum { values, .. }) => values, - _ => unreachable!(), - }; - let v = vs[i].clone().into(); - Self(self.0.clone().index(i), v) - } } impl PartialEq for ValueHandle { @@ -238,8 +226,9 @@ mod test { let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); + let (_, fields) = v1.as_sum().unwrap(); // we do not compare the value, just the key - assert_ne!(v1.index(0), v2); - assert_eq!(v1.index(0).value(), v2.value()); + assert_ne!(fields[0], v2); + assert_eq!(fields[0].value(), v2.value()); } } From 98bf94a3b35a5a9f422f03d328900b6224086ff9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 21:16:46 +0100 Subject: [PATCH 008/203] Rm ValueHandle::tag, use variant_values - inefficient, presume this is what was meant --- hugr-passes/src/const_fold2/partial_value.rs | 2 +- .../src/const_fold2/partial_value/value_handle.rs | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 2e01108e5..0c7c5b4f2 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -397,7 +397,7 @@ impl PartialValue { pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, - PartialValue::Value(v) => v.tag() == tag, // can never be a sum or tuple + PartialValue::Value(v) => v.variant_values(tag).is_some(), PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index ae5facbc5..728caeb33 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -94,16 +94,6 @@ impl ValueHandle { _ => None, } } - - pub fn tag(&self) -> usize { - match self.value() { - Value::Sum(Sum { tag, .. }) => *tag, - _ => panic!( - "ValueHandle::tag called on non-Sum, non-Tuple value: {:#?}", - self - ), - } - } } impl PartialEq for ValueHandle { From 295ec3277e180e141ae2a5fded88d894f3ce8848 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:09:18 +0100 Subject: [PATCH 009/203] add variant_values, rewrite one use of outputs_for_variant --- hugr-passes/src/const_fold2/datalog.rs | 9 +++------ hugr-passes/src/const_fold2/datalog/utils.rs | 16 ++++++++++++++- hugr-passes/src/const_fold2/partial_value.rs | 21 ++++++++++++++++++++ 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 96c4dd50c..1cf85f18a 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -115,12 +115,9 @@ ascent::ascent! { io_node(c,tl_n,in_n, IO::Input), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if out_in_row[0].supports_tag(0), // if it is possible for tag to be 0 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - let variant_len = tailloop.just_inputs.len(), - for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map( - |(input_p, v)| utils::outputs_for_variant(input_p, 0, variant_len, v) - ); + if let Some(fields) = out_in_row[0].variant_values(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 + for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields.into_iter().chain(out_in_row.iter().skip(1).cloned())); // Output node of child region propagate to outputs of tail loop out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), @@ -129,7 +126,7 @@ ascent::ascent! { if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), let variant_len = tailloop.just_outputs.len(), - for (out_p, v) in out_in_row.iter(c.hugr(), *out_n).flat_map( + for (out_p, v) in out_in_row.iter_with_ports(c.hugr(), *out_n).flat_map( |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) ); diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 5c2b12730..05ceccf16 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -30,6 +30,16 @@ impl PV { self.variant_field_value(0, idx) } + pub fn variant_values(&self, variant: usize, len: usize) -> Option> { + Some( + self.0 + .variant_values(variant, len)? + .into_iter() + .map(PV::from) + .collect(), + ) + } + /// TODO the arguments here are not pretty, two usizes, better not mix them /// up!!! pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { @@ -104,7 +114,11 @@ impl ValueRow { Self::new(r.len()) } - pub fn iter<'b>( + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn iter_with_ports<'b>( &'b self, h: &'b impl HugrView, n: Node, diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 0c7c5b4f2..e7e532fdf 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -34,6 +34,12 @@ impl PartialSum { } } + pub fn variant_values(&self, variant: usize, len: usize) -> Option> { + let row = self.0.get(&variant)?; + assert!(row.len() == len); + Some(row.clone()) + } + pub fn variant_field_value(&self, variant: usize, idx: usize) -> PartialValue { if let Some(row) = self.0.get(&variant) { assert!(row.len() > idx); @@ -394,6 +400,21 @@ impl PartialValue { Self::variant(0, []) } + pub fn variant_values(&self, tag: usize, len: usize) -> Option> { + let vals = match self { + PartialValue::Bottom => return None, + PartialValue::Value(v) => v + .variant_values(tag)? + .into_iter() + .map(PartialValue::Value) + .collect(), + PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, + PartialValue::Top => vec![PartialValue::Top; len], + }; + assert_eq!(vals.len(), len); + Some(vals) + } + pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, From 5c8289e88f94c8fac2c3a4755f768bd5d89a97ce Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:18:52 +0100 Subject: [PATCH 010/203] ...and the other two; remove outputs_for_variant --- hugr-passes/src/const_fold2/datalog.rs | 17 ++++---- hugr-passes/src/const_fold2/datalog/utils.rs | 42 -------------------- 2 files changed, 9 insertions(+), 50 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 1cf85f18a..71dee7483 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -117,17 +117,17 @@ ascent::ascent! { node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), if let Some(fields) = out_in_row[0].variant_values(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 - for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields.into_iter().chain(out_in_row.iter().skip(1).cloned())); + for (out_p, v) in (0..).map(OutgoingPort::from).zip( + fields.into_iter().chain(out_in_row.iter().skip(1).cloned())); // Output node of child region propagate to outputs of tail loop out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if out_in_row[0].supports_tag(1), // if it is possible for the tag to be 1 if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - let variant_len = tailloop.just_outputs.len(), - for (out_p, v) in out_in_row.iter_with_ports(c.hugr(), *out_n).flat_map( - |(input_p, v)| utils::outputs_for_variant(input_p, 1, variant_len, v) + if let Some(fields) = out_in_row[0].variant_values(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 + for (out_p, v) in (0..).map(OutgoingPort::from).zip( + fields.into_iter().chain(out_in_row.iter().skip(1).cloned()) ); lattice tail_loop_termination(C,Node,TailLoopTermination); @@ -152,10 +152,11 @@ ascent::ascent! { out_wire_value(c, i_node, i_p, v) <-- case_node(c, cond, case_index, case), io_node(c, case, i_node, IO::Input), - in_wire_value(c, cond, cond_in_p, cond_in_v), + node_in_value_row(c, cond, in_row), + //in_wire_value(c, cond, cond_in_p, cond_in_v), if let Some(conditional) = c.get_optype(*cond).as_conditional(), - let variant_len = conditional.sum_rows[*case_index].len(), - for (i_p, v) in utils::outputs_for_variant(*cond_in_p, *case_index, variant_len, cond_in_v); + if let Some(fields) = in_row[0].variant_values(*case_index, conditional.sum_rows[*case_index].len()), + for (i_p, v) in (0..).map(OutgoingPort::from).zip(fields.into_iter().chain(in_row.iter().skip(1).cloned())); // outputs of case nodes propagate to outputs of conditional out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 05ceccf16..d41e4bceb 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -233,48 +233,6 @@ pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator( - output_p: IncomingPort, - variant_tag: usize, - variant_len: usize, - v: &'a PV, -) -> impl Iterator + 'a { - if output_p.index() == 0 { - Either::Left( - (0..variant_len).map(move |i| (i.into(), v.variant_field_value(variant_tag, i))), - ) - } else { - let v = if v.supports_tag(variant_tag) { - v.clone() - } else { - PV::bottom() - }; - Either::Right(std::iter::once(( - (variant_len + output_p.index() - 1).into(), - v, - ))) - } -} - #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] #[cfg_attr(test, derive(Arbitrary))] pub enum TailLoopTermination { From bf173ab4ae14279c0c32dee70f6cd80bda128ec1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:26:22 +0100 Subject: [PATCH 011/203] Rewrite tuple rule to avoid indexing --- hugr-passes/src/const_fold2/datalog.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 71dee7483..0fc56241c 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -85,10 +85,11 @@ ascent::ascent! { relation unpack_tuple_node(C, Node); unpack_tuple_node(c,n) <-- node(c, n), if c.get_optype(*n).is_unpack_tuple(); - out_wire_value(c, n, p, v.tuple_field_value(p.index())) <-- + out_wire_value(c, n, p, v) <-- unpack_tuple_node(c, n), - in_wire_value(c, n, IncomingPort::from(0), v), - out_wire(c, n, p); + in_wire_value(c, n, IncomingPort::from(0), tup), + if let Some(fields) = tup.variant_values(0, utils::value_outputs(c.hugr(),*n).count()), + for (p,v) in (0..).map(OutgoingPort::from).zip(fields); // DFG From 0ae4d196046d99c9ceedec904c06f54328e00eff Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:28:44 +0100 Subject: [PATCH 012/203] GC unused (tuple,variant)_field_value, iter_with_ports --- hugr-passes/src/const_fold2/datalog/utils.rs | 18 ------------- hugr-passes/src/const_fold2/partial_value.rs | 27 +------------------- 2 files changed, 1 insertion(+), 44 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index d41e4bceb..bebb741f8 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -26,10 +26,6 @@ impl From for PV { } impl PV { - pub fn tuple_field_value(&self, idx: usize) -> Self { - self.variant_field_value(0, idx) - } - pub fn variant_values(&self, variant: usize, len: usize) -> Option> { Some( self.0 @@ -40,12 +36,6 @@ impl PV { ) } - /// TODO the arguments here are not pretty, two usizes, better not mix them - /// up!!! - pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { - self.0.variant_field_value(variant, idx).into() - } - pub fn supports_tag(&self, tag: usize) -> bool { self.0.supports_tag(tag) } @@ -118,14 +108,6 @@ impl ValueRow { self.0.iter() } - pub fn iter_with_ports<'b>( - &'b self, - h: &'b impl HugrView, - n: Node, - ) -> impl Iterator + 'b { - zip_eq(value_inputs(h, n), self.0.iter()) - } - // fn initialised(&self) -> bool { // self.0.iter().all(|x| x != &PV::top()) // } diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index e7e532fdf..4337e5bb6 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -40,15 +40,6 @@ impl PartialSum { Some(row.clone()) } - pub fn variant_field_value(&self, variant: usize, idx: usize) -> PartialValue { - if let Some(row) = self.0.get(&variant) { - assert!(row.len() > idx); - row[idx].clone() - } else { - PartialValue::bottom() - } - } - pub fn try_into_value(self, typ: &Type) -> Result { let Ok((k, v)) = self.0.iter().exactly_one() else { Err(self)? @@ -418,28 +409,12 @@ impl PartialValue { pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, + // TODO this is wildly expensive - only used for case reachability but still... PartialValue::Value(v) => v.variant_values(tag).is_some(), PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } } - - /// TODO docs - pub fn tuple_field_value(&self, idx: usize) -> Self { - self.variant_field_value(0, idx) - } - - /// TODO docs - pub fn variant_field_value(&self, variant: usize, idx: usize) -> Self { - match self { - Self::Bottom => Self::Bottom, - Self::PartialSum(ps) => ps.variant_field_value(variant, idx), - Self::Value(v) => v - .variant_values(variant) - .map_or(Self::Bottom, |vals| Self::Value(vals[idx].clone())), - Self::Top => Self::Top, - } - } } impl PartialOrd for PartialValue { From 8608ba9a9a7e2400f964e878207956f5b23d029e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 22:33:53 +0100 Subject: [PATCH 013/203] Common up via ValueRow.unpack_first --- hugr-passes/src/const_fold2/datalog.rs | 15 ++++++--------- hugr-passes/src/const_fold2/datalog/utils.rs | 10 ++++++++++ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 0fc56241c..2944503c7 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -117,19 +117,16 @@ ascent::ascent! { io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - if let Some(fields) = out_in_row[0].variant_values(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 - for (out_p, v) in (0..).map(OutgoingPort::from).zip( - fields.into_iter().chain(out_in_row.iter().skip(1).cloned())); + if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 + for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); // Output node of child region propagate to outputs of tail loop out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - if let Some(fields) = out_in_row[0].variant_values(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 - for (out_p, v) in (0..).map(OutgoingPort::from).zip( - fields.into_iter().chain(out_in_row.iter().skip(1).cloned()) - ); + if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 + for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); lattice tail_loop_termination(C,Node,TailLoopTermination); tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <-- @@ -156,8 +153,8 @@ ascent::ascent! { node_in_value_row(c, cond, in_row), //in_wire_value(c, cond, cond_in_p, cond_in_v), if let Some(conditional) = c.get_optype(*cond).as_conditional(), - if let Some(fields) = in_row[0].variant_values(*case_index, conditional.sum_rows[*case_index].len()), - for (i_p, v) in (0..).map(OutgoingPort::from).zip(fields.into_iter().chain(in_row.iter().skip(1).cloned())); + if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + for (i_p, v) in (0..).map(OutgoingPort::from).zip(fields); // outputs of case nodes propagate to outputs of conditional out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index bebb741f8..5a9ac8495 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -108,6 +108,16 @@ impl ValueRow { self.0.iter() } + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option + '_> { + self[0] + .variant_values(variant, len) + .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) + } + // fn initialised(&self) -> bool { // self.0.iter().all(|x| x != &PV::top()) // } From 2dca3e9fcb06d73cf7c6c8728d11a5a80412433a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 19:25:23 +0100 Subject: [PATCH 014/203] No DeRef for ValueHandle, just add get_type() --- .../const_fold2/partial_value/value_handle.rs | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 728caeb33..6a5d9dd81 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -1,10 +1,10 @@ use std::hash::{DefaultHasher, Hash, Hasher}; -use std::ops::Deref; use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::Value; +use hugr_core::types::Type; use hugr_core::Node; #[derive(Clone, Debug)] @@ -94,6 +94,10 @@ impl ValueHandle { _ => None, } } + + pub fn get_type(&self) -> Type { + self.1.get_type() + } } impl PartialEq for ValueHandle { @@ -119,18 +123,6 @@ impl Hash for ValueHandle { } } -/// TODO this is perhaps dodgy -/// we do not hash or compare the value, just the key -/// this means two handles with different keys, but with the same value, will -/// not compare equal. -impl Deref for ValueHandle { - type Target = Value; - - fn deref(&self) -> &Self::Target { - self.value() - } -} - #[cfg(test)] mod test { use hugr_core::{ From 51e68ea99a44addad85d4aaee18d425b1d52829d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 6 Aug 2024 19:27:22 +0100 Subject: [PATCH 015/203] ValueKey::{Select->Field,index->field} --- .../src/const_fold2/partial_value/value_handle.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 6a5d9dd81..3b450c178 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -29,7 +29,7 @@ impl Hash for HashedConst { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum ValueKey { - Select(usize, Box), + Field(usize, Box), Const(HashedConst), Node(Node), } @@ -61,8 +61,8 @@ impl ValueKey { }) } - pub fn index(self, i: usize) -> Self { - Self::Select(i, Box::new(self)) + pub fn field(self, i: usize) -> Self { + Self::Field(i, Box::new(self)) } } @@ -87,7 +87,7 @@ impl ValueHandle { match self.value() { Value::Sum(Sum { tag, values, .. }) => { let vals = values.iter().cloned().map(Arc::new); - let keys = (0..).map(|i| self.0.clone().index(i)); + let keys = (0..).map(|i| self.0.clone().field(i)); let vec = keys.zip(vals).map(|(i, v)| Self(i, v)).collect(); Some((*tag, vec)) } @@ -163,12 +163,12 @@ mod test { assert_eq!(&k4, &k5); assert_ne!(&k4, &k6); - let k7 = k5.clone().index(3); - let k4 = k4.index(3); + let k7 = k5.clone().field(3); + let k4 = k4.field(3); assert_eq!(&k4, &k7); - let k5 = k5.index(2); + let k5 = k5.field(2); assert_ne!(&k5, &k7); } From 863547413877b7ca3e305453da1726ef30a5b792 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 7 Aug 2024 11:24:02 +0100 Subject: [PATCH 016/203] (join/meet)_mut_unsafe => try_(join/meet)_mut with Err for conflicting len --- hugr-passes/src/const_fold2/partial_value.rs | 43 +++++++++++++++----- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 4337e5bb6..dc0ea005b 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -66,12 +66,16 @@ impl PartialSum { } } - // unsafe because we panic if any common rows have different lengths - fn join_mut_unsafe(&mut self, other: Self) -> bool { + // Err with key if any common rows have different lengths (self may have been mutated) + fn try_join_mut(&mut self, other: Self) -> Result { let mut changed = false; for (k, v) in other.0 { if let Some(row) = self.0.get_mut(&k) { + if v.len() != row.len() { + // Better to check first and avoid mutation, but fine here + return Err(k); + } for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { changed |= lhs.join_mut(rhs); } @@ -80,16 +84,21 @@ impl PartialSum { changed = true; } } - changed + Ok(changed) } - // unsafe because we panic if any common rows have different lengths - fn meet_mut_unsafe(&mut self, other: Self) -> bool { + // Error with key if any common rows have different lengths ( => Bottom) + fn try_meet_mut(&mut self, other: Self) -> Result { let mut changed = false; let mut keys_to_remove = vec![]; - for k in self.0.keys() { - if !other.0.contains_key(k) { - keys_to_remove.push(*k); + for (k, v) in self.0.iter() { + match other.0.get(k) { + None => keys_to_remove.push(*k), + Some(o_v) => { + if v.len() != o_v.len() { + return Err(*k); + } + } } } for (k, v) in other.0 { @@ -105,7 +114,7 @@ impl PartialSum { self.0.remove(&k); changed = true; } - changed + Ok(changed) } pub fn supports_tag(&self, tag: usize) -> bool { @@ -304,7 +313,13 @@ impl PartialValue { let Self::PartialSum(ps1) = self else { unreachable!() }; - ps1.join_mut_unsafe(ps2) + match ps1.try_join_mut(ps2) { + Ok(ch) => ch, + Err(_) => { + *self = Self::Top; + true + } + } } (Self::Value(_), mut other) => { std::mem::swap(self, &mut other); @@ -354,7 +369,13 @@ impl PartialValue { let Self::PartialSum(ps1) = self else { unreachable!() }; - ps1.meet_mut_unsafe(ps2) + match ps1.try_meet_mut(ps2) { + Ok(ch) => ch, + Err(_) => { + *self = Self::Bottom; + true + } + } } (Self::Value(_), mut other @ Self::PartialSum(_)) => { std::mem::swap(self, &mut other); From 80d5b866903e06e1aa3cc341cf4309e409b1bbc5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 8 Aug 2024 15:02:03 +0100 Subject: [PATCH 017/203] Remove ValueHandle::variant_values - just have as_sum --- hugr-passes/src/const_fold2/partial_value.rs | 6 ++++-- hugr-passes/src/const_fold2/partial_value/value_handle.rs | 5 ----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index dc0ea005b..4bb56222d 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -416,7 +416,9 @@ impl PartialValue { let vals = match self { PartialValue::Bottom => return None, PartialValue::Value(v) => v - .variant_values(tag)? + .as_sum() + .filter(|(variant, _)| tag == *variant)? + .1 .into_iter() .map(PartialValue::Value) .collect(), @@ -431,7 +433,7 @@ impl PartialValue { match self { PartialValue::Bottom => false, // TODO this is wildly expensive - only used for case reachability but still... - PartialValue::Value(v) => v.variant_values(tag).is_some(), + PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 3b450c178..6a4d70a60 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -78,11 +78,6 @@ impl ValueHandle { self.1.as_ref() } - pub fn variant_values(&self, variant: usize) -> Option> { - self.as_sum() - .and_then(|(tag, vals)| (tag == variant).then_some(vals)) - } - pub fn as_sum(&self) -> Option<(usize, Vec)> { match self.value() { Value::Sum(Sum { tag, values, .. }) => { From 1c8be9989380588ed4c48c6220bc2c05d5fa4101 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 8 Aug 2024 15:06:25 +0100 Subject: [PATCH 018/203] Optimize as_sum() by returning impl Iterator not Vec --- hugr-passes/src/const_fold2/partial_value.rs | 4 +--- .../const_fold2/partial_value/value_handle.rs | 17 +++++++++-------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 4bb56222d..3ba8e7c57 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -177,7 +177,7 @@ impl TryFrom for PartialSum { fn try_from(value: ValueHandle) -> Result { value .as_sum() - .map(|(tag, values)| Self::variant(tag, values.into_iter().map(PartialValue::from))) + .map(|(tag, values)| Self::variant(tag, values.map(PartialValue::from))) .ok_or(value) } } @@ -419,7 +419,6 @@ impl PartialValue { .as_sum() .filter(|(variant, _)| tag == *variant)? .1 - .into_iter() .map(PartialValue::Value) .collect(), PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, @@ -432,7 +431,6 @@ impl PartialValue { pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, - // TODO this is wildly expensive - only used for case reachability but still... PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/partial_value/value_handle.rs index 6a4d70a60..147048ae7 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/partial_value/value_handle.rs @@ -78,14 +78,15 @@ impl ValueHandle { self.1.as_ref() } - pub fn as_sum(&self) -> Option<(usize, Vec)> { + pub fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { match self.value() { - Value::Sum(Sum { tag, values, .. }) => { - let vals = values.iter().cloned().map(Arc::new); - let keys = (0..).map(|i| self.0.clone().field(i)); - let vec = keys.zip(vals).map(|(i, v)| Self(i, v)).collect(); - Some((*tag, vec)) - } + Value::Sum(Sum { tag, values, .. }) => Some(( + *tag, + values + .iter() + .enumerate() + .map(|(i, v)| Self(self.0.clone().field(i), Arc::new(v.clone()))), + )), _ => None, } } @@ -203,7 +204,7 @@ mod test { let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); - let (_, fields) = v1.as_sum().unwrap(); + let fields = v1.as_sum().unwrap().1.collect::>(); // we do not compare the value, just the key assert_ne!(fields[0], v2); assert_eq!(fields[0].value(), v2.value()); From b0afa54aab5c94cca1e84b44a4734c6aaee5db94 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 27 Aug 2024 17:27:09 +0100 Subject: [PATCH 019/203] Machine uses PV not PartialValue --- hugr-passes/src/const_fold2/datalog.rs | 10 +++------- hugr-passes/src/const_fold2/datalog/test.rs | 6 +++--- hugr-passes/src/const_fold2/datalog/utils.rs | 8 +++++++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 2944503c7..04a63a608 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -2,7 +2,6 @@ use ascent::lattice::BoundedLattice; use std::collections::HashMap; use std::hash::Hash; -use super::partial_value::PartialValue; use hugr_core::ops::Value; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; @@ -170,10 +169,7 @@ ascent::ascent! { } // TODO This should probably be called 'Analyser' or something -struct Machine( - AscentProgram>, - Option>, -); +struct Machine(AscentProgram>, Option>); /// Usage: /// 1. [Self::new()] @@ -185,7 +181,7 @@ impl Machine { Self(Default::default(), None) } - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { assert!(self.1.is_none()); self.0.out_wire_value_proto.extend( wires @@ -207,7 +203,7 @@ impl Machine { ) } - pub fn read_out_wire_partial_value(&self, w: Wire) -> Option { + pub fn read_out_wire_partial_value(&self, w: Wire) -> Option { self.1.as_ref().unwrap().get(&w).cloned() } diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 783171525..2f3ad5d5d 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -118,9 +118,9 @@ fn test_tail_loop_always_iterates() { machine.run_hugr(&hugr); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); - assert_eq!(o_r1, PartialValue::bottom()); + assert_eq!(o_r1, PartialValue::bottom().into()); let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); - assert_eq!(o_r2, PartialValue::bottom()); + assert_eq!(o_r2, PartialValue::bottom().into()); assert_eq!( TailLoopTermination::bottom(), machine.tail_loop_terminates(&hugr, tail_loop.node()) @@ -220,7 +220,7 @@ fn conditional() { let mut machine = Machine::new(); let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); - machine.propolutate_out_wires([(arg_w, arg_pv)]); + machine.propolutate_out_wires([(arg_w, arg_pv.into())]); machine.run_hugr(&hugr); let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 5a9ac8495..e10e96ed3 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -10,7 +10,9 @@ use itertools::{zip_eq, Either}; use crate::const_fold2::partial_value::{PartialValue, ValueHandle}; use hugr_core::{ - ops::OpTrait as _, types::TypeRow, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, + ops::{OpTrait as _, Value}, + types::{Type, TypeRow}, + HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; #[cfg(test)] @@ -39,6 +41,10 @@ impl PV { pub fn supports_tag(&self, tag: usize) -> bool { self.0.supports_tag(tag) } + + pub fn try_into_value(self, ty: &Type) -> Result { + self.0.try_into_value(ty).map_err(Self) + } } impl From for PartialValue { From d09a1fe769e477d82533e1b42ba9d6a00f0d90bb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 12:15:40 +0100 Subject: [PATCH 020/203] Parametrize PartialValue+PV+Machine by AbstractValue/Into, Context interprets load_constant --- hugr-passes/src/const_fold2.rs | 5 +- hugr-passes/src/const_fold2/datalog.rs | 61 ++++--- .../src/const_fold2/datalog/context.rs | 15 +- hugr-passes/src/const_fold2/datalog/test.rs | 18 +- hugr-passes/src/const_fold2/datalog/utils.rs | 110 ++++++------ hugr-passes/src/const_fold2/partial_value.rs | 161 +++++++++--------- .../src/const_fold2/partial_value/test.rs | 20 ++- .../{partial_value => }/value_handle.rs | 21 ++- 8 files changed, 227 insertions(+), 184 deletions(-) rename hugr-passes/src/const_fold2/{partial_value => }/value_handle.rs (95%) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 96af004e1..13af5c709 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,2 +1,3 @@ -mod datalog; -pub mod partial_value; +pub mod datalog; +mod partial_value; +pub mod value_handle; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 04a63a608..c06b2a285 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -8,11 +8,13 @@ use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _ mod context; mod utils; -use context::DataflowContext; pub use utils::{TailLoopTermination, ValueRow, IO, PV}; -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { +use super::partial_value::AbstractValue; + +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; + fn value_from_load_constant(&self, node: Node) -> V; } ascent::ascent! { @@ -20,18 +22,18 @@ ascent::ascent! { // DataflowContext (for H: HugrView) would be sufficient, there's really no // point in using anything else yet. However DFContext will be useful when we // move interpretation of nodes out into a trait method. - struct AscentProgram; + struct AscentProgram>; relation context(C); - relation out_wire_value_proto(Node, OutgoingPort, PV); + relation out_wire_value_proto(Node, OutgoingPort, PV); relation node(C, Node); relation in_wire(C, Node, IncomingPort); relation out_wire(C, Node, OutgoingPort); relation parent_of_node(C, Node, Node); relation io_node(C, Node, Node, IO); - lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); - lattice in_wire_value(C, Node, IncomingPort, PV); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); node(c, n) <-- context(c), for n in c.nodes(); @@ -68,7 +70,7 @@ ascent::ascent! { relation load_constant_node(C, Node); load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); - out_wire_value(c, n, 0.into(), utils::partial_value_from_load_constant(c.hugr(), *n)) <-- + out_wire_value(c, n, 0.into(), PV::from(c.value_from_load_constant(*n))) <-- load_constant_node(c, n); @@ -169,19 +171,22 @@ ascent::ascent! { } // TODO This should probably be called 'Analyser' or something -struct Machine(AscentProgram>, Option>); +pub struct Machine>( + AscentProgram, + Option>>, +); /// Usage: /// 1. [Self::new()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run_hugr] to do the analysis /// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] -impl Machine { +impl> Machine { pub fn new() -> Self { Self(Default::default(), None) } - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator)>) { assert!(self.1.is_none()); self.0.out_wire_value_proto.extend( wires @@ -190,9 +195,9 @@ impl Machine { ); } - pub fn run_hugr(&mut self, hugr: H) { + pub fn run(&mut self, context: C) { assert!(self.1.is_none()); - self.0.context.push((DataflowContext::new(hugr),)); + self.0.context.push((context,)); self.0.run(); self.1 = Some( self.0 @@ -203,22 +208,11 @@ impl Machine { ) } - pub fn read_out_wire_partial_value(&self, w: Wire) -> Option { + pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { self.1.as_ref().unwrap().get(&w).cloned() } - pub fn read_out_wire_value(&self, hugr: H, w: Wire) -> Option { - // dbg!(&w); - let pv = self.read_out_wire_partial_value(w)?; - // dbg!(&pv); - let (_, typ) = hugr - .out_value_types(w.node()) - .find(|(p, _)| *p == w.source()) - .unwrap(); - pv.try_into_value(&typ).ok() - } - - pub fn tail_loop_terminates(&self, hugr: H, node: Node) -> TailLoopTermination { + pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { assert!(hugr.get_optype(node).is_tail_loop()); self.0 .tail_loop_termination @@ -227,7 +221,7 @@ impl Machine { .unwrap() } - pub fn case_reachable(&self, hugr: H, case: Node) -> bool { + pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { assert!(hugr.get_optype(case).is_case()); let cond = hugr.get_parent(case).unwrap(); assert!(hugr.get_optype(cond).is_conditional()); @@ -239,5 +233,18 @@ impl Machine { } } +impl, C: DFContext> Machine { + pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { + // dbg!(&w); + let pv = self.read_out_wire_partial_value(w)?; + // dbg!(&pv); + let (_, typ) = hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + pv.try_into_value(&typ).ok() + } +} + #[cfg(test)] mod test; diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 1d77e39eb..31a3233fd 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -4,6 +4,8 @@ use std::sync::Arc; use hugr_core::{Hugr, HugrView}; +use crate::const_fold2::value_handle::ValueHandle; + use super::DFContext; #[derive(Debug)] @@ -53,8 +55,19 @@ impl Deref for DataflowContext { } } -impl DFContext for DataflowContext { +impl DFContext for DataflowContext { fn hugr(&self) -> &impl HugrView { self.0.as_ref() } + + fn value_from_load_constant(&self, node: hugr_core::Node) -> ValueHandle { + let load_op = self.0.get_optype(node).as_load_constant().unwrap(); + let const_node = self + .0 + .single_linked_output(node, load_op.constant_port()) + .unwrap() + .0; + let const_op = self.0.get_optype(const_node).as_const().unwrap(); + ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())) + } } diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 2f3ad5d5d..f80cc903e 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -1,3 +1,4 @@ +use context::DataflowContext; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, @@ -6,8 +7,7 @@ use hugr_core::{ types::{Signature, SumType, Type, TypeRow}, }; -use crate::const_fold2::partial_value::PartialValue; - +use super::super::partial_value::PartialValue; use super::*; #[test] @@ -19,7 +19,7 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let x = machine.read_out_wire_value(&hugr, v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -38,7 +38,7 @@ fn test_unpack_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -57,7 +57,7 @@ fn test_unpack_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); assert_eq!(o_r, Value::true_val()); @@ -83,7 +83,7 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -115,7 +115,7 @@ fn test_tail_loop_always_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom().into()); @@ -165,7 +165,7 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::new(); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -221,7 +221,7 @@ fn conditional() { let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); machine.propolutate_out_wires([(arg_w, arg_pv.into())]); - machine.run_hugr(&hugr); + machine.run(DataflowContext::new(&hugr)); let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index e10e96ed3..42138396e 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -3,32 +3,40 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -use std::{cmp::Ordering, ops::Index, sync::Arc}; +use std::{cmp::Ordering, ops::Index}; use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::{zip_eq, Either}; +use itertools::zip_eq; -use crate::const_fold2::partial_value::{PartialValue, ValueHandle}; +use crate::const_fold2::partial_value::{AbstractValue, PartialValue}; use hugr_core::{ ops::{OpTrait as _, Value}, - types::{Type, TypeRow}, + types::{Signature, Type, TypeRow}, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; #[cfg(test)] use proptest_derive::Arbitrary; -#[derive(PartialEq, Eq, Hash, PartialOrd, Clone, Debug)] -pub struct PV(PartialValue); +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +pub struct PV(PartialValue); -impl From for PV { - fn from(inner: PartialValue) -> Self { +// Implement manually as PartialValue is PartialOrd even when V isn't +// (deriving PartialOrd conditions on V: PartialOrd, which is not necessary) +impl PartialOrd for PV { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl From> for PV { + fn from(inner: PartialValue) -> Self { Self(inner) } } -impl PV { - pub fn variant_values(&self, variant: usize, len: usize) -> Option> { +impl PV { + pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { Some( self.0 .variant_values(variant, len)? @@ -41,25 +49,27 @@ impl PV { pub fn supports_tag(&self, tag: usize) -> bool { self.0.supports_tag(tag) } +} +impl> PV { pub fn try_into_value(self, ty: &Type) -> Result { self.0.try_into_value(ty).map_err(Self) } } -impl From for PartialValue { - fn from(value: PV) -> Self { +impl From> for PartialValue { + fn from(value: PV) -> Self { value.0 } } -impl From for PV { - fn from(inner: ValueHandle) -> Self { +impl From for PV { + fn from(inner: V) -> Self { Self(inner.into()) } } -impl Lattice for PV { +impl Lattice for PV { fn meet(self, other: Self) -> Self { self.0.meet(other.0).into() } @@ -77,7 +87,7 @@ impl Lattice for PV { } } -impl BoundedLattice for PV { +impl BoundedLattice for PV { fn bottom() -> Self { PartialValue::bottom().into() } @@ -87,22 +97,22 @@ impl BoundedLattice for PV { } } -#[derive(PartialEq, Clone, Eq, Hash, PartialOrd)] -pub struct ValueRow(Vec); +#[derive(PartialEq, Clone, Eq, Hash)] +pub struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { fn new(len: usize) -> Self { Self(vec![PV::bottom(); len]) } - fn singleton(len: usize, idx: usize, v: PV) -> Self { + fn singleton(len: usize, idx: usize, v: PV) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { + fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { Self::singleton(r.len(), idx, v) } @@ -110,7 +120,7 @@ impl ValueRow { Self::new(r.len()) } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator> { self.0.iter() } @@ -118,7 +128,7 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option + '_> { + ) -> Option> + '_> { self[0] .variant_values(variant, len) .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) @@ -129,7 +139,13 @@ impl ValueRow { // } } -impl Lattice for ValueRow { +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { fn meet(mut self, other: Self) -> Self { self.meet_mut(other); self @@ -159,36 +175,42 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PV; +impl IntoIterator for ValueRow { + type Item = PV; - type IntoIter = as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec: Index, + Vec>: Index, { - type Output = as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { - if let Some(sig) = h.signature(n) { - ValueRow::new(sig.input_count()) - } else { - ValueRow::new(0) - } +pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { + ValueRow::new( + h.signature(n) + .as_ref() + .map(Signature::input_count) + .unwrap_or(0), + ) } -pub(super) fn singleton_in_row(h: &impl HugrView, n: &Node, ip: &IncomingPort, v: PV) -> ValueRow { +pub(super) fn singleton_in_row( + h: &impl HugrView, + n: &Node, + ip: &IncomingPort, + v: PV, +) -> ValueRow { let Some(sig) = h.signature(*n) else { panic!("dougrulz"); }; @@ -203,17 +225,7 @@ pub(super) fn singleton_in_row(h: &impl HugrView, n: &Node, ip: &IncomingPort, v ValueRow::singleton_from_row(&h.signature(*n).unwrap().input, ip.index(), v) } -pub(super) fn partial_value_from_load_constant(h: &impl HugrView, node: Node) -> PV { - let load_op = h.get_optype(node).as_load_constant().unwrap(); - let const_node = h - .single_linked_output(node, load_op.constant_port()) - .unwrap() - .0; - let const_op = h.get_optype(const_node).as_const().unwrap(); - ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())).into() -} - -pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { +pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() } @@ -240,7 +252,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - pub fn from_control_value(v: &PV) -> Self { + pub fn from_control_value(v: &PV) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break && !may_continue { Self::ExactlyZeroContinues diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 3ba8e7c57..6a2a614a8 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -1,32 +1,34 @@ #![allow(missing_docs)] -use itertools::{zip_eq, Itertools as _}; -use std::cmp::Ordering; -use std::collections::HashMap; -use std::hash::{Hash, Hasher}; use hugr_core::ops::Value; use hugr_core::types::{Type, TypeEnum, TypeRow}; +use itertools::{zip_eq, Itertools}; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; -mod value_handle; - -pub use value_handle::{ValueHandle, ValueKey}; +pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { + fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; +} -// TODO ALAN inline into PartialValue +// TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum #[derive(PartialEq, Clone, Eq)] -struct PartialSum(HashMap>); +pub struct PartialSum(pub HashMap>>); -impl PartialSum { +impl PartialSum { pub fn unit() -> Self { Self::variant(0, []) } - pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + pub fn variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } pub fn num_variants(&self) -> usize { self.0.len() } +} +impl PartialSum { fn assert_variants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { @@ -34,38 +36,6 @@ impl PartialSum { } } - pub fn variant_values(&self, variant: usize, len: usize) -> Option> { - let row = self.0.get(&variant)?; - assert!(row.len() == len); - Some(row.clone()) - } - - pub fn try_into_value(self, typ: &Type) -> Result { - let Ok((k, v)) = self.0.iter().exactly_one() else { - Err(self)? - }; - - let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(self)? - }; - let Some(r) = st.get_variant(*k) else { - Err(self)? - }; - let Ok(r): Result = r.clone().try_into() else { - Err(self)? - }; - if v.len() != r.len() { - return Err(self); - } - match zip_eq(v.into_iter(), r.into_iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>() - { - Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), - Err(_) => Err(self), - } - } - // Err with key if any common rows have different lengths (self may have been mutated) fn try_join_mut(&mut self, other: Self) -> Result { let mut changed = false; @@ -122,7 +92,43 @@ impl PartialSum { } } -impl PartialOrd for PartialSum { +impl> PartialSum { + pub fn try_into_value(self, typ: &Type) -> Result { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; + + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + let Ok(r): Result = r.clone().try_into() else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self); + } + match zip_eq(v.into_iter(), r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + { + Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), + Err(_) => Err(self), + } + } +} + +impl PartialSum { + pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { + let row = self.0.get(&variant)?; + assert!(row.len() == len); + Some(row.clone()) + } +} + +impl PartialOrd for PartialSum { fn partial_cmp(&self, other: &Self) -> Option { let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); @@ -156,13 +162,13 @@ impl PartialOrd for PartialSum { } } -impl std::fmt::Debug for PartialSum { +impl std::fmt::Debug for PartialSum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl Hash for PartialSum { +impl Hash for PartialSum { fn hash(&self, state: &mut H) { for (k, v) in &self.0 { k.hash(state); @@ -171,38 +177,29 @@ impl Hash for PartialSum { } } -impl TryFrom for PartialSum { - type Error = ValueHandle; - - fn try_from(value: ValueHandle) -> Result { - value - .as_sum() - .map(|(tag, values)| Self::variant(tag, values.map(PartialValue::from))) - .ok_or(value) - } -} - #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub enum PartialValue { Bottom, - Value(ValueHandle), - PartialSum(PartialSum), + Value(V), + PartialSum(PartialSum), Top, } -impl From for PartialValue { - fn from(v: ValueHandle) -> Self { - TryInto::::try_into(v).map_or_else(Self::Value, Self::PartialSum) +impl From for PartialValue { + fn from(v: V) -> Self { + v.as_sum() + .map(|(tag, values)| Self::variant(tag, values.map(Self::Value))) + .unwrap_or(Self::Value(v)) } } -impl From for PartialValue { - fn from(v: PartialSum) -> Self { +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { Self::PartialSum(v) } } -impl PartialValue { +impl PartialValue { // const BOTTOM: Self = Self::Bottom; // const BOTTOM_REF: &'static Self = &Self::BOTTOM; @@ -220,23 +217,13 @@ impl PartialValue { ps.assert_variants(); } Self::Value(v) => { - assert!(matches!(v.clone().into(), Self::Value(_))) + assert!(v.as_sum().is_none()) } _ => {} } } - pub fn try_into_value(self, typ: &Type) -> Result { - let r = match self { - Self::Value(v) => Ok(v.value().clone()), - Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), - x => Err(x), - }?; - assert_eq!(typ, &r.get_type()); - Ok(r) - } - - fn join_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + fn join_mut_value_handle(&mut self, vh: V) -> bool { self.assert_invariants(); match &*self { Self::Top => return false, @@ -257,7 +244,7 @@ impl PartialValue { true } - fn meet_mut_value_handle(&mut self, vh: ValueHandle) -> bool { + fn meet_mut_value_handle(&mut self, vh: V) -> bool { self.assert_invariants(); match &*self { Self::Bottom => false, @@ -412,7 +399,7 @@ impl PartialValue { Self::variant(0, []) } - pub fn variant_values(&self, tag: usize, len: usize) -> Option> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { PartialValue::Bottom => return None, PartialValue::Value(v) => v @@ -438,7 +425,19 @@ impl PartialValue { } } -impl PartialOrd for PartialValue { +impl> PartialValue { + pub fn try_into_value(self, typ: &Type) -> Result { + let r = match self { + Self::Value(v) => Ok(v.into().clone()), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + }?; + assert_eq!(typ, &r.get_type()); + Ok(r) + } +} + +impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; match (self, other) { diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/partial_value/test.rs index 6621f0a69..33c8f3c8d 100644 --- a/hugr-passes/src/const_fold2/partial_value/test.rs +++ b/hugr-passes/src/const_fold2/partial_value/test.rs @@ -8,7 +8,9 @@ use hugr_core::{ types::{Type, TypeArg, TypeEnum}, }; -use super::{PartialSum, PartialValue, ValueHandle, ValueKey}; +use super::{PartialSum, PartialValue}; +use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; + impl Arbitrary for ValueHandle { type Parameters = (); type Strategy = BoxedStrategy; @@ -48,7 +50,7 @@ impl TestSumLeafType { } } - fn type_check(&self, ps: &PartialSum) -> bool { + fn type_check(&self, ps: &PartialSum) -> bool { match self { Self::Int(_) => false, Self::Unit => { @@ -61,7 +63,7 @@ impl TestSumLeafType { } } - fn partial_value_strategy(self) -> impl Strategy { + fn partial_value_strategy(self) -> impl Strategy> { match self { Self::Int(t) => { let TypeEnum::Extension(ct) = t.as_type_enum() else { @@ -165,7 +167,7 @@ impl TestSumType { } } - fn type_check(&self, pv: &PartialValue) -> bool { + fn type_check(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), @@ -253,7 +255,7 @@ proptest! { } } -fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy { +fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy> { ust.select().prop_flat_map(|x| match x { Either::Left(l) => l.partial_value_strategy().boxed(), Either::Right((index, usts)) => { @@ -273,15 +275,15 @@ fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy::Parameters, -) -> impl Strategy { +) -> impl Strategy> { any_with::(params).prop_flat_map(any_partial_value_of_type) } -fn any_partial_value() -> impl Strategy { +fn any_partial_value() -> impl Strategy> { any_partial_value_with(Default::default()) } -fn any_partial_values() -> impl Strategy { +fn any_partial_values() -> impl Strategy; N]> { any::().prop_flat_map(|ust| { TryInto::<[_; N]>::try_into( (0..N) @@ -292,7 +294,7 @@ fn any_partial_values() -> impl Strategy impl Strategy { +fn any_typed_partial_value() -> impl Strategy)> { any::() .prop_flat_map(|t| any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v))) } diff --git a/hugr-passes/src/const_fold2/partial_value/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs similarity index 95% rename from hugr-passes/src/const_fold2/partial_value/value_handle.rs rename to hugr-passes/src/const_fold2/value_handle.rs index 147048ae7..b5af487a8 100644 --- a/hugr-passes/src/const_fold2/partial_value/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -2,11 +2,12 @@ use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; - use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; +use super::partial_value::{AbstractValue, PartialSum, PartialValue}; + #[derive(Clone, Debug)] pub struct HashedConst { hash: u64, @@ -78,7 +79,13 @@ impl ValueHandle { self.1.as_ref() } - pub fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { + pub fn get_type(&self) -> Type { + self.1.get_type() + } +} + +impl AbstractValue for ValueHandle { + fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { match self.value() { Value::Sum(Sum { tag, values, .. }) => Some(( *tag, @@ -90,10 +97,6 @@ impl ValueHandle { _ => None, } } - - pub fn get_type(&self) -> Type { - self.1.get_type() - } } impl PartialEq for ValueHandle { @@ -119,6 +122,12 @@ impl Hash for ValueHandle { } } +impl From for Value { + fn from(value: ValueHandle) -> Self { + (*value.1).clone() + } +} + #[cfg(test)] mod test { use hugr_core::{ From af8827b42b25cb4a1f5158844cb0ebbb40c4d49c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 12:16:41 +0100 Subject: [PATCH 021/203] Move partial_value.rs inside datalog/ --- hugr-passes/src/const_fold2.rs | 1 - hugr-passes/src/const_fold2/datalog.rs | 11 +++++++++-- .../src/const_fold2/{ => datalog}/partial_value.rs | 0 .../const_fold2/{ => datalog}/partial_value/test.rs | 0 hugr-passes/src/const_fold2/datalog/test.rs | 3 ++- hugr-passes/src/const_fold2/datalog/utils.rs | 8 +------- hugr-passes/src/const_fold2/value_handle.rs | 2 +- 7 files changed, 13 insertions(+), 12 deletions(-) rename hugr-passes/src/const_fold2/{ => datalog}/partial_value.rs (100%) rename hugr-passes/src/const_fold2/{ => datalog}/partial_value/test.rs (100%) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 13af5c709..7d6725fb1 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,3 +1,2 @@ pub mod datalog; -mod partial_value; pub mod value_handle; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index c06b2a285..fbe008d43 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -6,17 +6,24 @@ use hugr_core::ops::Value; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; +mod partial_value; mod utils; -pub use utils::{TailLoopTermination, ValueRow, IO, PV}; +use utils::{TailLoopTermination, ValueRow, PV}; -use super::partial_value::AbstractValue; +pub use partial_value::{AbstractValue, PartialSum, PartialValue}; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; fn value_from_load_constant(&self, node: Node) -> V; } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum IO { + Input, + Output, +} + ascent::ascent! { // The trait-indirection layer here means we can just write 'C' but in practice ATM // DataflowContext (for H: HugrView) would be sufficient, there's really no diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/datalog/partial_value.rs similarity index 100% rename from hugr-passes/src/const_fold2/partial_value.rs rename to hugr-passes/src/const_fold2/datalog/partial_value.rs diff --git a/hugr-passes/src/const_fold2/partial_value/test.rs b/hugr-passes/src/const_fold2/datalog/partial_value/test.rs similarity index 100% rename from hugr-passes/src/const_fold2/partial_value/test.rs rename to hugr-passes/src/const_fold2/datalog/partial_value/test.rs diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index f80cc903e..7e7057ffa 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -7,7 +7,8 @@ use hugr_core::{ types::{Signature, SumType, Type, TypeRow}, }; -use super::super::partial_value::PartialValue; +use super::partial_value::PartialValue; + use super::*; #[test] diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 42138396e..16b942e1d 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -8,7 +8,7 @@ use std::{cmp::Ordering, ops::Index}; use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; -use crate::const_fold2::partial_value::{AbstractValue, PartialValue}; +use crate::const_fold2::datalog::{AbstractValue, PartialValue}; use hugr_core::{ ops::{OpTrait as _, Value}, types::{Signature, Type, TypeRow}, @@ -229,12 +229,6 @@ pub(super) fn partial_value_tuple_from_value_row(r: ValueRow impl Iterator + '_ { h.in_value_types(n).map(|x| x.0) } diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index b5af487a8..b586f4cab 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,7 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use super::partial_value::{AbstractValue, PartialSum, PartialValue}; +use super::datalog::{AbstractValue, PartialSum, PartialValue}; #[derive(Clone, Debug)] pub struct HashedConst { From 4b614365bd1e2bf686eec1e432c66bd4a766ede2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 15:26:50 +0100 Subject: [PATCH 022/203] Hide PartialSum/PartialValue --- hugr-passes/src/const_fold2/datalog.rs | 2 +- hugr-passes/src/const_fold2/datalog/utils.rs | 2 +- hugr-passes/src/const_fold2/value_handle.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index fbe008d43..4ffbca32a 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -11,7 +11,7 @@ mod utils; use utils::{TailLoopTermination, ValueRow, PV}; -pub use partial_value::{AbstractValue, PartialSum, PartialValue}; +pub use partial_value::AbstractValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 16b942e1d..881963666 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -8,7 +8,7 @@ use std::{cmp::Ordering, ops::Index}; use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; -use crate::const_fold2::datalog::{AbstractValue, PartialValue}; +use super::{partial_value::PartialValue, AbstractValue}; use hugr_core::{ ops::{OpTrait as _, Value}, types::{Signature, Type, TypeRow}, diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index b586f4cab..2bc16994a 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,7 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use super::datalog::{AbstractValue, PartialSum, PartialValue}; +use super::datalog::AbstractValue; #[derive(Clone, Debug)] pub struct HashedConst { From 8b31d8c8abb847533cf989d8dc82927a18fd7d78 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 16:08:52 +0100 Subject: [PATCH 023/203] refactor: ValueRow::single_among_bottoms --- hugr-passes/src/const_fold2/datalog/utils.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 881963666..91fb1723c 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -105,17 +105,13 @@ impl ValueRow { Self(vec![PV::bottom(); len]) } - fn singleton(len: usize, idx: usize, v: PV) -> Self { + fn single_among_bottoms(len: usize, idx: usize, v: PV) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - fn singleton_from_row(r: &TypeRow, idx: usize, v: PV) -> Self { - Self::singleton(r.len(), idx, v) - } - fn bottom_from_row(r: &TypeRow) -> Self { Self::new(r.len()) } @@ -222,7 +218,7 @@ pub(super) fn singleton_in_row( h.get_optype(*n).description() ); } - ValueRow::singleton_from_row(&h.signature(*n).unwrap().input, ip.index(), v) + ValueRow::single_among_bottoms(h.signature(*n).unwrap().input.len(), ip.index(), v) } pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { From 6a729613706ee8394027631409f59e57b6cee841 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 16:11:41 +0100 Subject: [PATCH 024/203] Factor out propagate_leaf_op; add ValueRow::from_iter --- hugr-passes/src/const_fold2/datalog.rs | 58 +++++++++----------- hugr-passes/src/const_fold2/datalog/utils.rs | 6 ++ 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 4ffbca32a..a636293d6 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -2,7 +2,7 @@ use ascent::lattice::BoundedLattice; use std::collections::HashMap; use std::hash::Hash; -use hugr_core::ops::Value; +use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod context; @@ -68,37 +68,12 @@ ascent::ascent! { node_in_value_row(c, n, utils::bottom_row(c.hugr(), *n)) <-- node(c, n); node_in_value_row(c, n, utils::singleton_in_row(c.hugr(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); - - // Per node-type rules - // TODO do all leaf ops with a rule - // define `fn propagate_leaf_op(Context, Node, ValueRow) -> ValueRow - - // LoadConstant - relation load_constant_node(C, Node); - load_constant_node(c, n) <-- node(c, n), if c.get_optype(*n).is_load_constant(); - - out_wire_value(c, n, 0.into(), PV::from(c.value_from_load_constant(*n))) <-- - load_constant_node(c, n); - - - // MakeTuple - relation make_tuple_node(C, Node); - make_tuple_node(c, n) <-- node(c, n), if c.get_optype(*n).is_make_tuple(); - - out_wire_value(c, n, 0.into(), utils::partial_value_tuple_from_value_row(vs.clone())) <-- - make_tuple_node(c, n), node_in_value_row(c, n, vs); - - - // UnpackTuple - relation unpack_tuple_node(C, Node); - unpack_tuple_node(c,n) <-- node(c, n), if c.get_optype(*n).is_unpack_tuple(); - out_wire_value(c, n, p, v) <-- - unpack_tuple_node(c, n), - in_wire_value(c, n, IncomingPort::from(0), tup), - if let Some(fields) = tup.variant_values(0, utils::value_outputs(c.hugr(),*n).count()), - for (p,v) in (0..).map(OutgoingPort::from).zip(fields); - + node(c, n), + if !c.get_optype(*n).is_container(), + node_in_value_row(c, n, vs), + if let Some(outs) = propagate_leaf_op(c, *n, vs.clone()), + for (p,v) in (0..).map(OutgoingPort::from).zip(outs); // DFG relation dfg_node(C, Node); @@ -177,6 +152,27 @@ ascent::ascent! { } +fn propagate_leaf_op( + c: &impl DFContext, + n: Node, + ins: ValueRow, +) -> Option> { + match c.get_optype(n) { + OpType::LoadConstant(_) => Some(ValueRow::from_iter([PV::from( + c.value_from_load_constant(n), + )])), // ins empty + OpType::MakeTuple(_) => Some(ValueRow::from_iter([ + utils::partial_value_tuple_from_value_row(ins), + ])), + OpType::UnpackTuple(_) => { + let [tup] = ins.into_iter().collect::>().try_into().unwrap(); + tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) + .map(ValueRow::from_iter) + } + _ => None, + } +} + // TODO This should probably be called 'Analyser' or something pub struct Machine>( AscentProgram, diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 91fb1723c..2719f19ce 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -135,6 +135,12 @@ impl ValueRow { // } } +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) From 780af9b5b838141707d1eb457b4b870b5b08944a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 16:19:44 +0100 Subject: [PATCH 025/203] Add handling for Tag --- hugr-passes/src/const_fold2/datalog.rs | 5 ++--- hugr-passes/src/const_fold2/datalog/utils.rs | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index a636293d6..2802e37a5 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -161,14 +161,13 @@ fn propagate_leaf_op( OpType::LoadConstant(_) => Some(ValueRow::from_iter([PV::from( c.value_from_load_constant(n), )])), // ins empty - OpType::MakeTuple(_) => Some(ValueRow::from_iter([ - utils::partial_value_tuple_from_value_row(ins), - ])), + OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant(0, ins)])), OpType::UnpackTuple(_) => { let [tup] = ins.into_iter().collect::>().try_into().unwrap(); tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) .map(ValueRow::from_iter) } + OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant(t.tag, ins)])), _ => None, } } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 2719f19ce..132172db0 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -36,6 +36,10 @@ impl From> for PV { } impl PV { + pub fn variant(tag: usize, r: impl IntoIterator>) -> Self { + PartialValue::variant(tag, r.into_iter().map(|x| x.0)).into() + } + pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { Some( self.0 @@ -227,10 +231,6 @@ pub(super) fn singleton_in_row( ValueRow::single_among_bottoms(h.signature(*n).unwrap().input.len(), ip.index(), v) } -pub(super) fn partial_value_tuple_from_value_row(r: ValueRow) -> PV { - PartialValue::variant(0, r.into_iter().map(|x| x.0)).into() -} - pub(super) fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { h.in_value_types(n).map(|x| x.0) } From cabcf04ac018574fe06ea8be9b109234893589bf Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 17:25:55 +0100 Subject: [PATCH 026/203] Remove PV (use typedef in datalog.rs) --- hugr-passes/src/const_fold2/datalog.rs | 13 ++- hugr-passes/src/const_fold2/datalog/utils.rs | 101 +++++-------------- 2 files changed, 29 insertions(+), 85 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 2802e37a5..9de140cb1 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -9,9 +9,10 @@ mod context; mod partial_value; mod utils; -use utils::{TailLoopTermination, ValueRow, PV}; +use utils::{TailLoopTermination, ValueRow}; pub use partial_value::AbstractValue; +type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; @@ -190,11 +191,9 @@ impl> Machine { pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator)>) { assert!(self.1.is_none()); - self.0.out_wire_value_proto.extend( - wires - .into_iter() - .map(|(w, v)| (w.node(), w.source(), v.into())), - ); + self.0 + .out_wire_value_proto + .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); } pub fn run(&mut self, context: C) { @@ -205,7 +204,7 @@ impl> Machine { self.0 .out_wire_value .iter() - .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone().into())) + .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(), ) } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 132172db0..8fbb40c02 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -10,106 +10,51 @@ use itertools::zip_eq; use super::{partial_value::PartialValue, AbstractValue}; use hugr_core::{ - ops::{OpTrait as _, Value}, - types::{Signature, Type, TypeRow}, + ops::OpTrait as _, + types::{Signature, TypeRow}, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, }; #[cfg(test)] use proptest_derive::Arbitrary; -#[derive(PartialEq, Eq, Hash, Clone, Debug)] -pub struct PV(PartialValue); - -// Implement manually as PartialValue is PartialOrd even when V isn't -// (deriving PartialOrd conditions on V: PartialOrd, which is not necessary) -impl PartialOrd for PV { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl From> for PV { - fn from(inner: PartialValue) -> Self { - Self(inner) - } -} - -impl PV { - pub fn variant(tag: usize, r: impl IntoIterator>) -> Self { - PartialValue::variant(tag, r.into_iter().map(|x| x.0)).into() - } - - pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { - Some( - self.0 - .variant_values(variant, len)? - .into_iter() - .map(PV::from) - .collect(), - ) - } - - pub fn supports_tag(&self, tag: usize) -> bool { - self.0.supports_tag(tag) - } -} - -impl> PV { - pub fn try_into_value(self, ty: &Type) -> Result { - self.0.try_into_value(ty).map_err(Self) - } -} - -impl From> for PartialValue { - fn from(value: PV) -> Self { - value.0 - } -} - -impl From for PV { - fn from(inner: V) -> Self { - Self(inner.into()) - } -} - -impl Lattice for PV { +impl Lattice for PartialValue { fn meet(self, other: Self) -> Self { - self.0.meet(other.0).into() + self.meet(other) } fn meet_mut(&mut self, other: Self) -> bool { - self.0.meet_mut(other.0) + self.meet_mut(other) } fn join(self, other: Self) -> Self { - self.0.join(other.0).into() + self.join(other) } fn join_mut(&mut self, other: Self) -> bool { - self.0.join_mut(other.0) + self.join_mut(other) } } -impl BoundedLattice for PV { +impl BoundedLattice for PartialValue { fn bottom() -> Self { - PartialValue::bottom().into() + Self::bottom() } fn top() -> Self { - PartialValue::top().into() + Self::top() } } #[derive(PartialEq, Clone, Eq, Hash)] -pub struct ValueRow(Vec>); +pub struct ValueRow(Vec>); impl ValueRow { fn new(len: usize) -> Self { - Self(vec![PV::bottom(); len]) + Self(vec![PartialValue::bottom(); len]) } - fn single_among_bottoms(len: usize, idx: usize, v: PV) -> Self { + fn single_among_bottoms(len: usize, idx: usize, v: PartialValue) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; @@ -120,7 +65,7 @@ impl ValueRow { Self::new(r.len()) } - pub fn iter(&self) -> impl Iterator> { + pub fn iter(&self) -> impl Iterator> { self.0.iter() } @@ -128,7 +73,7 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option> + '_> { + ) -> Option> + '_> { self[0] .variant_values(variant, len) .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) @@ -139,8 +84,8 @@ impl ValueRow { // } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } @@ -182,9 +127,9 @@ impl Lattice for ValueRow { } impl IntoIterator for ValueRow { - type Item = PV; + type Item = PartialValue; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() @@ -193,9 +138,9 @@ impl IntoIterator for ValueRow { impl Index for ValueRow where - Vec>: Index, + Vec>: Index, { - type Output = > as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) @@ -215,7 +160,7 @@ pub(super) fn singleton_in_row( h: &impl HugrView, n: &Node, ip: &IncomingPort, - v: PV, + v: PartialValue, ) -> ValueRow { let Some(sig) = h.signature(*n) else { panic!("dougrulz"); @@ -248,7 +193,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - pub fn from_control_value(v: &PV) -> Self { + pub fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break && !may_continue { Self::ExactlyZeroContinues From 5e4a04fe1cbbd8487d2a2dd7d53babb3fd2887ae Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 18:27:32 +0100 Subject: [PATCH 027/203] Allow DFContext to interpret any leaf op (except MakeTuple/etc.); pub PV+PS --- hugr-passes/src/const_fold2/datalog.rs | 30 ++++++++++----- .../src/const_fold2/datalog/context.rs | 38 +++++++++++++------ 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 9de140cb1..e8a4ce621 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -11,12 +11,16 @@ mod utils; use utils::{TailLoopTermination, ValueRow}; -pub use partial_value::AbstractValue; +pub use partial_value::{AbstractValue, PartialSum, PartialValue}; type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn hugr(&self) -> &impl HugrView; - fn value_from_load_constant(&self, node: Node) -> V; + fn interpret_leaf_op( + &self, + node: Node, + ins: &[PartialValue], + ) -> Option>>; } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -73,7 +77,7 @@ ascent::ascent! { node(c, n), if !c.get_optype(*n).is_container(), node_in_value_row(c, n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, vs.clone()), + if let Some(outs) = propagate_leaf_op(c, *n, &vs[..]), for (p,v) in (0..).map(OutgoingPort::from).zip(outs); // DFG @@ -156,20 +160,26 @@ ascent::ascent! { fn propagate_leaf_op( c: &impl DFContext, n: Node, - ins: ValueRow, + ins: &[PV], ) -> Option> { match c.get_optype(n) { - OpType::LoadConstant(_) => Some(ValueRow::from_iter([PV::from( - c.value_from_load_constant(n), - )])), // ins empty - OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant(0, ins)])), + // Handle basics here. I guess we could allow DFContext to specify but at the least + // we'd want these ones to be easily available for reuse. + OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant( + 0, + ins.into_iter().cloned(), + )])), OpType::UnpackTuple(_) => { let [tup] = ins.into_iter().collect::>().try_into().unwrap(); tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) .map(ValueRow::from_iter) } - OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant(t.tag, ins)])), - _ => None, + OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( + t.tag, + ins.into_iter().cloned(), + )])), + OpType::Input(_) | OpType::Output(_) => None, // handled by parent + _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), } } diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 31a3233fd..f4d6b7ab8 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -2,12 +2,13 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::{Hugr, HugrView}; +use hugr_core::ops::OpType; +use hugr_core::{Hugr, HugrView, Node}; +// ALAN Note this probably belongs with ValueHandle, outside datalog +use super::{DFContext, PartialValue}; use crate::const_fold2::value_handle::ValueHandle; -use super::DFContext; - #[derive(Debug)] pub(super) struct DataflowContext(Arc); @@ -60,14 +61,27 @@ impl DFContext for DataflowContext { self.0.as_ref() } - fn value_from_load_constant(&self, node: hugr_core::Node) -> ValueHandle { - let load_op = self.0.get_optype(node).as_load_constant().unwrap(); - let const_node = self - .0 - .single_linked_output(node, load_op.constant_port()) - .unwrap() - .0; - let const_op = self.0.get_optype(const_node).as_const().unwrap(); - ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())) + fn interpret_leaf_op( + &self, + n: Node, + ins: &[PartialValue], + ) -> Option>> { + match self.0.get_optype(n) { + OpType::LoadConstant(load_op) => { + // ins empty as static edge, we need to find the constant ourselves + let const_node = self + .0 + .single_linked_output(n, load_op.constant_port()) + .unwrap() + .0; + let const_op = self.0.get_optype(const_node).as_const().unwrap(); + Some(vec![ValueHandle::new( + const_node.into(), + Arc::new(const_op.value().clone()), + ) + .into()]) + } + _ => None, + } } } From 145388653efe29df2ac959b1c2813ac39b98146f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 28 Aug 2024 18:31:18 +0100 Subject: [PATCH 028/203] Also fold extension ops --- .../src/const_fold2/datalog/context.rs | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index f4d6b7ab8..7985e5410 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -2,12 +2,11 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::ops::OpType; -use hugr_core::{Hugr, HugrView, Node}; +use hugr_core::ops::{CustomOp, DataflowOpTrait, OpType}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, PortIndex}; -// ALAN Note this probably belongs with ValueHandle, outside datalog +use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; use super::{DFContext, PartialValue}; -use crate::const_fold2::value_handle::ValueHandle; #[derive(Debug)] pub(super) struct DataflowContext(Arc); @@ -81,6 +80,27 @@ impl DFContext for DataflowContext { ) .into()]) } + OpType::CustomOp(CustomOp::Extension(op)) => { + let sig = op.signature(); + let known_ins = sig + .input_types() + .into_iter() + .enumerate() + .zip(ins.iter()) + .filter_map(|((i, ty), pv)| { + pv.clone() + .try_into_value(ty) + .map(|v| (IncomingPort::from(i), v)) + .ok() + }) + .collect::>(); + let outs = op.constant_fold(&known_ins)?; + let mut res = vec![PartialValue::bottom(); sig.output_count()]; + for (op, v) in outs { + res[op.index()] = ValueHandle::new(ValueKey::Node(n), Arc::new(v)).into() + } + Some(res) + } _ => None, } } From 221e96cc0f6fe571563b19b2ffb4fa0643045cc1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 14:13:08 +0100 Subject: [PATCH 029/203] Comment as_sum --- hugr-passes/src/const_fold2/datalog/partial_value.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hugr-passes/src/const_fold2/datalog/partial_value.rs b/hugr-passes/src/const_fold2/datalog/partial_value.rs index 6a2a614a8..8441027d4 100644 --- a/hugr-passes/src/const_fold2/datalog/partial_value.rs +++ b/hugr-passes/src/const_fold2/datalog/partial_value.rs @@ -7,7 +7,11 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; +/// Aka, deconstructible into Sum (TryIntoSum ?) pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { + /// We write this way to optimize query/inspection (is-it-a-sum), + /// at the cost of requiring more cloning during actual conversion + /// (inside the lazy Iterator, or for the error case, as Self remains) fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } From cd4e15c47467ab9a8894e705e140f381ae94478f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 10:44:19 +0100 Subject: [PATCH 030/203] Rename DataflowContext to HugrValueContext --- hugr-passes/src/const_fold2/datalog.rs | 4 --- .../src/const_fold2/datalog/context.rs | 25 +++++++++++-------- hugr-passes/src/const_fold2/datalog/test.rs | 16 ++++++------ 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index e8a4ce621..38a6dadbc 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -30,10 +30,6 @@ pub enum IO { } ascent::ascent! { - // The trait-indirection layer here means we can just write 'C' but in practice ATM - // DataflowContext (for H: HugrView) would be sufficient, there's really no - // point in using anything else yet. However DFContext will be useful when we - // move interpretation of nodes out into a trait method. struct AscentProgram>; relation context(C); relation out_wire_value_proto(Node, OutgoingPort, PV); diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/datalog/context.rs index 7985e5410..a17b69ab2 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/datalog/context.rs @@ -8,10 +8,13 @@ use hugr_core::{Hugr, HugrView, IncomingPort, Node, PortIndex}; use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; use super::{DFContext, PartialValue}; +/// An implementation of [DFContext] with [ValueHandle] +/// that just stores a Hugr (actually any [HugrView]), +/// (there is )no state for operation-interpretation). #[derive(Debug)] -pub(super) struct DataflowContext(Arc); +pub struct HugrValueContext(Arc); -impl DataflowContext { +impl HugrValueContext { pub fn new(hugr: H) -> Self { Self(Arc::new(hugr)) } @@ -19,35 +22,35 @@ impl DataflowContext { // Deriving Clone requires H:HugrView to implement Clone, // but we don't need that as we only clone the Arc. -impl Clone for DataflowContext { +impl Clone for HugrValueContext { fn clone(&self) -> Self { Self(self.0.clone()) } } -impl Hash for DataflowContext { +impl Hash for HugrValueContext { fn hash(&self, _state: &mut I) {} } -impl PartialEq for DataflowContext { +impl PartialEq for HugrValueContext { fn eq(&self, other: &Self) -> bool { - // Any AscentProgram should have only one DataflowContext (maybe cloned) + // Any AscentProgram should have only one DFContext (maybe cloned) assert!(Arc::ptr_eq(&self.0, &other.0)); true } } -impl Eq for DataflowContext {} +impl Eq for HugrValueContext {} -impl PartialOrd for DataflowContext { +impl PartialOrd for HugrValueContext { fn partial_cmp(&self, other: &Self) -> Option { - // Any AscentProgram should have only one DataflowContext (maybe cloned) + // Any AscentProgram should have only one DFContext (maybe cloned) assert!(Arc::ptr_eq(&self.0, &other.0)); Some(std::cmp::Ordering::Equal) } } -impl Deref for DataflowContext { +impl Deref for HugrValueContext { type Target = Hugr; fn deref(&self) -> &Self::Target { @@ -55,7 +58,7 @@ impl Deref for DataflowContext { } } -impl DFContext for DataflowContext { +impl DFContext for HugrValueContext { fn hugr(&self) -> &impl HugrView { self.0.as_ref() } diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 7e7057ffa..f24ba67c5 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -1,4 +1,4 @@ -use context::DataflowContext; +use context::HugrValueContext; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, @@ -20,7 +20,7 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let x = machine.read_out_wire_value(&hugr, v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -39,7 +39,7 @@ fn test_unpack_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -58,7 +58,7 @@ fn test_unpack_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); assert_eq!(o_r, Value::true_val()); @@ -84,7 +84,7 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -116,7 +116,7 @@ fn test_tail_loop_always_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom().into()); @@ -166,7 +166,7 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::new(); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -222,7 +222,7 @@ fn conditional() { let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); machine.propolutate_out_wires([(arg_w, arg_pv.into())]); - machine.run(DataflowContext::new(&hugr)); + machine.run(HugrValueContext::new(&hugr)); let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); From c5ab2a94def4cf80709ac85b01b838c8bfa46ae8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 10:52:10 +0100 Subject: [PATCH 031/203] Move {datalog=>value_handle}/context.rs - an impl, datalog uses only DFContext --- hugr-passes/src/const_fold2/datalog.rs | 1 - hugr-passes/src/const_fold2/datalog/test.rs | 3 ++- hugr-passes/src/const_fold2/value_handle.rs | 3 +++ .../src/const_fold2/{datalog => value_handle}/context.rs | 4 ++-- 4 files changed, 7 insertions(+), 4 deletions(-) rename hugr-passes/src/const_fold2/{datalog => value_handle}/context.rs (97%) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 38a6dadbc..69156a9d9 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -5,7 +5,6 @@ use std::hash::Hash; use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; -mod context; mod partial_value; mod utils; diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index f24ba67c5..1e3cdcc98 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -1,4 +1,5 @@ -use context::HugrValueContext; +use crate::const_fold2::value_handle::HugrValueContext; + use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 2bc16994a..daf8a98fd 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -8,6 +8,9 @@ use hugr_core::Node; use super::datalog::AbstractValue; +mod context; +pub use context::HugrValueContext; + #[derive(Clone, Debug)] pub struct HashedConst { hash: u64, diff --git a/hugr-passes/src/const_fold2/datalog/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs similarity index 97% rename from hugr-passes/src/const_fold2/datalog/context.rs rename to hugr-passes/src/const_fold2/value_handle/context.rs index a17b69ab2..06ced3238 100644 --- a/hugr-passes/src/const_fold2/datalog/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -5,8 +5,8 @@ use std::sync::Arc; use hugr_core::ops::{CustomOp, DataflowOpTrait, OpType}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, PortIndex}; -use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; -use super::{DFContext, PartialValue}; +use super::{ValueHandle, ValueKey}; +use crate::const_fold2::datalog::{DFContext, PartialValue}; /// An implementation of [DFContext] with [ValueHandle] /// that just stores a Hugr (actually any [HugrView]), From 6c80acf500491f16722767e61d80c4803e2b8bbe Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 10:57:31 +0100 Subject: [PATCH 032/203] Comment re. propagate_leaf_op --- hugr-passes/src/const_fold2/datalog.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 69156a9d9..c924f36e0 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -158,8 +158,9 @@ fn propagate_leaf_op( ins: &[PV], ) -> Option> { match c.get_optype(n) { - // Handle basics here. I guess we could allow DFContext to specify but at the least - // we'd want these ones to be easily available for reuse. + // Handle basics here. I guess (given the current interface) we could allow + // DFContext to handle these but at the least we'd want these impls to be + // easily available for reuse. OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant( 0, ins.into_iter().cloned(), @@ -174,6 +175,9 @@ fn propagate_leaf_op( ins.into_iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) => None, // handled by parent + // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, + // thus keeping PartialValue hidden, but AbstractValues + // are not necessarily convertible to Value! _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), } } From 636f14dbece90f238f2cacfe923b6d9f1cab47c5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 14:28:51 +0100 Subject: [PATCH 033/203] Hide PartialValue by abstracting DFContext::InterpretableVal: FromSum (==Value) --- hugr-passes/src/const_fold2/datalog.rs | 39 ++++++++++++++---- .../src/const_fold2/datalog/partial_value.rs | 24 +++++------ hugr-passes/src/const_fold2/datalog/utils.rs | 12 +++++- hugr-passes/src/const_fold2/value_handle.rs | 8 +++- .../src/const_fold2/value_handle/context.rs | 41 ++++++------------- 5 files changed, 71 insertions(+), 53 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index c924f36e0..c6aad5840 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -2,24 +2,25 @@ use ascent::lattice::BoundedLattice; use std::collections::HashMap; use std::hash::Hash; -use hugr_core::ops::{OpType, Value}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; +use hugr_core::ops::{OpTrait, OpType}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; mod partial_value; mod utils; use utils::{TailLoopTermination, ValueRow}; -pub use partial_value::{AbstractValue, PartialSum, PartialValue}; +pub use partial_value::{AbstractValue, FromSum}; type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + type InterpretableVal: FromSum + From; fn hugr(&self) -> &impl HugrView; fn interpret_leaf_op( &self, node: Node, - ins: &[PartialValue], - ) -> Option>>; + ins: &[(IncomingPort, Self::InterpretableVal)], + ) -> Vec<(OutgoingPort, V)>; } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -178,7 +179,29 @@ fn propagate_leaf_op( // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues // are not necessarily convertible to Value! - _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), + op => { + let sig = op.dataflow_signature()?; + let known_ins = sig + .input_types() + .into_iter() + .enumerate() + .zip(ins.iter()) + .filter_map(|((i, ty), pv)| { + pv.clone() + .try_into_value(ty) + .ok() + .map(|v| (IncomingPort::from(i), v)) + }) + .collect::>(); + let known_outs = c.interpret_leaf_op(n, &known_ins); + (!known_outs.is_empty()).then(|| { + let mut res = ValueRow::new(sig.output_count()); + for (p, v) in known_outs { + res[p.index()] = v.into(); + } + res + }) + } } } @@ -241,10 +264,8 @@ impl> Machine { .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) .unwrap() } -} -impl, C: DFContext> Machine { - pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { + pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { // dbg!(&w); let pv = self.read_out_wire_partial_value(w)?; // dbg!(&pv); diff --git a/hugr-passes/src/const_fold2/datalog/partial_value.rs b/hugr-passes/src/const_fold2/datalog/partial_value.rs index 8441027d4..b8f9067d4 100644 --- a/hugr-passes/src/const_fold2/datalog/partial_value.rs +++ b/hugr-passes/src/const_fold2/datalog/partial_value.rs @@ -1,7 +1,6 @@ #![allow(missing_docs)] -use hugr_core::ops::Value; -use hugr_core::types::{Type, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeEnum, TypeRow}; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -15,6 +14,11 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } +pub trait FromSum { + fn new_sum(tag: usize, items: impl IntoIterator, st: &SumType) -> Self; + fn debug_check_is_type(&self, _ty: &Type) {} +} + // TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum #[derive(PartialEq, Clone, Eq)] pub struct PartialSum(pub HashMap>>); @@ -94,10 +98,8 @@ impl PartialSum { pub fn supports_tag(&self, tag: usize) -> bool { self.0.contains_key(&tag) } -} -impl> PartialSum { - pub fn try_into_value(self, typ: &Type) -> Result { + pub fn try_into_value>(self, typ: &Type) -> Result { let Ok((k, v)) = self.0.iter().exactly_one() else { Err(self)? }; @@ -118,7 +120,7 @@ impl> PartialSum { .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { - Ok(vs) => Value::sum(*k, vs, st.clone()).map_err(|_| self), + Ok(vs) => Ok(V2::new_sum(*k, vs, &st)), Err(_) => Err(self), } } @@ -427,16 +429,14 @@ impl PartialValue { PartialValue::Top => true, } } -} -impl> PartialValue { - pub fn try_into_value(self, typ: &Type) -> Result { - let r = match self { - Self::Value(v) => Ok(v.into().clone()), + pub fn try_into_value>(self, typ: &Type) -> Result { + let r: V2 = match self { + Self::Value(v) => Ok(v.clone().into()), Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), x => Err(x), }?; - assert_eq!(typ, &r.get_type()); + r.debug_check_is_type(typ); Ok(r) } } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index 8fbb40c02..c0594fdb7 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -3,7 +3,7 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -use std::{cmp::Ordering, ops::Index}; +use std::{cmp::Ordering, ops::{Index, IndexMut}}; use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; @@ -50,7 +50,7 @@ impl BoundedLattice for PartialValue { pub struct ValueRow(Vec>); impl ValueRow { - fn new(len: usize) -> Self { + pub fn new(len: usize) -> Self { Self(vec![PartialValue::bottom(); len]) } @@ -147,6 +147,14 @@ where } } +impl IndexMut for ValueRow +where + Vec>: IndexMut { + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} + pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { ValueRow::new( h.signature(n) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index daf8a98fd..2aa1bc2fe 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,7 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use super::datalog::AbstractValue; +use super::datalog::{AbstractValue, FromSum}; mod context; pub use context::HugrValueContext; @@ -102,6 +102,12 @@ impl AbstractValue for ValueHandle { } } +impl FromSum for Value { + fn new_sum(tag: usize, items: impl IntoIterator, st: &hugr_core::types::SumType) -> Self { + Value::Sum(Sum {tag, values: items.into_iter().collect(), sum_type: st.clone()}) + } +} + impl PartialEq for ValueHandle { fn eq(&self, other: &Self) -> bool { // If the keys are equal, we return true since the values must have the diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs index 06ced3238..14571bb1b 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -2,11 +2,11 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::ops::{CustomOp, DataflowOpTrait, OpType}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, PortIndex}; +use hugr_core::ops::{CustomOp, OpType, Value}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::{ValueHandle, ValueKey}; -use crate::const_fold2::datalog::{DFContext, PartialValue}; +use crate::const_fold2::datalog::DFContext; /// An implementation of [DFContext] with [ValueHandle] /// that just stores a Hugr (actually any [HugrView]), @@ -59,6 +59,7 @@ impl Deref for HugrValueContext { } impl DFContext for HugrValueContext { + type InterpretableVal = Value; fn hugr(&self) -> &impl HugrView { self.0.as_ref() } @@ -66,45 +67,27 @@ impl DFContext for HugrValueContext { fn interpret_leaf_op( &self, n: Node, - ins: &[PartialValue], - ) -> Option>> { + ins: &[(IncomingPort, Value)], + ) -> Vec<(OutgoingPort,ValueHandle)> { match self.0.get_optype(n) { OpType::LoadConstant(load_op) => { - // ins empty as static edge, we need to find the constant ourselves + assert!(ins.is_empty()); // static edge, so need to find constant let const_node = self .0 .single_linked_output(n, load_op.constant_port()) .unwrap() .0; let const_op = self.0.get_optype(const_node).as_const().unwrap(); - Some(vec![ValueHandle::new( + vec![(OutgoingPort::from(0), ValueHandle::new( const_node.into(), Arc::new(const_op.value().clone()), - ) - .into()]) + ))] } OpType::CustomOp(CustomOp::Extension(op)) => { - let sig = op.signature(); - let known_ins = sig - .input_types() - .into_iter() - .enumerate() - .zip(ins.iter()) - .filter_map(|((i, ty), pv)| { - pv.clone() - .try_into_value(ty) - .map(|v| (IncomingPort::from(i), v)) - .ok() - }) - .collect::>(); - let outs = op.constant_fold(&known_ins)?; - let mut res = vec![PartialValue::bottom(); sig.output_count()]; - for (op, v) in outs { - res[op.index()] = ValueHandle::new(ValueKey::Node(n), Arc::new(v)).into() - } - Some(res) + let ins = ins.into_iter().map(|(p,v)|(*p,v.clone())).collect::>(); + op.constant_fold(&ins).map_or(Vec::new(), |outs|outs.into_iter().map(|(p,v)|(p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))).collect()) } - _ => None, + _ => vec![], } } } From 0acdcc5767372a290cd58ce6febd95ef152dfb37 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 30 Aug 2024 14:41:15 +0100 Subject: [PATCH 034/203] Use Value::sum, adding FromSum::Err; fmt --- .../src/const_fold2/datalog/partial_value.rs | 11 +++++++--- hugr-passes/src/const_fold2/datalog/utils.rs | 10 ++++++--- hugr-passes/src/const_fold2/value_handle.rs | 11 +++++++--- .../src/const_fold2/value_handle/context.rs | 21 ++++++++++++------- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog/partial_value.rs b/hugr-passes/src/const_fold2/datalog/partial_value.rs index b8f9067d4..73e962287 100644 --- a/hugr-passes/src/const_fold2/datalog/partial_value.rs +++ b/hugr-passes/src/const_fold2/datalog/partial_value.rs @@ -14,8 +14,13 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } -pub trait FromSum { - fn new_sum(tag: usize, items: impl IntoIterator, st: &SumType) -> Self; +pub trait FromSum: Sized { + type Err: std::error::Error; + fn try_new_sum( + tag: usize, + items: impl IntoIterator, + st: &SumType, + ) -> Result; fn debug_check_is_type(&self, _ty: &Type) {} } @@ -120,7 +125,7 @@ impl PartialSum { .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { - Ok(vs) => Ok(V2::new_sum(*k, vs, &st)), + Ok(vs) => V2::try_new_sum(*k, vs, &st).map_err(|_| self), Err(_) => Err(self), } } diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index c0594fdb7..a5bc8bfa2 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -3,7 +3,10 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -use std::{cmp::Ordering, ops::{Index, IndexMut}}; +use std::{ + cmp::Ordering, + ops::{Index, IndexMut}, +}; use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; @@ -149,8 +152,9 @@ where impl IndexMut for ValueRow where - Vec>: IndexMut { - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) } } diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 2aa1bc2fe..e2bfc1930 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::Value; -use hugr_core::types::Type; +use hugr_core::types::{ConstTypeError, Type}; use hugr_core::Node; use super::datalog::{AbstractValue, FromSum}; @@ -103,8 +103,13 @@ impl AbstractValue for ValueHandle { } impl FromSum for Value { - fn new_sum(tag: usize, items: impl IntoIterator, st: &hugr_core::types::SumType) -> Self { - Value::Sum(Sum {tag, values: items.into_iter().collect(), sum_type: st.clone()}) + type Err = ConstTypeError; + fn try_new_sum( + tag: usize, + items: impl IntoIterator, + st: &hugr_core::types::SumType, + ) -> Result { + Self::sum(tag, items, st.clone()) } } diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs index 14571bb1b..7ac0b6ba5 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -68,7 +68,7 @@ impl DFContext for HugrValueContext { &self, n: Node, ins: &[(IncomingPort, Value)], - ) -> Vec<(OutgoingPort,ValueHandle)> { + ) -> Vec<(OutgoingPort, ValueHandle)> { match self.0.get_optype(n) { OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant @@ -78,14 +78,21 @@ impl DFContext for HugrValueContext { .unwrap() .0; let const_op = self.0.get_optype(const_node).as_const().unwrap(); - vec![(OutgoingPort::from(0), ValueHandle::new( - const_node.into(), - Arc::new(const_op.value().clone()), - ))] + vec![( + OutgoingPort::from(0), + ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())), + )] } OpType::CustomOp(CustomOp::Extension(op)) => { - let ins = ins.into_iter().map(|(p,v)|(*p,v.clone())).collect::>(); - op.constant_fold(&ins).map_or(Vec::new(), |outs|outs.into_iter().map(|(p,v)|(p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))).collect()) + let ins = ins + .into_iter() + .map(|(p, v)| (*p, v.clone())) + .collect::>(); + op.constant_fold(&ins).map_or(Vec::new(), |outs| { + outs.into_iter() + .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))) + .collect() + }) } _ => vec![], } From 1012e0a618608d29b0ad47bbdb187c53945d96e5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 13:50:59 +0100 Subject: [PATCH 035/203] Cargo.toml: use explicit git= tag for ascent --- hugr-passes/Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index a6ed580c3..4234f7f95 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -15,7 +15,8 @@ categories = ["compilers"] [dependencies] hugr-core = { path = "../hugr-core", version = "0.9.1" } portgraph = { workspace = true } -ascent = "0.6.0" +# This ascent commit has a fix for unsoundness in release/tag 0.6.0: +ascent = {git = "https://github.com/s-arash/ascent", rev="9805d02cb830b6e66abcd4d48836a14cd98366f3"} downcast-rs = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } From 6bd8dbaf39f385eb73707b05d046e377dc91cae0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 14:00:33 +0100 Subject: [PATCH 036/203] Fix rebase: TryHash, UnpackTuple/MakeTuple now in prelude --- hugr-passes/src/const_fold2/datalog.rs | 5 +++-- hugr-passes/src/const_fold2/datalog/test.rs | 5 +++-- hugr-passes/src/const_fold2/value_handle.rs | 2 +- hugr-passes/src/const_fold2/value_handle/context.rs | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index c6aad5840..a3256d41e 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -1,4 +1,5 @@ use ascent::lattice::BoundedLattice; +use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; use std::hash::Hash; @@ -162,11 +163,11 @@ fn propagate_leaf_op( // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be // easily available for reuse. - OpType::MakeTuple(_) => Some(ValueRow::from_iter([PV::variant( + op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::variant( 0, ins.into_iter().cloned(), )])), - OpType::UnpackTuple(_) => { + op if op.cast::().is_some() => { let [tup] = ins.into_iter().collect::>().try_into().unwrap(); tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) .map(ValueRow::from_iter) diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 1e3cdcc98..0e5d0c48e 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -2,8 +2,9 @@ use crate::const_fold2::value_handle::HugrValueContext; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, - extension::{prelude::BOOL_T, ExtensionSet, EMPTY_REG}, - ops::{handle::NodeHandle, OpTrait, UnpackTuple, Value}, + extension::prelude::{UnpackTuple, BOOL_T}, + extension::{ExtensionSet, EMPTY_REG}, + ops::{handle::NodeHandle, OpTrait, Value}, type_row, types::{Signature, SumType, Type, TypeRow}, }; diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index e2bfc1930..e3dab5af0 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -57,7 +57,7 @@ impl ValueKey { pub fn try_new(cst: impl CustomConst) -> Option { let mut hasher = DefaultHasher::new(); - cst.maybe_hash(&mut hasher).then(|| { + cst.try_hash(&mut hasher).then(|| { Self::Const(HashedConst { hash: hasher.finish(), val: Arc::new(cst), diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs index 7ac0b6ba5..ccb7d27bd 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -2,7 +2,7 @@ use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; -use hugr_core::ops::{CustomOp, OpType, Value}; +use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::{ValueHandle, ValueKey}; @@ -83,7 +83,7 @@ impl DFContext for HugrValueContext { ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())), )] } - OpType::CustomOp(CustomOp::Extension(op)) => { + OpType::ExtensionOp(op) => { let ins = ins .into_iter() .map(|(p, v)| (*p, v.clone())) From f7d288f47d74215d0e9e1d3ba59604c793006d95 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 14:40:45 +0100 Subject: [PATCH 037/203] pub ValueRow+Partial(Value/Sum); add TotalContext --- hugr-passes/src/const_fold2.rs | 1 + hugr-passes/src/const_fold2/datalog.rs | 48 ++++++------------ hugr-passes/src/const_fold2/total_context.rs | 50 +++++++++++++++++++ .../src/const_fold2/value_handle/context.rs | 7 +-- 4 files changed, 67 insertions(+), 39 deletions(-) create mode 100644 hugr-passes/src/const_fold2/total_context.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 7d6725fb1..db1d99467 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,2 +1,3 @@ pub mod datalog; +pub mod total_context; pub mod value_handle; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index a3256d41e..0dc393620 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -3,25 +3,22 @@ use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; use std::hash::Hash; -use hugr_core::ops::{OpTrait, OpType}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; +use hugr_core::ops::{OpType, Value}; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; mod partial_value; mod utils; -use utils::{TailLoopTermination, ValueRow}; +// TODO separate this into its own analysis? +use utils::TailLoopTermination; -pub use partial_value::{AbstractValue, FromSum}; +pub use partial_value::{AbstractValue, FromSum, PartialSum, PartialValue}; +pub use utils::ValueRow; type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - type InterpretableVal: FromSum + From; fn hugr(&self) -> &impl HugrView; - fn interpret_leaf_op( - &self, - node: Node, - ins: &[(IncomingPort, Self::InterpretableVal)], - ) -> Vec<(OutgoingPort, V)>; + fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -180,29 +177,7 @@ fn propagate_leaf_op( // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues // are not necessarily convertible to Value! - op => { - let sig = op.dataflow_signature()?; - let known_ins = sig - .input_types() - .into_iter() - .enumerate() - .zip(ins.iter()) - .filter_map(|((i, ty), pv)| { - pv.clone() - .try_into_value(ty) - .ok() - .map(|v| (IncomingPort::from(i), v)) - }) - .collect::>(); - let known_outs = c.interpret_leaf_op(n, &known_ins); - (!known_outs.is_empty()).then(|| { - let mut res = ValueRow::new(sig.output_count()); - for (p, v) in known_outs { - res[p.index()] = v.into(); - } - res - }) - } + _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), } } @@ -265,8 +240,13 @@ impl> Machine { .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) .unwrap() } +} - pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { +impl> Machine +where + Value: From, +{ + pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { // dbg!(&w); let pv = self.read_out_wire_partial_value(w)?; // dbg!(&pv); diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/const_fold2/total_context.rs new file mode 100644 index 000000000..69bd516a4 --- /dev/null +++ b/hugr-passes/src/const_fold2/total_context.rs @@ -0,0 +1,50 @@ +use std::hash::Hash; + +use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; + +use super::datalog::{AbstractValue, DFContext, FromSum, PartialValue, ValueRow}; + +/// A simpler interface like [DFContext] but where the context only cares about +/// values that are completely known (in the lattice `V`) +/// rather than e.g. Sums potentially of two variants each of known values. +pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { + type InterpretableVal: FromSum + From; + fn interpret_leaf_op( + &self, + node: Node, + ins: &[(IncomingPort, Self::InterpretableVal)], + ) -> Vec<(OutgoingPort, V)>; +} + +impl> DFContext for T { + fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option> { + let op = self.get_optype(node); + let sig = op.dataflow_signature()?; + let known_ins = sig + .input_types() + .into_iter() + .enumerate() + .zip(ins.iter()) + .filter_map(|((i, ty), pv)| { + pv.clone() + .try_into_value(ty) + .ok() + .map(|v| (IncomingPort::from(i), v)) + }) + .collect::>(); + let known_outs = self.interpret_leaf_op(node, &known_ins); + (!known_outs.is_empty()).then(|| { + let mut res = ValueRow::new(sig.output_count()); + for (p, v) in known_outs { + res[p.index()] = v.into(); + } + res + }) + } + + fn hugr(&self) -> &impl HugrView { + // Adding `fn hugr(&self) -> &impl HugrView` to trait TotalContext + // and calling that here requires a lifetime bound on V, so avoid that + self.as_ref() + } +} diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/value_handle/context.rs index ccb7d27bd..b24c57aa8 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/value_handle/context.rs @@ -6,7 +6,7 @@ use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::{ValueHandle, ValueKey}; -use crate::const_fold2::datalog::DFContext; +use crate::const_fold2::total_context::TotalContext; /// An implementation of [DFContext] with [ValueHandle] /// that just stores a Hugr (actually any [HugrView]), @@ -58,11 +58,8 @@ impl Deref for HugrValueContext { } } -impl DFContext for HugrValueContext { +impl TotalContext for HugrValueContext { type InterpretableVal = Value; - fn hugr(&self) -> &impl HugrView { - self.0.as_ref() - } fn interpret_leaf_op( &self, From a62eb0f5506b02065fe7723ea0e9c0796feeb6bb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 14:48:32 +0100 Subject: [PATCH 038/203] Remove DFContext::hugr(), as_ref() does just as well --- hugr-passes/src/const_fold2/datalog.rs | 11 +++++------ hugr-passes/src/const_fold2/total_context.rs | 6 ------ 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 0dc393620..4baef4a9b 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -17,7 +17,6 @@ pub use utils::ValueRow; type PV = partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - fn hugr(&self) -> &impl HugrView; fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; } @@ -43,9 +42,9 @@ ascent::ascent! { node(c, n) <-- context(c), for n in c.nodes(); - in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c.hugr(), *n); + in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c.as_ref(), *n); - out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c.hugr(), *n); + out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c.as_ref(), *n); parent_of_node(c, parent, child) <-- node(c, child), if let Some(parent) = c.get_parent(*child); @@ -64,8 +63,8 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, utils::bottom_row(c.hugr(), *n)) <-- node(c, n); - node_in_value_row(c, n, utils::singleton_in_row(c.hugr(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, utils::bottom_row(c.as_ref(), *n)) <-- node(c, n); + node_in_value_row(c, n, utils::singleton_in_row(c.as_ref(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); out_wire_value(c, n, p, v) <-- node(c, n), @@ -166,7 +165,7 @@ fn propagate_leaf_op( )])), op if op.cast::().is_some() => { let [tup] = ins.into_iter().collect::>().try_into().unwrap(); - tup.variant_values(0, utils::value_outputs(c.hugr(), n).count()) + tup.variant_values(0, utils::value_outputs(c.as_ref(), n).count()) .map(ValueRow::from_iter) } OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/const_fold2/total_context.rs index 69bd516a4..2ccb7db88 100644 --- a/hugr-passes/src/const_fold2/total_context.rs +++ b/hugr-passes/src/const_fold2/total_context.rs @@ -41,10 +41,4 @@ impl> DFContext for T { res }) } - - fn hugr(&self) -> &impl HugrView { - // Adding `fn hugr(&self) -> &impl HugrView` to trait TotalContext - // and calling that here requires a lifetime bound on V, so avoid that - self.as_ref() - } } From f0ec2373fa253107cc3b6f5524d1c205cd541997 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 15:13:01 +0100 Subject: [PATCH 039/203] Move partial_value out of datalog, combine tests; move ValueRow out of utils --- hugr-passes/src/const_fold2.rs | 9 +- hugr-passes/src/const_fold2/datalog.rs | 13 +- .../const_fold2/datalog/partial_value/test.rs | 347 ----------------- hugr-passes/src/const_fold2/datalog/test.rs | 2 +- hugr-passes/src/const_fold2/datalog/utils.rs | 156 +------- .../{datalog => }/partial_value.rs | 357 +++++++++++++++++- hugr-passes/src/const_fold2/total_context.rs | 4 +- hugr-passes/src/const_fold2/value_handle.rs | 2 +- hugr-passes/src/const_fold2/value_row.rs | 117 ++++++ 9 files changed, 500 insertions(+), 507 deletions(-) delete mode 100644 hugr-passes/src/const_fold2/datalog/partial_value/test.rs rename hugr-passes/src/const_fold2/{datalog => }/partial_value.rs (54%) create mode 100644 hugr-passes/src/const_fold2/value_row.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index db1d99467..b0ab62fdc 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,3 +1,10 @@ -pub mod datalog; +mod datalog; +pub use datalog::Machine; + +pub mod partial_value; + +mod value_row; +pub use value_row::ValueRow; + pub mod total_context; pub mod value_handle; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/const_fold2/datalog.rs index 4baef4a9b..064aab91f 100644 --- a/hugr-passes/src/const_fold2/datalog.rs +++ b/hugr-passes/src/const_fold2/datalog.rs @@ -6,18 +6,17 @@ use std::hash::Hash; use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; -mod partial_value; mod utils; // TODO separate this into its own analysis? use utils::TailLoopTermination; -pub use partial_value::{AbstractValue, FromSum, PartialSum, PartialValue}; -pub use utils::ValueRow; -type PV = partial_value::PartialValue; +use super::partial_value::AbstractValue; +use super::value_row::ValueRow; +type PV = super::partial_value::PartialValue; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; + fn interpret_leaf_op(&self, node: Node, ins: &[PV]) -> Option>; } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -63,8 +62,8 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, utils::bottom_row(c.as_ref(), *n)) <-- node(c, n); - node_in_value_row(c, n, utils::singleton_in_row(c.as_ref(), n, p, v.clone())) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, ValueRow::new(utils::input_count(c.as_ref(), *n))) <-- node(c, n); + node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input.len(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); out_wire_value(c, n, p, v) <-- node(c, n), diff --git a/hugr-passes/src/const_fold2/datalog/partial_value/test.rs b/hugr-passes/src/const_fold2/datalog/partial_value/test.rs deleted file mode 100644 index 33c8f3c8d..000000000 --- a/hugr-passes/src/const_fold2/datalog/partial_value/test.rs +++ /dev/null @@ -1,347 +0,0 @@ -use std::sync::Arc; - -use itertools::{zip_eq, Either, Itertools as _}; -use proptest::prelude::*; - -use hugr_core::{ - std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, - types::{Type, TypeArg, TypeEnum}, -}; - -use super::{PartialSum, PartialValue}; -use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; - -impl Arbitrary for ValueHandle { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { - // prop_oneof![ - - // ] - todo!() - } -} - -#[derive(Debug, PartialEq, Eq, Clone)] -enum TestSumLeafType { - Int(Type), - Unit, -} - -impl TestSumLeafType { - fn assert_invariants(&self) { - match self { - Self::Int(t) => { - if let TypeEnum::Extension(ct) = t.as_type_enum() { - assert_eq!("int", ct.name()); - assert_eq!(&int_types::EXTENSION_ID, ct.extension()); - } else { - panic!("Expected int type, got {:#?}", t); - } - } - _ => (), - } - } - - fn get_type(&self) -> Type { - match self { - Self::Int(t) => t.clone(), - Self::Unit => Type::UNIT, - } - } - - fn type_check(&self, ps: &PartialSum) -> bool { - match self { - Self::Int(_) => false, - Self::Unit => { - if let Ok((0, v)) = ps.0.iter().exactly_one() { - v.is_empty() - } else { - false - } - } - } - } - - fn partial_value_strategy(self) -> impl Strategy> { - match self { - Self::Int(t) => { - let TypeEnum::Extension(ct) = t.as_type_enum() else { - unreachable!() - }; - // TODO this should be get_log_width, but that's not pub - let TypeArg::BoundedNat { n: lw } = ct.args()[0] else { - panic!() - }; - (0u64..(1 << (2u64.pow(lw as u32) - 1))) - .prop_map(move |x| { - let ki = ConstInt::new_u(lw as u8, x).unwrap(); - let k = ValueKey::try_new(ki.clone()).unwrap(); - ValueHandle::new(k, Arc::new(ki.into())).into() - }) - .boxed() - } - Self::Unit => Just(PartialSum::unit().into()).boxed(), - } - } -} - -impl Arbitrary for TestSumLeafType { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { - let int_strat = (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); - prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() - } -} - -#[derive(Debug, PartialEq, Eq, Clone)] -enum TestSumType { - Branch(usize, Vec>>), - Leaf(TestSumLeafType), -} - -impl TestSumType { - const UNIT: TestSumLeafType = TestSumLeafType::Unit; - - fn leaf(v: Type) -> Self { - TestSumType::Leaf(TestSumLeafType::Int(v)) - } - - fn branch(vs: impl IntoIterator>>) -> Self { - let vec = vs.into_iter().collect_vec(); - let depth: usize = vec - .iter() - .flat_map(|x| x.iter()) - .map(|x| x.depth() + 1) - .max() - .unwrap_or(0); - Self::Branch(depth, vec) - } - - fn depth(&self) -> usize { - match self { - TestSumType::Branch(x, _) => *x, - TestSumType::Leaf(_) => 0, - } - } - - fn is_leaf(&self) -> bool { - self.depth() == 0 - } - - fn assert_invariants(&self) { - match self { - TestSumType::Branch(d, sop) => { - assert!(!sop.is_empty(), "No variants"); - for v in sop.iter().flat_map(|x| x.iter()) { - assert!(v.depth() < *d); - v.assert_invariants(); - } - } - TestSumType::Leaf(l) => { - l.assert_invariants(); - } - } - } - - fn select(self) -> impl Strategy>)>> { - match self { - TestSumType::Branch(_, sop) => any::() - .prop_map(move |i| { - let index = i.index(sop.len()); - Either::Right((index, sop[index].clone())) - }) - .boxed(), - TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), - } - } - - fn get_type(&self) -> Type { - match self { - TestSumType::Branch(_, sop) => Type::new_sum( - sop.iter() - .map(|row| row.iter().map(|x| x.get_type()).collect_vec()), - ), - TestSumType::Leaf(l) => l.get_type(), - } - } - - fn type_check(&self, pv: &PartialValue) -> bool { - match (self, pv) { - (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), - (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { - for (k, v) in &ps.0 { - if *k >= sop.len() { - return false; - } - let prod = &sop[*k]; - if prod.len() != v.len() { - return false; - } - if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { - return false; - } - } - true - } - (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), - } - } -} - -impl From for TestSumType { - fn from(value: TestSumLeafType) -> Self { - Self::Leaf(value) - } -} - -#[derive(Clone, PartialEq, Eq, Debug)] -struct UnarySumTypeParams { - depth: usize, - branch_width: usize, -} - -impl UnarySumTypeParams { - pub fn descend(mut self, d: usize) -> Self { - assert!(d < self.depth); - self.depth = d; - self - } -} - -impl Default for UnarySumTypeParams { - fn default() -> Self { - Self { - depth: 3, - branch_width: 3, - } - } -} - -impl Arbitrary for TestSumType { - type Parameters = UnarySumTypeParams; - type Strategy = BoxedStrategy; - fn arbitrary_with( - params @ UnarySumTypeParams { - depth, - branch_width, - }: Self::Parameters, - ) -> Self::Strategy { - if depth == 0 { - any::().prop_map_into().boxed() - } else { - (0..depth) - .prop_flat_map(move |d| { - prop::collection::vec( - prop::collection::vec( - any_with::(params.clone().descend(d)).prop_map_into(), - 0..branch_width, - ), - 1..=branch_width, - ) - .prop_map(TestSumType::branch) - }) - .boxed() - } - } -} - -proptest! { - #[test] - fn unary_sum_type_valid(ust: TestSumType) { - ust.assert_invariants(); - } -} - -fn any_partial_value_of_type(ust: TestSumType) -> impl Strategy> { - ust.select().prop_flat_map(|x| match x { - Either::Left(l) => l.partial_value_strategy().boxed(), - Either::Right((index, usts)) => { - let pvs = usts - .into_iter() - .map(|x| { - any_partial_value_of_type( - Arc::::try_unwrap(x).unwrap_or_else(|x| x.as_ref().clone()), - ) - }) - .collect_vec(); - pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) - .boxed() - } - }) -} - -fn any_partial_value_with( - params: ::Parameters, -) -> impl Strategy> { - any_with::(params).prop_flat_map(any_partial_value_of_type) -} - -fn any_partial_value() -> impl Strategy> { - any_partial_value_with(Default::default()) -} - -fn any_partial_values() -> impl Strategy; N]> { - any::().prop_flat_map(|ust| { - TryInto::<[_; N]>::try_into( - (0..N) - .map(|_| any_partial_value_of_type(ust.clone())) - .collect_vec(), - ) - .unwrap() - }) -} - -fn any_typed_partial_value() -> impl Strategy)> { - any::() - .prop_flat_map(|t| any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v))) -} - -proptest! { - #[test] - fn partial_value_type((tst, pv) in any_typed_partial_value()) { - prop_assert!(tst.type_check(&pv)) - } - - // todo: ValidHandle is valid - // todo: ValidHandle eq is an equivalence relation - - // todo: PartialValue PartialOrd is transitive - // todo: PartialValue eq is an equivalence relation - #[test] - fn partial_value_valid(pv in any_partial_value()) { - pv.assert_invariants(); - } - - #[test] - fn bounded_lattice(v in any_partial_value()) { - prop_assert!(v <= PartialValue::top()); - prop_assert!(v >= PartialValue::bottom()); - } - - #[test] - fn meet_join_self_noop(v1 in any_partial_value()) { - let mut subject = v1.clone(); - - assert_eq!(v1.clone(), v1.clone().join(v1.clone())); - assert!(!subject.join_mut(v1.clone())); - assert_eq!(subject, v1); - - assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); - assert!(!subject.meet_mut(v1.clone())); - assert_eq!(subject, v1); - } - - #[test] - fn lattice([v1,v2] in any_partial_values()) { - let meet = v1.clone().meet(v2.clone()); - prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); - prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); - - let join = v1.clone().join(v2.clone()); - prop_assert!(join >= v1, "join not >=: {:#?}", &join); - prop_assert!(join >= v2, "join not >=: {:#?}", &join); - } -} diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/const_fold2/datalog/test.rs index 0e5d0c48e..42c34c451 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/const_fold2/datalog/test.rs @@ -9,7 +9,7 @@ use hugr_core::{ types::{Signature, SumType, Type, TypeRow}, }; -use super::partial_value::PartialValue; +use super::super::partial_value::PartialValue; use super::*; diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/const_fold2/datalog/utils.rs index a5bc8bfa2..ae486e280 100644 --- a/hugr-passes/src/const_fold2/datalog/utils.rs +++ b/hugr-passes/src/const_fold2/datalog/utils.rs @@ -3,20 +3,12 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -use std::{ - cmp::Ordering, - ops::{Index, IndexMut}, -}; +use std::cmp::Ordering; use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::zip_eq; -use super::{partial_value::PartialValue, AbstractValue}; -use hugr_core::{ - ops::OpTrait as _, - types::{Signature, TypeRow}, - HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, -}; +use super::super::partial_value::{AbstractValue, PartialValue}; +use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort}; #[cfg(test)] use proptest_derive::Arbitrary; @@ -49,143 +41,11 @@ impl BoundedLattice for PartialValue { } } -#[derive(PartialEq, Clone, Eq, Hash)] -pub struct ValueRow(Vec>); - -impl ValueRow { - pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) - } - - fn single_among_bottoms(len: usize, idx: usize, v: PartialValue) -> Self { - assert!(idx < len); - let mut r = Self::new(len); - r.0[idx] = v; - r - } - - fn bottom_from_row(r: &TypeRow) -> Self { - Self::new(r.len()) - } - - pub fn iter(&self) -> impl Iterator> { - self.0.iter() - } - - pub fn unpack_first( - &self, - variant: usize, - len: usize, - ) -> Option> + '_> { - self[0] - .variant_values(variant, len) - .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) - } - - // fn initialised(&self) -> bool { - // self.0.iter().all(|x| x != &PV::top()) - // } -} - -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { - Self(iter.into_iter().collect()) - } -} - -impl PartialOrd for ValueRow { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl Lattice for ValueRow { - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - - fn join_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.join_mut(v2); - } - changed - } - - fn meet_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.meet_mut(v2); - } - changed - } -} - -impl IntoIterator for ValueRow { - type Item = PartialValue; - - type IntoIter = > as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl Index for ValueRow -where - Vec>: Index, -{ - type Output = > as Index>::Output; - - fn index(&self, index: Idx) -> &Self::Output { - self.0.index(index) - } -} - -impl IndexMut for ValueRow -where - Vec>: IndexMut, -{ - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { - self.0.index_mut(index) - } -} - -pub(super) fn bottom_row(h: &impl HugrView, n: Node) -> ValueRow { - ValueRow::new( - h.signature(n) - .as_ref() - .map(Signature::input_count) - .unwrap_or(0), - ) -} - -pub(super) fn singleton_in_row( - h: &impl HugrView, - n: &Node, - ip: &IncomingPort, - v: PartialValue, -) -> ValueRow { - let Some(sig) = h.signature(*n) else { - panic!("dougrulz"); - }; - if sig.input_count() <= ip.index() { - panic!( - "bad port index: {} >= {}: {}", - ip.index(), - sig.input_count(), - h.get_optype(*n).description() - ); - } - ValueRow::single_among_bottoms(h.signature(*n).unwrap().input.len(), ip.index(), v) +pub(super) fn input_count(h: &impl HugrView, n: Node) -> usize { + h.signature(n) + .as_ref() + .map(Signature::input_count) + .unwrap_or(0) } pub(super) fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { diff --git a/hugr-passes/src/const_fold2/datalog/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs similarity index 54% rename from hugr-passes/src/const_fold2/datalog/partial_value.rs rename to hugr-passes/src/const_fold2/partial_value.rs index 73e962287..933376027 100644 --- a/hugr-passes/src/const_fold2/datalog/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -464,4 +464,359 @@ impl PartialOrd for PartialValue { } #[cfg(test)] -mod test; +mod test { + use std::sync::Arc; + + use itertools::{zip_eq, Either, Itertools as _}; + use proptest::prelude::*; + + use hugr_core::{ + std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, + types::{Type, TypeArg, TypeEnum}, + }; + + use super::{PartialSum, PartialValue}; + use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; + + impl Arbitrary for ValueHandle { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + // prop_oneof![ + + // ] + todo!() + } + } + + #[derive(Debug, PartialEq, Eq, Clone)] + enum TestSumLeafType { + Int(Type), + Unit, + } + + impl TestSumLeafType { + fn assert_invariants(&self) { + match self { + Self::Int(t) => { + if let TypeEnum::Extension(ct) = t.as_type_enum() { + assert_eq!("int", ct.name()); + assert_eq!(&int_types::EXTENSION_ID, ct.extension()); + } else { + panic!("Expected int type, got {:#?}", t); + } + } + _ => (), + } + } + + fn get_type(&self) -> Type { + match self { + Self::Int(t) => t.clone(), + Self::Unit => Type::UNIT, + } + } + + fn type_check(&self, ps: &PartialSum) -> bool { + match self { + Self::Int(_) => false, + Self::Unit => { + if let Ok((0, v)) = ps.0.iter().exactly_one() { + v.is_empty() + } else { + false + } + } + } + } + + fn partial_value_strategy(self) -> impl Strategy> { + match self { + Self::Int(t) => { + let TypeEnum::Extension(ct) = t.as_type_enum() else { + unreachable!() + }; + // TODO this should be get_log_width, but that's not pub + let TypeArg::BoundedNat { n: lw } = ct.args()[0] else { + panic!() + }; + (0u64..(1 << (2u64.pow(lw as u32) - 1))) + .prop_map(move |x| { + let ki = ConstInt::new_u(lw as u8, x).unwrap(); + let k = ValueKey::try_new(ki.clone()).unwrap(); + ValueHandle::new(k, Arc::new(ki.into())).into() + }) + .boxed() + } + Self::Unit => Just(PartialSum::unit().into()).boxed(), + } + } + } + + impl Arbitrary for TestSumLeafType { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + let int_strat = + (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); + prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() + } + } + + #[derive(Debug, PartialEq, Eq, Clone)] + enum TestSumType { + Branch(usize, Vec>>), + Leaf(TestSumLeafType), + } + + impl TestSumType { + const UNIT: TestSumLeafType = TestSumLeafType::Unit; + + fn leaf(v: Type) -> Self { + TestSumType::Leaf(TestSumLeafType::Int(v)) + } + + fn branch(vs: impl IntoIterator>>) -> Self { + let vec = vs.into_iter().collect_vec(); + let depth: usize = vec + .iter() + .flat_map(|x| x.iter()) + .map(|x| x.depth() + 1) + .max() + .unwrap_or(0); + Self::Branch(depth, vec) + } + + fn depth(&self) -> usize { + match self { + TestSumType::Branch(x, _) => *x, + TestSumType::Leaf(_) => 0, + } + } + + fn is_leaf(&self) -> bool { + self.depth() == 0 + } + + fn assert_invariants(&self) { + match self { + TestSumType::Branch(d, sop) => { + assert!(!sop.is_empty(), "No variants"); + for v in sop.iter().flat_map(|x| x.iter()) { + assert!(v.depth() < *d); + v.assert_invariants(); + } + } + TestSumType::Leaf(l) => { + l.assert_invariants(); + } + } + } + + fn select(self) -> impl Strategy>)>> { + match self { + TestSumType::Branch(_, sop) => any::() + .prop_map(move |i| { + let index = i.index(sop.len()); + Either::Right((index, sop[index].clone())) + }) + .boxed(), + TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), + } + } + + fn get_type(&self) -> Type { + match self { + TestSumType::Branch(_, sop) => Type::new_sum( + sop.iter() + .map(|row| row.iter().map(|x| x.get_type()).collect_vec()), + ), + TestSumType::Leaf(l) => l.get_type(), + } + } + + fn type_check(&self, pv: &PartialValue) -> bool { + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, + (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), + (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { + for (k, v) in &ps.0 { + if *k >= sop.len() { + return false; + } + let prod = &sop[*k]; + if prod.len() != v.len() { + return false; + } + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { + return false; + } + } + true + } + (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), + } + } + } + + impl From for TestSumType { + fn from(value: TestSumLeafType) -> Self { + Self::Leaf(value) + } + } + + #[derive(Clone, PartialEq, Eq, Debug)] + struct UnarySumTypeParams { + depth: usize, + branch_width: usize, + } + + impl UnarySumTypeParams { + pub fn descend(mut self, d: usize) -> Self { + assert!(d < self.depth); + self.depth = d; + self + } + } + + impl Default for UnarySumTypeParams { + fn default() -> Self { + Self { + depth: 3, + branch_width: 3, + } + } + } + + impl Arbitrary for TestSumType { + type Parameters = UnarySumTypeParams; + type Strategy = BoxedStrategy; + fn arbitrary_with( + params @ UnarySumTypeParams { + depth, + branch_width, + }: Self::Parameters, + ) -> Self::Strategy { + if depth == 0 { + any::().prop_map_into().boxed() + } else { + (0..depth) + .prop_flat_map(move |d| { + prop::collection::vec( + prop::collection::vec( + any_with::(params.clone().descend(d)).prop_map_into(), + 0..branch_width, + ), + 1..=branch_width, + ) + .prop_map(TestSumType::branch) + }) + .boxed() + } + } + } + + proptest! { + #[test] + fn unary_sum_type_valid(ust: TestSumType) { + ust.assert_invariants(); + } + } + + fn any_partial_value_of_type( + ust: TestSumType, + ) -> impl Strategy> { + ust.select().prop_flat_map(|x| match x { + Either::Left(l) => l.partial_value_strategy().boxed(), + Either::Right((index, usts)) => { + let pvs = usts + .into_iter() + .map(|x| { + any_partial_value_of_type( + Arc::::try_unwrap(x) + .unwrap_or_else(|x| x.as_ref().clone()), + ) + }) + .collect_vec(); + pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) + .boxed() + } + }) + } + + fn any_partial_value_with( + params: ::Parameters, + ) -> impl Strategy> { + any_with::(params).prop_flat_map(any_partial_value_of_type) + } + + fn any_partial_value() -> impl Strategy> { + any_partial_value_with(Default::default()) + } + + fn any_partial_values() -> impl Strategy; N]> + { + any::().prop_flat_map(|ust| { + TryInto::<[_; N]>::try_into( + (0..N) + .map(|_| any_partial_value_of_type(ust.clone())) + .collect_vec(), + ) + .unwrap() + }) + } + + fn any_typed_partial_value() -> impl Strategy)> + { + any::().prop_flat_map(|t| { + any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v)) + }) + } + + proptest! { + #[test] + fn partial_value_type((tst, pv) in any_typed_partial_value()) { + prop_assert!(tst.type_check(&pv)) + } + + // todo: ValidHandle is valid + // todo: ValidHandle eq is an equivalence relation + + // todo: PartialValue PartialOrd is transitive + // todo: PartialValue eq is an equivalence relation + #[test] + fn partial_value_valid(pv in any_partial_value()) { + pv.assert_invariants(); + } + + #[test] + fn bounded_lattice(v in any_partial_value()) { + prop_assert!(v <= PartialValue::top()); + prop_assert!(v >= PartialValue::bottom()); + } + + #[test] + fn meet_join_self_noop(v1 in any_partial_value()) { + let mut subject = v1.clone(); + + assert_eq!(v1.clone(), v1.clone().join(v1.clone())); + assert!(!subject.join_mut(v1.clone())); + assert_eq!(subject, v1); + + assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); + assert!(!subject.meet_mut(v1.clone())); + assert_eq!(subject, v1); + } + + #[test] + fn lattice([v1,v2] in any_partial_values()) { + let meet = v1.clone().meet(v2.clone()); + prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); + prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + + let join = v1.clone().join(v2.clone()); + prop_assert!(join >= v1, "join not >=: {:#?}", &join); + prop_assert!(join >= v2, "join not >=: {:#?}", &join); + } + } +} diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/const_fold2/total_context.rs index 2ccb7db88..63a7b4965 100644 --- a/hugr-passes/src/const_fold2/total_context.rs +++ b/hugr-passes/src/const_fold2/total_context.rs @@ -2,7 +2,9 @@ use std::hash::Hash; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use super::datalog::{AbstractValue, DFContext, FromSum, PartialValue, ValueRow}; +use super::datalog::DFContext; +use super::partial_value::{AbstractValue, FromSum, PartialValue}; +use super::ValueRow; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index e3dab5af0..3d26d5ac8 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,7 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::{ConstTypeError, Type}; use hugr_core::Node; -use super::datalog::{AbstractValue, FromSum}; +use super::partial_value::{AbstractValue, FromSum}; mod context; pub use context::HugrValueContext; diff --git a/hugr-passes/src/const_fold2/value_row.rs b/hugr-passes/src/const_fold2/value_row.rs new file mode 100644 index 000000000..91b45052c --- /dev/null +++ b/hugr-passes/src/const_fold2/value_row.rs @@ -0,0 +1,117 @@ +// Really this is part of partial_value.rs + +use std::{ + cmp::Ordering, + ops::{Index, IndexMut}, +}; + +use ascent::lattice::Lattice; +use itertools::zip_eq; + +use super::partial_value::{AbstractValue, PartialValue}; + +#[derive(PartialEq, Clone, Eq, Hash)] +pub struct ValueRow(Vec>); + +impl ValueRow { + pub fn new(len: usize) -> Self { + Self(vec![PartialValue::bottom(); len]) + } + + pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + pub fn iter(&self) -> impl Iterator> { + self.0.iter() + } + + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option> + '_> { + self[0] + .variant_values(variant, len) + .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) + } + + // fn initialised(&self) -> bool { + // self.0.iter().all(|x| x != &PV::top()) + // } +} + +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PartialValue; + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec>: Index, +{ + type Output = > as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +impl IndexMut for ValueRow +where + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} From 82a3f22d256e41ac2a4e86c3dd4948695242bb3c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 15:21:44 +0100 Subject: [PATCH 040/203] Move FromSum and try_into_value into total_context.rs --- hugr-passes/src/const_fold2/partial_value.rs | 49 +------------- hugr-passes/src/const_fold2/total_context.rs | 68 +++++++++++++++++++- hugr-passes/src/const_fold2/value_handle.rs | 15 +---- 3 files changed, 70 insertions(+), 62 deletions(-) diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/const_fold2/partial_value.rs index 933376027..3c20ff965 100644 --- a/hugr-passes/src/const_fold2/partial_value.rs +++ b/hugr-passes/src/const_fold2/partial_value.rs @@ -1,7 +1,6 @@ #![allow(missing_docs)] -use hugr_core::types::{SumType, Type, TypeEnum, TypeRow}; -use itertools::{zip_eq, Itertools}; +use itertools::zip_eq; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; @@ -14,16 +13,6 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } -pub trait FromSum: Sized { - type Err: std::error::Error; - fn try_new_sum( - tag: usize, - items: impl IntoIterator, - st: &SumType, - ) -> Result; - fn debug_check_is_type(&self, _ty: &Type) {} -} - // TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum #[derive(PartialEq, Clone, Eq)] pub struct PartialSum(pub HashMap>>); @@ -103,32 +92,6 @@ impl PartialSum { pub fn supports_tag(&self, tag: usize) -> bool { self.0.contains_key(&tag) } - - pub fn try_into_value>(self, typ: &Type) -> Result { - let Ok((k, v)) = self.0.iter().exactly_one() else { - Err(self)? - }; - - let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(self)? - }; - let Some(r) = st.get_variant(*k) else { - Err(self)? - }; - let Ok(r): Result = r.clone().try_into() else { - Err(self)? - }; - if v.len() != r.len() { - return Err(self); - } - match zip_eq(v.into_iter(), r.into_iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>() - { - Ok(vs) => V2::try_new_sum(*k, vs, &st).map_err(|_| self), - Err(_) => Err(self), - } - } } impl PartialSum { @@ -434,16 +397,6 @@ impl PartialValue { PartialValue::Top => true, } } - - pub fn try_into_value>(self, typ: &Type) -> Result { - let r: V2 = match self { - Self::Value(v) => Ok(v.clone().into()), - Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), - x => Err(x), - }?; - r.debug_check_is_type(typ); - Ok(r) - } } impl PartialOrd for PartialValue { diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/const_fold2/total_context.rs index 63a7b4965..8caafda52 100644 --- a/hugr-passes/src/const_fold2/total_context.rs +++ b/hugr-passes/src/const_fold2/total_context.rs @@ -1,11 +1,24 @@ use std::hash::Hash; +use hugr_core::ops::Value; +use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; +use itertools::{zip_eq, Itertools}; use super::datalog::DFContext; -use super::partial_value::{AbstractValue, FromSum, PartialValue}; +use super::partial_value::{AbstractValue, PartialSum, PartialValue}; use super::ValueRow; +pub trait FromSum: Sized { + type Err: std::error::Error; + fn try_new_sum( + tag: usize, + items: impl IntoIterator, + st: &SumType, + ) -> Result; + fn debug_check_is_type(&self, _ty: &Type) {} +} + /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) /// rather than e.g. Sums potentially of two variants each of known values. @@ -18,6 +31,59 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { ) -> Vec<(OutgoingPort, V)>; } +impl FromSum for Value { + type Err = ConstTypeError; + fn try_new_sum( + tag: usize, + items: impl IntoIterator, + st: &hugr_core::types::SumType, + ) -> Result { + Self::sum(tag, items, st.clone()) + } +} + +// These are here because they rely on FromSum, that they are `impl PartialSum/Value` +// is merely a nice syntax. +impl PartialValue { + pub fn try_into_value>(self, typ: &Type) -> Result { + let r: V2 = match self { + Self::Value(v) => Ok(v.clone().into()), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + }?; + r.debug_check_is_type(typ); + Ok(r) + } +} + +impl PartialSum { + pub fn try_into_value>(self, typ: &Type) -> Result { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; + + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + let Ok(r): Result = r.clone().try_into() else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self); + } + match zip_eq(v.into_iter(), r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + { + Ok(vs) => V2::try_new_sum(*k, vs, &st).map_err(|_| self), + Err(_) => Err(self), + } + } +} + impl> DFContext for T { fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option> { let op = self.get_optype(node); diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 3d26d5ac8..137699763 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, Type}; +use hugr_core::types::Type; use hugr_core::Node; -use super::partial_value::{AbstractValue, FromSum}; +use super::partial_value::AbstractValue; mod context; pub use context::HugrValueContext; @@ -102,17 +102,6 @@ impl AbstractValue for ValueHandle { } } -impl FromSum for Value { - type Err = ConstTypeError; - fn try_new_sum( - tag: usize, - items: impl IntoIterator, - st: &hugr_core::types::SumType, - ) -> Result { - Self::sum(tag, items, st.clone()) - } -} - impl PartialEq for ValueHandle { fn eq(&self, other: &Self) -> bool { // If the keys are equal, we return true since the values must have the From e6dc114b87fe97cbe45cbe88b17399d4c594a17b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 15:41:48 +0100 Subject: [PATCH 041/203] Separate mod dataflow from mod const_fold2 --- hugr-passes/src/const_fold2.rs | 13 +++++-------- .../src/const_fold2/{value_handle => }/context.rs | 4 ++-- hugr-passes/src/const_fold2/value_handle.rs | 5 +---- hugr-passes/src/dataflow.rs | 13 +++++++++++++ .../src/{const_fold2 => dataflow}/datalog.rs | 0 .../src/{const_fold2 => dataflow}/datalog/test.rs | 2 +- .../src/{const_fold2 => dataflow}/datalog/utils.rs | 0 .../src/{const_fold2 => dataflow}/partial_value.rs | 0 .../src/{const_fold2 => dataflow}/total_context.rs | 0 .../src/{const_fold2 => dataflow}/value_row.rs | 0 hugr-passes/src/lib.rs | 1 + 11 files changed, 23 insertions(+), 15 deletions(-) rename hugr-passes/src/const_fold2/{value_handle => }/context.rs (97%) create mode 100644 hugr-passes/src/dataflow.rs rename hugr-passes/src/{const_fold2 => dataflow}/datalog.rs (100%) rename hugr-passes/src/{const_fold2 => dataflow}/datalog/test.rs (99%) rename hugr-passes/src/{const_fold2 => dataflow}/datalog/utils.rs (100%) rename hugr-passes/src/{const_fold2 => dataflow}/partial_value.rs (100%) rename hugr-passes/src/{const_fold2 => dataflow}/total_context.rs (100%) rename hugr-passes/src/{const_fold2 => dataflow}/value_row.rs (100%) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index b0ab62fdc..1fa3498e0 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,10 +1,7 @@ -mod datalog; -pub use datalog::Machine; +//! An (example) use of the [super::dataflow](dataflow-analysis framework) +//! to perform constant-folding. -pub mod partial_value; - -mod value_row; -pub use value_row::ValueRow; - -pub mod total_context; +// These are pub because this "example" is used for testing the framework. pub mod value_handle; +mod context; +pub use context::HugrValueContext; diff --git a/hugr-passes/src/const_fold2/value_handle/context.rs b/hugr-passes/src/const_fold2/context.rs similarity index 97% rename from hugr-passes/src/const_fold2/value_handle/context.rs rename to hugr-passes/src/const_fold2/context.rs index b24c57aa8..007fcfa92 100644 --- a/hugr-passes/src/const_fold2/value_handle/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -5,8 +5,8 @@ use std::sync::Arc; use hugr_core::ops::{OpType, Value}; use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; -use super::{ValueHandle, ValueKey}; -use crate::const_fold2::total_context::TotalContext; +use super::value_handle::{ValueHandle, ValueKey}; +use crate::dataflow::TotalContext; /// An implementation of [DFContext] with [ValueHandle] /// that just stores a Hugr (actually any [HugrView]), diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 137699763..4bacef114 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -6,10 +6,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use super::partial_value::AbstractValue; - -mod context; -pub use context::HugrValueContext; +use crate::dataflow::AbstractValue; #[derive(Clone, Debug)] pub struct HashedConst { diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs new file mode 100644 index 000000000..ec00e7f9a --- /dev/null +++ b/hugr-passes/src/dataflow.rs @@ -0,0 +1,13 @@ +//! Dataflow analysis of Hugrs. + +mod datalog; +pub use datalog::Machine; + +mod partial_value; +pub use partial_value::{PartialValue, AbstractValue}; + +mod value_row; +pub use value_row::ValueRow; + +mod total_context; +pub use total_context::TotalContext; diff --git a/hugr-passes/src/const_fold2/datalog.rs b/hugr-passes/src/dataflow/datalog.rs similarity index 100% rename from hugr-passes/src/const_fold2/datalog.rs rename to hugr-passes/src/dataflow/datalog.rs diff --git a/hugr-passes/src/const_fold2/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs similarity index 99% rename from hugr-passes/src/const_fold2/datalog/test.rs rename to hugr-passes/src/dataflow/datalog/test.rs index 42c34c451..4f3a3b187 100644 --- a/hugr-passes/src/const_fold2/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -1,4 +1,4 @@ -use crate::const_fold2::value_handle::HugrValueContext; +use crate::const_fold2::HugrValueContext; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, diff --git a/hugr-passes/src/const_fold2/datalog/utils.rs b/hugr-passes/src/dataflow/datalog/utils.rs similarity index 100% rename from hugr-passes/src/const_fold2/datalog/utils.rs rename to hugr-passes/src/dataflow/datalog/utils.rs diff --git a/hugr-passes/src/const_fold2/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs similarity index 100% rename from hugr-passes/src/const_fold2/partial_value.rs rename to hugr-passes/src/dataflow/partial_value.rs diff --git a/hugr-passes/src/const_fold2/total_context.rs b/hugr-passes/src/dataflow/total_context.rs similarity index 100% rename from hugr-passes/src/const_fold2/total_context.rs rename to hugr-passes/src/dataflow/total_context.rs diff --git a/hugr-passes/src/const_fold2/value_row.rs b/hugr-passes/src/dataflow/value_row.rs similarity index 100% rename from hugr-passes/src/const_fold2/value_row.rs rename to hugr-passes/src/dataflow/value_row.rs diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 8949d8bd4..9bf576a5e 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,7 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; +pub mod dataflow; pub mod const_fold2; pub mod force_order; mod half_node; From 03ff1658a62805ecbd0e2654ccc38035924d41d9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 15:44:49 +0100 Subject: [PATCH 042/203] fmt --- hugr-passes/src/const_fold2.rs | 2 +- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/lib.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 1fa3498e0..93b772d88 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -2,6 +2,6 @@ //! to perform constant-folding. // These are pub because this "example" is used for testing the framework. -pub mod value_handle; mod context; +pub mod value_handle; pub use context::HugrValueContext; diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index ec00e7f9a..8c0ad1d8c 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -4,7 +4,7 @@ mod datalog; pub use datalog::Machine; mod partial_value; -pub use partial_value::{PartialValue, AbstractValue}; +pub use partial_value::{AbstractValue, PartialValue}; mod value_row; pub use value_row::ValueRow; diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 9bf576a5e..0b73fcbb0 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,8 +1,8 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; -pub mod dataflow; pub mod const_fold2; +pub mod dataflow; pub mod force_order; mod half_node; pub mod lower; From 25ed1fbb4efd7759ae07bf85827e2ef02c65fa61 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:11:14 +0100 Subject: [PATCH 043/203] TailLoopTermination just examine whatever PartialValue's we have, remove most --- hugr-passes/src/dataflow/datalog.rs | 23 ++--- hugr-passes/src/dataflow/datalog/test.rs | 2 +- hugr-passes/src/dataflow/datalog/utils.rs | 120 +--------------------- 3 files changed, 11 insertions(+), 134 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 064aab91f..c1655d280 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,4 +1,3 @@ -use ascent::lattice::BoundedLattice; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; use std::hash::Hash; @@ -108,15 +107,6 @@ ascent::ascent! { if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); - lattice tail_loop_termination(C,Node,TailLoopTermination); - tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <-- - tail_loop_node(c,tl_n); - tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <-- - tail_loop_node(c,tl_n), - io_node(c,tl,out_n, IO::Output), - in_wire_value(c, out_n, IncomingPort::from(0), v); - - // Conditional relation conditional_node(C, Node); relation case_node(C,Node,usize, Node); @@ -221,11 +211,14 @@ impl> Machine { pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { assert!(hugr.get_optype(node).is_tail_loop()); - self.0 - .tail_loop_termination - .iter() - .find_map(|(_, n, v)| (n == &node).then_some(*v)) - .unwrap() + let [_, out] = hugr.get_io(node).unwrap(); + TailLoopTermination::from_control_value( + self.0 + .in_wire_value + .iter() + .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) + .unwrap(), + ) } pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs index 4f3a3b187..e9e61fb9e 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -125,7 +125,7 @@ fn test_tail_loop_always_iterates() { let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom().into()); assert_eq!( - TailLoopTermination::bottom(), + TailLoopTermination::Bottom, machine.tail_loop_terminates(&hugr, tail_loop.node()) ) } diff --git a/hugr-passes/src/dataflow/datalog/utils.rs b/hugr-passes/src/dataflow/datalog/utils.rs index ae486e280..117d58628 100644 --- a/hugr-passes/src/dataflow/datalog/utils.rs +++ b/hugr-passes/src/dataflow/datalog/utils.rs @@ -1,18 +1,8 @@ -// proptest-derive generates many of these warnings. -// https://github.com/rust-lang/rust/issues/120363 -// https://github.com/proptest-rs/proptest/issues/447 -#![cfg_attr(test, allow(non_local_definitions))] - -use std::cmp::Ordering; - use ascent::lattice::{BoundedLattice, Lattice}; use super::super::partial_value::{AbstractValue, PartialValue}; use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort}; -#[cfg(test)] -use proptest_derive::Arbitrary; - impl Lattice for PartialValue { fn meet(self, other: Self) -> Self { self.meet(other) @@ -57,7 +47,6 @@ pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator Option { - if self == other { - return Some(std::cmp::Ordering::Equal); - }; - match (self, other) { - (Self::Bottom, _) => Some(Ordering::Less), - (_, Self::Bottom) => Some(Ordering::Greater), - (Self::Top, _) => Some(Ordering::Greater), - (_, Self::Top) => Some(Ordering::Less), - _ => None, - } - } -} - -impl Lattice for TailLoopTermination { - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - - fn meet_mut(&mut self, other: Self) -> bool { - // let new_self = &mut self; - match (*self).partial_cmp(&other) { - Some(Ordering::Greater) => { - *self = other; - true - } - Some(_) => false, - _ => { - *self = Self::Bottom; - true - } - } - } - - fn join_mut(&mut self, other: Self) -> bool { - match (*self).partial_cmp(&other) { - Some(Ordering::Less) => { - *self = other; - true - } - Some(_) => false, - _ => { - *self = Self::Top; - true - } - } - } -} - -impl BoundedLattice for TailLoopTermination { - fn bottom() -> Self { - Self::Bottom - } - - fn top() -> Self { - Self::Top - } -} - -#[cfg(test)] -#[cfg_attr(test, allow(non_local_definitions))] -mod test { - use super::*; - use proptest::prelude::*; - - proptest! { - #[test] - fn bounded_lattice(v: TailLoopTermination) { - prop_assert!(v <= TailLoopTermination::top()); - prop_assert!(v >= TailLoopTermination::bottom()); - } - - #[test] - fn meet_join_self_noop(v1: TailLoopTermination) { - let mut subject = v1.clone(); - - assert_eq!(v1.clone(), v1.clone().join(v1.clone())); - assert!(!subject.join_mut(v1.clone())); - assert_eq!(subject, v1); - - assert_eq!(v1.clone(), v1.clone().meet(v1.clone())); - assert!(!subject.meet_mut(v1.clone())); - assert_eq!(subject, v1); - } - - #[test] - fn lattice(v1: TailLoopTermination, v2: TailLoopTermination) { - let meet = v1.clone().meet(v2.clone()); - prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); - prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); - - let join = v1.clone().join(v2.clone()); - prop_assert!(join >= v1, "join not >=: {:#?}", &join); - prop_assert!(join >= v2, "join not >=: {:#?}", &join); + Self::Bottom } } } From 09911df3001979314ec1b6aade2356c9ed9e4c2f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:21:47 +0100 Subject: [PATCH 044/203] Drop non-(Bounded)Lattice impls of (join/meet)(_mut),top,bottom --- hugr-passes/src/dataflow/datalog.rs | 1 + hugr-passes/src/dataflow/datalog/test.rs | 1 + hugr-passes/src/dataflow/datalog/utils.rs | 30 -------- hugr-passes/src/dataflow/partial_value.rs | 92 +++++++++++------------ hugr-passes/src/dataflow/value_row.rs | 2 +- 5 files changed, 48 insertions(+), 78 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index c1655d280..d0983ee6c 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,3 +1,4 @@ +use ascent::lattice::BoundedLattice; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; use std::hash::Hash; diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs index e9e61fb9e..c6c4eda4e 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -1,5 +1,6 @@ use crate::const_fold2::HugrValueContext; +use ascent::lattice::BoundedLattice; use hugr_core::{ builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::prelude::{UnpackTuple, BOOL_T}, diff --git a/hugr-passes/src/dataflow/datalog/utils.rs b/hugr-passes/src/dataflow/datalog/utils.rs index 117d58628..4d3a056c6 100644 --- a/hugr-passes/src/dataflow/datalog/utils.rs +++ b/hugr-passes/src/dataflow/datalog/utils.rs @@ -1,36 +1,6 @@ -use ascent::lattice::{BoundedLattice, Lattice}; - use super::super::partial_value::{AbstractValue, PartialValue}; use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort}; -impl Lattice for PartialValue { - fn meet(self, other: Self) -> Self { - self.meet(other) - } - - fn meet_mut(&mut self, other: Self) -> bool { - self.meet_mut(other) - } - - fn join(self, other: Self) -> Self { - self.join(other) - } - - fn join_mut(&mut self, other: Self) -> bool { - self.join_mut(other) - } -} - -impl BoundedLattice for PartialValue { - fn bottom() -> Self { - Self::bottom() - } - - fn top() -> Self { - Self::top() - } -} - pub(super) fn input_count(h: &impl HugrView, n: Node) -> usize { h.signature(n) .as_ref() diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 3c20ff965..670a7e38c 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,5 +1,7 @@ #![allow(missing_docs)] +use ascent::lattice::BoundedLattice; +use ascent::Lattice; use itertools::zip_eq; use std::cmp::Ordering; use std::collections::HashMap; @@ -176,15 +178,6 @@ impl From> for PartialValue { impl PartialValue { // const BOTTOM: Self = Self::Bottom; // const BOTTOM_REF: &'static Self = &Self::BOTTOM; - - // fn initialised(&self) -> bool { - // !self.is_top() - // } - - // fn is_top(&self) -> bool { - // self == &PartialValue::Top - // } - fn assert_invariants(&self) { match self { Self::PartialSum(ps) => { @@ -249,7 +242,42 @@ impl PartialValue { self } - pub fn join_mut(&mut self, other: Self) -> bool { + pub fn variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::variant(tag, values).into() + } + + pub fn unit() -> Self { + Self::variant(0, []) + } + + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + let vals = match self { + PartialValue::Bottom => return None, + PartialValue::Value(v) => v + .as_sum() + .filter(|(variant, _)| tag == *variant)? + .1 + .map(PartialValue::Value) + .collect(), + PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, + PartialValue::Top => vec![PartialValue::Top; len], + }; + assert_eq!(vals.len(), len); + Some(vals) + } + + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom => false, + PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } +} + +impl Lattice for PartialValue { + fn join_mut(&mut self, other: Self) -> bool { // println!("join {self:?}\n{:?}", &other); let changed = match (&*self, other) { (Self::Top, _) => false, @@ -301,12 +329,12 @@ impl PartialValue { changed } - pub fn meet(mut self, other: Self) -> Self { + fn meet(mut self, other: Self) -> Self { self.meet_mut(other); self } - pub fn meet_mut(&mut self, other: Self) -> bool { + fn meet_mut(&mut self, other: Self) -> bool { let changed = match (&*self, other) { (Self::Bottom, _) => false, (_, other @ Self::Bottom) => { @@ -356,47 +384,16 @@ impl PartialValue { // } changed } +} - pub fn top() -> Self { +impl BoundedLattice for PartialValue { + fn top() -> Self { Self::Top } - pub fn bottom() -> Self { + fn bottom() -> Self { Self::Bottom } - - pub fn variant(tag: usize, values: impl IntoIterator) -> Self { - PartialSum::variant(tag, values).into() - } - - pub fn unit() -> Self { - Self::variant(0, []) - } - - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { - let vals = match self { - PartialValue::Bottom => return None, - PartialValue::Value(v) => v - .as_sum() - .filter(|(variant, _)| tag == *variant)? - .1 - .map(PartialValue::Value) - .collect(), - PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, - PartialValue::Top => vec![PartialValue::Top; len], - }; - assert_eq!(vals.len(), len); - Some(vals) - } - - pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom => false, - PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, - } - } } impl PartialOrd for PartialValue { @@ -420,6 +417,7 @@ impl PartialOrd for PartialValue { mod test { use std::sync::Arc; + use ascent::{lattice::BoundedLattice, Lattice}; use itertools::{zip_eq, Either, Itertools as _}; use proptest::prelude::*; diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 91b45052c..9f7b8bef7 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -5,7 +5,7 @@ use std::{ ops::{Index, IndexMut}, }; -use ascent::lattice::Lattice; +use ascent::lattice::{BoundedLattice, Lattice}; use itertools::zip_eq; use super::partial_value::{AbstractValue, PartialValue}; From 248fb0757c357a3168e1db66ea6cb6334d1896c0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:25:47 +0100 Subject: [PATCH 045/203] doc fixes, remove missing-docs for partial_value...how does it still work --- hugr-passes/src/const_fold2/context.rs | 8 +++++--- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/datalog.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 2 -- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 007fcfa92..f632427dd 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -8,9 +8,11 @@ use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::value_handle::{ValueHandle, ValueKey}; use crate::dataflow::TotalContext; -/// An implementation of [DFContext] with [ValueHandle] -/// that just stores a Hugr (actually any [HugrView]), -/// (there is )no state for operation-interpretation). +/// A context ([DFContext]) for doing analysis with [ValueHandle]s. +/// Just stores a Hugr (actually any [HugrView]), +/// (there is )no state for operation-interpretation. +/// +/// [DFContext]: crate::dataflow::DFContext #[derive(Debug)] pub struct HugrValueContext(Arc); diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 8c0ad1d8c..3ceda2570 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,7 +1,7 @@ //! Dataflow analysis of Hugrs. mod datalog; -pub use datalog::Machine; +pub use datalog::{DFContext, Machine}; mod partial_value; pub use partial_value::{AbstractValue, PartialValue}; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index d0983ee6c..00cc59279 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -179,7 +179,7 @@ pub struct Machine>( /// Usage: /// 1. [Self::new()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values -/// 3. Exactly one [Self::run_hugr] to do the analysis +/// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] impl> Machine { pub fn new() -> Self { diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 670a7e38c..ee5c6b8ba 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,5 +1,3 @@ -#![allow(missing_docs)] - use ascent::lattice::BoundedLattice; use ascent::Lattice; use itertools::zip_eq; From e2ad079141f9ec528c4f79991a8f95ddd3f16a83 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:46:33 +0100 Subject: [PATCH 046/203] fix all warnings (inc Machine::new() -> impl Default) --- hugr-passes/src/const_fold2/context.rs | 5 +--- hugr-passes/src/dataflow/datalog.rs | 29 +++++++++++-------- hugr-passes/src/dataflow/datalog/test.rs | 20 ++++++------- hugr-passes/src/dataflow/partial_value.rs | 34 +++++++---------------- hugr-passes/src/dataflow/total_context.rs | 6 ++-- 5 files changed, 42 insertions(+), 52 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index f632427dd..c18f5430b 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -83,10 +83,7 @@ impl TotalContext for HugrValueContext { )] } OpType::ExtensionOp(op) => { - let ins = ins - .into_iter() - .map(|(p, v)| (*p, v.clone())) - .collect::>(); + let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); op.constant_fold(&ins).map_or(Vec::new(), |outs| { outs.into_iter() .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 00cc59279..17552ad59 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,3 +1,9 @@ +#![allow( + clippy::clone_on_copy, + clippy::unused_enumerate_index, + clippy::collapsible_if +)] + use ascent::lattice::BoundedLattice; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use std::collections::HashMap; @@ -149,18 +155,17 @@ fn propagate_leaf_op( // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be // easily available for reuse. - op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::variant( - 0, - ins.into_iter().cloned(), - )])), + op if op.cast::().is_some() => { + Some(ValueRow::from_iter([PV::variant(0, ins.iter().cloned())])) + } op if op.cast::().is_some() => { - let [tup] = ins.into_iter().collect::>().try_into().unwrap(); + let [tup] = ins.iter().collect::>().try_into().unwrap(); tup.variant_values(0, utils::value_outputs(c.as_ref(), n).count()) .map(ValueRow::from_iter) } OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( t.tag, - ins.into_iter().cloned(), + ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) => None, // handled by parent // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, @@ -170,22 +175,24 @@ fn propagate_leaf_op( } } -// TODO This should probably be called 'Analyser' or something pub struct Machine>( AscentProgram, Option>>, ); +/// derived-Default requires the context to be Defaultable, which is unnecessary +impl> Default for Machine { + fn default() -> Self { + Self(Default::default(), None) + } +} + /// Usage: /// 1. [Self::new()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] impl> Machine { - pub fn new() -> Self { - Self(Default::default(), None) - } - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator)>) { assert!(self.1.is_none()); self.0 diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs index c6c4eda4e..d531e0ffd 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -22,7 +22,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); let x = machine.read_out_wire_value(&hugr, v3).unwrap(); @@ -41,7 +41,7 @@ fn test_unpack_tuple() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); @@ -60,7 +60,7 @@ fn test_unpack_const() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); @@ -86,7 +86,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); @@ -118,7 +118,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); @@ -168,15 +168,15 @@ fn test_tail_loop_iterates_twice() { // we should be able to propagate their values let [o_w1, o_w2] = tail_loop.outputs_arr(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); // TODO these hould be the propagated values for now they will bt join(true,false) - let o_r1 = machine.read_out_wire_partial_value(o_w1).unwrap(); + let _ = machine.read_out_wire_partial_value(o_w1).unwrap(); // assert_eq!(o_r1, PartialValue::top()); - let o_r2 = machine.read_out_wire_partial_value(o_w2).unwrap(); + let _ = machine.read_out_wire_partial_value(o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( TailLoopTermination::Top, @@ -212,7 +212,7 @@ fn conditional() { let case2 = case2_b.finish_with_outputs([false_w, c2a]).unwrap(); let case3_b = cond_builder.case_builder(2).unwrap(); - let [c3_1, c3_2] = case3_b.input_wires_arr(); + let [c3_1, _c3_2] = case3_b.input_wires_arr(); let case3 = case3_b.finish_with_outputs([c3_1, false_w]).unwrap(); let cond = cond_builder.finish_sub_container().unwrap(); @@ -221,7 +221,7 @@ fn conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::new(); + let mut machine = Machine::default(); let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); machine.propolutate_out_wires([(arg_w, arg_pv.into())]); diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index ee5c6b8ba..f1f05f877 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -122,7 +122,7 @@ impl PartialOrd for PartialSum { return None; } for (k, lhs) in &self.0 { - let Some(rhs) = other.0.get(&k) else { + let Some(rhs) = other.0.get(k) else { unreachable!() }; match lhs.partial_cmp(rhs) { @@ -192,8 +192,10 @@ impl PartialValue { self.assert_invariants(); match &*self { Self::Top => return false, - Self::Value(v) if v == &vh => return false, Self::Value(v) => { + if v == &vh { + return false; + }; *self = Self::Top; } Self::PartialSum(_) => match vh.into() { @@ -277,7 +279,7 @@ impl PartialValue { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { // println!("join {self:?}\n{:?}", &other); - let changed = match (&*self, other) { + match (&*self, other) { (Self::Top, _) => false, (_, other @ Self::Top) => { *self = other; @@ -316,15 +318,7 @@ impl Lattice for PartialValue { self.join_mut_value_handle(old_self) } (_, Self::Value(h)) => self.join_mut_value_handle(h), - // (new_self, _) => { - // **new_self = Self::Top; - // false - // } - }; - // if changed { - // println!("join new self: {:?}", s); - // } - changed + } } fn meet(mut self, other: Self) -> Self { @@ -333,7 +327,7 @@ impl Lattice for PartialValue { } fn meet_mut(&mut self, other: Self) -> bool { - let changed = match (&*self, other) { + match (&*self, other) { (Self::Bottom, _) => false, (_, other @ Self::Bottom) => { *self = other; @@ -372,15 +366,7 @@ impl Lattice for PartialValue { self.meet_mut_value_handle(old_self) } (Self::PartialSum(_), Self::Value(h)) => self.meet_mut_value_handle(h), - // (new_self, _) => { - // **new_self = Self::Bottom; - // false - // } - }; - // if changed { - // println!("join new self: {:?}", s); - // } - changed + } } } @@ -519,8 +505,7 @@ mod test { } impl TestSumType { - const UNIT: TestSumLeafType = TestSumLeafType::Unit; - + #[allow(unused)] // ALAN ? fn leaf(v: Type) -> Self { TestSumType::Leaf(TestSumLeafType::Int(v)) } @@ -543,6 +528,7 @@ mod test { } } + #[allow(unused)] // ALAN ? fn is_leaf(&self) -> bool { self.depth() == 0 } diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 8caafda52..882175af0 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -74,11 +74,11 @@ impl PartialSum { if v.len() != r.len() { return Err(self); } - match zip_eq(v.into_iter(), r.into_iter()) + match zip_eq(v, r.iter()) .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { - Ok(vs) => V2::try_new_sum(*k, vs, &st).map_err(|_| self), + Ok(vs) => V2::try_new_sum(*k, vs, st).map_err(|_| self), Err(_) => Err(self), } } @@ -90,7 +90,7 @@ impl> DFContext for T { let sig = op.dataflow_signature()?; let known_ins = sig .input_types() - .into_iter() + .iter() .enumerate() .zip(ins.iter()) .filter_map(|((i, ty), pv)| { From 95e2dd952c6aab27c180eef90db003e19933caab Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 16:59:33 +0100 Subject: [PATCH 047/203] distribute utils.rs -> machine.rs --- hugr-passes/src/dataflow.rs | 10 +- hugr-passes/src/dataflow/datalog.rs | 115 ++++------------------ hugr-passes/src/dataflow/datalog/test.rs | 7 +- hugr-passes/src/dataflow/datalog/utils.rs | 37 ------- hugr-passes/src/dataflow/machine.rs | 110 +++++++++++++++++++++ hugr-passes/src/dataflow/total_context.rs | 2 +- 6 files changed, 145 insertions(+), 136 deletions(-) delete mode 100644 hugr-passes/src/dataflow/datalog/utils.rs create mode 100644 hugr-passes/src/dataflow/machine.rs diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 3ceda2570..452d070be 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,7 +1,8 @@ //! Dataflow analysis of Hugrs. mod datalog; -pub use datalog::{DFContext, Machine}; +mod machine; +pub use machine::Machine; mod partial_value; pub use partial_value::{AbstractValue, PartialValue}; @@ -11,3 +12,10 @@ pub use value_row::ValueRow; mod total_context; pub use total_context::TotalContext; + +use hugr_core::{Hugr, Node}; +use std::hash::Hash; + +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; +} diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 17552ad59..411b1dbcc 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,3 +1,6 @@ +//! [ascent] datalog implementation of analysis. +//! Since ascent-(macro-)generated code generates a bunch of warnings, +//! keep code in here to a minimum. #![allow( clippy::clone_on_copy, clippy::unused_enumerate_index, @@ -6,25 +9,16 @@ use ascent::lattice::BoundedLattice; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use std::collections::HashMap; +use hugr_core::types::Signature; use std::hash::Hash; -use hugr_core::ops::{OpType, Value}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; +use hugr_core::ops::OpType; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; -mod utils; - -// TODO separate this into its own analysis? -use utils::TailLoopTermination; - -use super::partial_value::AbstractValue; use super::value_row::ValueRow; +use super::{AbstractValue, DFContext}; type PV = super::partial_value::PartialValue; -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - fn interpret_leaf_op(&self, node: Node, ins: &[PV]) -> Option>; -} - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IO { Input, @@ -32,7 +26,7 @@ pub enum IO { } ascent::ascent! { - struct AscentProgram>; + pub(super) struct AscentProgram>; relation context(C); relation out_wire_value_proto(Node, OutgoingPort, PV); @@ -47,9 +41,9 @@ ascent::ascent! { node(c, n) <-- context(c), for n in c.nodes(); - in_wire(c, n,p) <-- node(c, n), for p in utils::value_inputs(c.as_ref(), *n); + in_wire(c, n,p) <-- node(c, n), for p in value_inputs(c.as_ref(), *n); - out_wire(c, n,p) <-- node(c, n), for p in utils::value_outputs(c.as_ref(), *n); + out_wire(c, n,p) <-- node(c, n), for p in value_outputs(c.as_ref(), *n); parent_of_node(c, parent, child) <-- node(c, child), if let Some(parent) = c.get_parent(*child); @@ -68,7 +62,7 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, ValueRow::new(utils::input_count(c.as_ref(), *n))) <-- node(c, n); + node_in_value_row(c, n, ValueRow::new(input_count(c.as_ref(), *n))) <-- node(c, n); node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input.len(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); out_wire_value(c, n, p, v) <-- @@ -160,7 +154,7 @@ fn propagate_leaf_op( } op if op.cast::().is_some() => { let [tup] = ins.iter().collect::>().try_into().unwrap(); - tup.variant_values(0, utils::value_outputs(c.as_ref(), n).count()) + tup.variant_values(0, value_outputs(c.as_ref(), n).count()) .map(ValueRow::from_iter) } OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( @@ -175,86 +169,19 @@ fn propagate_leaf_op( } } -pub struct Machine>( - AscentProgram, - Option>>, -); - -/// derived-Default requires the context to be Defaultable, which is unnecessary -impl> Default for Machine { - fn default() -> Self { - Self(Default::default(), None) - } +fn input_count(h: &impl HugrView, n: Node) -> usize { + h.signature(n) + .as_ref() + .map(Signature::input_count) + .unwrap_or(0) } -/// Usage: -/// 1. [Self::new()] -/// 2. Zero or more [Self::propolutate_out_wires] with initial values -/// 3. Exactly one [Self::run] to do the analysis -/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] -impl> Machine { - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator)>) { - assert!(self.1.is_none()); - self.0 - .out_wire_value_proto - .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); - } - - pub fn run(&mut self, context: C) { - assert!(self.1.is_none()); - self.0.context.push((context,)); - self.0.run(); - self.1 = Some( - self.0 - .out_wire_value - .iter() - .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) - .collect(), - ) - } - - pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { - self.1.as_ref().unwrap().get(&w).cloned() - } - - pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { - assert!(hugr.get_optype(node).is_tail_loop()); - let [_, out] = hugr.get_io(node).unwrap(); - TailLoopTermination::from_control_value( - self.0 - .in_wire_value - .iter() - .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) - .unwrap(), - ) - } - - pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { - assert!(hugr.get_optype(case).is_case()); - let cond = hugr.get_parent(case).unwrap(); - assert!(hugr.get_optype(cond).is_conditional()); - self.0 - .case_reachable - .iter() - .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) - .unwrap() - } +fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { + h.in_value_types(n).map(|x| x.0) } -impl> Machine -where - Value: From, -{ - pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { - // dbg!(&w); - let pv = self.read_out_wire_partial_value(w)?; - // dbg!(&pv); - let (_, typ) = hugr - .out_value_types(w.node()) - .find(|(p, _)| *p == w.source()) - .unwrap(); - pv.try_into_value(&typ).ok() - } +fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { + h.out_value_types(n).map(|x| x.0) } #[cfg(test)] diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/datalog/test.rs index d531e0ffd..68708084a 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/datalog/test.rs @@ -1,4 +1,7 @@ -use crate::const_fold2::HugrValueContext; +use crate::{ + const_fold2::HugrValueContext, + dataflow::{machine::TailLoopTermination, Machine}, +}; use ascent::lattice::BoundedLattice; use hugr_core::{ @@ -12,8 +15,6 @@ use hugr_core::{ use super::super::partial_value::PartialValue; -use super::*; - #[test] fn test_make_tuple() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); diff --git a/hugr-passes/src/dataflow/datalog/utils.rs b/hugr-passes/src/dataflow/datalog/utils.rs deleted file mode 100644 index 4d3a056c6..000000000 --- a/hugr-passes/src/dataflow/datalog/utils.rs +++ /dev/null @@ -1,37 +0,0 @@ -use super::super::partial_value::{AbstractValue, PartialValue}; -use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort}; - -pub(super) fn input_count(h: &impl HugrView, n: Node) -> usize { - h.signature(n) - .as_ref() - .map(Signature::input_count) - .unwrap_or(0) -} - -pub(super) fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { - h.in_value_types(n).map(|x| x.0) -} - -pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { - h.out_value_types(n).map(|x| x.0) -} - -#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] -pub enum TailLoopTermination { - Bottom, - ExactlyZeroContinues, - Top, -} - -impl TailLoopTermination { - pub fn from_control_value(v: &PartialValue) -> Self { - let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); - if may_break && !may_continue { - Self::ExactlyZeroContinues - } else if may_break && may_continue { - Self::Top - } else { - Self::Bottom - } - } -} diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs new file mode 100644 index 000000000..b6736a336 --- /dev/null +++ b/hugr-passes/src/dataflow/machine.rs @@ -0,0 +1,110 @@ +use std::collections::HashMap; + +use hugr_core::{ops::Value, HugrView, Node, PortIndex, Wire}; + +use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; + +pub struct Machine>( + AscentProgram, + Option>>, +); + +/// derived-Default requires the context to be Defaultable, which is unnecessary +impl> Default for Machine { + fn default() -> Self { + Self(Default::default(), None) + } +} + +/// Usage: +/// 1. [Self::new()] +/// 2. Zero or more [Self::propolutate_out_wires] with initial values +/// 3. Exactly one [Self::run] to do the analysis +/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] +impl> Machine { + pub fn propolutate_out_wires( + &mut self, + wires: impl IntoIterator)>, + ) { + assert!(self.1.is_none()); + self.0 + .out_wire_value_proto + .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); + } + + pub fn run(&mut self, context: C) { + assert!(self.1.is_none()); + self.0.context.push((context,)); + self.0.run(); + self.1 = Some( + self.0 + .out_wire_value + .iter() + .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(), + ) + } + + pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { + self.1.as_ref().unwrap().get(&w).cloned() + } + + pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { + assert!(hugr.get_optype(node).is_tail_loop()); + let [_, out] = hugr.get_io(node).unwrap(); + TailLoopTermination::from_control_value( + self.0 + .in_wire_value + .iter() + .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) + .unwrap(), + ) + } + + pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { + assert!(hugr.get_optype(case).is_case()); + let cond = hugr.get_parent(case).unwrap(); + assert!(hugr.get_optype(cond).is_conditional()); + self.0 + .case_reachable + .iter() + .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) + .unwrap() + } +} + +impl> Machine +where + Value: From, +{ + pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { + // dbg!(&w); + let pv = self.read_out_wire_partial_value(w)?; + // dbg!(&pv); + let (_, typ) = hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + pv.try_into_value(&typ).ok() + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +pub enum TailLoopTermination { + Bottom, + ExactlyZeroContinues, + Top, +} + +impl TailLoopTermination { + pub fn from_control_value(v: &PartialValue) -> Self { + let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); + if may_break && !may_continue { + Self::ExactlyZeroContinues + } else if may_break && may_continue { + Self::Top + } else { + Self::Bottom + } + } +} diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 882175af0..89067105b 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -5,8 +5,8 @@ use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; use itertools::{zip_eq, Itertools}; -use super::datalog::DFContext; use super::partial_value::{AbstractValue, PartialSum, PartialValue}; +use super::DFContext; use super::ValueRow; pub trait FromSum: Sized { From 7f1e122587d4832b5e263d35d99ff205748d7ef1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 17:03:14 +0100 Subject: [PATCH 048/203] Move dataflow{/datalog=>}/test.rs --- hugr-passes/src/dataflow.rs | 3 +++ hugr-passes/src/dataflow/datalog.rs | 3 --- hugr-passes/src/dataflow/{datalog => }/test.rs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename hugr-passes/src/dataflow/{datalog => }/test.rs (99%) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 452d070be..827489144 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -19,3 +19,6 @@ use std::hash::Hash; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; } + +#[cfg(test)] +mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 411b1dbcc..06f49150f 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -183,6 +183,3 @@ fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator impl Iterator + '_ { h.out_value_types(n).map(|x| x.0) } - -#[cfg(test)] -mod test; diff --git a/hugr-passes/src/dataflow/datalog/test.rs b/hugr-passes/src/dataflow/test.rs similarity index 99% rename from hugr-passes/src/dataflow/datalog/test.rs rename to hugr-passes/src/dataflow/test.rs index 68708084a..738e64073 100644 --- a/hugr-passes/src/dataflow/datalog/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -13,7 +13,7 @@ use hugr_core::{ types::{Signature, SumType, Type, TypeRow}, }; -use super::super::partial_value::PartialValue; +use super::partial_value::PartialValue; #[test] fn test_make_tuple() { From faff5560dae64f17f17f0c748d3121f4a4c6497c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 17:09:52 +0100 Subject: [PATCH 049/203] and more warnings --- hugr-passes/src/const_fold2/value_handle.rs | 4 ++-- hugr-passes/src/dataflow/partial_value.rs | 15 ++++++--------- hugr-passes/src/dataflow/test.rs | 6 +++--- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 4bacef114..f24a4b734 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -157,7 +157,7 @@ mod test { assert_ne!(k1, k3); assert_eq!(ValueKey::from(n), ValueKey::from(n)); - let f = ConstF64::new(3.141); + let f = ConstF64::new(std::f64::consts::PI); assert_eq!(ValueKey::new(n, f.clone()), ValueKey::from(n)); assert_ne!(ValueKey::new(n, f.clone()), ValueKey::new(n2, f)); // Node taken into account @@ -182,7 +182,7 @@ mod test { fn value_key_list() { let v1 = ConstInt::new_u(3, 3).unwrap(); let v2 = ConstInt::new_u(4, 3).unwrap(); - let v3 = ConstF64::new(3.141); + let v3 = ConstF64::new(std::f64::consts::PI); let n = Node::from(portgraph::NodeIndex::new(0)); let n2: Node = portgraph::NodeIndex::new(1).into(); diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f1f05f877..d792154c2 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -432,16 +432,13 @@ mod test { impl TestSumLeafType { fn assert_invariants(&self) { - match self { - Self::Int(t) => { - if let TypeEnum::Extension(ct) = t.as_type_enum() { - assert_eq!("int", ct.name()); - assert_eq!(&int_types::EXTENSION_ID, ct.extension()); - } else { - panic!("Expected int type, got {:#?}", t); - } + if let Self::Int(t) = self { + if let TypeEnum::Extension(ct) = t.as_type_enum() { + assert_eq!("int", ct.name()); + assert_eq!(&int_types::EXTENSION_ID, ct.extension()); + } else { + panic!("Expected int type, got {:#?}", t); } - _ => (), } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 738e64073..fd683eaa8 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -123,9 +123,9 @@ fn test_tail_loop_always_iterates() { machine.run(HugrValueContext::new(&hugr)); let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); - assert_eq!(o_r1, PartialValue::bottom().into()); + assert_eq!(o_r1, PartialValue::bottom()); let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); - assert_eq!(o_r2, PartialValue::bottom().into()); + assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( TailLoopTermination::Bottom, machine.tail_loop_terminates(&hugr, tail_loop.node()) @@ -225,7 +225,7 @@ fn conditional() { let mut machine = Machine::default(); let arg_pv = PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); - machine.propolutate_out_wires([(arg_w, arg_pv.into())]); + machine.propolutate_out_wires([(arg_w, arg_pv)]); machine.run(HugrValueContext::new(&hugr)); let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); From 401354dacdbcede64246138914d2894b9e58a7a5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 17:14:45 +0100 Subject: [PATCH 050/203] fix extension tests --- hugr-passes/src/dataflow/test.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index fd683eaa8..749de5dd3 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -5,7 +5,7 @@ use crate::{ use ascent::lattice::BoundedLattice; use hugr_core::{ - builder::{DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, + builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::prelude::{UnpackTuple, BOOL_T}, extension::{ExtensionSet, EMPTY_REG}, ops::{handle::NodeHandle, OpTrait, Value}, @@ -17,7 +17,7 @@ use super::partial_value::PartialValue; #[test] fn test_make_tuple() { - let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); @@ -32,7 +32,7 @@ fn test_make_tuple() { #[test] fn test_unpack_tuple() { - let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); let v1 = builder.add_load_value(Value::false_val()); let v2 = builder.add_load_value(Value::true_val()); let v3 = builder.make_tuple([v1, v2]).unwrap(); @@ -53,7 +53,7 @@ fn test_unpack_tuple() { #[test] fn test_unpack_const() { - let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); let [o] = builder .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) From c468387e5df3cfbb7e7c7000c4d89eafa76d8ab1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 2 Sep 2024 17:19:11 +0100 Subject: [PATCH 051/203] Fix doclink, fix DefaultHasher pre-1.76 --- hugr-passes/src/const_fold2/value_handle.rs | 3 ++- hugr-passes/src/dataflow/machine.rs | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index f24a4b734..7b6e26106 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -1,4 +1,5 @@ -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::ops::constant::{CustomConst, Sum}; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index b6736a336..6fe79208b 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -17,7 +17,7 @@ impl> Default for Machine { } /// Usage: -/// 1. [Self::new()] +/// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] From 88db5b18a7cb6b33be8e3229361c7188e3b856d5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 16:47:50 +0100 Subject: [PATCH 052/203] comment conditional test --- hugr-passes/src/dataflow/test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 749de5dd3..d8d8698af 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -232,7 +232,7 @@ fn conditional() { assert_eq!(cond_r1, Value::false_val()); assert!(machine.read_out_wire_value(&hugr, cond_o2).is_none()); - assert!(!machine.case_reachable(&hugr, case1.node())); + assert!(!machine.case_reachable(&hugr, case1.node())); // arg_pv is variant 1 or 2 only assert!(machine.case_reachable(&hugr, case2.node())); assert!(machine.case_reachable(&hugr, case3.node())); } From 738b61b168933bfc097eaa286f6d7eafa88cb360 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 16:48:19 +0100 Subject: [PATCH 053/203] Clarify (TODO untested) branches of join_mut --- hugr-passes/src/dataflow/partial_value.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index d792154c2..4a00de876 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -310,14 +310,14 @@ impl Lattice for PartialValue { } } } - (Self::Value(_), mut other) => { + (Self::Value(_), mut other@Self::PartialSum(_)) => { std::mem::swap(self, &mut other); let Self::Value(old_self) = other else { unreachable!() }; self.join_mut_value_handle(old_self) } - (_, Self::Value(h)) => self.join_mut_value_handle(h), + (Self::PartialSum(_), Self::Value(h)) => self.join_mut_value_handle(h), } } From 8f9c1ed852b22c25cb36fac01c1c484bd29206fb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 17:29:17 +0100 Subject: [PATCH 054/203] Exploit invariant PartialValue::Value is not a sum (even single known variant) --- hugr-passes/src/dataflow/partial_value.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 4a00de876..bcae3fc79 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -151,6 +151,8 @@ impl Hash for PartialSum { } } +/// We really must prevent people from constructing PartialValue::Value of +/// any `value` where `value.as_sum().is_some()`` #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub enum PartialValue { Bottom, @@ -253,12 +255,10 @@ impl PartialValue { pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { PartialValue::Bottom => return None, - PartialValue::Value(v) => v - .as_sum() - .filter(|(variant, _)| tag == *variant)? - .1 - .map(PartialValue::Value) - .collect(), + PartialValue::Value(v) => { + assert!(v.as_sum().is_none()); + return None; + } PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, PartialValue::Top => vec![PartialValue::Top; len], }; @@ -269,7 +269,10 @@ impl PartialValue { pub fn supports_tag(&self, tag: usize) -> bool { match self { PartialValue::Bottom => false, - PartialValue::Value(v) => v.as_sum().is_some_and(|(v, _)| v == tag), + PartialValue::Value(v) => { + assert!(v.as_sum().is_none()); + false + } PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } @@ -310,7 +313,7 @@ impl Lattice for PartialValue { } } } - (Self::Value(_), mut other@Self::PartialSum(_)) => { + (Self::Value(_), mut other @ Self::PartialSum(_)) => { std::mem::swap(self, &mut other); let Self::Value(old_self) = other else { unreachable!() From a71ba9743e81f4d496a076b3df9681d3a6ef78b6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 17:31:06 +0100 Subject: [PATCH 055/203] Exploit invariant more, RIP join_mut_value_handle --- hugr-passes/src/dataflow/partial_value.rs | 35 ++++------------------- 1 file changed, 5 insertions(+), 30 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index bcae3fc79..ce1b8cd9a 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -190,29 +190,6 @@ impl PartialValue { } } - fn join_mut_value_handle(&mut self, vh: V) -> bool { - self.assert_invariants(); - match &*self { - Self::Top => return false, - Self::Value(v) => { - if v == &vh { - return false; - }; - *self = Self::Top; - } - Self::PartialSum(_) => match vh.into() { - Self::Value(_) => { - *self = Self::Top; - } - other => return self.join_mut(other), - }, - Self::Bottom => { - *self = vh.into(); - } - }; - true - } - fn meet_mut_value_handle(&mut self, vh: V) -> bool { self.assert_invariants(); match &*self { @@ -313,14 +290,12 @@ impl Lattice for PartialValue { } } } - (Self::Value(_), mut other @ Self::PartialSum(_)) => { - std::mem::swap(self, &mut other); - let Self::Value(old_self) = other else { - unreachable!() - }; - self.join_mut_value_handle(old_self) + (Self::Value(ref v), Self::PartialSum(_)) + | (Self::PartialSum(_), Self::Value(ref v)) => { + assert!(v.as_sum().is_none()); + *self = Self::Top; + true } - (Self::PartialSum(_), Self::Value(h)) => self.join_mut_value_handle(h), } } From 05280a8efd30442caf4d020ef0371dbb14ab1ca1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 11 Sep 2024 17:34:29 +0100 Subject: [PATCH 056/203] By similar logic, RIP meet_mut_value_handle; assert_(in)variants now unused --- hugr-passes/src/dataflow/partial_value.rs | 38 +++-------------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index ce1b8cd9a..d9baf2c63 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -190,32 +190,6 @@ impl PartialValue { } } - fn meet_mut_value_handle(&mut self, vh: V) -> bool { - self.assert_invariants(); - match &*self { - Self::Bottom => false, - Self::Value(v) => { - if v == &vh { - false - } else { - *self = Self::Bottom; - true - } - } - Self::PartialSum(_) => match vh.into() { - Self::Value(_) => { - *self = Self::Bottom; - true - } - other => self.meet_mut(other), - }, - Self::Top => { - *self = vh.into(); - true - } - } - } - pub fn join(mut self, other: Self) -> Self { self.join_mut(other); self @@ -336,14 +310,12 @@ impl Lattice for PartialValue { } } } - (Self::Value(_), mut other @ Self::PartialSum(_)) => { - std::mem::swap(self, &mut other); - let Self::Value(old_self) = other else { - unreachable!() - }; - self.meet_mut_value_handle(old_self) + (Self::Value(ref v), Self::PartialSum(_)) + | (Self::PartialSum(_), Self::Value(ref v)) => { + assert!(v.as_sum().is_none()); + *self = Self::Bottom; + true } - (Self::PartialSum(_), Self::Value(h)) => self.meet_mut_value_handle(h), } } } From 96b0856f41bec7ea5c0680bfc8c15993c885b0eb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 13 Sep 2024 10:50:36 +0100 Subject: [PATCH 057/203] Rename TestSum(,Leaf)Type::assert_{invariants=>valid} --- hugr-passes/src/dataflow/partial_value.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index d9baf2c63..319371a5d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -381,7 +381,7 @@ mod test { } impl TestSumLeafType { - fn assert_invariants(&self) { + fn assert_valid(&self) { if let Self::Int(t) = self { if let TypeEnum::Extension(ct) = t.as_type_enum() { assert_eq!("int", ct.name()); @@ -480,17 +480,17 @@ mod test { self.depth() == 0 } - fn assert_invariants(&self) { + fn assert_valid(&self) { match self { TestSumType::Branch(d, sop) => { assert!(!sop.is_empty(), "No variants"); for v in sop.iter().flat_map(|x| x.iter()) { assert!(v.depth() < *d); - v.assert_invariants(); + v.assert_valid(); } } TestSumType::Leaf(l) => { - l.assert_invariants(); + l.assert_valid(); } } } @@ -601,7 +601,7 @@ mod test { proptest! { #[test] fn unary_sum_type_valid(ust: TestSumType) { - ust.assert_invariants(); + ust.assert_valid(); } } From f21e278dff9985dd3759c54a271aaa8861d3c604 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 13 Sep 2024 10:55:35 +0100 Subject: [PATCH 058/203] Rename assert_(=>in)variants (i.e. to match); call in (join/meet)_mut; fix! --- hugr-passes/src/dataflow/partial_value.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 319371a5d..3e1b654ed 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -31,7 +31,7 @@ impl PartialSum { } impl PartialSum { - fn assert_variants(&self) { + fn assert_invariants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { pv.assert_invariants(); @@ -164,7 +164,7 @@ pub enum PartialValue { impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() - .map(|(tag, values)| Self::variant(tag, values.map(Self::Value))) + .map(|(tag, values)| Self::variant(tag, values.map(Self::from))) .unwrap_or(Self::Value(v)) } } @@ -181,7 +181,7 @@ impl PartialValue { fn assert_invariants(&self) { match self { Self::PartialSum(ps) => { - ps.assert_variants(); + ps.assert_invariants(); } Self::Value(v) => { assert!(v.as_sum().is_none()) @@ -232,6 +232,7 @@ impl PartialValue { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); match (&*self, other) { (Self::Top, _) => false, @@ -279,6 +280,7 @@ impl Lattice for PartialValue { } fn meet_mut(&mut self, other: Self) -> bool { + self.assert_invariants(); match (&*self, other) { (Self::Bottom, _) => false, (_, other @ Self::Bottom) => { From ce53b1cc4832b1f13f9f171efc87d5bba5bdbd9f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 13 Sep 2024 11:16:40 +0100 Subject: [PATCH 059/203] test unpacking constant tuple --- hugr-passes/src/dataflow/test.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index d8d8698af..e4b3d5c24 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -31,13 +31,11 @@ fn test_make_tuple() { } #[test] -fn test_unpack_tuple() { +fn test_unpack_tuple_const() { let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); - let v1 = builder.add_load_value(Value::false_val()); - let v2 = builder.add_load_value(Value::true_val()); - let v3 = builder.make_tuple([v1, v2]).unwrap(); + let v = builder.add_load_value(Value::tuple([Value::false_val(), Value::true_val()])); let [o1, o2] = builder - .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v3]) + .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v]) .unwrap() .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); From 13f29a909e58d4b1cf64fb49df786b6983ede9f9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 12 Sep 2024 21:25:40 +0100 Subject: [PATCH 060/203] try_into_value returns new enum ValueOrSum; TryFrom replaces FromSum --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 28 ++++++++- hugr-passes/src/dataflow/partial_value.rs | 51 +++++++++++++++- hugr-passes/src/dataflow/total_context.rs | 73 ++--------------------- 4 files changed, 81 insertions(+), 73 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 827489144..52ecca9e1 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -5,7 +5,7 @@ mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PartialValue}; +pub use partial_value::{AbstractValue, PartialValue, ValueOrSum}; mod value_row; pub use value_row::ValueRow; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 6fe79208b..d4aa97e78 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,8 +1,10 @@ use std::collections::HashMap; -use hugr_core::{ops::Value, HugrView, Node, PortIndex, Wire}; +use hugr_core::{ops::Value, types::ConstTypeError, HugrView, Node, PortIndex, Wire}; -use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; +use super::{ + datalog::AscentProgram, partial_value::ValueOrSum, AbstractValue, DFContext, PartialValue, +}; pub struct Machine>( AscentProgram, @@ -85,7 +87,27 @@ where .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); - pv.try_into_value(&typ).ok() + let v: ValueOrSum = pv.try_into_value(&typ).ok()?; + v.try_into().ok() + } +} + +impl TryFrom> for Value +where + Value: From, +{ + type Error = ConstTypeError; + fn try_from(value: ValueOrSum) -> Result { + match value { + ValueOrSum::Value(v) => Ok(v.into()), + ValueOrSum::Sum { tag, items, st } => { + let items = items + .into_iter() + .map(Value::try_from) + .collect::, _>>()?; + Value::sum(tag, items, st.clone()) + } + } } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 3e1b654ed..e6b53b935 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,6 +1,7 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use itertools::zip_eq; +use hugr_core::types::{SumType, Type, TypeEnum, TypeRow}; +use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; @@ -13,6 +14,16 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ValueOrSum { + Value(V), + Sum { + tag: usize, + items: Vec, + st: SumType, + }, +} + // TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum #[derive(PartialEq, Clone, Eq)] pub struct PartialSum(pub HashMap>>); @@ -92,6 +103,36 @@ impl PartialSum { pub fn supports_tag(&self, tag: usize) -> bool { self.0.contains_key(&tag) } + + pub fn try_into_value(self, typ: &Type) -> Result, Self> { + let Ok((k, v)) = self.0.iter().exactly_one() else { + Err(self)? + }; + + let TypeEnum::Sum(st) = typ.as_type_enum() else { + Err(self)? + }; + let Some(r) = st.get_variant(*k) else { + Err(self)? + }; + let Ok(r) = TypeRow::try_from(r.clone()) else { + Err(self)? + }; + if v.len() != r.len() { + return Err(self); + } + match zip_eq(v, r.into_iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>() + { + Ok(vs) => Ok(ValueOrSum::Sum { + tag: *k, + items: vs, + st: st.clone(), + }), + Err(_) => Err(self), + } + } } impl PartialSum { @@ -228,6 +269,14 @@ impl PartialValue { PartialValue::Top => true, } } + + pub fn try_into_value(self, typ: &Type) -> Result, Self> { + match self { + Self::Value(v) => Ok(ValueOrSum::Value(v.clone())), + Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), + x => Err(x), + } + } } impl Lattice for PartialValue { diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 89067105b..26acc31ed 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -1,29 +1,16 @@ use std::hash::Hash; -use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use itertools::{zip_eq, Itertools}; -use super::partial_value::{AbstractValue, PartialSum, PartialValue}; +use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; use super::DFContext; use super::ValueRow; -pub trait FromSum: Sized { - type Err: std::error::Error; - fn try_new_sum( - tag: usize, - items: impl IntoIterator, - st: &SumType, - ) -> Result; - fn debug_check_is_type(&self, _ty: &Type) {} -} - /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) /// rather than e.g. Sums potentially of two variants each of known values. pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { - type InterpretableVal: FromSum + From; + type InterpretableVal: TryFrom>; fn interpret_leaf_op( &self, node: Node, @@ -31,59 +18,6 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { ) -> Vec<(OutgoingPort, V)>; } -impl FromSum for Value { - type Err = ConstTypeError; - fn try_new_sum( - tag: usize, - items: impl IntoIterator, - st: &hugr_core::types::SumType, - ) -> Result { - Self::sum(tag, items, st.clone()) - } -} - -// These are here because they rely on FromSum, that they are `impl PartialSum/Value` -// is merely a nice syntax. -impl PartialValue { - pub fn try_into_value>(self, typ: &Type) -> Result { - let r: V2 = match self { - Self::Value(v) => Ok(v.clone().into()), - Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), - x => Err(x), - }?; - r.debug_check_is_type(typ); - Ok(r) - } -} - -impl PartialSum { - pub fn try_into_value>(self, typ: &Type) -> Result { - let Ok((k, v)) = self.0.iter().exactly_one() else { - Err(self)? - }; - - let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(self)? - }; - let Some(r) = st.get_variant(*k) else { - Err(self)? - }; - let Ok(r): Result = r.clone().try_into() else { - Err(self)? - }; - if v.len() != r.len() { - return Err(self); - } - match zip_eq(v, r.iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>() - { - Ok(vs) => V2::try_new_sum(*k, vs, st).map_err(|_| self), - Err(_) => Err(self), - } - } -} - impl> DFContext for T { fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option> { let op = self.get_optype(node); @@ -96,7 +30,10 @@ impl> DFContext for T { .filter_map(|((i, ty), pv)| { pv.clone() .try_into_value(ty) + // Discard PVs which don't produce ValueOrSum, i.e. Bottom/Top :-) .ok() + // And discard any ValueOrSum that don't produce V - this is a bit silent :-( + .and_then(|v_s| T::InterpretableVal::try_from(v_s).ok()) .map(|v| (IncomingPort::from(i), v)) }) .collect::>(); From 9f1a5cda508988b5883816b62b3b3ad25e6f1c37 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 11:43:26 +0100 Subject: [PATCH 061/203] Hide ValueRow (and move into datalog.rs) --- hugr-passes/src/dataflow.rs | 10 +- hugr-passes/src/dataflow/datalog.rs | 122 +++++++++++++++++++++- hugr-passes/src/dataflow/total_context.rs | 9 +- hugr-passes/src/dataflow/value_row.rs | 117 --------------------- 4 files changed, 129 insertions(+), 129 deletions(-) delete mode 100644 hugr-passes/src/dataflow/value_row.rs diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 52ecca9e1..2e2e95936 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,15 +1,13 @@ //! Dataflow analysis of Hugrs. mod datalog; + mod machine; pub use machine::Machine; mod partial_value; pub use partial_value::{AbstractValue, PartialValue, ValueOrSum}; -mod value_row; -pub use value_row::ValueRow; - mod total_context; pub use total_context::TotalContext; @@ -17,7 +15,11 @@ use hugr_core::{Hugr, Node}; use std::hash::Hash; pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option>; + fn interpret_leaf_op( + &self, + node: Node, + ins: &[PartialValue], + ) -> Option>>; } #[cfg(test)] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 06f49150f..c4c798324 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,16 +7,20 @@ clippy::collapsible_if )] -use ascent::lattice::BoundedLattice; -use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::types::Signature; +use ascent::lattice::{BoundedLattice, Lattice}; +use itertools::zip_eq; +use std::cmp::Ordering; use std::hash::Hash; +use std::ops::{Index, IndexMut}; +use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::OpType; +use hugr_core::types::Signature; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; -use super::value_row::ValueRow; -use super::{AbstractValue, DFContext}; +use super::partial_value::{AbstractValue, PartialValue}; +use super::DFContext; + type PV = super::partial_value::PartialValue; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -183,3 +187,111 @@ fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator impl Iterator + '_ { h.out_value_types(n).map(|x| x.0) } + +// Wrap a (known-length) row of values into a lattice. Perhaps could be part of partial_value.rs? + +#[derive(PartialEq, Clone, Eq, Hash)] +struct ValueRow(Vec>); + +impl ValueRow { + pub fn new(len: usize) -> Self { + Self(vec![PartialValue::bottom(); len]) + } + + pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + pub fn iter(&self) -> impl Iterator> { + self.0.iter() + } + + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option> + '_> { + self[0] + .variant_values(variant, len) + .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) + } + + // fn initialised(&self) -> bool { + // self.0.iter().all(|x| x != &PV::top()) + // } +} + +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PartialValue; + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec>: Index, +{ + type Output = > as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +impl IndexMut for ValueRow +where + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 26acc31ed..dc3c7a69a 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -4,7 +4,6 @@ use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; use super::DFContext; -use super::ValueRow; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (in the lattice `V`) @@ -19,7 +18,11 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { } impl> DFContext for T { - fn interpret_leaf_op(&self, node: Node, ins: &[PartialValue]) -> Option> { + fn interpret_leaf_op( + &self, + node: Node, + ins: &[PartialValue], + ) -> Option>> { let op = self.get_optype(node); let sig = op.dataflow_signature()?; let known_ins = sig @@ -39,7 +42,7 @@ impl> DFContext for T { .collect::>(); let known_outs = self.interpret_leaf_op(node, &known_ins); (!known_outs.is_empty()).then(|| { - let mut res = ValueRow::new(sig.output_count()); + let mut res = vec![PartialValue::Bottom; sig.output_count()]; for (p, v) in known_outs { res[p.index()] = v.into(); } diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs deleted file mode 100644 index 9f7b8bef7..000000000 --- a/hugr-passes/src/dataflow/value_row.rs +++ /dev/null @@ -1,117 +0,0 @@ -// Really this is part of partial_value.rs - -use std::{ - cmp::Ordering, - ops::{Index, IndexMut}, -}; - -use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::zip_eq; - -use super::partial_value::{AbstractValue, PartialValue}; - -#[derive(PartialEq, Clone, Eq, Hash)] -pub struct ValueRow(Vec>); - -impl ValueRow { - pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) - } - - pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { - assert!(idx < len); - let mut r = Self::new(len); - r.0[idx] = v; - r - } - - pub fn iter(&self) -> impl Iterator> { - self.0.iter() - } - - pub fn unpack_first( - &self, - variant: usize, - len: usize, - ) -> Option> + '_> { - self[0] - .variant_values(variant, len) - .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) - } - - // fn initialised(&self) -> bool { - // self.0.iter().all(|x| x != &PV::top()) - // } -} - -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { - Self(iter.into_iter().collect()) - } -} - -impl PartialOrd for ValueRow { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl Lattice for ValueRow { - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - - fn join_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.join_mut(v2); - } - changed - } - - fn meet_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.meet_mut(v2); - } - changed - } -} - -impl IntoIterator for ValueRow { - type Item = PartialValue; - - type IntoIter = > as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl Index for ValueRow -where - Vec>: Index, -{ - type Output = > as Index>::Output; - - fn index(&self, index: Idx) -> &Self::Output { - self.0.index(index) - } -} - -impl IndexMut for ValueRow -where - Vec>: IndexMut, -{ - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { - self.0.index_mut(index) - } -} From 15e642e7d03762d424536def8d6c0abaeaa70380 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 12:30:40 +0100 Subject: [PATCH 062/203] Remove PartialSum::unit() --- hugr-passes/src/dataflow/partial_value.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index e6b53b935..26de29c3f 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -29,9 +29,6 @@ pub enum ValueOrSum { pub struct PartialSum(pub HashMap>>); impl PartialSum { - pub fn unit() -> Self { - Self::variant(0, []) - } pub fn variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -481,7 +478,7 @@ mod test { }) .boxed() } - Self::Unit => Just(PartialSum::unit().into()).boxed(), + Self::Unit => Just(PartialValue::unit()).boxed(), } } } From 514af1307e6b84e6c166dd8dd58b958f09fda257 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 13:30:17 +0100 Subject: [PATCH 063/203] PartialValue is private struct containing PVEnum (with ::Sum not ::PartialSum) --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 153 ++++++++++++---------- hugr-passes/src/dataflow/total_context.rs | 3 +- 3 files changed, 84 insertions(+), 74 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 2e2e95936..15c08be04 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -6,7 +6,7 @@ mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PartialValue, ValueOrSum}; +pub use partial_value::{AbstractValue, PVEnum, PartialValue, ValueOrSum}; mod total_context; pub use total_context::TotalContext; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 26de29c3f..87400759d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -189,13 +189,23 @@ impl Hash for PartialSum { } } -/// We really must prevent people from constructing PartialValue::Value of -/// any `value` where `value.as_sum().is_some()`` #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub struct PartialValue(PVEnum); + +impl PartialValue { + /// Allows to read the enum, which guarantees that we never return [PVEnum::Value] + /// for a value whose [AbstractValue::as_sum] is `Some` - any such value will be + /// in the form of a [PVEnum::Sum] instead. + pub fn as_enum(&self) -> &PVEnum { + &self.0 + } +} + +#[derive(PartialEq, Clone, Eq, Hash, Debug)] +pub enum PVEnum { Bottom, Value(V), - PartialSum(PartialSum), + Sum(PartialSum), Top, } @@ -203,25 +213,25 @@ impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() .map(|(tag, values)| Self::variant(tag, values.map(Self::from))) - .unwrap_or(Self::Value(v)) + .unwrap_or(Self(PVEnum::Value(v))) } } impl From> for PartialValue { fn from(v: PartialSum) -> Self { - Self::PartialSum(v) + Self(PVEnum::Sum(v)) } } impl PartialValue { - // const BOTTOM: Self = Self::Bottom; + // const BOTTOM: Self = PVEnum::Bottom; // const BOTTOM_REF: &'static Self = &Self::BOTTOM; fn assert_invariants(&self) { - match self { - Self::PartialSum(ps) => { + match &self.0 { + PVEnum::Sum(ps) => { ps.assert_invariants(); } - Self::Value(v) => { + PVEnum::Value(v) => { assert!(v.as_sum().is_none()) } _ => {} @@ -242,36 +252,36 @@ impl PartialValue { } pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { - let vals = match self { - PartialValue::Bottom => return None, - PartialValue::Value(v) => { + let vals = match &self.0 { + PVEnum::Bottom => return None, + PVEnum::Value(v) => { assert!(v.as_sum().is_none()); return None; } - PartialValue::PartialSum(ps) => ps.variant_values(tag, len)?, - PartialValue::Top => vec![PartialValue::Top; len], + PVEnum::Sum(ps) => ps.variant_values(tag, len)?, + PVEnum::Top => vec![PartialValue(PVEnum::Top); len], }; assert_eq!(vals.len(), len); Some(vals) } pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom => false, - PartialValue::Value(v) => { + match &self.0 { + PVEnum::Bottom => false, + PVEnum::Value(v) => { assert!(v.as_sum().is_none()); false } - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, + PVEnum::Sum(ps) => ps.supports_tag(tag), + PVEnum::Top => true, } } pub fn try_into_value(self, typ: &Type) -> Result, Self> { - match self { - Self::Value(v) => Ok(ValueOrSum::Value(v.clone())), - Self::PartialSum(ps) => ps.try_into_value(typ).map_err(Self::PartialSum), - x => Err(x), + match self.0 { + PVEnum::Value(v) => return Ok(ValueOrSum::Value(v.clone())), + PVEnum::Sum(ps) => ps.try_into_value(typ).map_err(Self::from), + _ => Err(self), } } } @@ -280,41 +290,40 @@ impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); - match (&*self, other) { - (Self::Top, _) => false, - (_, other @ Self::Top) => { - *self = other; + match (&self.0, other.0) { + (PVEnum::Top, _) => false, + (_, other @ PVEnum::Top) => { + self.0 = other; true } - (_, Self::Bottom) => false, - (Self::Bottom, other) => { - *self = other; + (_, PVEnum::Bottom) => false, + (PVEnum::Bottom, other) => { + self.0 = other; true } - (Self::Value(h1), Self::Value(h2)) => { + (PVEnum::Value(h1), PVEnum::Value(h2)) => { if h1 == &h2 { false } else { - *self = Self::Top; + self.0 = PVEnum::Top; true } } - (Self::PartialSum(_), Self::PartialSum(ps2)) => { - let Self::PartialSum(ps1) = self else { + (PVEnum::Sum(_), PVEnum::Sum(ps2)) => { + let Self(PVEnum::Sum(ps1)) = self else { unreachable!() }; match ps1.try_join_mut(ps2) { Ok(ch) => ch, Err(_) => { - *self = Self::Top; + self.0 = PVEnum::Top; true } } } - (Self::Value(ref v), Self::PartialSum(_)) - | (Self::PartialSum(_), Self::Value(ref v)) => { + (PVEnum::Value(ref v), PVEnum::Sum(_)) | (PVEnum::Sum(_), PVEnum::Value(ref v)) => { assert!(v.as_sum().is_none()); - *self = Self::Top; + self.0 = PVEnum::Top; true } } @@ -327,41 +336,41 @@ impl Lattice for PartialValue { fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - match (&*self, other) { - (Self::Bottom, _) => false, - (_, other @ Self::Bottom) => { - *self = other; + match (&self.0, other.0) { + (PVEnum::Bottom, _) => false, + (_, other @ PVEnum::Bottom) => { + self.0 = other; true } - (_, Self::Top) => false, - (Self::Top, other) => { - *self = other; + (_, PVEnum::Top) => false, + (PVEnum::Top, other) => { + self.0 = other; true } - (Self::Value(h1), Self::Value(h2)) => { + (PVEnum::Value(h1), PVEnum::Value(h2)) => { if h1 == &h2 { false } else { - *self = Self::Bottom; + self.0 = PVEnum::Bottom; true } } - (Self::PartialSum(_), Self::PartialSum(ps2)) => { - let Self::PartialSum(ps1) = self else { - unreachable!() + (PVEnum::Sum(_), PVEnum::Sum(ps2)) => { + let ps1 = match &mut self.0 { + PVEnum::Sum(ps1) => ps1, + _ => unreachable!(), }; match ps1.try_meet_mut(ps2) { Ok(ch) => ch, Err(_) => { - *self = Self::Bottom; + self.0 = PVEnum::Bottom; true } } } - (Self::Value(ref v), Self::PartialSum(_)) - | (Self::PartialSum(_), Self::Value(ref v)) => { + (PVEnum::Value(ref v), PVEnum::Sum(_)) | (PVEnum::Sum(_), PVEnum::Value(ref v)) => { assert!(v.as_sum().is_none()); - *self = Self::Bottom; + self.0 = PVEnum::Bottom; true } } @@ -370,26 +379,26 @@ impl Lattice for PartialValue { impl BoundedLattice for PartialValue { fn top() -> Self { - Self::Top + Self(PVEnum::Top) } fn bottom() -> Self { - Self::Bottom + Self(PVEnum::Bottom) } } impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; - match (self, other) { - (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), - (Self::Top, Self::Top) => Some(Ordering::Equal), - (Self::Bottom, _) => Some(Ordering::Less), - (_, Self::Bottom) => Some(Ordering::Greater), - (Self::Top, _) => Some(Ordering::Greater), - (_, Self::Top) => Some(Ordering::Less), - (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), - (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), + match (&self.0, &other.0) { + (PVEnum::Bottom, PVEnum::Bottom) => Some(Ordering::Equal), + (PVEnum::Top, PVEnum::Top) => Some(Ordering::Equal), + (PVEnum::Bottom, _) => Some(Ordering::Less), + (_, PVEnum::Bottom) => Some(Ordering::Greater), + (PVEnum::Top, _) => Some(Ordering::Greater), + (_, PVEnum::Top) => Some(Ordering::Less), + (PVEnum::Value(v1), PVEnum::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (PVEnum::Sum(ps1), PVEnum::Sum(ps2)) => ps1.partial_cmp(ps2), _ => None, } } @@ -408,7 +417,7 @@ mod test { types::{Type, TypeArg, TypeEnum}, }; - use super::{PartialSum, PartialValue}; + use super::{PVEnum, PartialSum, PartialValue}; use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; impl Arbitrary for ValueHandle { @@ -566,10 +575,10 @@ mod test { } fn type_check(&self, pv: &PartialValue) -> bool { - match (self, pv) { - (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (_, PartialValue::Value(v)) => self.get_type() == v.get_type(), - (TestSumType::Branch(_, sop), PartialValue::PartialSum(ps)) => { + match (self, pv.as_enum()) { + (_, PVEnum::Bottom) | (_, PVEnum::Top) => true, + (_, PVEnum::Value(v)) => self.get_type() == v.get_type(), + (TestSumType::Branch(_, sop), PVEnum::Sum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { return false; @@ -584,7 +593,7 @@ mod test { } true } - (Self::Leaf(l), PartialValue::PartialSum(ps)) => l.type_check(ps), + (Self::Leaf(l), PVEnum::Sum(ps)) => l.type_check(ps), } } } diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index dc3c7a69a..cba3f08fe 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -1,5 +1,6 @@ use std::hash::Hash; +use ascent::lattice::BoundedLattice; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; @@ -42,7 +43,7 @@ impl> DFContext for T { .collect::>(); let known_outs = self.interpret_leaf_op(node, &known_ins); (!known_outs.is_empty()).then(|| { - let mut res = vec![PartialValue::Bottom; sig.output_count()]; + let mut res = vec![PartialValue::bottom(); sig.output_count()]; for (p, v) in known_outs { res[p.index()] = v.into(); } From 5d86f4669ef6d976b28331ab15c67197bafd138f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 13:38:48 +0100 Subject: [PATCH 064/203] variant => new_variant, unit => new_unit --- hugr-passes/src/dataflow/datalog.rs | 9 +++++---- hugr-passes/src/dataflow/partial_value.rs | 16 ++++++++-------- hugr-passes/src/dataflow/test.rs | 6 ++++-- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index c4c798324..b8f5e7230 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -153,15 +153,16 @@ fn propagate_leaf_op( // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be // easily available for reuse. - op if op.cast::().is_some() => { - Some(ValueRow::from_iter([PV::variant(0, ins.iter().cloned())])) - } + op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( + 0, + ins.iter().cloned(), + )])), op if op.cast::().is_some() => { let [tup] = ins.iter().collect::>().try_into().unwrap(); tup.variant_values(0, value_outputs(c.as_ref(), n).count()) .map(ValueRow::from_iter) } - OpType::Tag(t) => Some(ValueRow::from_iter([PV::variant( + OpType::Tag(t) => Some(ValueRow::from_iter([PV::new_variant( t.tag, ins.iter().cloned(), )])), diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 87400759d..2d3a83742 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -29,7 +29,7 @@ pub enum ValueOrSum { pub struct PartialSum(pub HashMap>>); impl PartialSum { - pub fn variant(tag: usize, values: impl IntoIterator>) -> Self { + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -212,7 +212,7 @@ pub enum PVEnum { impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() - .map(|(tag, values)| Self::variant(tag, values.map(Self::from))) + .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) .unwrap_or(Self(PVEnum::Value(v))) } } @@ -243,12 +243,12 @@ impl PartialValue { self } - pub fn variant(tag: usize, values: impl IntoIterator) -> Self { - PartialSum::variant(tag, values).into() + pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::new_variant(tag, values).into() } - pub fn unit() -> Self { - Self::variant(0, []) + pub fn new_unit() -> Self { + Self::new_variant(0, []) } pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { @@ -487,7 +487,7 @@ mod test { }) .boxed() } - Self::Unit => Just(PartialValue::unit()).boxed(), + Self::Unit => Just(PartialValue::new_unit()).boxed(), } } } @@ -677,7 +677,7 @@ mod test { ) }) .collect_vec(); - pvs.prop_map(move |pvs| PartialValue::variant(index, pvs)) + pvs.prop_map(move |pvs| PartialValue::new_variant(index, pvs)) .boxed() } }) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e4b3d5c24..1d9668abf 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -221,8 +221,10 @@ fn conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - let arg_pv = - PartialValue::variant(1, []).join(PartialValue::variant(2, [PartialValue::variant(0, [])])); + let arg_pv = PartialValue::new_variant(1, []).join(PartialValue::new_variant( + 2, + [PartialValue::new_variant(0, [])], + )); machine.propolutate_out_wires([(arg_w, arg_pv)]); machine.run(HugrValueContext::new(&hugr)); From d8c8140b4b7e45b661579c5b27920e7a7217330f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 14:03:09 +0100 Subject: [PATCH 065/203] Simplify PartialOrd for PartialSum, keys(1,2) support cmp --- hugr-passes/src/dataflow/partial_value.rs | 30 ++++++++++------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 2d3a83742..1987eaf2d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -152,25 +152,21 @@ impl PartialOrd for PartialSum { keys2[*k] = 1; } - if let Some(ord) = keys1.partial_cmp(&keys2) { - if ord != Ordering::Equal { - return Some(ord); - } - } else { - return None; - } - for (k, lhs) in &self.0 { - let Some(rhs) = other.0.get(k) else { - unreachable!() - }; - match lhs.partial_cmp(rhs) { - Some(Ordering::Equal) => continue, - x => { - return x; + Some(match keys1.cmp(&keys2) { + ord @ Ordering::Greater | ord @ Ordering::Less => ord, + Ordering::Equal => { + for (k, lhs) in &self.0 { + let Some(rhs) = other.0.get(k) else { + unreachable!() + }; + let key_cmp = lhs.partial_cmp(rhs); + if key_cmp != Some(Ordering::Equal) { + return key_cmp; + } } + Ordering::Equal } - } - Some(Ordering::Equal) + }) } } From 13156857fdcceebc552757391daa17ac87e4347b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 14:16:51 +0100 Subject: [PATCH 066/203] PartialSum::variant_values does not take `len` (PartialValue:: still does) --- hugr-passes/src/dataflow/partial_value.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 1987eaf2d..c2eff146e 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -133,10 +133,8 @@ impl PartialSum { } impl PartialSum { - pub fn variant_values(&self, variant: usize, len: usize) -> Option>> { - let row = self.0.get(&variant)?; - assert!(row.len() == len); - Some(row.clone()) + pub fn variant_values(&self, variant: usize) -> Option>> { + self.0.get(&variant).cloned() } } @@ -254,7 +252,7 @@ impl PartialValue { assert!(v.as_sum().is_none()); return None; } - PVEnum::Sum(ps) => ps.variant_values(tag, len)?, + PVEnum::Sum(ps) => ps.variant_values(tag)?, PVEnum::Top => vec![PartialValue(PVEnum::Top); len], }; assert_eq!(vals.len(), len); From 2aaaeb9791526e9f2ec5371ea1f5349dd5f12a22 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 15:22:05 +0100 Subject: [PATCH 067/203] clippy --- hugr-passes/src/dataflow/partial_value.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index c2eff146e..7a31478c4 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -118,7 +118,7 @@ impl PartialSum { if v.len() != r.len() { return Err(self); } - match zip_eq(v, r.into_iter()) + match zip_eq(v, r.iter()) .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { @@ -273,7 +273,7 @@ impl PartialValue { pub fn try_into_value(self, typ: &Type) -> Result, Self> { match self.0 { - PVEnum::Value(v) => return Ok(ValueOrSum::Value(v.clone())), + PVEnum::Value(v) => Ok(ValueOrSum::Value(v.clone())), PVEnum::Sum(ps) => ps.try_into_value(typ).map_err(Self::from), _ => Err(self), } From 5619761cdec00745c8aba786173078554cb0d762 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:41:02 +0100 Subject: [PATCH 068/203] Machine::tail_loop_terminates + case_reachable return Option not panic --- hugr-passes/src/dataflow/machine.rs | 32 +++++++++++++++++------------ hugr-passes/src/dataflow/test.rs | 26 ++++++++++++++--------- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index d4aa97e78..ad57328d1 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -51,27 +51,33 @@ impl> Machine { self.1.as_ref().unwrap().get(&w).cloned() } - pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination { - assert!(hugr.get_optype(node).is_tail_loop()); + pub fn tail_loop_terminates( + &self, + hugr: impl HugrView, + node: Node, + ) -> Option { + hugr.get_optype(node).as_tail_loop()?; let [_, out] = hugr.get_io(node).unwrap(); - TailLoopTermination::from_control_value( + Some(TailLoopTermination::from_control_value( self.0 .in_wire_value .iter() .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) .unwrap(), - ) + )) } - pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool { - assert!(hugr.get_optype(case).is_case()); - let cond = hugr.get_parent(case).unwrap(); - assert!(hugr.get_optype(cond).is_conditional()); - self.0 - .case_reachable - .iter() - .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) - .unwrap() + pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> Option { + hugr.get_optype(case).as_case()?; + let cond = hugr.get_parent(case)?; + hugr.get_optype(cond).as_conditional()?; + Some( + self.0 + .case_reachable + .iter() + .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) + .unwrap(), + ) } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 1d9668abf..66b9285a4 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -6,11 +6,14 @@ use crate::{ use ascent::lattice::BoundedLattice; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, - extension::prelude::{UnpackTuple, BOOL_T}, - extension::{ExtensionSet, EMPTY_REG}, + extension::{ + prelude::{UnpackTuple, BOOL_T}, + ExtensionSet, EMPTY_REG, + }, ops::{handle::NodeHandle, OpTrait, Value}, type_row, types::{Signature, SumType, Type, TypeRow}, + HugrView, }; use super::partial_value::PartialValue; @@ -93,7 +96,7 @@ fn test_tail_loop_never_iterates() { let o_r = machine.read_out_wire_value(&hugr, tl_o).unwrap(); assert_eq!(o_r, r_v); assert_eq!( - TailLoopTermination::ExactlyZeroContinues, + Some(TailLoopTermination::ExactlyZeroContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) ) } @@ -125,9 +128,10 @@ fn test_tail_loop_always_iterates() { let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( - TailLoopTermination::Bottom, + Some(TailLoopTermination::Bottom), machine.tail_loop_terminates(&hugr, tail_loop.node()) - ) + ); + assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); } #[test] @@ -178,9 +182,10 @@ fn test_tail_loop_iterates_twice() { let _ = machine.read_out_wire_partial_value(o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( - TailLoopTermination::Top, + Some(TailLoopTermination::Top), machine.tail_loop_terminates(&hugr, tail_loop.node()) - ) + ); + assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); } #[test] @@ -232,7 +237,8 @@ fn conditional() { assert_eq!(cond_r1, Value::false_val()); assert!(machine.read_out_wire_value(&hugr, cond_o2).is_none()); - assert!(!machine.case_reachable(&hugr, case1.node())); // arg_pv is variant 1 or 2 only - assert!(machine.case_reachable(&hugr, case2.node())); - assert!(machine.case_reachable(&hugr, case3.node())); + assert_eq!(machine.case_reachable(&hugr, case1.node()), Some(false)); // arg_pv is variant 1 or 2 only + assert_eq!(machine.case_reachable(&hugr, case2.node()), Some(true)); + assert_eq!(machine.case_reachable(&hugr, case3.node()), Some(true)); + assert_eq!(machine.case_reachable(&hugr, cond.node()), None); } From f3c175c12e418bf83e0d44429a284691e8bea3b2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:46:37 +0100 Subject: [PATCH 069/203] Machine::read_out_wire_value fails with ConstTypeError if there was one --- hugr-passes/src/dataflow/machine.rs | 15 ++++++++++----- hugr-passes/src/dataflow/test.rs | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index ad57328d1..1b169abd2 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -85,16 +85,21 @@ impl> Machine where Value: From, { - pub fn read_out_wire_value(&self, hugr: impl HugrView, w: Wire) -> Option { + pub fn read_out_wire_value( + &self, + hugr: impl HugrView, + w: Wire, + ) -> Result> { // dbg!(&w); - let pv = self.read_out_wire_partial_value(w)?; - // dbg!(&pv); let (_, typ) = hugr .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); - let v: ValueOrSum = pv.try_into_value(&typ).ok()?; - v.try_into().ok() + let v = self + .read_out_wire_partial_value(w) + .and_then(|pv| pv.try_into_value(&typ).ok()) + .ok_or(None)?; + Ok(v.try_into().map_err(Some)?) } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 66b9285a4..844aa47d5 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -235,7 +235,7 @@ fn conditional() { let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(machine.read_out_wire_value(&hugr, cond_o2).is_none()); + assert!(machine.read_out_wire_value(&hugr, cond_o2).is_err()); assert_eq!(machine.case_reachable(&hugr, case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(machine.case_reachable(&hugr, case2.node()), Some(true)); From aad2ef00d0474074f60d3f6cbd89cf828c0a52a0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:48:02 +0100 Subject: [PATCH 070/203] PartialValue::try_(join|meet)_mut are pub, don't mutate upon failure --- hugr-passes/src/dataflow/partial_value.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 7a31478c4..b5ccc6ade 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -46,16 +46,17 @@ impl PartialSum { } } - // Err with key if any common rows have different lengths (self may have been mutated) - fn try_join_mut(&mut self, other: Self) -> Result { + // Err with key if any common rows have different lengths (self not mutated) + pub fn try_join_mut(&mut self, other: Self) -> Result { + for (k, v) in &other.0 { + if self.0.get(k).is_some_and(|row| row.len() != v.len()) { + return Err(*k); + } + } let mut changed = false; for (k, v) in other.0 { if let Some(row) = self.0.get_mut(&k) { - if v.len() != row.len() { - // Better to check first and avoid mutation, but fine here - return Err(k); - } for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { changed |= lhs.join_mut(rhs); } @@ -68,7 +69,7 @@ impl PartialSum { } // Error with key if any common rows have different lengths ( => Bottom) - fn try_meet_mut(&mut self, other: Self) -> Result { + pub fn try_meet_mut(&mut self, other: Self) -> Result { let mut changed = false; let mut keys_to_remove = vec![]; for (k, v) in self.0.iter() { From 0a8cc12bdabe332047d4d21eecdc9fc5c4eea496 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:48:20 +0100 Subject: [PATCH 071/203] Remove some commented-out code --- hugr-passes/src/dataflow/partial_value.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index b5ccc6ade..fb79a93f6 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -219,8 +219,6 @@ impl From> for PartialValue { } impl PartialValue { - // const BOTTOM: Self = PVEnum::Bottom; - // const BOTTOM_REF: &'static Self = &Self::BOTTOM; fn assert_invariants(&self) { match &self.0 { PVEnum::Sum(ps) => { From bfcd0a675109dab997f3f2f5c750ba6c068c9d9a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:39:48 +0100 Subject: [PATCH 072/203] Expose PartialSum --- hugr-passes/src/dataflow.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 15c08be04..d80ab275e 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -6,7 +6,7 @@ mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PVEnum, PartialValue, ValueOrSum}; +pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, ValueOrSum}; mod total_context; pub use total_context::TotalContext; From 248fb23eaacd880c1695ef13cc6653c3d29cab08 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 17:40:17 +0100 Subject: [PATCH 073/203] dataflow has docs! (enforced) --- hugr-passes/src/dataflow.rs | 5 ++ hugr-passes/src/dataflow/machine.rs | 44 +++++++++++++++-- hugr-passes/src/dataflow/partial_value.rs | 58 +++++++++++++++++++++-- hugr-passes/src/dataflow/total_context.rs | 8 +++- 4 files changed, 103 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index d80ab275e..6a8f94ebd 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -1,3 +1,4 @@ +#![warn(missing_docs)] //! Dataflow analysis of Hugrs. mod datalog; @@ -14,7 +15,11 @@ pub use total_context::TotalContext; use hugr_core::{Hugr, Node}; use std::hash::Hash; +/// Clients of the dataflow framework (particular analyses, such as constant folding) +/// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + /// Given lattice values for each input, produce lattice values for (what we know of) + /// the outputs. Returning `None` indicates nothing can be deduced. fn interpret_leaf_op( &self, node: Node, diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 1b169abd2..d15e49653 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -6,6 +6,11 @@ use super::{ datalog::AscentProgram, partial_value::ValueOrSum, AbstractValue, DFContext, PartialValue, }; +/// Basic structure for performing an analysis. Usage: +/// 1. Get a new instance via [Self::default()] +/// 2. Zero or more [Self::propolutate_out_wires] with initial values +/// 3. Exactly one [Self::run] to do the analysis +/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] pub struct Machine>( AscentProgram, Option>>, @@ -18,12 +23,9 @@ impl> Default for Machine { } } -/// Usage: -/// 1. Get a new instance via [Self::default()] -/// 2. Zero or more [Self::propolutate_out_wires] with initial values -/// 3. Exactly one [Self::run] to do the analysis -/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] impl> Machine { + /// Provide initial values for some wires. + /// (For example, if some properties of the Hugr's inputs are known.) pub fn propolutate_out_wires( &mut self, wires: impl IntoIterator)>, @@ -34,6 +36,13 @@ impl> Machine { .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); } + /// Run the analysis (iterate until a lattice fixpoint is reached). + /// The context passed in allows interpretation of leaf operations. + /// + /// # Panics + /// + /// If this Machine has been run already. + /// pub fn run(&mut self, context: C) { assert!(self.1.is_none()); self.0.context.push((context,)); @@ -47,10 +56,16 @@ impl> Machine { ) } + /// Gets the lattice value computed by [Self::run] for the given wire pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { self.1.as_ref().unwrap().get(&w).cloned() } + /// Tells whether a [TailLoop] node can terminate, i.e. whether + /// `Break` and/or `Continue` tags may be returned by the nested DFG. + /// Returns `None` if the specified `node` is not a [TailLoop]. + /// + /// [TailLoop]: hugr_core::ops::TailLoop pub fn tail_loop_terminates( &self, hugr: impl HugrView, @@ -67,6 +82,13 @@ impl> Machine { )) } + /// Tells whether a [Case] node is reachable, i.e. whether the predicate + /// to its parent [Conditional] may possibly have the tag corresponding to the [Case]. + /// Returns `None` if the specified `case` is not a [Case], or is not within a [Conditional] + /// (e.g. a [Case]-rooted Hugr). + /// + /// [Case]: hugr_core::ops::Case + /// [Conditional]: hugr_core::ops::Conditional pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> Option { hugr.get_optype(case).as_case()?; let cond = hugr.get_parent(case)?; @@ -85,6 +107,18 @@ impl> Machine where Value: From, { + /// Gets the Hugr [Value] computed by [Self::run] for the given wire, if possible. + /// (Only if the analysis determined a single `V`, or a Sum of `V`s with a single + /// possible tag, was present on that wire.) + /// + /// # Errors + /// `None` if the analysis did not result in a single [ValueOrSum] on that wire + /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] + /// + /// # Panics + /// If a [Type] for the specified wire could not be extracted from the Hugr + /// + /// [Type]: hugr_core::types::Type pub fn read_out_wire_value( &self, hugr: impl HugrView, diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index fb79a93f6..7a2e52f01 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -6,17 +6,31 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -/// Aka, deconstructible into Sum (TryIntoSum ?) +/// Trait for values which can be deconstructed into Sums (with a single known tag). +/// Required for values used in dataflow analysis. pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { - /// We write this way to optimize query/inspection (is-it-a-sum), + /// Deconstruct a value into a single known tag plus a row of values, if it is a [Sum]. + /// Note that one can just always return `None` but this will mean the analysis + /// is unable to understand untupling, and may give inconsistent results wrt. [Tag] + /// operations, etc. + /// + /// The signature is this way to optimize query/inspection (is-it-a-sum), /// at the cost of requiring more cloning during actual conversion /// (inside the lazy Iterator, or for the error case, as Self remains) + /// + /// [Sum]: TypeEnum::Sum + /// [Tag]: hugr_core::ops::Tag fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } +/// A struct returned from [PartialValue::try_into_value] and [PartialSum::try_into_value] +/// indicating the value is either a single value or a sum with a single known tag. #[derive(Clone, Debug, PartialEq, Eq)] pub enum ValueOrSum { + /// Single value in the domain `V` Value(V), + /// Sum with a single known Tag + #[allow(missing_docs)] Sum { tag: usize, items: Vec, @@ -24,15 +38,20 @@ pub enum ValueOrSum { }, } -// TODO ALAN inline into PartialValue? Has to be public as it's in a pub enum +/// A representation of a value of [SumType], that may have one or more possible tags, +/// with a [PartialValue] representation of each element-value of each possible tag. #[derive(PartialEq, Clone, Eq)] pub struct PartialSum(pub HashMap>>); impl PartialSum { + /// New instance for a single known tag. + /// (Multi-tag instances can be created via [Self::try_join_mut].) pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } + /// The number of possible variants we know about. (NOT the number + /// of tags possible for the value's type, whatever [SumType] that might be.) pub fn num_variants(&self) -> usize { self.0.len() } @@ -46,7 +65,10 @@ impl PartialSum { } } - // Err with key if any common rows have different lengths (self not mutated) + /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns + /// whether `self` has changed. + /// + /// Fails (without mutation) with the conflicting tag if any common rows have different lengths. pub fn try_join_mut(&mut self, other: Self) -> Result { for (k, v) in &other.0 { if self.0.get(k).is_some_and(|row| row.len() != v.len()) { @@ -68,7 +90,10 @@ impl PartialSum { Ok(changed) } - // Error with key if any common rows have different lengths ( => Bottom) + /// Mutates self according to lattice meet operation (towards `Bottom`). If successful, + /// returns whether `self` has changed. + /// + /// Fails (without mutation) with the conflicting tag if any common rows have different lengths pub fn try_meet_mut(&mut self, other: Self) -> Result { let mut changed = false; let mut keys_to_remove = vec![]; @@ -98,10 +123,14 @@ impl PartialSum { Ok(changed) } + /// Whether this sum might have the specified tag pub fn supports_tag(&self, tag: usize) -> bool { self.0.contains_key(&tag) } + /// Turns this instance into a [ValueOrSum::Sum] if it has exactly one possible tag, + /// otherwise failing and returning itself back unmodified (also if there is another + /// error, e.g. this instance is not described by `typ`). pub fn try_into_value(self, typ: &Type) -> Result, Self> { let Ok((k, v)) = self.0.iter().exactly_one() else { Err(self)? @@ -134,6 +163,7 @@ impl PartialSum { } impl PartialSum { + /// If this Sum might have the specified `tag`, get the elements inside that tag. pub fn variant_values(&self, variant: usize) -> Option>> { self.0.get(&variant).cloned() } @@ -184,6 +214,9 @@ impl Hash for PartialSum { } } +/// Wraps some underlying representation (knowledge) of values into a lattice +/// for use in dataflow analysis, including that an instance may be a [PartialSum] +/// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub struct PartialValue(PVEnum); @@ -196,11 +229,16 @@ impl PartialValue { } } +/// The contents of a [PartialValue], i.e. used as a view. #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub enum PVEnum { + /// No possibilities known (so far) Bottom, + /// A single value (of the underlying representation) Value(V), + /// Sum (with perhaps several possible tags) of underlying values Sum(PartialSum), + /// Might be more than one distinct value of the underlying type `V` Top, } @@ -231,19 +269,27 @@ impl PartialValue { } } + /// Computes the lattice-join (i.e. towards `Top`) of this [PartialValue] with another. pub fn join(mut self, other: Self) -> Self { self.join_mut(other); self } + /// New instance of a sum with a single known tag. pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { PartialSum::new_variant(tag, values).into() } + /// New instance of unit type (i.e. the only possible value, with no contents) pub fn new_unit() -> Self { Self::new_variant(0, []) } + /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. + /// + /// # Panics + /// + /// if the value is believed, for that tag, to have a number of values other than `len` pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match &self.0 { PVEnum::Bottom => return None, @@ -258,6 +304,7 @@ impl PartialValue { Some(vals) } + /// Tells us whether this value might be a Sum with the specified `tag` pub fn supports_tag(&self, tag: usize) -> bool { match &self.0 { PVEnum::Bottom => false, @@ -270,6 +317,7 @@ impl PartialValue { } } + /// Extracts a [ValueOrSum] if there is such a single representation pub fn try_into_value(self, typ: &Type) -> Result, Self> { match self.0 { PVEnum::Value(v) => Ok(ValueOrSum::Value(v.clone())), diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index cba3f08fe..262c250f9 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -7,10 +7,14 @@ use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; use super::DFContext; /// A simpler interface like [DFContext] but where the context only cares about -/// values that are completely known (in the lattice `V`) -/// rather than e.g. Sums potentially of two variants each of known values. +/// values that are completely known (as `V`s), i.e. not `Bottom`, `Top`, or +/// Sums of potentially multiple variants. pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { + /// The representation of values on which [Self::interpret_leaf_op] operates type InterpretableVal: TryFrom>; + /// Interpret a leaf op. + /// `ins` gives the input ports for which we know (interpretable) values, and will be non-empty. + /// Returns a list of output ports for which we know (abstract) values (may be empty). fn interpret_leaf_op( &self, node: Node, From 5b8654e4673800fb6d0129635ea2480496b69fa6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 2 Oct 2024 18:13:44 +0100 Subject: [PATCH 074/203] clippy --- hugr-passes/src/dataflow/machine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index d15e49653..834e9913b 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -133,7 +133,7 @@ where .read_out_wire_partial_value(w) .and_then(|pv| pv.try_into_value(&typ).ok()) .ok_or(None)?; - Ok(v.try_into().map_err(Some)?) + v.try_into().map_err(Some) } } From c40e718761e235df438835f05ada796c3c47b46e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 11:50:18 +0100 Subject: [PATCH 075/203] Move PartialValue::join into impl Lattice for --- hugr-passes/src/const_fold2.rs | 1 + hugr-passes/src/const_fold2/context.rs | 7 +++++++ hugr-passes/src/const_fold2/value_handle.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 11 +++++------ hugr-passes/src/dataflow/test.rs | 2 +- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs index 93b772d88..58f285d43 100644 --- a/hugr-passes/src/const_fold2.rs +++ b/hugr-passes/src/const_fold2.rs @@ -1,3 +1,4 @@ +#![warn(missing_docs)] //! An (example) use of the [super::dataflow](dataflow-analysis framework) //! to perform constant-folding. diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index c18f5430b..32fc57765 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -9,6 +9,10 @@ use super::value_handle::{ValueHandle, ValueKey}; use crate::dataflow::TotalContext; /// A context ([DFContext]) for doing analysis with [ValueHandle]s. +/// Interprets [LoadConstant](OpType::LoadConstant) nodes, +/// and [ExtensionOp](OpType::ExtensionOp) nodes where the extension does +/// (using [Value]s for extension-op inputs). +/// /// Just stores a Hugr (actually any [HugrView]), /// (there is )no state for operation-interpretation. /// @@ -17,6 +21,7 @@ use crate::dataflow::TotalContext; pub struct HugrValueContext(Arc); impl HugrValueContext { + /// Creates a new instance, given ownership of the [HugrView] pub fn new(hugr: H) -> Self { Self(Arc::new(hugr)) } @@ -30,6 +35,8 @@ impl Clone for HugrValueContext { } } +// Any value used in an Ascent program must be hashable. +// However, there should only be one DFContext, so its hash is immaterial. impl Hash for HugrValueContext { fn hash(&self, _state: &mut I) {} } diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index 7b6e26106..bbcd25129 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -63,7 +63,7 @@ impl ValueKey { }) } - pub fn field(self, i: usize) -> Self { + fn field(self, i: usize) -> Self { Self::Field(i, Box::new(self)) } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 7a2e52f01..85b7ed395 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -269,12 +269,6 @@ impl PartialValue { } } - /// Computes the lattice-join (i.e. towards `Top`) of this [PartialValue] with another. - pub fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - /// New instance of a sum with a single known tag. pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { PartialSum::new_variant(tag, values).into() @@ -328,6 +322,11 @@ impl PartialValue { } impl Lattice for PartialValue { + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } + fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 844aa47d5..a7c90236f 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,7 +3,7 @@ use crate::{ dataflow::{machine::TailLoopTermination, Machine}, }; -use ascent::lattice::BoundedLattice; +use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ From 8732a637756d511e0a67a1a873e7e444f8293c74 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 12:17:59 +0100 Subject: [PATCH 076/203] Machine::read_out_wire_value => PartialValue::try_into_wire_value Is it worth keeping the ValueOrSum intermediate?? --- hugr-passes/src/dataflow/machine.rs | 36 +------------------ hugr-passes/src/dataflow/partial_value.rs | 35 +++++++++++++++++-- hugr-passes/src/dataflow/test.rs | 42 +++++++++++++++++++---- 3 files changed, 69 insertions(+), 44 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 834e9913b..533ac3a07 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -10,7 +10,7 @@ use super::{ /// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis -/// 4. Results then available via [Self::read_out_wire_partial_value] and [Self::read_out_wire_value] +/// 4. Results then available via [Self::read_out_wire_partial_value] pub struct Machine>( AscentProgram, Option>>, @@ -103,40 +103,6 @@ impl> Machine { } } -impl> Machine -where - Value: From, -{ - /// Gets the Hugr [Value] computed by [Self::run] for the given wire, if possible. - /// (Only if the analysis determined a single `V`, or a Sum of `V`s with a single - /// possible tag, was present on that wire.) - /// - /// # Errors - /// `None` if the analysis did not result in a single [ValueOrSum] on that wire - /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] - /// - /// # Panics - /// If a [Type] for the specified wire could not be extracted from the Hugr - /// - /// [Type]: hugr_core::types::Type - pub fn read_out_wire_value( - &self, - hugr: impl HugrView, - w: Wire, - ) -> Result> { - // dbg!(&w); - let (_, typ) = hugr - .out_value_types(w.node()) - .find(|(p, _)| *p == w.source()) - .unwrap(); - let v = self - .read_out_wire_partial_value(w) - .and_then(|pv| pv.try_into_value(&typ).ok()) - .ok_or(None)?; - v.try_into().map_err(Some) - } -} - impl TryFrom> for Value where Value: From, diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 85b7ed395..727cde174 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,6 +1,8 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use hugr_core::types::{SumType, Type, TypeEnum, TypeRow}; +use hugr_core::ops::Value; +use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::{HugrView, Wire}; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -311,7 +313,8 @@ impl PartialValue { } } - /// Extracts a [ValueOrSum] if there is such a single representation + /// Extracts a [ValueOrSum] if there is such a single representation, + /// given a [Type] pub fn try_into_value(self, typ: &Type) -> Result, Self> { match self.0 { PVEnum::Value(v) => Ok(ValueOrSum::Value(v.clone())), @@ -321,6 +324,34 @@ impl PartialValue { } } +impl PartialValue +where + Value: From, +{ + /// Extracts a [ValueOrSum] if there is such a single representation, + /// given a HugrView and Wire that determine the type. + /// + /// # Errors + /// `None` if the analysis did not result in a single [ValueOrSum] on that wire + /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] + /// + /// # Panics + /// + /// If a [Type] for the specified wire could not be extracted from the Hugr + pub fn try_into_wire_value( + self, + hugr: &impl HugrView, + w: Wire, + ) -> Result> { + let (_, typ) = hugr + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + let vs = self.try_into_value(&typ).map_err(|_| None)?; + vs.try_into().map_err(Some) + } +} + impl Lattice for PartialValue { fn join(mut self, other: Self) -> Self { self.join_mut(other); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index a7c90236f..fcca86b1b 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -29,7 +29,11 @@ fn test_make_tuple() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - let x = machine.read_out_wire_value(&hugr, v3).unwrap(); + let x = machine + .read_out_wire_partial_value(v3) + .unwrap() + .try_into_wire_value(&hugr, v3) + .unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); } @@ -46,9 +50,17 @@ fn test_unpack_tuple_const() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - let o1_r = machine.read_out_wire_value(&hugr, o1).unwrap(); + let o1_r = machine + .read_out_wire_partial_value(o1) + .unwrap() + .try_into_wire_value(&hugr, o1) + .unwrap(); assert_eq!(o1_r, Value::false_val()); - let o2_r = machine.read_out_wire_value(&hugr, o2).unwrap(); + let o2_r = machine + .read_out_wire_partial_value(o2) + .unwrap() + .try_into_wire_value(&hugr, o2) + .unwrap(); assert_eq!(o2_r, Value::true_val()); } @@ -65,7 +77,11 @@ fn test_unpack_const() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - let o_r = machine.read_out_wire_value(&hugr, o).unwrap(); + let o_r = machine + .read_out_wire_partial_value(o) + .unwrap() + .try_into_wire_value(&hugr, o) + .unwrap(); assert_eq!(o_r, Value::true_val()); } @@ -93,7 +109,11 @@ fn test_tail_loop_never_iterates() { // dbg!(&machine.tail_loop_io_node); // dbg!(&machine.out_wire_value); - let o_r = machine.read_out_wire_value(&hugr, tl_o).unwrap(); + let o_r = machine + .read_out_wire_partial_value(tl_o) + .unwrap() + .try_into_wire_value(&hugr, tl_o) + .unwrap(); assert_eq!(o_r, r_v); assert_eq!( Some(TailLoopTermination::ExactlyZeroContinues), @@ -233,9 +253,17 @@ fn conditional() { machine.propolutate_out_wires([(arg_w, arg_pv)]); machine.run(HugrValueContext::new(&hugr)); - let cond_r1 = machine.read_out_wire_value(&hugr, cond_o1).unwrap(); + let cond_r1 = machine + .read_out_wire_partial_value(cond_o1) + .unwrap() + .try_into_wire_value(&hugr, cond_o1) + .unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(machine.read_out_wire_value(&hugr, cond_o2).is_err()); + assert!(machine + .read_out_wire_partial_value(cond_o2) + .unwrap() + .try_into_wire_value(&hugr, cond_o2) + .is_err()); assert_eq!(machine.case_reachable(&hugr, case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(machine.case_reachable(&hugr, case2.node()), Some(true)); From 346187d7cd64ae2d84514c177f32727554fe8868 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 12:19:32 +0100 Subject: [PATCH 077/203] read_out_wire_partial_value => read_out_wire --- hugr-passes/src/dataflow/machine.rs | 4 ++-- hugr-passes/src/dataflow/test.rs | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 533ac3a07..f1b685fd8 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -10,7 +10,7 @@ use super::{ /// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis -/// 4. Results then available via [Self::read_out_wire_partial_value] +/// 4. Results then available via [Self::read_out_wire] pub struct Machine>( AscentProgram, Option>>, @@ -57,7 +57,7 @@ impl> Machine { } /// Gets the lattice value computed by [Self::run] for the given wire - pub fn read_out_wire_partial_value(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.1.as_ref().unwrap().get(&w).cloned() } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index fcca86b1b..cfc9b8975 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -30,7 +30,7 @@ fn test_make_tuple() { machine.run(HugrValueContext::new(&hugr)); let x = machine - .read_out_wire_partial_value(v3) + .read_out_wire(v3) .unwrap() .try_into_wire_value(&hugr, v3) .unwrap(); @@ -51,13 +51,13 @@ fn test_unpack_tuple_const() { machine.run(HugrValueContext::new(&hugr)); let o1_r = machine - .read_out_wire_partial_value(o1) + .read_out_wire(o1) .unwrap() .try_into_wire_value(&hugr, o1) .unwrap(); assert_eq!(o1_r, Value::false_val()); let o2_r = machine - .read_out_wire_partial_value(o2) + .read_out_wire(o2) .unwrap() .try_into_wire_value(&hugr, o2) .unwrap(); @@ -78,7 +78,7 @@ fn test_unpack_const() { machine.run(HugrValueContext::new(&hugr)); let o_r = machine - .read_out_wire_partial_value(o) + .read_out_wire(o) .unwrap() .try_into_wire_value(&hugr, o) .unwrap(); @@ -110,7 +110,7 @@ fn test_tail_loop_never_iterates() { // dbg!(&machine.out_wire_value); let o_r = machine - .read_out_wire_partial_value(tl_o) + .read_out_wire(tl_o) .unwrap() .try_into_wire_value(&hugr, tl_o) .unwrap(); @@ -143,9 +143,9 @@ fn test_tail_loop_always_iterates() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - let o_r1 = machine.read_out_wire_partial_value(tl_o1).unwrap(); + let o_r1 = machine.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); - let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap(); + let o_r2 = machine.read_out_wire(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( Some(TailLoopTermination::Bottom), @@ -197,9 +197,9 @@ fn test_tail_loop_iterates_twice() { // dbg!(&machine.out_wire_value); // TODO these hould be the propagated values for now they will bt join(true,false) - let _ = machine.read_out_wire_partial_value(o_w1).unwrap(); + let _ = machine.read_out_wire(o_w1).unwrap(); // assert_eq!(o_r1, PartialValue::top()); - let _ = machine.read_out_wire_partial_value(o_w2).unwrap(); + let _ = machine.read_out_wire(o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( Some(TailLoopTermination::Top), @@ -254,13 +254,13 @@ fn conditional() { machine.run(HugrValueContext::new(&hugr)); let cond_r1 = machine - .read_out_wire_partial_value(cond_o1) + .read_out_wire(cond_o1) .unwrap() .try_into_wire_value(&hugr, cond_o1) .unwrap(); assert_eq!(cond_r1, Value::false_val()); assert!(machine - .read_out_wire_partial_value(cond_o2) + .read_out_wire(cond_o2) .unwrap() .try_into_wire_value(&hugr, cond_o2) .is_err()); From a139f9e0e78eccb49c45f7fa5e8b4f82b46f82d6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 12:53:16 +0100 Subject: [PATCH 078/203] Remove ValueOrSum (and add Sum) via complex parametrization of try_into_value --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 25 +------- hugr-passes/src/dataflow/partial_value.rs | 70 ++++++++++++++--------- hugr-passes/src/dataflow/total_context.rs | 9 +-- 4 files changed, 49 insertions(+), 57 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 6a8f94ebd..f7c7555fa 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -7,7 +7,7 @@ mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, ValueOrSum}; +pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, Sum}; mod total_context; pub use total_context::TotalContext; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index f1b685fd8..acd3ac1ca 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,10 +1,8 @@ use std::collections::HashMap; -use hugr_core::{ops::Value, types::ConstTypeError, HugrView, Node, PortIndex, Wire}; +use hugr_core::{HugrView, Node, PortIndex, Wire}; -use super::{ - datalog::AscentProgram, partial_value::ValueOrSum, AbstractValue, DFContext, PartialValue, -}; +use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] @@ -103,25 +101,6 @@ impl> Machine { } } -impl TryFrom> for Value -where - Value: From, -{ - type Error = ConstTypeError; - fn try_from(value: ValueOrSum) -> Result { - match value { - ValueOrSum::Value(v) => Ok(v.into()), - ValueOrSum::Sum { tag, items, st } => { - let items = items - .into_iter() - .map(Value::try_from) - .collect::, _>>()?; - Value::sum(tag, items, st.clone()) - } - } - } -} - #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub enum TailLoopTermination { Bottom, diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 727cde174..4eef5787d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -25,19 +25,18 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; } -/// A struct returned from [PartialValue::try_into_value] and [PartialSum::try_into_value] -/// indicating the value is either a single value or a sum with a single known tag. +/// Represents a sum with a single/known tag, abstracted over the representation of the elements. +/// (Identical to [Sum](hugr_core::ops::constant::Sum) except for the type abstraction.) #[derive(Clone, Debug, PartialEq, Eq)] -pub enum ValueOrSum { - /// Single value in the domain `V` - Value(V), - /// Sum with a single known Tag - #[allow(missing_docs)] - Sum { - tag: usize, - items: Vec, - st: SumType, - }, +pub struct Sum { + /// The tag index of the variant. + pub tag: usize, + /// The value of the variant. + /// + /// Sum variants are always a row of values, hence the Vec. + pub values: Vec, + /// The full type of the Sum, including the other variants. + pub st: SumType, } /// A representation of a value of [SumType], that may have one or more possible tags, @@ -130,10 +129,14 @@ impl PartialSum { self.0.contains_key(&tag) } - /// Turns this instance into a [ValueOrSum::Sum] if it has exactly one possible tag, + /// Turns this instance into a [Sum] if it has exactly one possible tag, /// otherwise failing and returning itself back unmodified (also if there is another /// error, e.g. this instance is not described by `typ`). - pub fn try_into_value(self, typ: &Type) -> Result, Self> { + // ALAN is this too parametric? Should we fix V2 == Value? Is the 'Self' error useful (no?) + pub fn try_into_value + TryFrom>>( + self, + typ: &Type, + ) -> Result, Self> { let Ok((k, v)) = self.0.iter().exactly_one() else { Err(self)? }; @@ -154,9 +157,9 @@ impl PartialSum { .map(|(v, t)| v.clone().try_into_value(t)) .collect::, _>>() { - Ok(vs) => Ok(ValueOrSum::Sum { + Ok(values) => Ok(Sum { tag: *k, - items: vs, + values, st: st.clone(), }), Err(_) => Err(self), @@ -313,26 +316,40 @@ impl PartialValue { } } - /// Extracts a [ValueOrSum] if there is such a single representation, - /// given a [Type] - pub fn try_into_value(self, typ: &Type) -> Result, Self> { + /// Extracts a value (in any representation supporting both leaf values and sums) + // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? + pub fn try_into_value + TryFrom>>( + self, + typ: &Type, + ) -> Result>>::Error>> { match self.0 { - PVEnum::Value(v) => Ok(ValueOrSum::Value(v.clone())), - PVEnum::Sum(ps) => ps.try_into_value(typ).map_err(Self::from), - _ => Err(self), + PVEnum::Value(v) => Ok(V2::from(v.clone())), + PVEnum::Sum(ps) => { + let v = ps.try_into_value(typ).map_err(|_| None)?; + V2::try_from(v).map_err(Some) + } + _ => Err(None), } } } +impl TryFrom> for Value { + type Error = ConstTypeError; + + fn try_from(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } +} + impl PartialValue where Value: From, { - /// Extracts a [ValueOrSum] if there is such a single representation, - /// given a HugrView and Wire that determine the type. + /// Turns this instance into a [Value], if it is either a single [value](PVEnum::Value) or + /// a [sum](PVEnum::Sum) with a single known tag, extracting the desired type from a HugrView and Wire. /// /// # Errors - /// `None` if the analysis did not result in a single [ValueOrSum] on that wire + /// `None` if the analysis did not result in a single value on that wire /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] /// /// # Panics @@ -347,8 +364,7 @@ where .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); - let vs = self.try_into_value(&typ).map_err(|_| None)?; - vs.try_into().map_err(Some) + self.try_into_value(&typ) } } diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 262c250f9..2326d78cd 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -3,7 +3,7 @@ use std::hash::Hash; use ascent::lattice::BoundedLattice; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use super::partial_value::{AbstractValue, PartialValue, ValueOrSum}; +use super::partial_value::{AbstractValue, PartialValue, Sum}; use super::DFContext; /// A simpler interface like [DFContext] but where the context only cares about @@ -11,7 +11,7 @@ use super::DFContext; /// Sums of potentially multiple variants. pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { /// The representation of values on which [Self::interpret_leaf_op] operates - type InterpretableVal: TryFrom>; + type InterpretableVal: From + TryFrom>; /// Interpret a leaf op. /// `ins` gives the input ports for which we know (interpretable) values, and will be non-empty. /// Returns a list of output ports for which we know (abstract) values (may be empty). @@ -37,11 +37,8 @@ impl> DFContext for T { .zip(ins.iter()) .filter_map(|((i, ty), pv)| { pv.clone() - .try_into_value(ty) - // Discard PVs which don't produce ValueOrSum, i.e. Bottom/Top :-) + .try_into_value::<>::InterpretableVal>(ty) .ok() - // And discard any ValueOrSum that don't produce V - this is a bit silent :-( - .and_then(|v_s| T::InterpretableVal::try_from(v_s).ok()) .map(|v| (IncomingPort::from(i), v)) }) .collect::>(); From 7f2a91a5fc5bc26143f5e82543c172e10ebea90d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 7 Oct 2024 17:18:35 +0100 Subject: [PATCH 079/203] Datalog works on any AbstractValue; impl'd by PartialValue for a BaseValue --- hugr-passes/src/const_fold2/value_handle.rs | 4 +- hugr-passes/src/dataflow.rs | 18 +---- hugr-passes/src/dataflow/datalog.rs | 90 +++++++++++++-------- hugr-passes/src/dataflow/machine.rs | 19 ++--- hugr-passes/src/dataflow/partial_value.rs | 79 +++++++++--------- hugr-passes/src/dataflow/test.rs | 2 +- hugr-passes/src/dataflow/total_context.rs | 6 +- 7 files changed, 114 insertions(+), 104 deletions(-) diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs index bbcd25129..59a08b50a 100644 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ b/hugr-passes/src/const_fold2/value_handle.rs @@ -7,7 +7,7 @@ use hugr_core::ops::Value; use hugr_core::types::Type; use hugr_core::Node; -use crate::dataflow::AbstractValue; +use crate::dataflow::BaseValue; #[derive(Clone, Debug)] pub struct HashedConst { @@ -85,7 +85,7 @@ impl ValueHandle { } } -impl AbstractValue for ValueHandle { +impl BaseValue for ValueHandle { fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { match self.value() { Value::Sum(Sum { tag, values, .. }) => Some(( diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f7c7555fa..a66edde03 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -2,30 +2,16 @@ //! Dataflow analysis of Hugrs. mod datalog; +pub use datalog::{AbstractValue, DFContext}; mod machine; pub use machine::Machine; mod partial_value; -pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, Sum}; +pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum}; mod total_context; pub use total_context::TotalContext; -use hugr_core::{Hugr, Node}; -use std::hash::Hash; - -/// Clients of the dataflow framework (particular analyses, such as constant folding) -/// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - /// Given lattice values for each input, produce lattice values for (what we know of) - /// the outputs. Returning `None` indicates nothing can be deduced. - fn interpret_leaf_op( - &self, - node: Node, - ins: &[PartialValue], - ) -> Option>>; -} - #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b8f5e7230..fcde4f96b 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -16,12 +16,7 @@ use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::OpType; use hugr_core::types::Signature; -use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; - -use super::partial_value::{AbstractValue, PartialValue}; -use super::DFContext; - -type PV = super::partial_value::PartialValue; +use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IO { @@ -29,19 +24,50 @@ pub enum IO { Output, } +/// Clients of the dataflow framework (particular analyses, such as constant folding) +/// must implement this trait (including providing an appropriate domain type `PV`). +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + /// Given lattice values for each input, produce lattice values for (what we know of) + /// the outputs. Returning `None` indicates nothing can be deduced. + fn interpret_leaf_op(&self, node: Node, ins: &[PV]) -> Option>; +} + +/// Values which can be the domain for dataflow analysis. Must be able to deconstructed +/// into (and constructed from) Sums as these determine control flow. +pub trait AbstractValue: BoundedLattice + Clone + Eq + Hash + std::fmt::Debug { + /// Create a new instance representing a Sum with a single known tag + /// and (recursive) representations of the elements within that tag. + fn new_variant(tag: usize, values: impl IntoIterator) -> Self; + + /// New instance of unit type (i.e. the only possible value, with no contents) + fn new_unit() -> Self { + Self::new_variant(0, []) + } + + /// Test whether this value *might* be a Sum with the specified tag. + fn supports_tag(&self, tag: usize) -> bool; + + /// If this value might be a Sum with the specified tag, return values + /// describing the elements of the Sum, otherwise `None`. + /// + /// Implementations must hold the invariant that for all `x`, `tag` and `len`: + /// `x.variant_values(tag, len).is_some() == x.supports_tag(tag)` + fn variant_values(&self, tag: usize, len: usize) -> Option>; +} + ascent::ascent! { - pub(super) struct AscentProgram>; + pub(super) struct AscentProgram>; relation context(C); - relation out_wire_value_proto(Node, OutgoingPort, PV); + relation out_wire_value_proto(Node, OutgoingPort, PV); relation node(C, Node); relation in_wire(C, Node, IncomingPort); relation out_wire(C, Node, OutgoingPort); relation parent_of_node(C, Node, Node); relation io_node(C, Node, Node, IO); - lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); - lattice in_wire_value(C, Node, IncomingPort, PV); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); node(c, n) <-- context(c), for n in c.nodes(); @@ -144,11 +170,11 @@ ascent::ascent! { } -fn propagate_leaf_op( - c: &impl DFContext, +fn propagate_leaf_op( + c: &impl DFContext, n: Node, - ins: &[PV], -) -> Option> { + ins: &[PV], +) -> Option> { match c.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be @@ -192,21 +218,21 @@ fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator(Vec>); +struct ValueRow(Vec); -impl ValueRow { +impl ValueRow { pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) + Self(vec![PV::bottom(); len]) } - pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { + pub fn single_known(len: usize, idx: usize, v: PV) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - pub fn iter(&self) -> impl Iterator> { + pub fn iter(&self) -> impl Iterator { self.0.iter() } @@ -214,7 +240,7 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option> + '_> { + ) -> Option + '_> { self[0] .variant_values(variant, len) .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) @@ -225,13 +251,13 @@ impl ValueRow { // } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator for ValueRow { + fn from_iter>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } @@ -267,30 +293,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PartialValue; +impl IntoIterator for ValueRow { + type Item = PV; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec>: Index, + Vec: Index, { - type Output = > as Index>::Output; + type Output = as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec>: IndexMut, + Vec: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index acd3ac1ca..986fafa76 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -2,16 +2,16 @@ use std::collections::HashMap; use hugr_core::{HugrView, Node, PortIndex, Wire}; -use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; +use super::{datalog::AscentProgram, AbstractValue, DFContext}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire] -pub struct Machine>( - AscentProgram, - Option>>, +pub struct Machine>( + AscentProgram, + Option>, ); /// derived-Default requires the context to be Defaultable, which is unnecessary @@ -21,13 +21,10 @@ impl> Default for Machine { } } -impl> Machine { +impl> Machine { /// Provide initial values for some wires. /// (For example, if some properties of the Hugr's inputs are known.) - pub fn propolutate_out_wires( - &mut self, - wires: impl IntoIterator)>, - ) { + pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { assert!(self.1.is_none()); self.0 .out_wire_value_proto @@ -55,7 +52,7 @@ impl> Machine { } /// Gets the lattice value computed by [Self::run] for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option { self.1.as_ref().unwrap().get(&w).cloned() } @@ -109,7 +106,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - pub fn from_control_value(v: &PartialValue) -> Self { + pub fn from_control_value(v: &impl AbstractValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break && !may_continue { Self::ExactlyZeroContinues diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 4eef5787d..e985eceab 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -8,10 +8,12 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -/// Trait for values which can be deconstructed into Sums (with a single known tag). -/// Required for values used in dataflow analysis. -pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { - /// Deconstruct a value into a single known tag plus a row of values, if it is a [Sum]. +use super::AbstractValue; + +/// Trait for abstract values that may represent sums. +/// Can be wrapped into an [AbstractValue] for analysis via [PartialValue]. +pub trait BaseValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { + /// Deconstruct an abstract value into a single known tag plus a row of values, if it is a [Sum]. /// Note that one can just always return `None` but this will mean the analysis /// is unable to understand untupling, and may give inconsistent results wrt. [Tag] /// operations, etc. @@ -58,7 +60,7 @@ impl PartialSum { } } -impl PartialSum { +impl PartialSum { fn assert_invariants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { @@ -219,15 +221,15 @@ impl Hash for PartialSum { } } -/// Wraps some underlying representation (knowledge) of values into a lattice -/// for use in dataflow analysis, including that an instance may be a [PartialSum] -/// of values of the underlying representation +/// Wraps some underlying representation of values (that `impl`s [BaseValue]) into +/// a lattice for use in dataflow analysis, including that an instance may be +/// a [PartialSum] of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub struct PartialValue(PVEnum); impl PartialValue { /// Allows to read the enum, which guarantees that we never return [PVEnum::Value] - /// for a value whose [AbstractValue::as_sum] is `Some` - any such value will be + /// for a value whose [BaseValue::as_sum] is `Some` - any such value will be /// in the form of a [PVEnum::Sum] instead. pub fn as_enum(&self) -> &PVEnum { &self.0 @@ -247,7 +249,7 @@ pub enum PVEnum { Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) @@ -261,7 +263,7 @@ impl From> for PartialValue { } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { match &self.0 { PVEnum::Sum(ps) => { @@ -274,22 +276,30 @@ impl PartialValue { } } - /// New instance of a sum with a single known tag. - pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { - PartialSum::new_variant(tag, values).into() - } - - /// New instance of unit type (i.e. the only possible value, with no contents) - pub fn new_unit() -> Self { - Self::new_variant(0, []) + /// Extracts a value (in any representation supporting both leaf values and sums) + // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? + pub fn try_into_value + TryFrom>>( + self, + typ: &Type, + ) -> Result>>::Error>> { + match self.0 { + PVEnum::Value(v) => Ok(V2::from(v.clone())), + PVEnum::Sum(ps) => { + let v = ps.try_into_value(typ).map_err(|_| None)?; + V2::try_from(v).map_err(Some) + } + _ => Err(None), + } } +} +impl AbstractValue for PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match &self.0 { PVEnum::Bottom => return None, PVEnum::Value(v) => { @@ -304,7 +314,7 @@ impl PartialValue { } /// Tells us whether this value might be a Sum with the specified `tag` - pub fn supports_tag(&self, tag: usize) -> bool { + fn supports_tag(&self, tag: usize) -> bool { match &self.0 { PVEnum::Bottom => false, PVEnum::Value(v) => { @@ -316,20 +326,8 @@ impl PartialValue { } } - /// Extracts a value (in any representation supporting both leaf values and sums) - // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? - pub fn try_into_value + TryFrom>>( - self, - typ: &Type, - ) -> Result>>::Error>> { - match self.0 { - PVEnum::Value(v) => Ok(V2::from(v.clone())), - PVEnum::Sum(ps) => { - let v = ps.try_into_value(typ).map_err(|_| None)?; - V2::try_from(v).map_err(Some) - } - _ => Err(None), - } + fn new_variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::new_variant(tag, values).into() } } @@ -341,7 +339,7 @@ impl TryFrom> for Value { } } -impl PartialValue +impl PartialValue where Value: From, { @@ -368,7 +366,7 @@ where } } -impl Lattice for PartialValue { +impl Lattice for PartialValue { fn join(mut self, other: Self) -> Self { self.join_mut(other); self @@ -464,7 +462,7 @@ impl Lattice for PartialValue { } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self(PVEnum::Top) } @@ -505,7 +503,10 @@ mod test { }; use super::{PVEnum, PartialSum, PartialValue}; - use crate::const_fold2::value_handle::{ValueHandle, ValueKey}; + use crate::{ + const_fold2::value_handle::{ValueHandle, ValueKey}, + dataflow::AbstractValue, + }; impl Arbitrary for ValueHandle { type Parameters = (); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index cfc9b8975..127dcc373 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,6 +1,6 @@ use crate::{ const_fold2::HugrValueContext, - dataflow::{machine::TailLoopTermination, Machine}, + dataflow::{machine::TailLoopTermination, AbstractValue, Machine}, }; use ascent::{lattice::BoundedLattice, Lattice}; diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs index 2326d78cd..d512912d0 100644 --- a/hugr-passes/src/dataflow/total_context.rs +++ b/hugr-passes/src/dataflow/total_context.rs @@ -3,8 +3,8 @@ use std::hash::Hash; use ascent::lattice::BoundedLattice; use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; -use super::partial_value::{AbstractValue, PartialValue, Sum}; -use super::DFContext; +use super::partial_value::{PartialValue, Sum}; +use super::{BaseValue, DFContext}; /// A simpler interface like [DFContext] but where the context only cares about /// values that are completely known (as `V`s), i.e. not `Bottom`, `Top`, or @@ -22,7 +22,7 @@ pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { ) -> Vec<(OutgoingPort, V)>; } -impl> DFContext for T { +impl> DFContext> for T { fn interpret_leaf_op( &self, node: Node, From 1680829179d4223bc728e43b63b60759852bfdd6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 16:57:28 +0100 Subject: [PATCH 080/203] PartialValue proptests: rm TestSumLeafType, replace ValueHandle with TestValue --- hugr-passes/Cargo.toml | 1 + hugr-passes/src/dataflow/partial_value.rs | 386 ++++++++-------------- 2 files changed, 134 insertions(+), 253 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 4234f7f95..8e68e9ad7 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -31,3 +31,4 @@ extension_inference = ["hugr-core/extension_inference"] rstest = { workspace = true } proptest = { workspace = true } proptest-derive = { workspace = true } +proptest-recurse = { version = "0.5.0" } \ No newline at end of file diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index e985eceab..272e43f2a 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -10,21 +10,26 @@ use std::hash::{Hash, Hasher}; use super::AbstractValue; -/// Trait for abstract values that may represent sums. -/// Can be wrapped into an [AbstractValue] for analysis via [PartialValue]. +/// Trait for abstract values that can be wrapped by [PartialValue] for dataflow analysis. +/// (Allows the values to represent sums, but does not require this). pub trait BaseValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { - /// Deconstruct an abstract value into a single known tag plus a row of values, if it is a [Sum]. - /// Note that one can just always return `None` but this will mean the analysis - /// is unable to understand untupling, and may give inconsistent results wrt. [Tag] - /// operations, etc. + /// If the abstract value represents a [Sum] with a single known tag, deconstruct it + /// into that tag plus the elements. The default just returns `None` which is + /// appropriate if the abstract value never does (in which case [interpret_leaf_op] + /// must produce a [PartialValue::new_variant] for any operation producing + /// a sum). /// /// The signature is this way to optimize query/inspection (is-it-a-sum), /// at the cost of requiring more cloning during actual conversion /// (inside the lazy Iterator, or for the error case, as Self remains) /// + /// [interpret_leaf_op]: super::DFContext::interpret_leaf_op /// [Sum]: TypeEnum::Sum /// [Tag]: hugr_core::ops::Tag - fn as_sum(&self) -> Option<(usize, impl Iterator + '_)>; + fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { + let res: Option<(usize, as IntoIterator>::IntoIter)> = None; + res + } } /// Represents a sum with a single/known tag, abstracted over the representation of the elements. @@ -494,179 +499,51 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; - use itertools::{zip_eq, Either, Itertools as _}; + use itertools::{zip_eq, Itertools as _}; + use prop::sample::subsequence; use proptest::prelude::*; - use hugr_core::{ - std_extensions::arithmetic::int_types::{self, ConstInt, INT_TYPES, LOG_WIDTH_BOUND}, - types::{Type, TypeArg, TypeEnum}, - }; - - use super::{PVEnum, PartialSum, PartialValue}; - use crate::{ - const_fold2::value_handle::{ValueHandle, ValueKey}, - dataflow::AbstractValue, - }; - - impl Arbitrary for ValueHandle { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { - // prop_oneof![ - - // ] - todo!() - } - } + use proptest_recurse::{StrategyExt, StrategySet}; - #[derive(Debug, PartialEq, Eq, Clone)] - enum TestSumLeafType { - Int(Type), - Unit, - } - - impl TestSumLeafType { - fn assert_valid(&self) { - if let Self::Int(t) = self { - if let TypeEnum::Extension(ct) = t.as_type_enum() { - assert_eq!("int", ct.name()); - assert_eq!(&int_types::EXTENSION_ID, ct.extension()); - } else { - panic!("Expected int type, got {:#?}", t); - } - } - } - - fn get_type(&self) -> Type { - match self { - Self::Int(t) => t.clone(), - Self::Unit => Type::UNIT, - } - } - - fn type_check(&self, ps: &PartialSum) -> bool { - match self { - Self::Int(_) => false, - Self::Unit => { - if let Ok((0, v)) = ps.0.iter().exactly_one() { - v.is_empty() - } else { - false - } - } - } - } - - fn partial_value_strategy(self) -> impl Strategy> { - match self { - Self::Int(t) => { - let TypeEnum::Extension(ct) = t.as_type_enum() else { - unreachable!() - }; - // TODO this should be get_log_width, but that's not pub - let TypeArg::BoundedNat { n: lw } = ct.args()[0] else { - panic!() - }; - (0u64..(1 << (2u64.pow(lw as u32) - 1))) - .prop_map(move |x| { - let ki = ConstInt::new_u(lw as u8, x).unwrap(); - let k = ValueKey::try_new(ki.clone()).unwrap(); - ValueHandle::new(k, Arc::new(ki.into())).into() - }) - .boxed() - } - Self::Unit => Just(PartialValue::new_unit()).boxed(), - } - } - } - - impl Arbitrary for TestSumLeafType { - type Parameters = (); - type Strategy = BoxedStrategy; - fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { - let int_strat = - (0..LOG_WIDTH_BOUND).prop_map(|i| Self::Int(INT_TYPES[i as usize].clone())); - prop_oneof![Just(TestSumLeafType::Unit), int_strat].boxed() - } - } + use super::{BaseValue, PVEnum, PartialSum, PartialValue}; + use crate::dataflow::AbstractValue; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { - Branch(usize, Vec>>), - Leaf(TestSumLeafType), + Branch(Vec>>), + /// None => unit, Some => TestValue <= this *usize* + Leaf(Option), } - impl TestSumType { - #[allow(unused)] // ALAN ? - fn leaf(v: Type) -> Self { - TestSumType::Leaf(TestSumLeafType::Int(v)) - } - - fn branch(vs: impl IntoIterator>>) -> Self { - let vec = vs.into_iter().collect_vec(); - let depth: usize = vec - .iter() - .flat_map(|x| x.iter()) - .map(|x| x.depth() + 1) - .max() - .unwrap_or(0); - Self::Branch(depth, vec) - } - - fn depth(&self) -> usize { - match self { - TestSumType::Branch(x, _) => *x, - TestSumType::Leaf(_) => 0, - } - } - - #[allow(unused)] // ALAN ? - fn is_leaf(&self) -> bool { - self.depth() == 0 - } + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct TestValue(usize); - fn assert_valid(&self) { - match self { - TestSumType::Branch(d, sop) => { - assert!(!sop.is_empty(), "No variants"); - for v in sop.iter().flat_map(|x| x.iter()) { - assert!(v.depth() < *d); - v.assert_valid(); - } - } - TestSumType::Leaf(l) => { - l.assert_valid(); - } - } - } + impl BaseValue for TestValue {} - fn select(self) -> impl Strategy>)>> { - match self { - TestSumType::Branch(_, sop) => any::() - .prop_map(move |i| { - let index = i.index(sop.len()); - Either::Right((index, sop[index].clone())) - }) - .boxed(), - TestSumType::Leaf(l) => Just(Either::Left(l)).boxed(), - } - } + #[derive(Clone)] + struct SumTypeParams { + depth: usize, + desired_size: usize, + expected_branch_size: usize, + } - fn get_type(&self) -> Type { - match self { - TestSumType::Branch(_, sop) => Type::new_sum( - sop.iter() - .map(|row| row.iter().map(|x| x.get_type()).collect_vec()), - ), - TestSumType::Leaf(l) => l.get_type(), + impl Default for SumTypeParams { + fn default() -> Self { + Self { + depth: 5, + desired_size: 20, + expected_branch_size: 5, } } + } - fn type_check(&self, pv: &PartialValue) -> bool { + impl TestSumType { + fn type_check(&self, pv: &PartialValue) -> bool { match (self, pv.as_enum()) { (_, PVEnum::Bottom) | (_, PVEnum::Top) => true, - (_, PVEnum::Value(v)) => self.get_type() == v.get_type(), - (TestSumType::Branch(_, sop), PVEnum::Sum(ps)) => { + (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), + (Self::Leaf(Some(max)), PVEnum::Value(TestValue(val))) => val <= max, + (Self::Branch(sop), PVEnum::Sum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { return false; @@ -681,123 +558,126 @@ mod test { } true } - (Self::Leaf(l), PVEnum::Sum(ps)) => l.type_check(ps), - } - } - } - - impl From for TestSumType { - fn from(value: TestSumLeafType) -> Self { - Self::Leaf(value) - } - } - - #[derive(Clone, PartialEq, Eq, Debug)] - struct UnarySumTypeParams { - depth: usize, - branch_width: usize, - } - - impl UnarySumTypeParams { - pub fn descend(mut self, d: usize) -> Self { - assert!(d < self.depth); - self.depth = d; - self - } - } - - impl Default for UnarySumTypeParams { - fn default() -> Self { - Self { - depth: 3, - branch_width: 3, + _ => false, } } } impl Arbitrary for TestSumType { - type Parameters = UnarySumTypeParams; - type Strategy = BoxedStrategy; - fn arbitrary_with( - params @ UnarySumTypeParams { - depth, - branch_width, - }: Self::Parameters, - ) -> Self::Strategy { - if depth == 0 { - any::().prop_map_into().boxed() - } else { - (0..depth) - .prop_flat_map(move |d| { - prop::collection::vec( - prop::collection::vec( - any_with::(params.clone().descend(d)).prop_map_into(), - 0..branch_width, + type Parameters = SumTypeParams; + type Strategy = SBoxedStrategy; + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { + use proptest::collection::vec; + let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); + let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat].sboxed(); + leaf_strat.prop_mutually_recursive( + params.depth as u32, + params.desired_size as u32, + params.expected_branch_size as u32, + set, + move |set| { + let self2 = params.clone(); + vec( + vec( + set.get::(move |set| arb(self2, set)) + .prop_map(Arc::new), + 1..=params.expected_branch_size, ), - 1..=branch_width, + 1..=params.expected_branch_size, ) - .prop_map(TestSumType::branch) - }) - .boxed() - } - } - } - - proptest! { - #[test] - fn unary_sum_type_valid(ust: TestSumType) { - ust.assert_valid(); - } + .prop_map(TestSumType::Branch) + .sboxed() + }, + ) + } + + arb(params, &mut StrategySet::default()) + } + } + + fn partial_sum_strat( + tag: usize, + elems_strat: impl Strategy>>, + ) -> impl Strategy> { + elems_strat.prop_map(move |elems| PartialSum::new_variant(tag, elems)) + } + + // Result gets fed into partial_sum_strat along with tag, so probably inline this into that + fn vec_strat( + elems: &Vec>, + ) -> impl Strategy>> { + elems + .into_iter() + .map(Arc::as_ref) + .map(any_partial_value_of_type) + .collect::>() + } + + fn multi_sum_strat( + variants: &Vec>>, + ) -> impl Strategy> { + let num_tags = variants.len(); + // We have to clone the `variants` here but only as far as the Vec>> + let s = subsequence( + variants.iter().cloned().enumerate().collect::>(), + 1..=num_tags, + ); + let sum_strat: BoxedStrategy>> = s + .prop_flat_map(|selected_tagged_variants| { + selected_tagged_variants + .into_iter() + .map(|(tag, elems)| partial_sum_strat(tag, vec_strat(&elems)).boxed()) + .collect::>() + }) + .boxed(); + sum_strat.prop_map(|psums: Vec>| { + let mut psums = psums.into_iter(); + let first = psums.next().unwrap(); + psums.fold(first, |mut a, b| { + a.try_join_mut(b).unwrap(); + a + }) + }) } fn any_partial_value_of_type( - ust: TestSumType, - ) -> impl Strategy> { - ust.select().prop_flat_map(|x| match x { - Either::Left(l) => l.partial_value_strategy().boxed(), - Either::Right((index, usts)) => { - let pvs = usts - .into_iter() - .map(|x| { - any_partial_value_of_type( - Arc::::try_unwrap(x) - .unwrap_or_else(|x| x.as_ref().clone()), - ) - }) - .collect_vec(); - pvs.prop_map(move |pvs| PartialValue::new_variant(index, pvs)) - .boxed() - } - }) + ust: &TestSumType, + ) -> impl Strategy> { + match ust { + TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), + TestSumType::Leaf(Some(i)) => (0..*i) + .prop_map(TestValue) + .prop_map(PartialValue::from) + .boxed(), + TestSumType::Branch(sop) => multi_sum_strat(sop).prop_map(PartialValue::from).boxed(), + } } fn any_partial_value_with( params: ::Parameters, - ) -> impl Strategy> { - any_with::(params).prop_flat_map(any_partial_value_of_type) + ) -> impl Strategy> { + any_with::(params).prop_flat_map(|t| any_partial_value_of_type(&t)) } - fn any_partial_value() -> impl Strategy> { + fn any_partial_value() -> impl Strategy> { any_partial_value_with(Default::default()) } - fn any_partial_values() -> impl Strategy; N]> - { + fn any_partial_values() -> impl Strategy; N]> { any::().prop_flat_map(|ust| { TryInto::<[_; N]>::try_into( (0..N) - .map(|_| any_partial_value_of_type(ust.clone())) + .map(|_| any_partial_value_of_type(&ust)) .collect_vec(), ) .unwrap() }) } - fn any_typed_partial_value() -> impl Strategy)> - { - any::().prop_flat_map(|t| { - any_partial_value_of_type(t.clone()).prop_map(move |v| (t.clone(), v)) - }) + fn any_typed_partial_value() -> impl Strategy)> { + any::() + .prop_flat_map(|t| any_partial_value_of_type(&t).prop_map(move |v| (t.clone(), v))) } proptest! { From 2a57a1517ada040bee083519d232d0339f6c5579 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 17:06:43 +0100 Subject: [PATCH 081/203] tests: Rename type_check -> check_value --- hugr-passes/src/dataflow/partial_value.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 272e43f2a..905f2547c 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -538,7 +538,7 @@ mod test { } impl TestSumType { - fn type_check(&self, pv: &PartialValue) -> bool { + fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv.as_enum()) { (_, PVEnum::Bottom) | (_, PVEnum::Top) => true, (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), @@ -552,7 +552,7 @@ mod test { if prod.len() != v.len() { return false; } - if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.type_check(rhs)) { + if !zip_eq(prod, v).all(|(lhs, rhs)| lhs.check_value(rhs)) { return false; } } @@ -683,7 +683,7 @@ mod test { proptest! { #[test] fn partial_value_type((tst, pv) in any_typed_partial_value()) { - prop_assert!(tst.type_check(&pv)) + prop_assert!(tst.check_value(&pv)) } // todo: ValidHandle is valid From fcfcb6b8b748a44f668d067b016e0ea227904903 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 17:09:30 +0100 Subject: [PATCH 082/203] tidies --- hugr-passes/src/dataflow/partial_value.rs | 50 ++++++++++------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 905f2547c..a3bea6b19 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -570,17 +570,17 @@ mod test { fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { use proptest::collection::vec; let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); - let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat].sboxed(); + let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; leaf_strat.prop_mutually_recursive( params.depth as u32, params.desired_size as u32, params.expected_branch_size as u32, set, move |set| { - let self2 = params.clone(); + let params2 = params.clone(); vec( vec( - set.get::(move |set| arb(self2, set)) + set.get::(move |set| arb(params2, set)) .prop_map(Arc::new), 1..=params.expected_branch_size, ), @@ -596,42 +596,34 @@ mod test { } } - fn partial_sum_strat( + fn single_sum_strat( tag: usize, - elems_strat: impl Strategy>>, + elems: Vec>, ) -> impl Strategy> { - elems_strat.prop_map(move |elems| PartialSum::new_variant(tag, elems)) - } - - // Result gets fed into partial_sum_strat along with tag, so probably inline this into that - fn vec_strat( - elems: &Vec>, - ) -> impl Strategy>> { elems - .into_iter() + .iter() .map(Arc::as_ref) .map(any_partial_value_of_type) .collect::>() + .prop_map(move |elems| PartialSum::new_variant(tag, elems)) } - fn multi_sum_strat( + fn partial_sum_strat( variants: &Vec>>, ) -> impl Strategy> { - let num_tags = variants.len(); // We have to clone the `variants` here but only as far as the Vec>> - let s = subsequence( - variants.iter().cloned().enumerate().collect::>(), - 1..=num_tags, - ); - let sum_strat: BoxedStrategy>> = s - .prop_flat_map(|selected_tagged_variants| { - selected_tagged_variants - .into_iter() - .map(|(tag, elems)| partial_sum_strat(tag, vec_strat(&elems)).boxed()) - .collect::>() - }) - .boxed(); - sum_strat.prop_map(|psums: Vec>| { + let tagged_variants = variants.iter().cloned().enumerate().collect::>(); + // The type annotation here (and the .boxed() enabling it) are just for documentation + let sum_variants_strat: BoxedStrategy>> = + subsequence(tagged_variants, 1..=variants.len()) + .prop_flat_map(|selected_variants| { + selected_variants + .into_iter() + .map(|(tag, elems)| single_sum_strat(tag, elems)) + .collect::>() + }) + .boxed(); + sum_variants_strat.prop_map(|psums: Vec>| { let mut psums = psums.into_iter(); let first = psums.next().unwrap(); psums.fold(first, |mut a, b| { @@ -650,7 +642,7 @@ mod test { .prop_map(TestValue) .prop_map(PartialValue::from) .boxed(), - TestSumType::Branch(sop) => multi_sum_strat(sop).prop_map(PartialValue::from).boxed(), + TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), } } From dcaa928517c5914518c24b26edaea784b8b46c0a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 17:56:37 +0100 Subject: [PATCH 083/203] Add a couple more proptests, and a TEMPORARY FIX for a BUG pending better answer --- hugr-passes/src/dataflow/partial_value.rs | 28 ++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index a3bea6b19..dcce791db 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -451,7 +451,16 @@ impl Lattice for PartialValue { _ => unreachable!(), }; match ps1.try_meet_mut(ps2) { - Ok(ch) => ch, + Ok(ch) => { + // ALAN the 'invariant' that a PartialSum always has >=1 tag can be broken here. + // Fix this by rewriting to Bottom, but should probably be refactored - at the + // least, it seems dangerous to expose a potentially-invalidating try_meet_mut. + if ps1.0.is_empty() { + assert!(ch); + self.0 = PVEnum::Bottom + } + ch + } Err(_) => { self.0 = PVEnum::Bottom; true @@ -712,10 +721,27 @@ mod test { let meet = v1.clone().meet(v2.clone()); prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet); prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet); + prop_assert!(meet == v2.clone().meet(v1.clone()), "meet not symmetric"); + prop_assert!(meet == meet.clone().meet(v1.clone()), "repeated meet should be a no-op"); + prop_assert!(meet == meet.clone().meet(v2.clone()), "repeated meet should be a no-op"); let join = v1.clone().join(v2.clone()); prop_assert!(join >= v1, "join not >=: {:#?}", &join); prop_assert!(join >= v2, "join not >=: {:#?}", &join); + prop_assert!(join == v2.clone().join(v1.clone()), "join not symmetric"); + prop_assert!(join == join.clone().join(v1.clone()), "repeated join should be a no-op"); + prop_assert!(join == join.clone().join(v2.clone()), "repeated join should be a no-op"); + } + + #[test] + fn lattice_associative([v1, v2, v3] in any_partial_values()) { + let a = v1.clone().meet(v2.clone()).meet(v3.clone()); + let b = v1.clone().meet(v2.clone().meet(v3.clone())); + prop_assert!(a==b, "meet not associative"); + + let a = v1.clone().join(v2.clone()).join(v3.clone()); + let b = v1.clone().join(v2.clone().join(v3.clone())); + prop_assert!(a==b, "join not associative") } } } From bcacbcca2d23991771bcc49d711161943c547be3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 20:07:56 +0100 Subject: [PATCH 084/203] Remove redundant test --- hugr-passes/src/dataflow/test.rs | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 127dcc373..6f7b2ef52 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -64,27 +64,6 @@ fn test_unpack_tuple_const() { assert_eq!(o2_r, Value::true_val()); } -#[test] -fn test_unpack_const() { - let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); - let v1 = builder.add_load_value(Value::tuple([Value::true_val()])); - let [o] = builder - .add_dataflow_op(UnpackTuple::new(type_row![BOOL_T]), [v1]) - .unwrap() - .outputs_arr(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - - let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); - - let o_r = machine - .read_out_wire(o) - .unwrap() - .try_into_wire_value(&hugr, o) - .unwrap(); - assert_eq!(o_r, Value::true_val()); -} - #[test] fn test_tail_loop_never_iterates() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); From 5192ed8ab8ed1b66d0c1ae9afd8da408f0361abd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 20:48:57 +0100 Subject: [PATCH 085/203] Refactor TailLoopTermination::from_control_value --- hugr-passes/src/dataflow/machine.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 986fafa76..5a86dd98e 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -108,10 +108,12 @@ pub enum TailLoopTermination { impl TailLoopTermination { pub fn from_control_value(v: &impl AbstractValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); - if may_break && !may_continue { - Self::ExactlyZeroContinues - } else if may_break && may_continue { - Self::Top + if may_break { + if may_continue { + Self::Top + } else { + Self::ExactlyZeroContinues + } } else { Self::Bottom } From e21bbd759ea13176764200086863dfc5f78075e2 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 20:57:40 +0100 Subject: [PATCH 086/203] pub TailLoopTermination, rename members, doc --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 22 ++++++++++++++++------ hugr-passes/src/dataflow/test.rs | 6 +++--- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index a66edde03..6f437f882 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -5,7 +5,7 @@ mod datalog; pub use datalog::{AbstractValue, DFContext}; mod machine; -pub use machine::Machine; +pub use machine::{Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum}; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 5a86dd98e..aa9408cdb 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -98,24 +98,34 @@ impl> Machine { } } +/// Tells whether a loop iterates (never, always, sometimes) #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub enum TailLoopTermination { - Bottom, - ExactlyZeroContinues, - Top, + /// The loop never exits (is an infinite loop); no value is ever + /// returned out of the loop. (aka, Bottom.) + // TODO what about a loop that never exits OR continues because of a nested infinite loop? + NeverBreaks, + /// The loop never iterates (so is equivalent to a [DFG](hugr_core::ops::DFG), + /// modulo untupling of the control value) + NeverContinues, + /// The loop might iterate and/or exit. (aka, Top) + BreaksAndContinues, } impl TailLoopTermination { + /// Extracts the relevant information from a value that should represent + /// the value provided to the [Output](hugr_core::ops::Output) node child + /// of the [TailLoop](hugr_core::ops::TailLoop) pub fn from_control_value(v: &impl AbstractValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { - Self::Top + Self::BreaksAndContinues } else { - Self::ExactlyZeroContinues + Self::NeverContinues } } else { - Self::Bottom + Self::NeverBreaks } } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 6f7b2ef52..e9aeb4c57 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -95,7 +95,7 @@ fn test_tail_loop_never_iterates() { .unwrap(); assert_eq!(o_r, r_v); assert_eq!( - Some(TailLoopTermination::ExactlyZeroContinues), + Some(TailLoopTermination::NeverContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) ) } @@ -127,7 +127,7 @@ fn test_tail_loop_always_iterates() { let o_r2 = machine.read_out_wire(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( - Some(TailLoopTermination::Bottom), + Some(TailLoopTermination::NeverBreaks), machine.tail_loop_terminates(&hugr, tail_loop.node()) ); assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); @@ -181,7 +181,7 @@ fn test_tail_loop_iterates_twice() { let _ = machine.read_out_wire(o_w2).unwrap(); // assert_eq!(o_r2, Value::true_val()); assert_eq!( - Some(TailLoopTermination::Top), + Some(TailLoopTermination::BreaksAndContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) ); assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); From 22e0192dbbf21cd8156871cecc9c791151b76cae Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 21:53:56 +0100 Subject: [PATCH 087/203] Test tidies (and some ALAN wtf? comments) --- hugr-passes/src/dataflow/test.rs | 45 +++++++++++--------------------- 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e9aeb4c57..408c9949d 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,7 +1,4 @@ -use crate::{ - const_fold2::HugrValueContext, - dataflow::{machine::TailLoopTermination, AbstractValue, Machine}, -}; +use crate::const_fold2::HugrValueContext; use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::{ @@ -10,13 +7,13 @@ use hugr_core::{ prelude::{UnpackTuple, BOOL_T}, ExtensionSet, EMPTY_REG, }, - ops::{handle::NodeHandle, OpTrait, Value}, + ops::{handle::NodeHandle, DataflowOpTrait, Value}, type_row, - types::{Signature, SumType, Type, TypeRow}, + types::{Signature, SumType, Type}, HugrView, }; -use super::partial_value::PartialValue; +use super::{AbstractValue, Machine, PartialValue, TailLoopTermination}; #[test] fn test_make_tuple() { @@ -85,8 +82,6 @@ fn test_tail_loop_never_iterates() { let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - // dbg!(&machine.tail_loop_io_node); - // dbg!(&machine.out_wire_value); let o_r = machine .read_out_wire(tl_o) @@ -152,34 +147,26 @@ fn test_tail_loop_iterates_twice() { ) .unwrap(); assert_eq!( - tlb.loop_signature().unwrap().dataflow_signature().unwrap(), + tlb.loop_signature().unwrap().signature(), Signature::new_endo(type_row![BOOL_T, BOOL_T]) ); let [in_w1, in_w2] = tlb.input_wires_arr(); let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); - // let optype = builder.hugr().get_optype(tail_loop.node()); - // for p in builder.hugr().node_outputs(tail_loop.node()) { - // use hugr_core::ops::OpType; - // println!("{:?}, {:?}", p, optype.port_kind(p)); - - // } - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); // TODO once we can do conditionals put these wires inside `just_outputs` and - // we should be able to propagate their values + // we should be able to propagate their values...ALAN wtf? loop control type IS bool ATM let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); machine.run(HugrValueContext::new(&hugr)); - // dbg!(&machine.tail_loop_io_node); - // dbg!(&machine.out_wire_value); - - // TODO these hould be the propagated values for now they will bt join(true,false) - let _ = machine.read_out_wire(o_w1).unwrap(); - // assert_eq!(o_r1, PartialValue::top()); - let _ = machine.read_out_wire(o_w2).unwrap(); - // assert_eq!(o_r2, Value::true_val()); + + let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); + // TODO these should be the propagated values for now they will be join(true,false) - ALAN wtf? + let o_r1 = machine.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, true_or_false); + let o_r2 = machine.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, true_or_false); assert_eq!( Some(TailLoopTermination::BreaksAndContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) @@ -191,19 +178,17 @@ fn test_tail_loop_iterates_twice() { fn conditional() { let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; let cond_t = Type::new_sum(variants.clone()); - let mut builder = - DFGBuilder::new(Signature::new(Into::::into(cond_t), type_row![])).unwrap(); + let mut builder = DFGBuilder::new(Signature::new(cond_t, type_row![])).unwrap(); let [arg_w] = builder.input_wires_arr(); let true_w = builder.add_load_value(Value::true_val()); let false_w = builder.add_load_value(Value::false_val()); let mut cond_builder = builder - .conditional_builder_exts( + .conditional_builder( (variants, arg_w), [(BOOL_T, true_w)], type_row!(BOOL_T, BOOL_T), - ExtensionSet::default(), ) .unwrap(); // will be unreachable From cae5e4fe44e5a5ba68a63fce9814a87e15a6e74d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:07:39 +0100 Subject: [PATCH 088/203] Use Tag --- hugr-passes/src/dataflow/test.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 408c9949d..4cbc3b893 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -7,7 +7,7 @@ use hugr_core::{ prelude::{UnpackTuple, BOOL_T}, ExtensionSet, EMPTY_REG, }, - ops::{handle::NodeHandle, DataflowOpTrait, Value}, + ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, type_row, types::{Signature, SumType, Type}, HugrView, @@ -65,18 +65,14 @@ fn test_unpack_tuple_const() { fn test_tail_loop_never_iterates() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let r_v = Value::unit_sum(3, 6).unwrap(); - let r_w = builder.add_load_value( - Value::sum( - 1, - [r_v.clone()], - SumType::new([type_row![], r_v.get_type().into()]), - ) - .unwrap(), - ); + let r_w = builder.add_load_value(r_v.clone()); + let tag = Tag::new(1, vec![type_row![], r_v.get_type().into()]); + let tagged = builder.add_dataflow_op(tag, [r_w]).unwrap(); + let tlb = builder .tail_loop_builder([], [], vec![r_v.get_type()].into()) .unwrap(); - let tail_loop = tlb.finish_with_outputs(r_w, []).unwrap(); + let tail_loop = tlb.finish_with_outputs(tagged.out_wire(0), []).unwrap(); let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); From 3014827c94002d9be8c859d628e19502016c0b1f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:08:38 +0100 Subject: [PATCH 089/203] Add TestContext (no interpret_leaf_op), propolutate, avoid HugrValueContext --- hugr-passes/src/dataflow/test.rs | 94 +++++++++++++++++++++++++++++--- 1 file changed, 86 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 4cbc3b893..4f42b4a4e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,4 +1,5 @@ -use crate::const_fold2::HugrValueContext; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::{ @@ -13,7 +14,80 @@ use hugr_core::{ HugrView, }; -use super::{AbstractValue, Machine, PartialValue, TailLoopTermination}; +use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; + +// ------- Minimal implementation of DFContext and BaseValue ------- +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum Void {} + +impl BaseValue for Void {} + +struct TestContext(Arc); + +// Deriving Clone requires H:HugrView to implement Clone, +// but we don't need that as we only clone the Arc. +impl Clone for TestContext { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl std::ops::Deref for TestContext { + type Target = hugr_core::Hugr; + + fn deref(&self) -> &Self::Target { + self.0.base_hugr() + } +} + +// Any value used in an Ascent program must be hashable. +// However, there should only be one DFContext, so its hash is immaterial. +impl Hash for TestContext { + fn hash(&self, _state: &mut I) {} +} + +impl PartialEq for TestContext { + fn eq(&self, other: &Self) -> bool { + // Any AscentProgram should have only one DFContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); + true + } +} + +impl Eq for TestContext {} + +impl PartialOrd for TestContext { + fn partial_cmp(&self, other: &Self) -> Option { + // Any AscentProgram should have only one DFContext (maybe cloned) + assert!(Arc::ptr_eq(&self.0, &other.0)); + Some(std::cmp::Ordering::Equal) + } +} + +impl DFContext> for TestContext { + fn interpret_leaf_op( + &self, + _node: hugr_core::Node, + _ins: &[PartialValue], + ) -> Option>> { + None + } +} + +// This allows testing creation of tuple/sum Values (only) +impl From for Value { + fn from(v: Void) -> Self { + match v {} + } +} + +fn pv_false() -> PartialValue { + PartialValue::new_variant(0, []) +} + +fn pv_true() -> PartialValue { + PartialValue::new_variant(1, []) +} #[test] fn test_make_tuple() { @@ -24,7 +98,8 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.propolutate_out_wires([(v1, pv_false()), (v2, pv_true())]); + machine.run(TestContext(Arc::new(&hugr))); let x = machine .read_out_wire(v3) @@ -45,7 +120,8 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.propolutate_out_wires([(v, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); + machine.run(TestContext(Arc::new(&hugr))); let o1_r = machine .read_out_wire(o1) @@ -77,7 +153,8 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.propolutate_out_wires([(r_w, PartialValue::new_variant(3, []))]); + machine.run(TestContext(Arc::new(&hugr))); let o_r = machine .read_out_wire(tl_o) @@ -111,7 +188,7 @@ fn test_tail_loop_always_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.run(TestContext(Arc::new(&hugr))); let o_r1 = machine.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -155,7 +232,8 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.run(HugrValueContext::new(&hugr)); + machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); + machine.run(TestContext(Arc::new(&hugr))); let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); // TODO these should be the propagated values for now they will be join(true,false) - ALAN wtf? @@ -211,7 +289,7 @@ fn conditional() { [PartialValue::new_variant(0, [])], )); machine.propolutate_out_wires([(arg_w, arg_pv)]); - machine.run(HugrValueContext::new(&hugr)); + machine.run(TestContext(Arc::new(&hugr))); let cond_r1 = machine .read_out_wire(cond_o1) From 64b9bb753b4904237569f2ebc39f80dab7c97dca Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:33:07 +0100 Subject: [PATCH 090/203] Avoid propolutate by interpreting LoadConstant (only) --- hugr-passes/src/dataflow/test.rs | 34 +++++++++++++++++++------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 4f42b4a4e..82a8f6f89 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -13,6 +13,7 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; +use itertools::Itertools; use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -67,10 +68,27 @@ impl PartialOrd for TestContext { impl DFContext> for TestContext { fn interpret_leaf_op( &self, - _node: hugr_core::Node, + node: hugr_core::Node, _ins: &[PartialValue], ) -> Option>> { - None + // Interpret LoadConstants of sums of sums (without leaves), only + fn try_into_pv(v: &Value) -> Option> { + let Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) = v else { + return None; + }; + Some(PartialValue::new_variant( + *tag, + values + .iter() + .map(try_into_pv) + .collect::>>>()?, + )) + } + self.0.get_optype(node).as_load_constant().and_then(|_| { + let const_node = self.0.input_neighbours(node).exactly_one().ok().unwrap(); + let v = self.0.get_optype(const_node).as_const().unwrap().value(); + try_into_pv(v).map(|v| vec![v]) + }) } } @@ -81,14 +99,6 @@ impl From for Value { } } -fn pv_false() -> PartialValue { - PartialValue::new_variant(0, []) -} - -fn pv_true() -> PartialValue { - PartialValue::new_variant(1, []) -} - #[test] fn test_make_tuple() { let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); @@ -98,7 +108,6 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(v1, pv_false()), (v2, pv_true())]); machine.run(TestContext(Arc::new(&hugr))); let x = machine @@ -120,7 +129,6 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(v, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o1_r = machine @@ -153,7 +161,6 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(r_w, PartialValue::new_variant(3, []))]); machine.run(TestContext(Arc::new(&hugr))); let o_r = machine @@ -232,7 +239,6 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); From 8bc5e122e6e8b88ff74e5e8775cbb9ba574da221 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:36:34 +0100 Subject: [PATCH 091/203] Revert "Avoid propolutate by interpreting LoadConstant (only)" This reverts commit 64b9bb753b4904237569f2ebc39f80dab7c97dca. --- hugr-passes/src/dataflow/test.rs | 34 +++++++++++++------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 82a8f6f89..4f42b4a4e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -13,7 +13,6 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; -use itertools::Itertools; use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -68,27 +67,10 @@ impl PartialOrd for TestContext { impl DFContext> for TestContext { fn interpret_leaf_op( &self, - node: hugr_core::Node, + _node: hugr_core::Node, _ins: &[PartialValue], ) -> Option>> { - // Interpret LoadConstants of sums of sums (without leaves), only - fn try_into_pv(v: &Value) -> Option> { - let Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) = v else { - return None; - }; - Some(PartialValue::new_variant( - *tag, - values - .iter() - .map(try_into_pv) - .collect::>>>()?, - )) - } - self.0.get_optype(node).as_load_constant().and_then(|_| { - let const_node = self.0.input_neighbours(node).exactly_one().ok().unwrap(); - let v = self.0.get_optype(const_node).as_const().unwrap().value(); - try_into_pv(v).map(|v| vec![v]) - }) + None } } @@ -99,6 +81,14 @@ impl From for Value { } } +fn pv_false() -> PartialValue { + PartialValue::new_variant(0, []) +} + +fn pv_true() -> PartialValue { + PartialValue::new_variant(1, []) +} + #[test] fn test_make_tuple() { let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); @@ -108,6 +98,7 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); + machine.propolutate_out_wires([(v1, pv_false()), (v2, pv_true())]); machine.run(TestContext(Arc::new(&hugr))); let x = machine @@ -129,6 +120,7 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); + machine.propolutate_out_wires([(v, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o1_r = machine @@ -161,6 +153,7 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); + machine.propolutate_out_wires([(r_w, PartialValue::new_variant(3, []))]); machine.run(TestContext(Arc::new(&hugr))); let o_r = machine @@ -239,6 +232,7 @@ fn test_tail_loop_iterates_twice() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); + machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); From a3a6213f40d4ac6b22bb7b70f66f00778fe4851b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:20:48 +0100 Subject: [PATCH 092/203] tiny const_fold2 doc tweaks --- hugr-passes/src/const_fold2/context.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs index 32fc57765..6338629df 100644 --- a/hugr-passes/src/const_fold2/context.rs +++ b/hugr-passes/src/const_fold2/context.rs @@ -8,15 +8,13 @@ use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; use super::value_handle::{ValueHandle, ValueKey}; use crate::dataflow::TotalContext; -/// A context ([DFContext]) for doing analysis with [ValueHandle]s. +/// A [context](crate::dataflow::DFContext) for doing analysis with [ValueHandle]s. /// Interprets [LoadConstant](OpType::LoadConstant) nodes, /// and [ExtensionOp](OpType::ExtensionOp) nodes where the extension does /// (using [Value]s for extension-op inputs). /// /// Just stores a Hugr (actually any [HugrView]), /// (there is )no state for operation-interpretation. -/// -/// [DFContext]: crate::dataflow::DFContext #[derive(Debug)] pub struct HugrValueContext(Arc); From a96ab20ac7f24001c81ccbd846b94bbc86efc722 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:22:04 +0100 Subject: [PATCH 093/203] (TEMP) remove const_fold2 module --- hugr-passes/src/const_fold2.rs | 8 - hugr-passes/src/const_fold2/context.rs | 101 --------- hugr-passes/src/const_fold2/value_handle.rs | 222 -------------------- hugr-passes/src/lib.rs | 1 - 4 files changed, 332 deletions(-) delete mode 100644 hugr-passes/src/const_fold2.rs delete mode 100644 hugr-passes/src/const_fold2/context.rs delete mode 100644 hugr-passes/src/const_fold2/value_handle.rs diff --git a/hugr-passes/src/const_fold2.rs b/hugr-passes/src/const_fold2.rs deleted file mode 100644 index 58f285d43..000000000 --- a/hugr-passes/src/const_fold2.rs +++ /dev/null @@ -1,8 +0,0 @@ -#![warn(missing_docs)] -//! An (example) use of the [super::dataflow](dataflow-analysis framework) -//! to perform constant-folding. - -// These are pub because this "example" is used for testing the framework. -mod context; -pub mod value_handle; -pub use context::HugrValueContext; diff --git a/hugr-passes/src/const_fold2/context.rs b/hugr-passes/src/const_fold2/context.rs deleted file mode 100644 index 6338629df..000000000 --- a/hugr-passes/src/const_fold2/context.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::hash::{Hash, Hasher}; -use std::ops::Deref; -use std::sync::Arc; - -use hugr_core::ops::{OpType, Value}; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort}; - -use super::value_handle::{ValueHandle, ValueKey}; -use crate::dataflow::TotalContext; - -/// A [context](crate::dataflow::DFContext) for doing analysis with [ValueHandle]s. -/// Interprets [LoadConstant](OpType::LoadConstant) nodes, -/// and [ExtensionOp](OpType::ExtensionOp) nodes where the extension does -/// (using [Value]s for extension-op inputs). -/// -/// Just stores a Hugr (actually any [HugrView]), -/// (there is )no state for operation-interpretation. -#[derive(Debug)] -pub struct HugrValueContext(Arc); - -impl HugrValueContext { - /// Creates a new instance, given ownership of the [HugrView] - pub fn new(hugr: H) -> Self { - Self(Arc::new(hugr)) - } -} - -// Deriving Clone requires H:HugrView to implement Clone, -// but we don't need that as we only clone the Arc. -impl Clone for HugrValueContext { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -// Any value used in an Ascent program must be hashable. -// However, there should only be one DFContext, so its hash is immaterial. -impl Hash for HugrValueContext { - fn hash(&self, _state: &mut I) {} -} - -impl PartialEq for HugrValueContext { - fn eq(&self, other: &Self) -> bool { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - true - } -} - -impl Eq for HugrValueContext {} - -impl PartialOrd for HugrValueContext { - fn partial_cmp(&self, other: &Self) -> Option { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - Some(std::cmp::Ordering::Equal) - } -} - -impl Deref for HugrValueContext { - type Target = Hugr; - - fn deref(&self) -> &Self::Target { - self.0.base_hugr() - } -} - -impl TotalContext for HugrValueContext { - type InterpretableVal = Value; - - fn interpret_leaf_op( - &self, - n: Node, - ins: &[(IncomingPort, Value)], - ) -> Vec<(OutgoingPort, ValueHandle)> { - match self.0.get_optype(n) { - OpType::LoadConstant(load_op) => { - assert!(ins.is_empty()); // static edge, so need to find constant - let const_node = self - .0 - .single_linked_output(n, load_op.constant_port()) - .unwrap() - .0; - let const_op = self.0.get_optype(const_node).as_const().unwrap(); - vec![( - OutgoingPort::from(0), - ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())), - )] - } - OpType::ExtensionOp(op) => { - let ins = ins.iter().map(|(p, v)| (*p, v.clone())).collect::>(); - op.constant_fold(&ins).map_or(Vec::new(), |outs| { - outs.into_iter() - .map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))) - .collect() - }) - } - _ => vec![], - } - } -} diff --git a/hugr-passes/src/const_fold2/value_handle.rs b/hugr-passes/src/const_fold2/value_handle.rs deleted file mode 100644 index 59a08b50a..000000000 --- a/hugr-passes/src/const_fold2/value_handle.rs +++ /dev/null @@ -1,222 +0,0 @@ -use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - -use hugr_core::ops::constant::{CustomConst, Sum}; -use hugr_core::ops::Value; -use hugr_core::types::Type; -use hugr_core::Node; - -use crate::dataflow::BaseValue; - -#[derive(Clone, Debug)] -pub struct HashedConst { - hash: u64, - val: Arc, -} - -impl PartialEq for HashedConst { - fn eq(&self, other: &Self) -> bool { - self.hash == other.hash && self.val.equal_consts(other.val.as_ref()) - } -} - -impl Eq for HashedConst {} - -impl Hash for HashedConst { - fn hash(&self, state: &mut H) { - state.write_u64(self.hash); - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum ValueKey { - Field(usize, Box), - Const(HashedConst), - Node(Node), -} - -impl From for ValueKey { - fn from(n: Node) -> Self { - Self::Node(n) - } -} - -impl From for ValueKey { - fn from(value: HashedConst) -> Self { - Self::Const(value) - } -} - -impl ValueKey { - pub fn new(n: Node, k: impl CustomConst) -> Self { - Self::try_new(k).unwrap_or(Self::Node(n)) - } - - pub fn try_new(cst: impl CustomConst) -> Option { - let mut hasher = DefaultHasher::new(); - cst.try_hash(&mut hasher).then(|| { - Self::Const(HashedConst { - hash: hasher.finish(), - val: Arc::new(cst), - }) - }) - } - - fn field(self, i: usize) -> Self { - Self::Field(i, Box::new(self)) - } -} - -#[derive(Clone, Debug)] -pub struct ValueHandle(ValueKey, Arc); - -impl ValueHandle { - pub fn new(key: ValueKey, value: Arc) -> Self { - Self(key, value) - } - - pub fn value(&self) -> &Value { - self.1.as_ref() - } - - pub fn get_type(&self) -> Type { - self.1.get_type() - } -} - -impl BaseValue for ValueHandle { - fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { - match self.value() { - Value::Sum(Sum { tag, values, .. }) => Some(( - *tag, - values - .iter() - .enumerate() - .map(|(i, v)| Self(self.0.clone().field(i), Arc::new(v.clone()))), - )), - _ => None, - } - } -} - -impl PartialEq for ValueHandle { - fn eq(&self, other: &Self) -> bool { - // If the keys are equal, we return true since the values must have the - // same provenance, and so be equal. If the keys are different but the - // values are equal, we could return true if we didn't impl Eq, but - // since we do impl Eq, the Hash contract prohibits us from having equal - // values with different hashes. - let r = self.0 == other.0; - if r { - debug_assert_eq!(self.get_type(), other.get_type()); - } - r - } -} - -impl Eq for ValueHandle {} - -impl Hash for ValueHandle { - fn hash(&self, state: &mut I) { - self.0.hash(state); - } -} - -impl From for Value { - fn from(value: ValueHandle) -> Self { - (*value.1).clone() - } -} - -#[cfg(test)] -mod test { - use hugr_core::{ - extension::prelude::ConstString, - ops::constant::CustomConst as _, - std_extensions::{ - arithmetic::{ - float_types::{ConstF64, FLOAT64_TYPE}, - int_types::{ConstInt, INT_TYPES}, - }, - collections::ListValue, - }, - types::SumType, - }; - - use super::*; - - #[test] - fn value_key_eq() { - let n = Node::from(portgraph::NodeIndex::new(0)); - let n2: Node = portgraph::NodeIndex::new(1).into(); - let k1 = ValueKey::new(n, ConstString::new("foo".to_string())); - let k2 = ValueKey::new(n2, ConstString::new("foo".to_string())); - let k3 = ValueKey::new(n, ConstString::new("bar".to_string())); - - assert_eq!(k1, k2); // Node ignored - assert_ne!(k1, k3); - - assert_eq!(ValueKey::from(n), ValueKey::from(n)); - let f = ConstF64::new(std::f64::consts::PI); - assert_eq!(ValueKey::new(n, f.clone()), ValueKey::from(n)); - - assert_ne!(ValueKey::new(n, f.clone()), ValueKey::new(n2, f)); // Node taken into account - let k4 = ValueKey::from(n); - let k5 = ValueKey::from(n); - let k6: ValueKey = ValueKey::from(n2); - - assert_eq!(&k4, &k5); - assert_ne!(&k4, &k6); - - let k7 = k5.clone().field(3); - let k4 = k4.field(3); - - assert_eq!(&k4, &k7); - - let k5 = k5.field(2); - - assert_ne!(&k5, &k7); - } - - #[test] - fn value_key_list() { - let v1 = ConstInt::new_u(3, 3).unwrap(); - let v2 = ConstInt::new_u(4, 3).unwrap(); - let v3 = ConstF64::new(std::f64::consts::PI); - - let n = Node::from(portgraph::NodeIndex::new(0)); - let n2: Node = portgraph::NodeIndex::new(1).into(); - - let lst = ListValue::new(INT_TYPES[0].clone(), [v1.into(), v2.into()]); - assert_eq!(ValueKey::new(n, lst.clone()), ValueKey::new(n2, lst)); - - let lst = ListValue::new(FLOAT64_TYPE, [v3.into()]); - assert_ne!( - ValueKey::new(n, lst.clone()), - ValueKey::new(n2, lst.clone()) - ); - } - - #[test] - fn value_handle_eq() { - let k_i = ConstInt::new_u(4, 2).unwrap(); - let subject_val = Arc::new( - Value::sum( - 0, - [k_i.clone().into()], - SumType::new([vec![k_i.get_type()], vec![]]), - ) - .unwrap(), - ); - - let k1 = ValueKey::try_new(ConstString::new("foo".to_string())).unwrap(); - let v1 = ValueHandle::new(k1.clone(), subject_val.clone()); - let v2 = ValueHandle::new(k1.clone(), Value::extension(k_i).into()); - - let fields = v1.as_sum().unwrap().1.collect::>(); - // we do not compare the value, just the key - assert_ne!(fields[0], v2); - assert_eq!(fields[0].value(), v2.value()); - } -} diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 0b73fcbb0..06781f7c5 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,7 +1,6 @@ //! Compilation passes acting on the HUGR program representation. pub mod const_fold; -pub mod const_fold2; pub mod dataflow; pub mod force_order; mod half_node; From 777694ca8db5568d6da13d191591e2b4806c6508 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:55:27 +0100 Subject: [PATCH 094/203] (TEMP) Rm total_context --- hugr-passes/src/dataflow.rs | 3 -- hugr-passes/src/dataflow/total_context.rs | 54 ----------------------- 2 files changed, 57 deletions(-) delete mode 100644 hugr-passes/src/dataflow/total_context.rs diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 6f437f882..6085c3e92 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -10,8 +10,5 @@ pub use machine::{Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum}; -mod total_context; -pub use total_context::TotalContext; - #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/total_context.rs b/hugr-passes/src/dataflow/total_context.rs deleted file mode 100644 index d512912d0..000000000 --- a/hugr-passes/src/dataflow/total_context.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::hash::Hash; - -use ascent::lattice::BoundedLattice; -use hugr_core::{ops::OpTrait, Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex}; - -use super::partial_value::{PartialValue, Sum}; -use super::{BaseValue, DFContext}; - -/// A simpler interface like [DFContext] but where the context only cares about -/// values that are completely known (as `V`s), i.e. not `Bottom`, `Top`, or -/// Sums of potentially multiple variants. -pub trait TotalContext: Clone + Eq + Hash + std::ops::Deref { - /// The representation of values on which [Self::interpret_leaf_op] operates - type InterpretableVal: From + TryFrom>; - /// Interpret a leaf op. - /// `ins` gives the input ports for which we know (interpretable) values, and will be non-empty. - /// Returns a list of output ports for which we know (abstract) values (may be empty). - fn interpret_leaf_op( - &self, - node: Node, - ins: &[(IncomingPort, Self::InterpretableVal)], - ) -> Vec<(OutgoingPort, V)>; -} - -impl> DFContext> for T { - fn interpret_leaf_op( - &self, - node: Node, - ins: &[PartialValue], - ) -> Option>> { - let op = self.get_optype(node); - let sig = op.dataflow_signature()?; - let known_ins = sig - .input_types() - .iter() - .enumerate() - .zip(ins.iter()) - .filter_map(|((i, ty), pv)| { - pv.clone() - .try_into_value::<>::InterpretableVal>(ty) - .ok() - .map(|v| (IncomingPort::from(i), v)) - }) - .collect::>(); - let known_outs = self.interpret_leaf_op(node, &known_ins); - (!known_outs.is_empty()).then(|| { - let mut res = vec![PartialValue::bottom(); sig.output_count()]; - for (p, v) in known_outs { - res[p.index()] = v.into(); - } - res - }) - } -} From 5a16e6bd740cff1d55fd598295b725629187745a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 8 Oct 2024 23:59:30 +0100 Subject: [PATCH 095/203] clippy --- hugr-passes/src/dataflow/partial_value.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index dcce791db..de1cd55d9 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -618,7 +618,7 @@ mod test { } fn partial_sum_strat( - variants: &Vec>>, + variants: &[Vec>], ) -> impl Strategy> { // We have to clone the `variants` here but only as far as the Vec>> let tagged_variants = variants.iter().cloned().enumerate().collect::>(); From 4f311782a4482639cd7afc451174bed8842c847d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 10:11:10 +0100 Subject: [PATCH 096/203] Better fix for PartialSum::try_meet_mut --- hugr-passes/src/dataflow/partial_value.rs | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index de1cd55d9..0fe37af43 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -101,8 +101,11 @@ impl PartialSum { /// Mutates self according to lattice meet operation (towards `Bottom`). If successful, /// returns whether `self` has changed. /// - /// Fails (without mutation) with the conflicting tag if any common rows have different lengths - pub fn try_meet_mut(&mut self, other: Self) -> Result { + /// # Errors + /// Fails without mutation, either: + /// * `Some(tag)` if the two [PartialSum]s both had rows with that `tag` but of different lengths + /// * `None` if the two instances had no rows in common (i.e., the result is "Bottom") + pub fn try_meet_mut(&mut self, other: Self) -> Result> { let mut changed = false; let mut keys_to_remove = vec![]; for (k, v) in self.0.iter() { @@ -110,11 +113,14 @@ impl PartialSum { None => keys_to_remove.push(*k), Some(o_v) => { if v.len() != o_v.len() { - return Err(*k); + return Err(Some(*k)); } } } } + if keys_to_remove.len() == self.0.len() { + return Err(None); + } for (k, v) in other.0 { if let Some(row) = self.0.get_mut(&k) { for (lhs, rhs) in zip_eq(row.iter_mut(), v.into_iter()) { @@ -451,16 +457,7 @@ impl Lattice for PartialValue { _ => unreachable!(), }; match ps1.try_meet_mut(ps2) { - Ok(ch) => { - // ALAN the 'invariant' that a PartialSum always has >=1 tag can be broken here. - // Fix this by rewriting to Bottom, but should probably be refactored - at the - // least, it seems dangerous to expose a potentially-invalidating try_meet_mut. - if ps1.0.is_empty() { - assert!(ch); - self.0 = PVEnum::Bottom - } - ch - } + Ok(ch) => ch, Err(_) => { self.0 = PVEnum::Bottom; true From 2b523c90f590b35dd691075fa651416504926b6d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 10:19:55 +0100 Subject: [PATCH 097/203] true_or_false uses pv_true+pv_false --- hugr-passes/src/dataflow/test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 4f42b4a4e..cb131f8d4 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -235,7 +235,7 @@ fn test_tail_loop_iterates_twice() { machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); - let true_or_false = PartialValue::new_variant(0, []).join(PartialValue::new_variant(1, [])); + let true_or_false = pv_true().join(pv_false()); // TODO these should be the propagated values for now they will be join(true,false) - ALAN wtf? let o_r1 = machine.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, true_or_false); From 1d2cb9bad3a8b9385fdf5bbce383ea6dfb0063c6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 12:24:24 +0100 Subject: [PATCH 098/203] Update to ascent 0.7.0, drop fn join/meet as these are now trait-default --- hugr-passes/Cargo.toml | 2 +- hugr-passes/src/dataflow/datalog.rs | 10 ---------- hugr-passes/src/dataflow/partial_value.rs | 10 ---------- 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index ff234494e..77b185d31 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -16,7 +16,7 @@ categories = ["compilers"] hugr-core = { path = "../hugr-core", version = "0.10.0" } portgraph = { workspace = true } # This ascent commit has a fix for unsoundness in release/tag 0.6.0: -ascent = {git = "https://github.com/s-arash/ascent", rev="9805d02cb830b6e66abcd4d48836a14cd98366f3"} +ascent = { version = "0.7.0" } downcast-rs = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index fcde4f96b..3d23ca269 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -264,16 +264,6 @@ impl PartialOrd for ValueRow { } impl Lattice for ValueRow { - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - fn join_mut(&mut self, other: Self) -> bool { assert_eq!(self.0.len(), other.0.len()); let mut changed = false; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 0fe37af43..2e67d0bb2 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -378,11 +378,6 @@ where } impl Lattice for PartialValue { - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } - fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); @@ -425,11 +420,6 @@ impl Lattice for PartialValue { } } - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); match (&self.0, other.0) { From ee91bbeae480e4d6d4b7d1f003d06cb43c357c7e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 12:26:56 +0100 Subject: [PATCH 099/203] Cargo.toml: oops, remove obsolete comment --- hugr-passes/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 77b185d31..cdf782ff3 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -15,7 +15,6 @@ categories = ["compilers"] [dependencies] hugr-core = { path = "../hugr-core", version = "0.10.0" } portgraph = { workspace = true } -# This ascent commit has a fix for unsoundness in release/tag 0.6.0: ascent = { version = "0.7.0" } downcast-rs = { workspace = true } itertools = { workspace = true } From e67051ff6daca7f3958a9af51bf33bba3a8ce7a0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 14:59:19 +0100 Subject: [PATCH 100/203] ValueRow cleanups (remove misleading 'pub's) --- hugr-passes/src/dataflow/datalog.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 3d23ca269..d4d040af5 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -221,22 +221,22 @@ fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator(Vec); impl ValueRow { - pub fn new(len: usize) -> Self { + fn new(len: usize) -> Self { Self(vec![PV::bottom(); len]) } - pub fn single_known(len: usize, idx: usize, v: PV) -> Self { + fn single_known(len: usize, idx: usize, v: PV) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - pub fn iter(&self) -> impl Iterator { + fn iter(&self) -> impl Iterator { self.0.iter() } - pub fn unpack_first( + fn unpack_first( &self, variant: usize, len: usize, @@ -245,10 +245,6 @@ impl ValueRow { .variant_values(variant, len) .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) } - - // fn initialised(&self) -> bool { - // self.0.iter().all(|x| x != &PV::top()) - // } } impl FromIterator for ValueRow { From 94cee551202af0ba67941819ae0a7582baf12597 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 15:00:28 +0100 Subject: [PATCH 101/203] Refactor: rm tail_node, clone earlier in ValueRow::unpack_first, rm ValueRow::iter Only one use of tail_node that didn't *also* filter to TailLoop nodes itself. Total amount of copying in unpack_first still same, but memory usage increased (clones whole lot in one go). Was required to fix borrow issue with refactor, but same issue otherwise prevents the next commit (CFGs)... --- hugr-passes/src/dataflow/datalog.rs | 30 ++++++++++++----------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index d4d040af5..5217ab276 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -114,27 +114,28 @@ ascent::ascent! { // TailLoop - relation tail_loop_node(C, Node); - tail_loop_node(c,n) <-- node(c, n), if c.get_optype(*n).is_tail_loop(); // inputs of tail loop propagate to Input node of child region - out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- tail_loop_node(c, tl), - io_node(c,tl,i, IO::Input), in_wire_value(c, tl, p, v); + out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- node(c, tl), + if c.get_optype(*tl).is_tail_loop(), + io_node(c,tl,i, IO::Input), + in_wire_value(c, tl, p, v); // Output node of child region propagate to Input node of child region - out_wire_value(c, in_n, out_p, v) <-- tail_loop_node(c, tl_n), + out_wire_value(c, in_n, out_p, v) <-- node(c, tl_n), + if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), io_node(c,tl_n,in_n, IO::Input), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), + if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); // Output node of child region propagate to outputs of tail loop - out_wire_value(c, tl_n, out_p, v) <-- tail_loop_node(c, tl_n), + out_wire_value(c, tl_n, out_p, v) <-- node(c, tl_n), + if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); @@ -232,18 +233,11 @@ impl ValueRow { r } - fn iter(&self) -> impl Iterator { - self.0.iter() - } - - fn unpack_first( - &self, - variant: usize, - len: usize, - ) -> Option + '_> { + fn unpack_first(&self, variant: usize, len: usize) -> Option> { + let rest: Vec<_> = self.0[1..].to_owned(); self[0] .variant_values(variant, len) - .map(|vals| vals.into_iter().chain(self.iter().skip(1).cloned())) + .map(|vals| vals.into_iter().chain(rest)) } } From 5cf5ff0aeb87a5d0aba63dc1fa28ac60b29ed8b7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 15:12:56 +0100 Subject: [PATCH 102/203] Add datalog for CFG --- hugr-passes/src/dataflow/datalog.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 5217ab276..69989d9b0 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -169,6 +169,34 @@ ascent::ascent! { in_wire_value(c, cond, IncomingPort::from(0), v), let reachable = v.supports_tag(*i); + // CFG + relation cfg_node(C, Node); + relation dfb_block(C, Node, Node); + cfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_cfg(); + dfb_block(c,cfg,blk) <-- cfg_node(c, cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); + + // Where do the values "fed" along a control-flow edge come out? + relation _cfg_succ_dest(C, Node, Node, Node); + _cfg_succ_dest(c, cfg, blk, inp) <-- dfb_block(c, cfg, blk), io_node(c, blk, inp, IO::Input); + _cfg_succ_dest(c, cfg, exit, cfg) <-- cfg_node(c, cfg), if let Some(exit) = c.children(*cfg).skip(1).next(); + + // Inputs of CFG propagate to entry block + out_wire_value(c, i_node, OutgoingPort::from(p.index()), v) <-- + cfg_node(c, cfg), + if let Some(entry) = c.children(*cfg).next(), + io_node(c, entry, i_node, IO::Input), + in_wire_value(c, cfg, p, v); + + // Outputs of each block propagated to successor blocks or (if exit block) then CFG itself + out_wire_value(c, dest, out_p, v) <-- + dfb_block(c, cfg, pred), + if let Some(df_block) = c.get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), + io_node(c, pred, out_n, IO::Output), + _cfg_succ_dest(c, cfg, succ, dest), + node_in_value_row(c, out_n, out_in_row), + if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); } fn propagate_leaf_op( From 7381087bf4a04a66be104a2eed32f8a2ae7573d1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 15:45:32 +0100 Subject: [PATCH 103/203] refactor: follow unpack_first with enumerate --- hugr-passes/src/dataflow/datalog.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 69989d9b0..750822f14 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -122,22 +122,22 @@ ascent::ascent! { in_wire_value(c, tl, p, v); // Output node of child region propagate to Input node of child region - out_wire_value(c, in_n, out_p, v) <-- node(c, tl_n), + out_wire_value(c, in_n, OutgoingPort::from(out_p), v) <-- node(c, tl_n), if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), io_node(c,tl_n,in_n, IO::Input), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 - for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); + for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop - out_wire_value(c, tl_n, out_p, v) <-- node(c, tl_n), + out_wire_value(c, tl_n, OutgoingPort::from(out_p), v) <-- node(c, tl_n), if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), io_node(c,tl_n,out_n, IO::Output), node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 - for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); + for (out_p, v) in fields.enumerate(); // Conditional relation conditional_node(C, Node); @@ -149,14 +149,14 @@ ascent::ascent! { if c.get_optype(case).is_case(); // inputs of conditional propagate into case nodes - out_wire_value(c, i_node, i_p, v) <-- + out_wire_value(c, i_node, OutgoingPort::from(out_p), v) <-- case_node(c, cond, case_index, case), io_node(c, case, i_node, IO::Input), node_in_value_row(c, cond, in_row), //in_wire_value(c, cond, cond_in_p, cond_in_v), if let Some(conditional) = c.get_optype(*cond).as_conditional(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), - for (i_p, v) in (0..).map(OutgoingPort::from).zip(fields); + for (out_p, v) in fields.enumerate(); // outputs of case nodes propagate to outputs of conditional out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- @@ -188,7 +188,7 @@ ascent::ascent! { in_wire_value(c, cfg, p, v); // Outputs of each block propagated to successor blocks or (if exit block) then CFG itself - out_wire_value(c, dest, out_p, v) <-- + out_wire_value(c, dest, OutgoingPort::from(out_p), v) <-- dfb_block(c, cfg, pred), if let Some(df_block) = c.get_optype(*pred).as_dataflow_block(), for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), @@ -196,7 +196,7 @@ ascent::ascent! { _cfg_succ_dest(c, cfg, succ, dest), node_in_value_row(c, out_n, out_in_row), if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), - for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields); + for (out_p, v) in fields.enumerate(); } fn propagate_leaf_op( From 60e33dba2473535965de85533bc46d9f8ea01de0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 17:01:04 +0100 Subject: [PATCH 104/203] Remove comments from test_tail_loop_(iterates_twice->two_iters) --- hugr-passes/src/dataflow/test.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index cb131f8d4..0babf2945 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -202,20 +202,17 @@ fn test_tail_loop_always_iterates() { } #[test] -fn test_tail_loop_iterates_twice() { +fn test_tail_loop_two_iters() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); - // let var_type = Type::new_sum([type_row![BOOL_T,BOOL_T], type_row![BOOL_T,BOOL_T]]); let true_w = builder.add_load_value(Value::true_val()); let false_w = builder.add_load_value(Value::false_val()); - // let r_w = builder - // .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); let tlb = builder .tail_loop_builder_exts( [], [(BOOL_T, false_w), (BOOL_T, true_w)], - vec![].into(), + type_row![], ExtensionSet::new(), ) .unwrap(); @@ -227,8 +224,6 @@ fn test_tail_loop_iterates_twice() { let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - // TODO once we can do conditionals put these wires inside `just_outputs` and - // we should be able to propagate their values...ALAN wtf? loop control type IS bool ATM let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); @@ -236,7 +231,6 @@ fn test_tail_loop_iterates_twice() { machine.run(TestContext(Arc::new(&hugr))); let true_or_false = pv_true().join(pv_false()); - // TODO these should be the propagated values for now they will be join(true,false) - ALAN wtf? let o_r1 = machine.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, true_or_false); let o_r2 = machine.read_out_wire(o_w2).unwrap(); From ef4f4335d24c22ad8462fe082fe2af1c299e656f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 17:01:15 +0100 Subject: [PATCH 105/203] Add a test of tail loop around conditional --- hugr-passes/src/dataflow/test.rs | 61 ++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 0babf2945..21fa80547 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -242,6 +242,67 @@ fn test_tail_loop_two_iters() { assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); } +#[test] +fn test_tail_loop_containing_conditional() { + let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); + let body_out_variants = vec![type_row![BOOL_T; 2]; 2]; + + let true_w = builder.add_load_value(Value::true_val()); + let false_w = builder.add_load_value(Value::false_val()); + + let mut tlb = builder + .tail_loop_builder_exts( + [(BOOL_T, false_w), (BOOL_T, true_w)], + [], + type_row![BOOL_T, BOOL_T], + ExtensionSet::new(), + ) + .unwrap(); + assert_eq!( + tlb.loop_signature().unwrap().signature(), + Signature::new_endo(type_row![BOOL_T, BOOL_T]) + ); + let [in_w1, in_w2] = tlb.input_wires_arr(); + + // Branch on in_w1, so first iter (false, true) uses false == tag 0 == continue with (true, true) + // second iter (true, true) uses true == tag 1 == break with (true, true) + let mut cond = tlb + .conditional_builder( + (vec![type_row![]; 2], in_w1), + [], + Type::new_sum(body_out_variants.clone()).into(), + ) + .unwrap(); + for (tag, second_output) in [(0, true_w), (1, false_w)] { + let mut case_b = cond.case_builder(tag).unwrap(); + let r = case_b + .add_dataflow_op(Tag::new(tag, body_out_variants), [in_w2, second_output]) + .unwrap() + .outputs(); + case_b.finish_with_outputs(r).unwrap(); + } + let [r] = cond.finish_sub_container().unwrap().outputs_arr(); + + let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); + + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + let [o_w1, o_w2] = tail_loop.outputs_arr(); + + let mut machine = Machine::default(); + machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); + machine.run(TestContext(Arc::new(&hugr))); + + let o_r1 = machine.read_out_wire(o_w1).unwrap(); + assert_eq!(o_r1, pv_true()); + let o_r2 = machine.read_out_wire(o_w2).unwrap(); + assert_eq!(o_r2, pv_false()); + assert_eq!( + Some(TailLoopTermination::BreaksAndContinues), + machine.tail_loop_terminates(&hugr, tail_loop.node()) + ); + assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); +} + #[test] fn conditional() { let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; From 59354894d06b309e5cd392c22bac0575e5a15fef Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 17:56:19 +0100 Subject: [PATCH 106/203] improve that test - loop input is a sum and the variants have different values --- hugr-passes/src/dataflow/test.rs | 62 ++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 21fa80547..125838010 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -2,6 +2,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; + use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -245,42 +246,49 @@ fn test_tail_loop_two_iters() { #[test] fn test_tail_loop_containing_conditional() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); - let body_out_variants = vec![type_row![BOOL_T; 2]; 2]; - - let true_w = builder.add_load_value(Value::true_val()); - let false_w = builder.add_load_value(Value::false_val()); + let control_variants = vec![type_row![BOOL_T;2]; 2]; + let control_t = Type::new_sum(control_variants.clone()); + let body_out_variants = vec![control_t.clone().into(), type_row![BOOL_T; 2]]; + + let init = builder.add_load_value( + Value::sum( + 0, + [Value::false_val(), Value::true_val()], + SumType::new(control_variants.clone()), + ) + .unwrap(), + ); let mut tlb = builder - .tail_loop_builder_exts( - [(BOOL_T, false_w), (BOOL_T, true_w)], - [], - type_row![BOOL_T, BOOL_T], - ExtensionSet::new(), - ) + .tail_loop_builder([(control_t, init)], [], type_row![BOOL_T; 2]) .unwrap(); - assert_eq!( - tlb.loop_signature().unwrap().signature(), - Signature::new_endo(type_row![BOOL_T, BOOL_T]) - ); - let [in_w1, in_w2] = tlb.input_wires_arr(); + let [in_w] = tlb.input_wires_arr(); - // Branch on in_w1, so first iter (false, true) uses false == tag 0 == continue with (true, true) - // second iter (true, true) uses true == tag 1 == break with (true, true) + // Branch on in_wire, so first iter 0(false, true)... let mut cond = tlb .conditional_builder( - (vec![type_row![]; 2], in_w1), + (control_variants.clone(), in_w), [], Type::new_sum(body_out_variants.clone()).into(), ) .unwrap(); - for (tag, second_output) in [(0, true_w), (1, false_w)] { - let mut case_b = cond.case_builder(tag).unwrap(); - let r = case_b - .add_dataflow_op(Tag::new(tag, body_out_variants), [in_w2, second_output]) - .unwrap() - .outputs(); - case_b.finish_with_outputs(r).unwrap(); - } + let mut case0_b = cond.case_builder(0).unwrap(); + let [a, b] = case0_b.input_wires_arr(); + // Builds value for next iter as 1(true, false) by flipping arguments + let [next_input] = case0_b + .add_dataflow_op(Tag::new(1, control_variants), [b, a]) + .unwrap() + .outputs_arr(); + let cont = case0_b + .add_dataflow_op(Tag::new(0, body_out_variants.clone()), [next_input]) + .unwrap(); + case0_b.finish_with_outputs(cont.outputs()).unwrap(); + // Second iter 1(true, false) => exit with (true, false) + let mut case1_b = cond.case_builder(1).unwrap(); + let loop_res = case1_b + .add_dataflow_op(Tag::new(1, body_out_variants), case1_b.input_wires()) + .unwrap(); + case1_b.finish_with_outputs(loop_res.outputs()).unwrap(); let [r] = cond.finish_sub_container().unwrap().outputs_arr(); let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); @@ -289,7 +297,7 @@ fn test_tail_loop_containing_conditional() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); + machine.propolutate_out_wires([(init, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o_r1 = machine.read_out_wire(o_w1).unwrap(); From b19868133a3ed3e5dfabea55404cbcd8f5cb4216 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 10:00:26 +0100 Subject: [PATCH 107/203] clippy/nth --- hugr-core/src/types.rs | 2 -- hugr-passes/src/dataflow/datalog.rs | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 5afab2294..39980d65f 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -16,9 +16,7 @@ use crate::types::type_param::check_type_arg; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; -pub(crate) use poly_func::PolyFuncTypeBase; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; -pub(crate) use signature::FuncTypeBase; pub use signature::{FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::TypeArg; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 750822f14..3b6fb05a4 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -178,7 +178,7 @@ ascent::ascent! { // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(C, Node, Node, Node); _cfg_succ_dest(c, cfg, blk, inp) <-- dfb_block(c, cfg, blk), io_node(c, blk, inp, IO::Input); - _cfg_succ_dest(c, cfg, exit, cfg) <-- cfg_node(c, cfg), if let Some(exit) = c.children(*cfg).skip(1).next(); + _cfg_succ_dest(c, cfg, exit, cfg) <-- cfg_node(c, cfg), if let Some(exit) = c.children(*cfg).nth(1); // Inputs of CFG propagate to entry block out_wire_value(c, i_node, OutgoingPort::from(p.index()), v) <-- From ed30f808a51d503cc58996998de34ef04b8790d8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 12:22:14 +0100 Subject: [PATCH 108/203] revert accidental changes to hugr-core/src/types.rs (how?!) --- hugr-core/src/types.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 39980d65f..5afab2294 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -16,7 +16,9 @@ use crate::types::type_param::check_type_arg; use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; +pub(crate) use poly_func::PolyFuncTypeBase; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; +pub(crate) use signature::FuncTypeBase; pub use signature::{FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::TypeArg; From 3f7808ade8e149d6d253037f94d945d88bdceede Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 15:18:03 +0100 Subject: [PATCH 109/203] Cleanup conditional, cfg, unpack_first --- hugr-passes/src/dataflow/datalog.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 3b6fb05a4..b64de3aaf 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -153,8 +153,7 @@ ascent::ascent! { case_node(c, cond, case_index, case), io_node(c, case, i_node, IO::Input), node_in_value_row(c, cond, in_row), - //in_wire_value(c, cond, cond_in_p, cond_in_v), - if let Some(conditional) = c.get_optype(*cond).as_conditional(), + let conditional = c.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); @@ -190,7 +189,7 @@ ascent::ascent! { // Outputs of each block propagated to successor blocks or (if exit block) then CFG itself out_wire_value(c, dest, OutgoingPort::from(out_p), v) <-- dfb_block(c, cfg, pred), - if let Some(df_block) = c.get_optype(*pred).as_dataflow_block(), + let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), io_node(c, pred, out_n, IO::Output), _cfg_succ_dest(c, cfg, succ, dest), @@ -262,10 +261,8 @@ impl ValueRow { } fn unpack_first(&self, variant: usize, len: usize) -> Option> { - let rest: Vec<_> = self.0[1..].to_owned(); - self[0] - .variant_values(variant, len) - .map(|vals| vals.into_iter().chain(rest)) + let vals = self[0].variant_values(variant, len)?; + Some(vals.into_iter().chain(self.0[1..].to_owned())) } } From 436b63533d908ea9215073ce0c668aac14a3ad0d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Oct 2024 19:39:18 +0100 Subject: [PATCH 110/203] Complex CFG that does a not-XOR...but analysis generally says "true or false" --- hugr-passes/src/dataflow/test.rs | 122 ++++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 125838010..e3e8545ec 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,6 +3,9 @@ use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; +use hugr_core::builder::CFGBuilder; +use hugr_core::types::TypeRow; +use hugr_core::Wire; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -14,6 +17,7 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; +use rstest::rstest; use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -90,6 +94,10 @@ fn pv_true() -> PartialValue { PartialValue::new_variant(1, []) } +fn pv_true_or_false() -> PartialValue { + pv_true().join(pv_false()) +} + #[test] fn test_make_tuple() { let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap(); @@ -231,11 +239,10 @@ fn test_tail_loop_two_iters() { machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); - let true_or_false = pv_true().join(pv_false()); let o_r1 = machine.read_out_wire(o_w1).unwrap(); - assert_eq!(o_r1, true_or_false); + assert_eq!(o_r1, pv_true_or_false()); let o_r2 = machine.read_out_wire(o_w2).unwrap(); - assert_eq!(o_r2, true_or_false); + assert_eq!(o_r2, pv_true_or_false()); assert_eq!( Some(TailLoopTermination::BreaksAndContinues), machine.tail_loop_terminates(&hugr, tail_loop.node()) @@ -371,3 +378,112 @@ fn conditional() { assert_eq!(machine.case_reachable(&hugr, case3.node()), Some(true)); assert_eq!(machine.case_reachable(&hugr, cond.node()), None); } + +#[rstest] +#[case(pv_true(), pv_true(), pv_true())] // OK +#[case(pv_true(), pv_false(), pv_true_or_false())] // Result should be false ?? +#[case(pv_false(), pv_true(), pv_true_or_false())] // Result should be false ?? +#[case(pv_false(), pv_false(), pv_true_or_false())] // Result should be true?? +#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Result should be true_or_false? TOP means all inputs inside cases are TOP +#[case(PartialValue::top(), pv_false(), PartialValue::top())] // Result should be true_or_false? +fn cfg( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] outp: PartialValue, +) { + // Entry + // /0 1\ + // A --1-> B + // \0 / + // > X < + let mut builder = CFGBuilder::new(Signature::new(type_row![BOOL_T;2], BOOL_T)).unwrap(); + + // entry (i, j) => if i {B(j)} else {A(j, i, true)}, note that (j, i, true) == (j, false, true) + let entry_outs = [type_row![BOOL_T;3], type_row![BOOL_T]]; + let mut entry = builder + .entry_builder(entry_outs.clone(), type_row![]) + .unwrap(); + let [in_i, in_j] = entry.input_wires_arr(); + let mut cond = entry + .conditional_builder( + (vec![type_row![]; 2], in_i), + [], + Type::new_sum(entry_outs.clone()).into(), + ) + .unwrap(); + let mut if_i_true = cond.case_builder(1).unwrap(); + let br_to_b = if_i_true + .add_dataflow_op(Tag::new(1, entry_outs.to_vec()), [in_j]) + .unwrap(); + if_i_true.finish_with_outputs(br_to_b.outputs()).unwrap(); + let mut if_i_false = cond.case_builder(0).unwrap(); + let true_w = if_i_false.add_load_value(Value::true_val()); + let br_to_a = if_i_false + .add_dataflow_op(Tag::new(0, entry_outs.into()), [in_j, in_i, true_w]) + .unwrap(); + if_i_false.finish_with_outputs(br_to_a.outputs()).unwrap(); + + let [res] = cond.finish_sub_container().unwrap().outputs_arr(); + let entry = entry.finish_with_outputs(res, []).unwrap(); + + // A(w, y, z) => if w {B(y)} else {X(z)} + let a_outs = vec![type_row![BOOL_T]; 2]; + let mut a = builder + .block_builder( + type_row![BOOL_T; 3], + vec![type_row![BOOL_T]; 2], + type_row![], + ) + .unwrap(); + let [in_w, in_y, in_z] = a.input_wires_arr(); + let mut cond = a + .conditional_builder( + (vec![type_row![]; 2], in_w), + [], + Type::new_sum(a_outs.clone()).into(), + ) + .unwrap(); + let mut if_w_true = cond.case_builder(1).unwrap(); + let br_to_b = if_w_true + .add_dataflow_op(Tag::new(1, a_outs.clone()), [in_y]) + .unwrap(); + if_w_true.finish_with_outputs(br_to_b.outputs()).unwrap(); + let mut if_w_false = cond.case_builder(0).unwrap(); + let br_to_x = if_w_false + .add_dataflow_op(Tag::new(0, a_outs), [in_z]) + .unwrap(); + if_w_false.finish_with_outputs(br_to_x.outputs()).unwrap(); + let [res] = cond.finish_sub_container().unwrap().outputs_arr(); + let a = a.finish_with_outputs(res, []).unwrap(); + + // B(v) => X(v) + let mut b = builder + .block_builder(type_row![BOOL_T], [type_row![BOOL_T]], type_row![]) + .unwrap(); + let [control] = b + .add_dataflow_op(Tag::new(0, vec![type_row![BOOL_T]]), b.input_wires()) + .unwrap() + .outputs_arr(); + let b = b.finish_with_outputs(control, []).unwrap(); + + let x = builder.exit_block(); + + builder.branch(&entry, 0, &a).unwrap(); + builder.branch(&entry, 1, &b).unwrap(); + builder.branch(&a, 0, &x).unwrap(); + builder.branch(&a, 1, &b).unwrap(); + builder.branch(&b, 0, &x).unwrap(); + let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); + + let [entry_input, _] = hugr.get_io(entry.node()).unwrap(); + let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); + + let mut machine = Machine::default(); + machine.propolutate_out_wires([(in_w0, inp0), (in_w1, inp1), (true_w, pv_true())]); + machine.run(TestContext(Arc::new(&hugr))); + + assert_eq!( + machine.read_out_wire(Wire::new(hugr.root(), 0)).unwrap(), + outp + ); +} From 0374d130c87a41575848ba7a5661daf61974abfc Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 16:13:25 +0100 Subject: [PATCH 111/203] Propagate case results to conditional output only if case reached; some test fix --- hugr-passes/src/dataflow/datalog.rs | 6 ++++-- hugr-passes/src/dataflow/test.rs | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b64de3aaf..258d76c0d 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -157,9 +157,11 @@ ascent::ascent! { if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); - // outputs of case nodes propagate to outputs of conditional + // outputs of case nodes propagate to outputs of conditional *if* case reachable out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- - case_node(c, cond, _, case), + case_node(c, cond, i, case), + in_wire_value(c, cond, IncomingPort::from(0), control), + if control.supports_tag(*i), io_node(c, case, o, IO::Output), in_wire_value(c, o, o_p, v); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e3e8545ec..40481855c 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -381,11 +381,11 @@ fn conditional() { #[rstest] #[case(pv_true(), pv_true(), pv_true())] // OK -#[case(pv_true(), pv_false(), pv_true_or_false())] // Result should be false ?? -#[case(pv_false(), pv_true(), pv_true_or_false())] // Result should be false ?? -#[case(pv_false(), pv_false(), pv_true_or_false())] // Result should be true?? +#[case(pv_true(), pv_false(), pv_false())] // OK +#[case(pv_false(), pv_true(), pv_false())] // OK +#[case(pv_false(), pv_false(), pv_true())] // OK #[case(PartialValue::top(), pv_true(), PartialValue::top())] // Result should be true_or_false? TOP means all inputs inside cases are TOP -#[case(PartialValue::top(), pv_false(), PartialValue::top())] // Result should be true_or_false? +#[case(PartialValue::top(), pv_false(), pv_true_or_false())] // OK fn cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, From 6a2dd9e79ab83d92d7acb049d64cc8b1022b90ec Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 16:15:17 +0100 Subject: [PATCH 112/203] More test cases --- hugr-passes/src/dataflow/test.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 40481855c..71b0425bf 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -380,12 +380,16 @@ fn conditional() { } #[rstest] -#[case(pv_true(), pv_true(), pv_true())] // OK -#[case(pv_true(), pv_false(), pv_false())] // OK -#[case(pv_false(), pv_true(), pv_false())] // OK -#[case(pv_false(), pv_false(), pv_true())] // OK -#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Result should be true_or_false? TOP means all inputs inside cases are TOP -#[case(PartialValue::top(), pv_false(), pv_true_or_false())] // OK +#[case(pv_true(), pv_true(), pv_true())] +#[case(pv_true(), pv_false(), pv_false())] +#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_false(), pv_true(), pv_false())] +#[case(pv_false(), pv_false(), pv_true())] +#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] +#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Ideally, result should be true_or_false +#[case(PartialValue::top(), pv_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_true(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false())] fn cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, From a75fee947bc64a0eea89f0d5a13768b44a96ebf1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 16:23:40 +0100 Subject: [PATCH 113/203] refactor as fixture --- hugr-passes/src/dataflow/test.rs | 51 ++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 71b0425bf..44f2bcaba 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -5,7 +5,6 @@ use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::CFGBuilder; use hugr_core::types::TypeRow; -use hugr_core::Wire; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -17,7 +16,8 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; -use rstest::rstest; +use hugr_core::{Hugr, Node, Wire}; +use rstest::{fixture, rstest}; use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -379,22 +379,14 @@ fn conditional() { assert_eq!(machine.case_reachable(&hugr, cond.node()), None); } -#[rstest] -#[case(pv_true(), pv_true(), pv_true())] -#[case(pv_true(), pv_false(), pv_false())] -#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] -#[case(pv_false(), pv_true(), pv_false())] -#[case(pv_false(), pv_false(), pv_true())] -#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] -#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Ideally, result should be true_or_false -#[case(PartialValue::top(), pv_false(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_true(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_false(), pv_true_or_false())] -fn cfg( - #[case] inp0: PartialValue, - #[case] inp1: PartialValue, - #[case] outp: PartialValue, -) { +// Tuple of +// 1. Hugr being a function on bools: (b,c) => !b XOR c +// 2. Input node of entry block +// 3. Wire out from "True" constant +// Result readable from root node outputs +// Inputs should be placed onto out-wires of the Node (2.) +#[fixture] +fn xnor_cfg() -> (Hugr, Node, Wire) { // Entry // /0 1\ // A --1-> B @@ -478,8 +470,29 @@ fn cfg( builder.branch(&a, 1, &b).unwrap(); builder.branch(&b, 0, &x).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let [entry_input, _] = hugr.get_io(entry.node()).unwrap(); + (hugr, entry_input, true_w) +} + +#[rstest] +#[case(pv_true(), pv_true(), pv_true())] +#[case(pv_true(), pv_false(), pv_false())] +#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_false(), pv_true(), pv_false())] +#[case(pv_false(), pv_false(), pv_true())] +#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] +#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Ideally, result should be true_or_false +#[case(PartialValue::top(), pv_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_true(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false())] +fn test_cfg( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] outp: PartialValue, + xnor_cfg: (Hugr, Node, Wire), +) { + let (hugr, entry_input, true_w) = xnor_cfg; + let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); let mut machine = Machine::default(); From 151e571c82804ac927292e85f681d314657eb87c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 18:59:51 +0100 Subject: [PATCH 114/203] clippy --- hugr-passes/src/dataflow/test.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 44f2bcaba..2d652c7a4 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::CFGBuilder; -use hugr_core::types::TypeRow; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ From e15b04d0806031f894dadb6501c669177811177e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 10:33:30 +0100 Subject: [PATCH 115/203] Revert "Datalog works on any AbstractValue; impl'd by PartialValue for a BaseValue" This reverts commit 7f2a91a5fc5bc26143f5e82543c172e10ebea90d. --- hugr-passes/src/dataflow.rs | 18 ++++- hugr-passes/src/dataflow/datalog.rs | 91 +++++++++-------------- hugr-passes/src/dataflow/machine.rs | 22 +++--- hugr-passes/src/dataflow/partial_value.rs | 77 +++++++++---------- hugr-passes/src/dataflow/test.rs | 6 +- 5 files changed, 103 insertions(+), 111 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 6085c3e92..5cf5d91eb 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -2,13 +2,27 @@ //! Dataflow analysis of Hugrs. mod datalog; -pub use datalog::{AbstractValue, DFContext}; mod machine; pub use machine::{Machine, TailLoopTermination}; mod partial_value; -pub use partial_value::{BaseValue, PVEnum, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, Sum}; + +use hugr_core::{Hugr, Node}; +use std::hash::Hash; + +/// Clients of the dataflow framework (particular analyses, such as constant folding) +/// must implement this trait (including providing an appropriate domain type `V`). +pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { + /// Given lattice values for each input, produce lattice values for (what we know of) + /// the outputs. Returning `None` indicates nothing can be deduced. + fn interpret_leaf_op( + &self, + node: Node, + ins: &[PartialValue], + ) -> Option>>; +} #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 258d76c0d..4ab12e380 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -16,7 +16,11 @@ use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::OpType; use hugr_core::types::Signature; -use hugr_core::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; + +use super::{AbstractValue, DFContext, PartialValue}; + +type PV = PartialValue; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IO { @@ -24,50 +28,19 @@ pub enum IO { Output, } -/// Clients of the dataflow framework (particular analyses, such as constant folding) -/// must implement this trait (including providing an appropriate domain type `PV`). -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - /// Given lattice values for each input, produce lattice values for (what we know of) - /// the outputs. Returning `None` indicates nothing can be deduced. - fn interpret_leaf_op(&self, node: Node, ins: &[PV]) -> Option>; -} - -/// Values which can be the domain for dataflow analysis. Must be able to deconstructed -/// into (and constructed from) Sums as these determine control flow. -pub trait AbstractValue: BoundedLattice + Clone + Eq + Hash + std::fmt::Debug { - /// Create a new instance representing a Sum with a single known tag - /// and (recursive) representations of the elements within that tag. - fn new_variant(tag: usize, values: impl IntoIterator) -> Self; - - /// New instance of unit type (i.e. the only possible value, with no contents) - fn new_unit() -> Self { - Self::new_variant(0, []) - } - - /// Test whether this value *might* be a Sum with the specified tag. - fn supports_tag(&self, tag: usize) -> bool; - - /// If this value might be a Sum with the specified tag, return values - /// describing the elements of the Sum, otherwise `None`. - /// - /// Implementations must hold the invariant that for all `x`, `tag` and `len`: - /// `x.variant_values(tag, len).is_some() == x.supports_tag(tag)` - fn variant_values(&self, tag: usize, len: usize) -> Option>; -} - ascent::ascent! { - pub(super) struct AscentProgram>; + pub(super) struct AscentProgram>; relation context(C); - relation out_wire_value_proto(Node, OutgoingPort, PV); + relation out_wire_value_proto(Node, OutgoingPort, PV); relation node(C, Node); relation in_wire(C, Node, IncomingPort); relation out_wire(C, Node, OutgoingPort); relation parent_of_node(C, Node, Node); relation io_node(C, Node, Node, IO); - lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); - lattice in_wire_value(C, Node, IncomingPort, PV); + lattice out_wire_value(C, Node, OutgoingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); + lattice in_wire_value(C, Node, IncomingPort, PV); node(c, n) <-- context(c), for n in c.nodes(); @@ -200,11 +173,11 @@ ascent::ascent! { for (out_p, v) in fields.enumerate(); } -fn propagate_leaf_op( - c: &impl DFContext, +fn propagate_leaf_op( + c: &impl DFContext, n: Node, - ins: &[PV], -) -> Option> { + ins: &[PV], +) -> Option> { match c.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be @@ -248,33 +221,37 @@ fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator(Vec); +struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { fn new(len: usize) -> Self { - Self(vec![PV::bottom(); len]) + Self(vec![PartialValue::bottom(); len]) } - fn single_known(len: usize, idx: usize, v: PV) -> Self { + fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { assert!(idx < len); let mut r = Self::new(len); r.0[idx] = v; r } - fn unpack_first(&self, variant: usize, len: usize) -> Option> { + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option>> { let vals = self[0].variant_values(variant, len)?; Some(vals.into_iter().chain(self.0[1..].to_owned())) } } -impl FromIterator for ValueRow { - fn from_iter>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } @@ -300,30 +277,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PV; +impl IntoIterator for ValueRow { + type Item = PartialValue; - type IntoIter = as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec: Index, + Vec>: Index, { - type Output = as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec: IndexMut, + Vec>: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index aa9408cdb..15262d4db 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -2,16 +2,16 @@ use std::collections::HashMap; use hugr_core::{HugrView, Node, PortIndex, Wire}; -use super::{datalog::AscentProgram, AbstractValue, DFContext}; +use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. Zero or more [Self::propolutate_out_wires] with initial values /// 3. Exactly one [Self::run] to do the analysis /// 4. Results then available via [Self::read_out_wire] -pub struct Machine>( - AscentProgram, - Option>, +pub struct Machine>( + AscentProgram, + Option>>, ); /// derived-Default requires the context to be Defaultable, which is unnecessary @@ -21,10 +21,13 @@ impl> Default for Machine { } } -impl> Machine { +impl> Machine { /// Provide initial values for some wires. /// (For example, if some properties of the Hugr's inputs are known.) - pub fn propolutate_out_wires(&mut self, wires: impl IntoIterator) { + pub fn propolutate_out_wires( + &mut self, + wires: impl IntoIterator)>, + ) { assert!(self.1.is_none()); self.0 .out_wire_value_proto @@ -52,7 +55,7 @@ impl> Machine { } /// Gets the lattice value computed by [Self::run] for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.1.as_ref().unwrap().get(&w).cloned() } @@ -113,10 +116,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - /// Extracts the relevant information from a value that should represent - /// the value provided to the [Output](hugr_core::ops::Output) node child - /// of the [TailLoop](hugr_core::ops::TailLoop) - pub fn from_control_value(v: &impl AbstractValue) -> Self { + fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 2e67d0bb2..880a30241 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -8,11 +8,9 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -use super::AbstractValue; - -/// Trait for abstract values that can be wrapped by [PartialValue] for dataflow analysis. -/// (Allows the values to represent sums, but does not require this). -pub trait BaseValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { +/// Trait for an underlying domain of abstract values which can form the *elements* of a +/// [PartialValue] and thus be used in dataflow analysis. +pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { /// If the abstract value represents a [Sum] with a single known tag, deconstruct it /// into that tag plus the elements. The default just returns `None` which is /// appropriate if the abstract value never does (in which case [interpret_leaf_op] @@ -65,7 +63,7 @@ impl PartialSum { } } -impl PartialSum { +impl PartialSum { fn assert_invariants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { @@ -232,15 +230,15 @@ impl Hash for PartialSum { } } -/// Wraps some underlying representation of values (that `impl`s [BaseValue]) into -/// a lattice for use in dataflow analysis, including that an instance may be -/// a [PartialSum] of values of the underlying representation +/// Wraps some underlying representation (knowledge) of values into a lattice +/// for use in dataflow analysis, including that an instance may be a [PartialSum] +/// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] pub struct PartialValue(PVEnum); impl PartialValue { /// Allows to read the enum, which guarantees that we never return [PVEnum::Value] - /// for a value whose [BaseValue::as_sum] is `Some` - any such value will be + /// for a value whose [AbstractValue::as_sum] is `Some` - any such value will be /// in the form of a [PVEnum::Sum] instead. pub fn as_enum(&self) -> &PVEnum { &self.0 @@ -260,7 +258,7 @@ pub enum PVEnum { Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) @@ -274,7 +272,7 @@ impl From> for PartialValue { } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { match &self.0 { PVEnum::Sum(ps) => { @@ -287,30 +285,22 @@ impl PartialValue { } } - /// Extracts a value (in any representation supporting both leaf values and sums) - // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? - pub fn try_into_value + TryFrom>>( - self, - typ: &Type, - ) -> Result>>::Error>> { - match self.0 { - PVEnum::Value(v) => Ok(V2::from(v.clone())), - PVEnum::Sum(ps) => { - let v = ps.try_into_value(typ).map_err(|_| None)?; - V2::try_from(v).map_err(Some) - } - _ => Err(None), - } + /// New instance of a sum with a single known tag. + pub fn new_variant(tag: usize, values: impl IntoIterator) -> Self { + PartialSum::new_variant(tag, values).into() + } + + /// New instance of unit type (i.e. the only possible value, with no contents) + pub fn new_unit() -> Self { + Self::new_variant(0, []) } -} -impl AbstractValue for PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - fn variant_values(&self, tag: usize, len: usize) -> Option>> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match &self.0 { PVEnum::Bottom => return None, PVEnum::Value(v) => { @@ -325,7 +315,7 @@ impl AbstractValue for PartialValue { } /// Tells us whether this value might be a Sum with the specified `tag` - fn supports_tag(&self, tag: usize) -> bool { + pub fn supports_tag(&self, tag: usize) -> bool { match &self.0 { PVEnum::Bottom => false, PVEnum::Value(v) => { @@ -337,8 +327,20 @@ impl AbstractValue for PartialValue { } } - fn new_variant(tag: usize, values: impl IntoIterator) -> Self { - PartialSum::new_variant(tag, values).into() + /// Extracts a value (in any representation supporting both leaf values and sums) + // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? + pub fn try_into_value + TryFrom>>( + self, + typ: &Type, + ) -> Result>>::Error>> { + match self.0 { + PVEnum::Value(v) => Ok(V2::from(v.clone())), + PVEnum::Sum(ps) => { + let v = ps.try_into_value(typ).map_err(|_| None)?; + V2::try_from(v).map_err(Some) + } + _ => Err(None), + } } } @@ -350,7 +352,7 @@ impl TryFrom> for Value { } } -impl PartialValue +impl PartialValue where Value: From, { @@ -377,7 +379,7 @@ where } } -impl Lattice for PartialValue { +impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); @@ -463,7 +465,7 @@ impl Lattice for PartialValue { } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self(PVEnum::Top) } @@ -501,8 +503,7 @@ mod test { use proptest_recurse::{StrategyExt, StrategySet}; - use super::{BaseValue, PVEnum, PartialSum, PartialValue}; - use crate::dataflow::AbstractValue; + use super::{AbstractValue, PVEnum, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { @@ -514,7 +515,7 @@ mod test { #[derive(Clone, Debug, PartialEq, Eq, Hash)] struct TestValue(usize); - impl BaseValue for TestValue {} + impl AbstractValue for TestValue {} #[derive(Clone)] struct SumTypeParams { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 2d652c7a4..66c1c80f5 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -18,13 +18,13 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, BaseValue, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{AbstractValue, DFContext, Machine, PartialValue, TailLoopTermination}; // ------- Minimal implementation of DFContext and BaseValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum Void {} -impl BaseValue for Void {} +impl AbstractValue for Void {} struct TestContext(Arc); @@ -68,7 +68,7 @@ impl PartialOrd for TestContext { } } -impl DFContext> for TestContext { +impl DFContext for TestContext { fn interpret_leaf_op( &self, _node: hugr_core::Node, From dc08f0d221cc67516731bfaa45f5db190a69f1d4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 11:03:33 +0100 Subject: [PATCH 116/203] (Re-)remove PVEnum --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/partial_value.rs | 153 ++++++++++------------ 2 files changed, 72 insertions(+), 83 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 5cf5d91eb..f786d62c7 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -7,7 +7,7 @@ mod machine; pub use machine::{Machine, TailLoopTermination}; mod partial_value; -pub use partial_value::{AbstractValue, PVEnum, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; use hugr_core::{Hugr, Node}; use std::hash::Hash; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 880a30241..d0f7c7ef4 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -234,26 +234,13 @@ impl Hash for PartialSum { /// for use in dataflow analysis, including that an instance may be a [PartialSum] /// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub struct PartialValue(PVEnum); - -impl PartialValue { - /// Allows to read the enum, which guarantees that we never return [PVEnum::Value] - /// for a value whose [AbstractValue::as_sum] is `Some` - any such value will be - /// in the form of a [PVEnum::Sum] instead. - pub fn as_enum(&self) -> &PVEnum { - &self.0 - } -} - -/// The contents of a [PartialValue], i.e. used as a view. -#[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PVEnum { +pub enum PartialValue { /// No possibilities known (so far) Bottom, /// A single value (of the underlying representation) Value(V), - /// Sum (with perhaps several possible tags) of underlying values - Sum(PartialSum), + /// Sum (with at least one, perhaps several, possible tags) of underlying values + PartialSum(PartialSum), /// Might be more than one distinct value of the underlying type `V` Top, } @@ -262,23 +249,23 @@ impl From for PartialValue { fn from(v: V) -> Self { v.as_sum() .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) - .unwrap_or(Self(PVEnum::Value(v))) + .unwrap_or(Self::Value(v)) } } impl From> for PartialValue { fn from(v: PartialSum) -> Self { - Self(PVEnum::Sum(v)) + Self::PartialSum(v) } } impl PartialValue { fn assert_invariants(&self) { - match &self.0 { - PVEnum::Sum(ps) => { + match self { + Self::PartialSum(ps) => { ps.assert_invariants(); } - PVEnum::Value(v) => { + Self::Value(v) => { assert!(v.as_sum().is_none()) } _ => {} @@ -301,14 +288,14 @@ impl PartialValue { /// /// if the value is believed, for that tag, to have a number of values other than `len` pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { - let vals = match &self.0 { - PVEnum::Bottom => return None, - PVEnum::Value(v) => { + let vals = match self { + PartialValue::Bottom => return None, + PartialValue::Value(v) => { assert!(v.as_sum().is_none()); return None; } - PVEnum::Sum(ps) => ps.variant_values(tag)?, - PVEnum::Top => vec![PartialValue(PVEnum::Top); len], + PartialValue::PartialSum(ps) => ps.variant_values(tag)?, + PartialValue::Top => vec![PartialValue::Top; len], }; assert_eq!(vals.len(), len); Some(vals) @@ -316,14 +303,14 @@ impl PartialValue { /// Tells us whether this value might be a Sum with the specified `tag` pub fn supports_tag(&self, tag: usize) -> bool { - match &self.0 { - PVEnum::Bottom => false, - PVEnum::Value(v) => { + match self { + PartialValue::Bottom => false, + PartialValue::Value(v) => { assert!(v.as_sum().is_none()); false } - PVEnum::Sum(ps) => ps.supports_tag(tag), - PVEnum::Top => true, + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, } } @@ -333,9 +320,9 @@ impl PartialValue { self, typ: &Type, ) -> Result>>::Error>> { - match self.0 { - PVEnum::Value(v) => Ok(V2::from(v.clone())), - PVEnum::Sum(ps) => { + match self { + Self::Value(v) => Ok(V2::from(v.clone())), + Self::PartialSum(ps) => { let v = ps.try_into_value(typ).map_err(|_| None)?; V2::try_from(v).map_err(Some) } @@ -356,8 +343,9 @@ impl PartialValue where Value: From, { - /// Turns this instance into a [Value], if it is either a single [value](PVEnum::Value) or - /// a [sum](PVEnum::Sum) with a single known tag, extracting the desired type from a HugrView and Wire. + /// Turns this instance into a [Value], if it is either a single [Value](Self::Value) or + /// a [Sum](PartialValue::PartialSum) with a single known tag, extracting the desired type + /// from a HugrView and Wire. /// /// # Errors /// `None` if the analysis did not result in a single value on that wire @@ -383,40 +371,41 @@ impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); // println!("join {self:?}\n{:?}", &other); - match (&self.0, other.0) { - (PVEnum::Top, _) => false, - (_, other @ PVEnum::Top) => { - self.0 = other; + match (&*self, other) { + (Self::Top, _) => false, + (_, other @ Self::Top) => { + *self = other; true } - (_, PVEnum::Bottom) => false, - (PVEnum::Bottom, other) => { - self.0 = other; + (_, Self::Bottom) => false, + (Self::Bottom, other) => { + *self = other; true } - (PVEnum::Value(h1), PVEnum::Value(h2)) => { + (Self::Value(h1), Self::Value(h2)) => { if h1 == &h2 { false } else { - self.0 = PVEnum::Top; + *self = Self::Top; true } } - (PVEnum::Sum(_), PVEnum::Sum(ps2)) => { - let Self(PVEnum::Sum(ps1)) = self else { + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { unreachable!() }; match ps1.try_join_mut(ps2) { Ok(ch) => ch, Err(_) => { - self.0 = PVEnum::Top; + *self = Self::Top; true } } } - (PVEnum::Value(ref v), PVEnum::Sum(_)) | (PVEnum::Sum(_), PVEnum::Value(ref v)) => { + (Self::Value(ref v), Self::PartialSum(_)) + | (Self::PartialSum(_), Self::Value(ref v)) => { assert!(v.as_sum().is_none()); - self.0 = PVEnum::Top; + *self = Self::Top; true } } @@ -424,41 +413,41 @@ impl Lattice for PartialValue { fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - match (&self.0, other.0) { - (PVEnum::Bottom, _) => false, - (_, other @ PVEnum::Bottom) => { - self.0 = other; + match (&*self, other) { + (Self::Bottom, _) => false, + (_, other @ Self::Bottom) => { + *self = other; true } - (_, PVEnum::Top) => false, - (PVEnum::Top, other) => { - self.0 = other; + (_, Self::Top) => false, + (Self::Top, other) => { + *self = other; true } - (PVEnum::Value(h1), PVEnum::Value(h2)) => { + (Self::Value(h1), Self::Value(h2)) => { if h1 == &h2 { false } else { - self.0 = PVEnum::Bottom; + *self = Self::Bottom; true } } - (PVEnum::Sum(_), PVEnum::Sum(ps2)) => { - let ps1 = match &mut self.0 { - PVEnum::Sum(ps1) => ps1, - _ => unreachable!(), + (Self::PartialSum(_), Self::PartialSum(ps2)) => { + let Self::PartialSum(ps1) = self else { + unreachable!() }; match ps1.try_meet_mut(ps2) { Ok(ch) => ch, Err(_) => { - self.0 = PVEnum::Bottom; + *self = Self::Bottom; true } } } - (PVEnum::Value(ref v), PVEnum::Sum(_)) | (PVEnum::Sum(_), PVEnum::Value(ref v)) => { + (Self::Value(ref v), Self::PartialSum(_)) + | (Self::PartialSum(_), Self::Value(ref v)) => { assert!(v.as_sum().is_none()); - self.0 = PVEnum::Bottom; + *self = Self::Bottom; true } } @@ -467,26 +456,26 @@ impl Lattice for PartialValue { impl BoundedLattice for PartialValue { fn top() -> Self { - Self(PVEnum::Top) + Self::Top } fn bottom() -> Self { - Self(PVEnum::Bottom) + Self::Bottom } } impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; - match (&self.0, &other.0) { - (PVEnum::Bottom, PVEnum::Bottom) => Some(Ordering::Equal), - (PVEnum::Top, PVEnum::Top) => Some(Ordering::Equal), - (PVEnum::Bottom, _) => Some(Ordering::Less), - (_, PVEnum::Bottom) => Some(Ordering::Greater), - (PVEnum::Top, _) => Some(Ordering::Greater), - (_, PVEnum::Top) => Some(Ordering::Less), - (PVEnum::Value(v1), PVEnum::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), - (PVEnum::Sum(ps1), PVEnum::Sum(ps2)) => ps1.partial_cmp(ps2), + match (self, other) { + (Self::Bottom, Self::Bottom) => Some(Ordering::Equal), + (Self::Top, Self::Top) => Some(Ordering::Equal), + (Self::Bottom, _) => Some(Ordering::Less), + (_, Self::Bottom) => Some(Ordering::Greater), + (Self::Top, _) => Some(Ordering::Greater), + (_, Self::Top) => Some(Ordering::Less), + (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } } @@ -503,7 +492,7 @@ mod test { use proptest_recurse::{StrategyExt, StrategySet}; - use super::{AbstractValue, PVEnum, PartialSum, PartialValue}; + use super::{AbstractValue, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { @@ -536,11 +525,11 @@ mod test { impl TestSumType { fn check_value(&self, pv: &PartialValue) -> bool { - match (self, pv.as_enum()) { - (_, PVEnum::Bottom) | (_, PVEnum::Top) => true, + match (self, pv) { + (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), - (Self::Leaf(Some(max)), PVEnum::Value(TestValue(val))) => val <= max, - (Self::Branch(sop), PVEnum::Sum(ps)) => { + (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::Branch(sop), PartialValue::PartialSum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { return false; From 436dcd277cf367c0c918e39313414f0f44a9b2d7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 11:19:49 +0100 Subject: [PATCH 117/203] Remove as_sum. AbstractValues are elements not sums --- hugr-passes/src/dataflow/partial_value.rs | 83 ++++++++++------------- 1 file changed, 37 insertions(+), 46 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index d0f7c7ef4..0c2f80ca4 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -11,22 +11,24 @@ use std::hash::{Hash, Hasher}; /// Trait for an underlying domain of abstract values which can form the *elements* of a /// [PartialValue] and thus be used in dataflow analysis. pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { - /// If the abstract value represents a [Sum] with a single known tag, deconstruct it - /// into that tag plus the elements. The default just returns `None` which is - /// appropriate if the abstract value never does (in which case [interpret_leaf_op] - /// must produce a [PartialValue::new_variant] for any operation producing - /// a sum). + /// Computes the join of two values (i.e. towards `Top``), if this is representable + /// within the underlying domain. + /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Top]). /// - /// The signature is this way to optimize query/inspection (is-it-a-sum), - /// at the cost of requiring more cloning during actual conversion - /// (inside the lazy Iterator, or for the error case, as Self remains) + /// The default checks equality between `self` and `other` and returns `self` if + /// the two are identical, otherwise `None`. + fn try_join(self, other: Self) -> Option { + (self == other).then_some(self) + } + + /// Computes the meet of two values (i.e. towards `Bottom`), if this is representable + /// within the underlying domain. + /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Bottom]). /// - /// [interpret_leaf_op]: super::DFContext::interpret_leaf_op - /// [Sum]: TypeEnum::Sum - /// [Tag]: hugr_core::ops::Tag - fn as_sum(&self) -> Option<(usize, impl Iterator + '_)> { - let res: Option<(usize, as IntoIterator>::IntoIter)> = None; - res + /// The default checks equality between `self` and `other` and returns `self` if + /// the two are identical, otherwise `None`. + fn try_meet(self, other: Self) -> Option { + (self == other).then_some(self) } } @@ -247,9 +249,7 @@ pub enum PartialValue { impl From for PartialValue { fn from(v: V) -> Self { - v.as_sum() - .map(|(tag, values)| Self::new_variant(tag, values.map(Self::from))) - .unwrap_or(Self::Value(v)) + Self::Value(v) } } @@ -265,9 +265,6 @@ impl PartialValue { Self::PartialSum(ps) => { ps.assert_invariants(); } - Self::Value(v) => { - assert!(v.as_sum().is_none()) - } _ => {} } } @@ -289,11 +286,7 @@ impl PartialValue { /// if the value is believed, for that tag, to have a number of values other than `len` pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { - PartialValue::Bottom => return None, - PartialValue::Value(v) => { - assert!(v.as_sum().is_none()); - return None; - } + PartialValue::Bottom | PartialValue::Value(_) => return None, PartialValue::PartialSum(ps) => ps.variant_values(tag)?, PartialValue::Top => vec![PartialValue::Top; len], }; @@ -304,11 +297,7 @@ impl PartialValue { /// Tells us whether this value might be a Sum with the specified `tag` pub fn supports_tag(&self, tag: usize) -> bool { match self { - PartialValue::Bottom => false, - PartialValue::Value(v) => { - assert!(v.as_sum().is_none()); - false - } + PartialValue::Bottom | PartialValue::Value(_) => false, PartialValue::PartialSum(ps) => ps.supports_tag(tag), PartialValue::Top => true, } @@ -382,14 +371,17 @@ impl Lattice for PartialValue { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => { - if h1 == &h2 { - false - } else { + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { + Some(h3) => { + let ch = h3 != *h1; + *self = Self::Value(h3); + ch + } + None => { *self = Self::Top; true } - } + }, (Self::PartialSum(_), Self::PartialSum(ps2)) => { let Self::PartialSum(ps1) = self else { unreachable!() @@ -402,9 +394,7 @@ impl Lattice for PartialValue { } } } - (Self::Value(ref v), Self::PartialSum(_)) - | (Self::PartialSum(_), Self::Value(ref v)) => { - assert!(v.as_sum().is_none()); + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { *self = Self::Top; true } @@ -424,14 +414,17 @@ impl Lattice for PartialValue { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => { - if h1 == &h2 { - false - } else { + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_meet(h2) { + Some(h3) => { + let ch = h3 != *h1; + *self = Self::Value(h3); + ch + } + None => { *self = Self::Bottom; true } - } + }, (Self::PartialSum(_), Self::PartialSum(ps2)) => { let Self::PartialSum(ps1) = self else { unreachable!() @@ -444,9 +437,7 @@ impl Lattice for PartialValue { } } } - (Self::Value(ref v), Self::PartialSum(_)) - | (Self::PartialSum(_), Self::Value(ref v)) => { - assert!(v.as_sum().is_none()); + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { *self = Self::Bottom; true } From e817bbea5c897e3878b38d88ada8fd1e5c125ffb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 14:44:04 +0100 Subject: [PATCH 118/203] clippy --- hugr-passes/src/dataflow/partial_value.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 0c2f80ca4..bdf774f2c 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -261,11 +261,8 @@ impl From> for PartialValue { impl PartialValue { fn assert_invariants(&self) { - match self { - Self::PartialSum(ps) => { - ps.assert_invariants(); - } - _ => {} + if let Self::PartialSum(ps) = self { + ps.assert_invariants(); } } From b06cfad2da7ec68efca1d03c66daacc4c1af7808 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 14:00:10 +0100 Subject: [PATCH 119/203] Refactor: remove 'fn input_count' --- hugr-passes/src/dataflow/datalog.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 4ab12e380..b1c649252 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -15,7 +15,6 @@ use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::OpType; -use hugr_core::types::Signature; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; use super::{AbstractValue, DFContext, PartialValue}; @@ -65,8 +64,8 @@ ascent::ascent! { out_wire_value(c, m, op, v); - node_in_value_row(c, n, ValueRow::new(input_count(c.as_ref(), *n))) <-- node(c, n); - node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input.len(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); + node_in_value_row(c, n, ValueRow::new(sig.input_count())) <-- node(c, n), if let Some(sig) = c.signature(*n); + node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); out_wire_value(c, n, p, v) <-- node(c, n), @@ -203,13 +202,6 @@ fn propagate_leaf_op( } } -fn input_count(h: &impl HugrView, n: Node) -> usize { - h.signature(n) - .as_ref() - .map(Signature::input_count) - .unwrap_or(0) -} - fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { h.in_value_types(n).map(|x| x.0) } From fd717be5a2f0f7d9ffcf7976316fd8476057cd5f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 14:24:10 +0100 Subject: [PATCH 120/203] Try to fix interpret_leaf_op: cannot use Bottom for output! But ascent borrowing --- hugr-passes/src/dataflow.rs | 14 +++++++++----- hugr-passes/src/dataflow/datalog.rs | 28 +++++++++++++++++++++------- hugr-passes/src/dataflow/test.rs | 10 +--------- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f786d62c7..51874136c 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -15,13 +15,17 @@ use std::hash::Hash; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { - /// Given lattice values for each input, produce lattice values for (what we know of) - /// the outputs. Returning `None` indicates nothing can be deduced. + /// Given lattice values for each input, update lattice values for the (dataflow) outputs. + /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] + /// which is the correct value to leave if nothing can be deduced about that output. + /// (The default does nothing, i.e. leaves `Top` for all outputs.) fn interpret_leaf_op( &self, - node: Node, - ins: &[PartialValue], - ) -> Option>>; + _node: Node, + _ins: &[PartialValue], + _outs: &mut [PartialValue], + ) { + } } #[cfg(test)] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b1c649252..70a404018 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -14,7 +14,7 @@ use std::hash::Hash; use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::OpType; +use hugr_core::ops::{OpTrait, OpType}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; use super::{AbstractValue, DFContext, PartialValue}; @@ -69,9 +69,11 @@ ascent::ascent! { out_wire_value(c, n, p, v) <-- node(c, n), - if !c.get_optype(*n).is_container(), + let op_t = c.get_optype(*n), + if !op_t.is_container(), + if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(c, n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, &vs[..]), + if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count(), &self.out_wire_value_proto[..]), for (p,v) in (0..).map(OutgoingPort::from).zip(outs); // DFG @@ -176,6 +178,8 @@ fn propagate_leaf_op( c: &impl DFContext, n: Node, ins: &[PV], + num_outs: usize, + out_wire_proto: &[(Node, OutgoingPort, PV)], ) -> Option> { match c.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow @@ -195,10 +199,20 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) => None, // handled by parent - // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, - // thus keeping PartialValue hidden, but AbstractValues - // are not necessarily convertible to Value! - _ => c.interpret_leaf_op(n, ins).map(ValueRow::from_iter), + _ => { + // Interpret op. Default/worst-case is that we can't deduce anything about any + // output (just `Top`). + let mut outs = vec![PartialValue::Top; num_outs]; + // However, we may have been told better outcomes: + for (_, p, v) in out_wire_proto.iter().filter(|(n2, _, _)| n == n2) { + outs[p.index()] = v.clone() + } + // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, + // thus keeping PartialValue hidden, but AbstractValues + // are not necessarily convertible to Value! + c.interpret_leaf_op(n, ins, &mut outs[..]); + Some(ValueRow::from_iter(outs)) + } } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 66c1c80f5..455554891 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -68,15 +68,7 @@ impl PartialOrd for TestContext { } } -impl DFContext for TestContext { - fn interpret_leaf_op( - &self, - _node: hugr_core::Node, - _ins: &[PartialValue], - ) -> Option>> { - None - } -} +impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) impl From for Value { From 9b174393ef81ea17929d0465b8a6b9b45f8974e3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 14 Oct 2024 14:42:33 +0100 Subject: [PATCH 121/203] interpret_leaf_op for ExtensionOps only; LoadConstant via value_from_(custom_)const(_hugr) --- hugr-passes/src/dataflow.rs | 59 +++++++++++++++++++++++ hugr-passes/src/dataflow/datalog.rs | 33 ++++++++----- hugr-passes/src/dataflow/partial_value.rs | 10 ++-- hugr-passes/src/dataflow/test.rs | 16 ++---- 4 files changed, 91 insertions(+), 27 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 51874136c..370b4d643 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -4,11 +4,13 @@ mod datalog; mod machine; +use hugr_core::ops::constant::CustomConst; pub use machine::{Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; +use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::{Hugr, Node}; use std::hash::Hash; @@ -16,16 +18,73 @@ use std::hash::Hash; /// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. + /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] /// which is the correct value to leave if nothing can be deduced about that output. /// (The default does nothing, i.e. leaves `Top` for all outputs.) + /// + /// [MakeTuple]: hugr_core::extension::prelude::MakeTuple + /// [UnpackTuple]: hugr_core::extension::prelude::UnpackTuple fn interpret_leaf_op( &self, _node: Node, + _e: &ExtensionOp, _ins: &[PartialValue], _outs: &mut [PartialValue], ) { } + + /// Produces an abstract value from a constant. The default impl + /// traverses the constant [Value] to its leaves ([Value::Extension] and [Value::Function]), + /// converts these using [Self::value_from_custom_const] and [Self::value_from_const_hugr], + /// and builds nested [PartialValue::new_variant] to represent the structure. + fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { + traverse_value(self, n, &mut Vec::new(), cst) + } + + /// Produces an abstract value from a [CustomConst], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_custom_const( + &self, + _node: Node, + _fields: &[usize], + _cc: &dyn CustomConst, + ) -> Option { + None + } + + /// Produces an abstract value from a Hugr in a [Value::Function], if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_const_hugr(&self, _node: Node, _fields: &[usize], _h: &Hugr) -> Option { + None + } +} + +fn traverse_value( + s: &impl DFContext, + n: Node, + fields: &mut Vec, + cst: &Value, +) -> PartialValue { + match cst { + Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { + let elems = values.iter().enumerate().map(|(idx, elem)| { + fields.push(idx); + let r = traverse_value(s, n, fields, elem); + fields.pop(); + r + }); + PartialValue::new_variant(*tag, elems) + } + Value::Extension { e } => s + .value_from_custom_const(n, fields, e.value()) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + Value::Function { hugr } => s + .value_from_const_hugr(n, fields, &**hugr) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + } } #[cfg(test)] diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 70a404018..d88e5cc98 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -73,7 +73,7 @@ ascent::ascent! { if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(c, n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count(), &self.out_wire_value_proto[..]), + if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count()), for (p,v) in (0..).map(OutgoingPort::from).zip(outs); // DFG @@ -179,7 +179,6 @@ fn propagate_leaf_op( n: Node, ins: &[PV], num_outs: usize, - out_wire_proto: &[(Node, OutgoingPort, PV)], ) -> Option> { match c.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow @@ -198,21 +197,31 @@ fn propagate_leaf_op( t.tag, ins.iter().cloned(), )])), - OpType::Input(_) | OpType::Output(_) => None, // handled by parent - _ => { - // Interpret op. Default/worst-case is that we can't deduce anything about any - // output (just `Top`). + OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent + OpType::Const(_) => None, // handled by LoadConstant: + OpType::LoadConstant(load_op) => { + assert!(ins.is_empty()); // static edge, so need to find constant + let const_node = c + .single_linked_output(n, load_op.constant_port()) + .unwrap() + .0; + let const_val = c.get_optype(const_node).as_const().unwrap().value(); + Some(ValueRow::single_known( + 1, + 0, + c.value_from_const(n, const_val), + )) + } + OpType::ExtensionOp(e) => { + // Interpret op. Default is we know nothing about the outputs (they still happen!) let mut outs = vec![PartialValue::Top; num_outs]; - // However, we may have been told better outcomes: - for (_, p, v) in out_wire_proto.iter().filter(|(n2, _, _)| n == n2) { - outs[p.index()] = v.clone() - } // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues - // are not necessarily convertible to Value! - c.interpret_leaf_op(n, ins, &mut outs[..]); + // are not necessarily convertible to Value. + c.interpret_leaf_op(n, e, ins, &mut outs[..]); Some(ValueRow::from_iter(outs)) } + o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index bdf774f2c..d3010ae45 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -63,16 +63,16 @@ impl PartialSum { pub fn num_variants(&self) -> usize { self.0.len() } -} -impl PartialSum { fn assert_invariants(&self) { assert_ne!(self.num_variants(), 0); for pv in self.0.values().flat_map(|x| x.iter()) { pv.assert_invariants(); } } +} +impl PartialSum { /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns /// whether `self` has changed. /// @@ -247,7 +247,7 @@ pub enum PartialValue { Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { Self::Value(v) } @@ -259,7 +259,7 @@ impl From> for PartialValue { } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { if let Self::PartialSum(ps) = self { ps.assert_invariants(); @@ -275,7 +275,9 @@ impl PartialValue { pub fn new_unit() -> Self { Self::new_variant(0, []) } +} +impl PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 455554891..a72ade121 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -98,7 +98,6 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(v1, pv_false()), (v2, pv_true())]); machine.run(TestContext(Arc::new(&hugr))); let x = machine @@ -120,7 +119,6 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(v, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o1_r = machine @@ -153,7 +151,6 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(r_w, PartialValue::new_variant(3, []))]); machine.run(TestContext(Arc::new(&hugr))); let o_r = machine @@ -227,7 +224,6 @@ fn test_tail_loop_two_iters() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(true_w, pv_true()), (false_w, pv_false())]); machine.run(TestContext(Arc::new(&hugr))); let o_r1 = machine.read_out_wire(o_w1).unwrap(); @@ -295,7 +291,6 @@ fn test_tail_loop_containing_conditional() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.propolutate_out_wires([(init, PartialValue::new_variant(0, [pv_false(), pv_true()]))]); machine.run(TestContext(Arc::new(&hugr))); let o_r1 = machine.read_out_wire(o_w1).unwrap(); @@ -373,11 +368,10 @@ fn conditional() { // Tuple of // 1. Hugr being a function on bools: (b,c) => !b XOR c // 2. Input node of entry block -// 3. Wire out from "True" constant // Result readable from root node outputs // Inputs should be placed onto out-wires of the Node (2.) #[fixture] -fn xnor_cfg() -> (Hugr, Node, Wire) { +fn xnor_cfg() -> (Hugr, Node) { // Entry // /0 1\ // A --1-> B @@ -462,7 +456,7 @@ fn xnor_cfg() -> (Hugr, Node, Wire) { builder.branch(&b, 0, &x).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [entry_input, _] = hugr.get_io(entry.node()).unwrap(); - (hugr, entry_input, true_w) + (hugr, entry_input) } #[rstest] @@ -480,14 +474,14 @@ fn test_cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, #[case] outp: PartialValue, - xnor_cfg: (Hugr, Node, Wire), + xnor_cfg: (Hugr, Node), ) { - let (hugr, entry_input, true_w) = xnor_cfg; + let (hugr, entry_input) = xnor_cfg; let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); let mut machine = Machine::default(); - machine.propolutate_out_wires([(in_w0, inp0), (in_w1, inp1), (true_w, pv_true())]); + machine.propolutate_out_wires([(in_w0, inp0), (in_w1, inp1)]); machine.run(TestContext(Arc::new(&hugr))); assert_eq!( From 846d1ee57fd057a20e4f522dd7e265956cd89572 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 09:39:51 +0100 Subject: [PATCH 122/203] Correct comment BaseValue -> AbstractValue --- hugr-passes/src/dataflow/test.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index a72ade121..56da25e5b 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -20,7 +20,7 @@ use rstest::{fixture, rstest}; use super::{AbstractValue, DFContext, Machine, PartialValue, TailLoopTermination}; -// ------- Minimal implementation of DFContext and BaseValue ------- +// ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum Void {} From 328e7f805e31f840fe90c89b13ba734c5c4a933d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 09:25:20 +0100 Subject: [PATCH 123/203] test Hugr now returns (XOR, AND) of two inputs, one case wrongly producing T|F --- hugr-passes/src/dataflow/test.rs | 116 +++++++++++++++++-------------- 1 file changed, 62 insertions(+), 54 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 56da25e5b..9ac49e6d1 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::CFGBuilder; +use hugr_core::builder::{CFGBuilder, Container}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -366,86 +366,88 @@ fn conditional() { } // Tuple of -// 1. Hugr being a function on bools: (b,c) => !b XOR c +// 1. Hugr being a function on bools: (x, y) => (x XOR y, x AND y) // 2. Input node of entry block // Result readable from root node outputs // Inputs should be placed onto out-wires of the Node (2.) #[fixture] -fn xnor_cfg() -> (Hugr, Node) { +fn xor_and_cfg() -> (Hugr, Node) { // Entry // /0 1\ - // A --1-> B - // \0 / + // A --1-> B A(x=true, y) => if y then X(false, true) else B(x=true) + // \0 / B(z) => X(z,false) // > X < - let mut builder = CFGBuilder::new(Signature::new(type_row![BOOL_T;2], BOOL_T)).unwrap(); - - // entry (i, j) => if i {B(j)} else {A(j, i, true)}, note that (j, i, true) == (j, false, true) - let entry_outs = [type_row![BOOL_T;3], type_row![BOOL_T]]; + let mut builder = + CFGBuilder::new(Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T; 2])).unwrap(); + let false_c = builder.add_constant(Value::false_val()); + // entry (x, y) => if x {A(y, x=true)} else B(y)} + let entry_outs = [type_row![BOOL_T;2], type_row![BOOL_T]]; let mut entry = builder .entry_builder(entry_outs.clone(), type_row![]) .unwrap(); - let [in_i, in_j] = entry.input_wires_arr(); + let [in_x, in_y] = entry.input_wires_arr(); let mut cond = entry .conditional_builder( - (vec![type_row![]; 2], in_i), + (vec![type_row![]; 2], in_x), [], Type::new_sum(entry_outs.clone()).into(), ) .unwrap(); - let mut if_i_true = cond.case_builder(1).unwrap(); - let br_to_b = if_i_true - .add_dataflow_op(Tag::new(1, entry_outs.to_vec()), [in_j]) + let mut if_x_true = cond.case_builder(1).unwrap(); + let br_to_a = if_x_true + .add_dataflow_op(Tag::new(0, entry_outs.to_vec()), [in_y, in_x]) .unwrap(); - if_i_true.finish_with_outputs(br_to_b.outputs()).unwrap(); - let mut if_i_false = cond.case_builder(0).unwrap(); - let true_w = if_i_false.add_load_value(Value::true_val()); - let br_to_a = if_i_false - .add_dataflow_op(Tag::new(0, entry_outs.into()), [in_j, in_i, true_w]) + if_x_true.finish_with_outputs(br_to_a.outputs()).unwrap(); + let mut if_x_false = cond.case_builder(0).unwrap(); + let br_to_b = if_x_false + .add_dataflow_op(Tag::new(1, entry_outs.into()), [in_y]) .unwrap(); - if_i_false.finish_with_outputs(br_to_a.outputs()).unwrap(); + if_x_false.finish_with_outputs(br_to_b.outputs()).unwrap(); let [res] = cond.finish_sub_container().unwrap().outputs_arr(); let entry = entry.finish_with_outputs(res, []).unwrap(); - // A(w, y, z) => if w {B(y)} else {X(z)} - let a_outs = vec![type_row![BOOL_T]; 2]; + // A(y, z always true) => if y {X(false, z)} else {B(z)} + let a_outs = vec![type_row![BOOL_T], type_row![]]; let mut a = builder .block_builder( - type_row![BOOL_T; 3], - vec![type_row![BOOL_T]; 2], - type_row![], + type_row![BOOL_T; 2], + a_outs.clone(), + type_row![BOOL_T], // Trailing z common to both branches ) .unwrap(); - let [in_w, in_y, in_z] = a.input_wires_arr(); + let [in_y, in_z] = a.input_wires_arr(); + let mut cond = a .conditional_builder( - (vec![type_row![]; 2], in_w), + (vec![type_row![]; 2], in_y), [], Type::new_sum(a_outs.clone()).into(), ) .unwrap(); - let mut if_w_true = cond.case_builder(1).unwrap(); - let br_to_b = if_w_true - .add_dataflow_op(Tag::new(1, a_outs.clone()), [in_y]) + let mut if_y_true = cond.case_builder(1).unwrap(); + let false_w1 = if_y_true.load_const(&false_c); + let br_to_x = if_y_true + .add_dataflow_op(Tag::new(0, a_outs.clone()), [false_w1]) .unwrap(); - if_w_true.finish_with_outputs(br_to_b.outputs()).unwrap(); - let mut if_w_false = cond.case_builder(0).unwrap(); - let br_to_x = if_w_false - .add_dataflow_op(Tag::new(0, a_outs), [in_z]) - .unwrap(); - if_w_false.finish_with_outputs(br_to_x.outputs()).unwrap(); + if_y_true.finish_with_outputs(br_to_x.outputs()).unwrap(); + let mut if_y_false = cond.case_builder(0).unwrap(); + let br_to_b = if_y_false.add_dataflow_op(Tag::new(1, a_outs), []).unwrap(); + if_y_false.finish_with_outputs(br_to_b.outputs()).unwrap(); let [res] = cond.finish_sub_container().unwrap().outputs_arr(); - let a = a.finish_with_outputs(res, []).unwrap(); + let a = a.finish_with_outputs(res, [in_z]).unwrap(); - // B(v) => X(v) + // B(v) => X(v, false) let mut b = builder - .block_builder(type_row![BOOL_T], [type_row![BOOL_T]], type_row![]) + .block_builder(type_row![BOOL_T], [type_row![]], type_row![BOOL_T; 2]) .unwrap(); + let [in_v] = b.input_wires_arr(); + let false_w2 = b.load_const(&false_c); let [control] = b - .add_dataflow_op(Tag::new(0, vec![type_row![BOOL_T]]), b.input_wires()) + .add_dataflow_op(Tag::new(0, vec![type_row![]]), []) .unwrap() .outputs_arr(); - let b = b.finish_with_outputs(control, []).unwrap(); + let b = b.finish_with_outputs(control, [in_v, false_w2]).unwrap(); let x = builder.exit_block(); @@ -460,23 +462,25 @@ fn xnor_cfg() -> (Hugr, Node) { } #[rstest] -#[case(pv_true(), pv_true(), pv_true())] -#[case(pv_true(), pv_false(), pv_false())] -#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] -#[case(pv_false(), pv_true(), pv_false())] -#[case(pv_false(), pv_false(), pv_true())] -#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] -#[case(PartialValue::top(), pv_true(), PartialValue::top())] // Ideally, result should be true_or_false -#[case(PartialValue::top(), pv_false(), pv_true_or_false())] +#[should_panic] // first case failing +#[case(pv_true(), pv_true(), pv_false(), pv_true())] +#[case(pv_true(), pv_false(), pv_true(), pv_false())] +//#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_false(), pv_true(), pv_true(), pv_false())] +#[case(pv_false(), pv_false(), pv_false(), pv_false())] +/*#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] +#[case(PartialValue::top(), pv_true(), pv_true_or_false())] +#[case(PartialValue::top(), pv_false(), PartialValue::top())] // Ideally pv_true_or_false #[case(pv_true_or_false(), pv_true(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false())]*/ fn test_cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, - #[case] outp: PartialValue, - xnor_cfg: (Hugr, Node), + #[case] out0: PartialValue, + #[case] out1: PartialValue, + xor_and_cfg: (Hugr, Node), ) { - let (hugr, entry_input) = xnor_cfg; + let (hugr, entry_input) = xor_and_cfg; let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); @@ -486,6 +490,10 @@ fn test_cfg( assert_eq!( machine.read_out_wire(Wire::new(hugr.root(), 0)).unwrap(), - outp + out0 + ); + assert_eq!( + machine.read_out_wire(Wire::new(hugr.root(), 1)).unwrap(), + out1 ); } From da3c05c0594c56893153697842ceae4d33b1b4ef Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Oct 2024 16:00:18 +0100 Subject: [PATCH 124/203] BB reachability, fixes! --- hugr-passes/src/dataflow/datalog.rs | 13 ++++++++++++- hugr-passes/src/dataflow/machine.rs | 18 ++++++++++++++++++ hugr-passes/src/dataflow/test.rs | 1 - 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index d88e5cc98..f3c4868a8 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -150,6 +150,16 @@ ascent::ascent! { cfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_cfg(); dfb_block(c,cfg,blk) <-- cfg_node(c, cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); + // Reachability + relation bb_reachable(C, Node, Node); + bb_reachable(c, cfg, entry) <-- cfg_node(c, cfg), if let Some(entry) = c.children(*cfg).next(); + bb_reachable(c, cfg, bb) <-- cfg_node(c, cfg), + bb_reachable(c, cfg, pred), + io_node(c, pred, pred_out, IO::Output), + in_wire_value(c, pred_out, IncomingPort::from(0), predicate), + for (tag, bb) in c.output_neighbours(*pred).enumerate(), + if predicate.supports_tag(tag); + // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(C, Node, Node, Node); _cfg_succ_dest(c, cfg, blk, inp) <-- dfb_block(c, cfg, blk), io_node(c, blk, inp, IO::Input); @@ -162,9 +172,10 @@ ascent::ascent! { io_node(c, entry, i_node, IO::Input), in_wire_value(c, cfg, p, v); - // Outputs of each block propagated to successor blocks or (if exit block) then CFG itself + // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself out_wire_value(c, dest, OutgoingPort::from(out_p), v) <-- dfb_block(c, cfg, pred), + bb_reachable(c, cfg, pred), let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), io_node(c, pred, out_n, IO::Output), diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 15262d4db..44744b5ee 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -99,6 +99,24 @@ impl> Machine { .unwrap(), ) } + + /// Tells us if a block ([DataflowBlock] or [ExitBlock]) in a [CFG] is known + /// to be reachable. (Returns `None` if argument is not a child of a CFG.) + pub fn bb_reachable(&self, hugr: impl HugrView, bb: Node) -> Option { + let cfg = hugr.get_parent(bb)?; // Not really required...?? + hugr.get_optype(cfg).as_cfg()?; + let t = hugr.get_optype(bb); + if !t.is_dataflow_block() && !t.is_exit_block() { + return None; + }; + Some( + self.0 + .bb_reachable + .iter() + .find(|(_, cfg2, bb2)| *cfg2 == cfg && *bb2 == bb) + .is_some(), + ) + } } /// Tells whether a loop iterates (never, always, sometimes) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 9ac49e6d1..7777cd3f7 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -462,7 +462,6 @@ fn xor_and_cfg() -> (Hugr, Node) { } #[rstest] -#[should_panic] // first case failing #[case(pv_true(), pv_true(), pv_false(), pv_true())] #[case(pv_true(), pv_false(), pv_true(), pv_false())] //#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] From 6930ad4358e3293320a62004a865694b39395ada Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 09:37:32 +0100 Subject: [PATCH 125/203] Test cases with true_or_false/top, standardize naming (->test_)conditional --- hugr-passes/src/dataflow/test.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 7777cd3f7..f3739062e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -305,7 +305,7 @@ fn test_tail_loop_containing_conditional() { } #[test] -fn conditional() { +fn test_conditional() { let variants = vec![type_row![], type_row![], type_row![BOOL_T]]; let cond_t = Type::new_sum(variants.clone()); let mut builder = DFGBuilder::new(Signature::new(cond_t, type_row![])).unwrap(); @@ -464,14 +464,16 @@ fn xor_and_cfg() -> (Hugr, Node) { #[rstest] #[case(pv_true(), pv_true(), pv_false(), pv_true())] #[case(pv_true(), pv_false(), pv_true(), pv_false())] -//#[case(pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true(), pv_true_or_false(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true(), PartialValue::Top, pv_true_or_false(), pv_true_or_false())] #[case(pv_false(), pv_true(), pv_true(), pv_false())] #[case(pv_false(), pv_false(), pv_false(), pv_false())] -/*#[case(pv_false(), pv_true_or_false(), pv_true_or_false())] -#[case(PartialValue::top(), pv_true(), pv_true_or_false())] -#[case(PartialValue::top(), pv_false(), PartialValue::top())] // Ideally pv_true_or_false -#[case(pv_true_or_false(), pv_true(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_false(), pv_true_or_false())]*/ +#[case(pv_false(), pv_true_or_false(), pv_true_or_false(), pv_false())] +#[case(pv_false(), PartialValue::Top, PartialValue::Top, pv_false())] // if !inp0 then out0=inp1 +#[case(pv_true_or_false(), pv_true(), pv_true_or_false(), pv_true_or_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false(), pv_false())] +#[case(PartialValue::Top, pv_true(), pv_true_or_false(), PartialValue::Top)] +#[case(PartialValue::Top, pv_false(), PartialValue::Top, pv_false())] fn test_cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, From 0a3e2812e8815108fdc153dabc10605598cfb36f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 11 Oct 2024 19:58:13 +0100 Subject: [PATCH 126/203] Try to common up by using case_reachable in conditional outputs - 6 tests fail --- hugr-passes/src/dataflow/datalog.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index f3c4868a8..68b09ea0e 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -133,9 +133,8 @@ ascent::ascent! { // outputs of case nodes propagate to outputs of conditional *if* case reachable out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- - case_node(c, cond, i, case), - in_wire_value(c, cond, IncomingPort::from(0), control), - if control.supports_tag(*i), + case_node(c, cond, _, case), + case_reachable(c, cond, case, true), io_node(c, case, o, IO::Output), in_wire_value(c, o, o_p, v); From b1e0bfd4b434a4d4891726528c76f7a338c3bdda Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 11 Oct 2024 20:01:09 +0100 Subject: [PATCH 127/203] Make case_reachable a relation (dropping bool), not lattice - fixes tests --- hugr-passes/src/dataflow/datalog.rs | 8 ++++---- hugr-passes/src/dataflow/machine.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 68b09ea0e..eaf33ade1 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -134,14 +134,14 @@ ascent::ascent! { // outputs of case nodes propagate to outputs of conditional *if* case reachable out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- case_node(c, cond, _, case), - case_reachable(c, cond, case, true), + case_reachable(c, cond, case), io_node(c, case, o, IO::Output), in_wire_value(c, o, o_p, v); - lattice case_reachable(C, Node, Node, bool); - case_reachable(c, cond, case, reachable) <-- case_node(c,cond,i,case), + relation case_reachable(C, Node, Node); + case_reachable(c, cond, case) <-- case_node(c,cond,i,case), in_wire_value(c, cond, IncomingPort::from(0), v), - let reachable = v.supports_tag(*i); + if v.supports_tag(*i); // CFG relation cfg_node(C, Node); diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 44744b5ee..b0439c1b2 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -95,8 +95,8 @@ impl> Machine { self.0 .case_reachable .iter() - .find_map(|(_, cond2, case2, i)| (&cond == cond2 && &case == case2).then_some(*i)) - .unwrap(), + .find(|(_, cond2, case2)| &cond == cond2 && &case == case2) + .is_some(), ) } From 355e814f6262acc41f00bf852c1b803382b6883e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 09:59:44 +0100 Subject: [PATCH 128/203] Call (+test) --- hugr-passes/src/dataflow/datalog.rs | 18 +++++++++++++ hugr-passes/src/dataflow/test.rs | 42 ++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index eaf33ade1..5c65d7bca 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -182,6 +182,23 @@ ascent::ascent! { node_in_value_row(c, out_n, out_in_row), if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), for (out_p, v) in fields.enumerate(); + + // Call + relation func_call(C, Node, Node); + func_call(c, call, func_defn) <-- + node(c, call), + if c.get_optype(*call).is_call(), + if let Some(func_defn) = c.static_source(*call); + + out_wire_value(c, inp, OutgoingPort::from(p.index()), v) <-- + func_call(c, call, func), + io_node(c, func, inp, IO::Input), + in_wire_value(c, call, p, v); + + out_wire_value(c, call, OutgoingPort::from(p.index()), v) <-- + func_call(c, call, func), + io_node(c, func, outp, IO::Output), + in_wire_value(c, outp, p, v); } fn propagate_leaf_op( @@ -208,6 +225,7 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent + OpType::Call(_) => None, // handled via Input/Output of FuncDefn OpType::Const(_) => None, // handled by LoadConstant: OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index f3739062e..4e314de49 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container}; +use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -498,3 +498,43 @@ fn test_cfg( out1 ); } + +#[rstest] +#[case(pv_true(), pv_true(), pv_true())] +#[case(pv_false(), pv_false(), pv_false())] +#[case(pv_true(), pv_false(), pv_true_or_false())] // Two calls alias +fn test_call( + #[case] inp0: PartialValue, + #[case] inp1: PartialValue, + #[case] out: PartialValue, +) { + let mut builder = DFGBuilder::new(Signature::new_endo(type_row![BOOL_T; 2])).unwrap(); + let func_bldr = builder + .define_function("id", Signature::new_endo(BOOL_T)) + .unwrap(); + let [v] = func_bldr.input_wires_arr(); + let func_defn = func_bldr.finish_with_outputs([v]).unwrap(); + let [a, b] = builder.input_wires_arr(); + let [a2] = builder + .call(func_defn.handle(), &[], [a], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let [b2] = builder + .call(func_defn.handle(), &[], [b], &EMPTY_REG) + .unwrap() + .outputs_arr(); + let hugr = builder + .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) + .unwrap(); + + let [root_inp, _] = hugr.get_io(hugr.root()).unwrap(); + let [inw0, inw1] = [0, 1].map(|i| Wire::new(root_inp, i)); + let mut machine = Machine::default(); + machine.propolutate_out_wires([(inw0, inp0), (inw1, inp1)]); + machine.run(TestContext(Arc::new(&hugr))); + + let [res0, res1] = [0, 1].map(|i| machine.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); + // The two calls alias so both results will be the same: + assert_eq!(res0, out); + assert_eq!(res1, out); +} From 22f3ce8ecf4888a1ba3de31ae8a1ccbaa16a6ac9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 17:27:32 +0100 Subject: [PATCH 129/203] propolutate_out_wires => prepopulate and set in wires in run --- hugr-passes/src/dataflow/datalog.rs | 16 +++++--- hugr-passes/src/dataflow/machine.rs | 31 ++++++++++------ hugr-passes/src/dataflow/test.rs | 57 +++++++++++++---------------- 3 files changed, 56 insertions(+), 48 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 5c65d7bca..e9e1b92a5 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -30,7 +30,6 @@ pub enum IO { ascent::ascent! { pub(super) struct AscentProgram>; relation context(C); - relation out_wire_value_proto(Node, OutgoingPort, PV); relation node(C, Node); relation in_wire(C, Node, IncomingPort); @@ -38,8 +37,8 @@ ascent::ascent! { relation parent_of_node(C, Node, Node); relation io_node(C, Node, Node, IO); lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); lattice in_wire_value(C, Node, IncomingPort, PV); + lattice node_in_value_row(C, Node, ValueRow); node(c, n) <-- context(c), for n in c.nodes(); @@ -53,16 +52,21 @@ ascent::ascent! { io_node(c, parent, child, io) <-- node(c, parent), if let Some([i,o]) = c.get_io(*parent), for (child,io) in [(i,IO::Input),(o,IO::Output)]; - // We support prepopulating out_wire_value via out_wire_value_proto. - // - // out wires that do not have prepopulation values are initialised to bottom. + + // Initialize all wires to bottom out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); - out_wire_value(c, n, p, v) <-- out_wire(c,n,p) , out_wire_value_proto(n, p, v); in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), if let Some((m,op)) = c.single_linked_output(*n, *ip), out_wire_value(c, m, op, v); + // We support prepopulating in_wire_value via in_wire_value_proto. + relation in_wire_value_proto(Node, IncomingPort, PV); + in_wire_value(c, n, p, PV::bottom()) <-- in_wire(c, n,p); + in_wire_value(c, n, p, v) <-- node(c,n), + if let Some(sig) = c.signature(*n), + for p in sig.input_ports(), + in_wire_value_proto(n, p, v); node_in_value_row(c, n, ValueRow::new(sig.input_count())) <-- node(c, n), if let Some(sig) = c.signature(*n); node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index b0439c1b2..8f5a8d5a7 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use hugr_core::{HugrView, Node, PortIndex, Wire}; +use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; @@ -23,26 +23,35 @@ impl> Default for Machine { impl> Machine { /// Provide initial values for some wires. - /// (For example, if some properties of the Hugr's inputs are known.) - pub fn propolutate_out_wires( - &mut self, - wires: impl IntoIterator)>, - ) { + // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? + pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { assert!(self.1.is_none()); - self.0 - .out_wire_value_proto - .extend(wires.into_iter().map(|(w, v)| (w.node(), w.source(), v))); + self.0.in_wire_value_proto.extend( + h.linked_inputs(wire.node(), wire.source()) + .map(|(n, inp)| (n, inp, value.clone())), + ); } - /// Run the analysis (iterate until a lattice fixpoint is reached). + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// given initial values for some of the root node inputs. + /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, + /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. /// /// # Panics /// /// If this Machine has been run already. /// - pub fn run(&mut self, context: C) { + pub fn run( + &mut self, + context: C, + in_values: impl IntoIterator)>, + ) { assert!(self.1.is_none()); + let root = context.root(); + self.0 + .in_wire_value_proto + .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); self.0.context.push((context,)); self.0.run(); self.1 = Some( diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 4e314de49..23ea73231 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -15,7 +15,7 @@ use hugr_core::{ types::{Signature, SumType, Type}, HugrView, }; -use hugr_core::{Hugr, Node, Wire}; +use hugr_core::{Hugr, Wire}; use rstest::{fixture, rstest}; use super::{AbstractValue, DFContext, Machine, PartialValue, TailLoopTermination}; @@ -98,7 +98,7 @@ fn test_make_tuple() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let x = machine .read_out_wire(v3) @@ -119,7 +119,7 @@ fn test_unpack_tuple_const() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o1_r = machine .read_out_wire(o1) @@ -151,7 +151,7 @@ fn test_tail_loop_never_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o_r = machine .read_out_wire(tl_o) @@ -185,7 +185,7 @@ fn test_tail_loop_always_iterates() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o_r1 = machine.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -224,7 +224,7 @@ fn test_tail_loop_two_iters() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o_r1 = machine.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -291,7 +291,7 @@ fn test_tail_loop_containing_conditional() { let [o_w1, o_w2] = tail_loop.outputs_arr(); let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), []); let o_r1 = machine.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -344,8 +344,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - machine.propolutate_out_wires([(arg_w, arg_pv)]); - machine.run(TestContext(Arc::new(&hugr))); + machine.run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); let cond_r1 = machine .read_out_wire(cond_o1) @@ -365,13 +364,9 @@ fn test_conditional() { assert_eq!(machine.case_reachable(&hugr, cond.node()), None); } -// Tuple of -// 1. Hugr being a function on bools: (x, y) => (x XOR y, x AND y) -// 2. Input node of entry block -// Result readable from root node outputs -// Inputs should be placed onto out-wires of the Node (2.) +// A Hugr being a function on bools: (x, y) => (x XOR y, x AND y) #[fixture] -fn xor_and_cfg() -> (Hugr, Node) { +fn xor_and_cfg() -> Hugr { // Entry // /0 1\ // A --1-> B A(x=true, y) => if y then X(false, true) else B(x=true) @@ -456,9 +451,7 @@ fn xor_and_cfg() -> (Hugr, Node) { builder.branch(&a, 0, &x).unwrap(); builder.branch(&a, 1, &b).unwrap(); builder.branch(&b, 0, &x).unwrap(); - let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let [entry_input, _] = hugr.get_io(entry.node()).unwrap(); - (hugr, entry_input) + builder.finish_hugr(&EMPTY_REG).unwrap() } #[rstest] @@ -479,22 +472,24 @@ fn test_cfg( #[case] inp1: PartialValue, #[case] out0: PartialValue, #[case] out1: PartialValue, - xor_and_cfg: (Hugr, Node), + xor_and_cfg: Hugr, ) { - let (hugr, entry_input) = xor_and_cfg; - - let [in_w0, in_w1] = [0, 1].map(|i| Wire::new(entry_input, i)); - let mut machine = Machine::default(); - machine.propolutate_out_wires([(in_w0, inp0), (in_w1, inp1)]); - machine.run(TestContext(Arc::new(&hugr))); + machine.run( + TestContext(Arc::new(&xor_and_cfg)), + [(0.into(), inp0), (1.into(), inp1)], + ); assert_eq!( - machine.read_out_wire(Wire::new(hugr.root(), 0)).unwrap(), + machine + .read_out_wire(Wire::new(xor_and_cfg.root(), 0)) + .unwrap(), out0 ); assert_eq!( - machine.read_out_wire(Wire::new(hugr.root(), 1)).unwrap(), + machine + .read_out_wire(Wire::new(xor_and_cfg.root(), 1)) + .unwrap(), out1 ); } @@ -527,11 +522,11 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let [root_inp, _] = hugr.get_io(hugr.root()).unwrap(); - let [inw0, inw1] = [0, 1].map(|i| Wire::new(root_inp, i)); let mut machine = Machine::default(); - machine.propolutate_out_wires([(inw0, inp0), (inw1, inp1)]); - machine.run(TestContext(Arc::new(&hugr))); + machine.run( + TestContext(Arc::new(&hugr)), + [(0.into(), inp0), (1.into(), inp1)], + ); let [res0, res1] = [0, 1].map(|i| machine.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: From 68b1d486efeb18f24515a23e043dd83de17ffe9e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 17:36:01 +0100 Subject: [PATCH 130/203] Rm/inline value_inputs/value_outputs, use UnpackTuple, comments --- hugr-passes/src/dataflow/datalog.rs | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index e9e1b92a5..f664a84d4 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -42,9 +42,8 @@ ascent::ascent! { node(c, n) <-- context(c), for n in c.nodes(); - in_wire(c, n,p) <-- node(c, n), for p in value_inputs(c.as_ref(), *n); - - out_wire(c, n,p) <-- node(c, n), for p in value_outputs(c.as_ref(), *n); + in_wire(c, n,p) <-- node(c, n), for (p,_) in c.in_value_types(*n); // Note, gets connected inports only + out_wire(c, n,p) <-- node(c, n), for (p,_) in c.out_value_types(*n); // (and likewise) parent_of_node(c, parent, child) <-- node(c, child), if let Some(parent) = c.get_parent(*child); @@ -220,8 +219,9 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), op if op.cast::().is_some() => { + let elem_tys = op.cast::().unwrap().0; let [tup] = ins.iter().collect::>().try_into().unwrap(); - tup.variant_values(0, value_outputs(c.as_ref(), n).count()) + tup.variant_values(0, elem_tys.len()) .map(ValueRow::from_iter) } OpType::Tag(t) => Some(ValueRow::from_iter([PV::new_variant( @@ -257,16 +257,7 @@ fn propagate_leaf_op( } } -fn value_inputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { - h.in_value_types(n).map(|x| x.0) -} - -fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator + '_ { - h.out_value_types(n).map(|x| x.0) -} - -// Wrap a (known-length) row of values into a lattice. Perhaps could be part of partial_value.rs? - +// Wrap a (known-length) row of values into a lattice. #[derive(PartialEq, Clone, Eq, Hash)] struct ValueRow(Vec>); From 2cc62f0d17999b4d84ae5b042bd932f592a43a00 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 17:48:28 +0100 Subject: [PATCH 131/203] clippy --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 370b4d643..a861d48b7 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -81,7 +81,7 @@ fn traverse_value( .map(PartialValue::from) .unwrap_or(PartialValue::Top), Value::Function { hugr } => s - .value_from_const_hugr(n, fields, &**hugr) + .value_from_const_hugr(n, fields, hugr) .map(PartialValue::from) .unwrap_or(PartialValue::Top), } diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 8f5a8d5a7..50d7b088b 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -104,8 +104,7 @@ impl> Machine { self.0 .case_reachable .iter() - .find(|(_, cond2, case2)| &cond == cond2 && &case == case2) - .is_some(), + .any(|(_, cond2, case2)| &cond == cond2 && &case == case2), ) } @@ -122,8 +121,7 @@ impl> Machine { self.0 .bb_reachable .iter() - .find(|(_, cfg2, bb2)| *cfg2 == cfg && *bb2 == bb) - .is_some(), + .any(|(_, cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), ) } } From 8254771b22d626c0657722b01d8ab9ba62cce900 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Oct 2024 17:52:01 +0100 Subject: [PATCH 132/203] docs --- hugr-passes/src/dataflow/machine.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 50d7b088b..888e34505 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -6,8 +6,8 @@ use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] -/// 2. Zero or more [Self::propolutate_out_wires] with initial values -/// 3. Exactly one [Self::run] to do the analysis +/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values +/// 3. Exactly one [Self::run], with initial values for root inputs, to do the analysis /// 4. Results then available via [Self::read_out_wire] pub struct Machine>( AscentProgram, @@ -110,6 +110,10 @@ impl> Machine { /// Tells us if a block ([DataflowBlock] or [ExitBlock]) in a [CFG] is known /// to be reachable. (Returns `None` if argument is not a child of a CFG.) + /// + /// [CFG]: hugr_core::ops::CFG + /// [DataflowBlock]: hugr_core::ops::DataflowBlock + /// [ExitBlock]: hugr_core::ops::ExitBlock pub fn bb_reachable(&self, hugr: impl HugrView, bb: Node) -> Option { let cfg = hugr.get_parent(bb)?; // Not really required...?? hugr.get_optype(cfg).as_cfg()?; From 7e81b153102f87be1c7ba44d7d008bc884be013c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 09:38:20 +0100 Subject: [PATCH 133/203] Separate AnalysisResults from Machine, use context.exactly_one() not HugrView --- hugr-passes/src/dataflow.rs | 2 +- hugr-passes/src/dataflow/machine.rs | 65 ++++++++++---------- hugr-passes/src/dataflow/test.rs | 92 ++++++++++++----------------- 3 files changed, 72 insertions(+), 87 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index a861d48b7..e04b4c2a8 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -5,7 +5,7 @@ mod datalog; mod machine; use hugr_core::ops::constant::CustomConst; -pub use machine::{Machine, TailLoopTermination}; +pub use machine::{AnalysisResults, Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 888e34505..4ac66686c 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,23 +1,27 @@ use std::collections::HashMap; use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; +use itertools::Itertools; use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Exactly one [Self::run], with initial values for root inputs, to do the analysis -/// 4. Results then available via [Self::read_out_wire] -pub struct Machine>( - AscentProgram, - Option>>, +/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via +/// [read_out_wire](AnalysisResults::read_out_wire) +pub struct Machine>(AscentProgram); + +/// Results of a dataflow analysis. +pub struct AnalysisResults>( + AscentProgram, // Already run - kept for tests/debug + HashMap>, ); /// derived-Default requires the context to be Defaultable, which is unnecessary impl> Default for Machine { fn default() -> Self { - Self(Default::default(), None) + Self(Default::default()) } } @@ -25,7 +29,6 @@ impl> Machine { /// Provide initial values for some wires. // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { - assert!(self.1.is_none()); self.0.in_wire_value_proto.extend( h.linked_inputs(wire.node(), wire.source()) .map(|(n, inp)| (n, inp, value.clone())), @@ -37,35 +40,36 @@ impl> Machine { /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. - /// - /// # Panics - /// - /// If this Machine has been run already. - /// pub fn run( - &mut self, + mut self, context: C, in_values: impl IntoIterator)>, - ) { - assert!(self.1.is_none()); + ) -> AnalysisResults { let root = context.root(); self.0 .in_wire_value_proto .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); self.0.context.push((context,)); self.0.run(); - self.1 = Some( - self.0 - .out_wire_value - .iter() - .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) - .collect(), - ) + let results = self + .0 + .out_wire_value + .iter() + .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(); + AnalysisResults(self.0, results) + } +} + +impl> AnalysisResults { + fn context(&self) -> &C { + let (c,) = self.0.context.iter().exactly_one().ok().unwrap(); + c } - /// Gets the lattice value computed by [Self::run] for the given wire + /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { - self.1.as_ref().unwrap().get(&w).cloned() + self.1.get(&w).cloned() } /// Tells whether a [TailLoop] node can terminate, i.e. whether @@ -73,11 +77,8 @@ impl> Machine { /// Returns `None` if the specified `node` is not a [TailLoop]. /// /// [TailLoop]: hugr_core::ops::TailLoop - pub fn tail_loop_terminates( - &self, - hugr: impl HugrView, - node: Node, - ) -> Option { + pub fn tail_loop_terminates(&self, node: Node) -> Option { + let hugr = self.context(); hugr.get_optype(node).as_tail_loop()?; let [_, out] = hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( @@ -96,7 +97,8 @@ impl> Machine { /// /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional - pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> Option { + pub fn case_reachable(&self, case: Node) -> Option { + let hugr = self.context(); hugr.get_optype(case).as_case()?; let cond = hugr.get_parent(case)?; hugr.get_optype(cond).as_conditional()?; @@ -114,7 +116,8 @@ impl> Machine { /// [CFG]: hugr_core::ops::CFG /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock - pub fn bb_reachable(&self, hugr: impl HugrView, bb: Node) -> Option { + pub fn bb_reachable(&self, bb: Node) -> Option { + let hugr = self.context(); let cfg = hugr.get_parent(bb)?; // Not really required...?? hugr.get_optype(cfg).as_cfg()?; let t = hugr.get_optype(bb); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 23ea73231..e79a0024b 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -97,10 +97,9 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let x = machine + let x = results .read_out_wire(v3) .unwrap() .try_into_wire_value(&hugr, v3) @@ -118,16 +117,15 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o1_r = machine + let o1_r = results .read_out_wire(o1) .unwrap() .try_into_wire_value(&hugr, o1) .unwrap(); assert_eq!(o1_r, Value::false_val()); - let o2_r = machine + let o2_r = results .read_out_wire(o2) .unwrap() .try_into_wire_value(&hugr, o2) @@ -150,10 +148,9 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r = machine + let o_r = results .read_out_wire(tl_o) .unwrap() .try_into_wire_value(&hugr, tl_o) @@ -161,7 +158,7 @@ fn test_tail_loop_never_iterates() { assert_eq!(o_r, r_v); assert_eq!( Some(TailLoopTermination::NeverContinues), - machine.tail_loop_terminates(&hugr, tail_loop.node()) + results.tail_loop_terminates(tail_loop.node()) ) } @@ -184,18 +181,17 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r1 = machine.read_out_wire(tl_o1).unwrap(); + let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); - let o_r2 = machine.read_out_wire(tl_o2).unwrap(); + let o_r2 = results.read_out_wire(tl_o2).unwrap(); assert_eq!(o_r2, PartialValue::bottom()); assert_eq!( Some(TailLoopTermination::NeverBreaks), - machine.tail_loop_terminates(&hugr, tail_loop.node()) + results.tail_loop_terminates(tail_loop.node()) ); - assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); } #[test] @@ -223,18 +219,17 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r1 = machine.read_out_wire(o_w1).unwrap(); + let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); - let o_r2 = machine.read_out_wire(o_w2).unwrap(); + let o_r2 = results.read_out_wire(o_w2).unwrap(); assert_eq!(o_r2, pv_true_or_false()); assert_eq!( Some(TailLoopTermination::BreaksAndContinues), - machine.tail_loop_terminates(&hugr, tail_loop.node()) + results.tail_loop_terminates(tail_loop.node()) ); - assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); } #[test] @@ -290,18 +285,17 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let mut machine = Machine::default(); - machine.run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r1 = machine.read_out_wire(o_w1).unwrap(); + let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); - let o_r2 = machine.read_out_wire(o_w2).unwrap(); + let o_r2 = results.read_out_wire(o_w2).unwrap(); assert_eq!(o_r2, pv_false()); assert_eq!( Some(TailLoopTermination::BreaksAndContinues), - machine.tail_loop_terminates(&hugr, tail_loop.node()) + results.tail_loop_terminates(tail_loop.node()) ); - assert_eq!(machine.tail_loop_terminates(&hugr, hugr.root()), None); + assert_eq!(results.tail_loop_terminates(hugr.root()), None); } #[test] @@ -339,29 +333,28 @@ fn test_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let mut machine = Machine::default(); let arg_pv = PartialValue::new_variant(1, []).join(PartialValue::new_variant( 2, [PartialValue::new_variant(0, [])], )); - machine.run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); + let results = Machine::default().run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); - let cond_r1 = machine + let cond_r1 = results .read_out_wire(cond_o1) .unwrap() .try_into_wire_value(&hugr, cond_o1) .unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(machine + assert!(results .read_out_wire(cond_o2) .unwrap() .try_into_wire_value(&hugr, cond_o2) .is_err()); - assert_eq!(machine.case_reachable(&hugr, case1.node()), Some(false)); // arg_pv is variant 1 or 2 only - assert_eq!(machine.case_reachable(&hugr, case2.node()), Some(true)); - assert_eq!(machine.case_reachable(&hugr, case3.node()), Some(true)); - assert_eq!(machine.case_reachable(&hugr, cond.node()), None); + assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only + assert_eq!(results.case_reachable(case2.node()), Some(true)); + assert_eq!(results.case_reachable(case3.node()), Some(true)); + assert_eq!(results.case_reachable(cond.node()), None); } // A Hugr being a function on bools: (x, y) => (x XOR y, x AND y) @@ -474,24 +467,14 @@ fn test_cfg( #[case] out1: PartialValue, xor_and_cfg: Hugr, ) { - let mut machine = Machine::default(); - machine.run( - TestContext(Arc::new(&xor_and_cfg)), + let root = xor_and_cfg.root(); + let results = Machine::default().run( + TestContext(Arc::new(xor_and_cfg)), [(0.into(), inp0), (1.into(), inp1)], ); - assert_eq!( - machine - .read_out_wire(Wire::new(xor_and_cfg.root(), 0)) - .unwrap(), - out0 - ); - assert_eq!( - machine - .read_out_wire(Wire::new(xor_and_cfg.root(), 1)) - .unwrap(), - out1 - ); + assert_eq!(results.read_out_wire(Wire::new(root, 0)).unwrap(), out0); + assert_eq!(results.read_out_wire(Wire::new(root, 1)).unwrap(), out1); } #[rstest] @@ -522,13 +505,12 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let mut machine = Machine::default(); - machine.run( + let results = Machine::default().run( TestContext(Arc::new(&hugr)), [(0.into(), inp0), (1.into(), inp1)], ); - let [res0, res1] = [0, 1].map(|i| machine.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); + let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: assert_eq!(res0, out); assert_eq!(res1, out); From 34e82ede05e00d64a5e4cf8a5de7a310d72987d3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 09:51:53 +0100 Subject: [PATCH 134/203] Move try_into_wire_value => AnalysisResults.try_read_wire_value --- hugr-passes/src/dataflow/machine.rs | 28 +++++++++++++++++- hugr-passes/src/dataflow/partial_value.rs | 29 ------------------ hugr-passes/src/dataflow/test.rs | 36 ++++------------------- 3 files changed, 33 insertions(+), 60 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 4ac66686c..df2320ff1 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; +use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; use itertools::Itertools; use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; @@ -133,6 +133,32 @@ impl> AnalysisResults { } } +impl> AnalysisResults +where + Value: From, +{ + /// Reads a [Value] from an output wire, if the lattice value computed for it can be turned + /// into one. (The lattice value must be either a single [Value](Self::Value) or + /// a [Sum](PartialValue::PartialSum with a single known tag.) + /// + /// # Errors + /// `None` if the analysis did not result in a single value on that wire + /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] + /// + /// # Panics + /// + /// If a [Type] for the specified wire could not be extracted from the Hugr + pub fn try_read_wire_value(&self, w: Wire) -> Result> { + let v = self.read_out_wire(w).ok_or(None)?; + let (_, typ) = self + .context() + .out_value_types(w.node()) + .find(|(p, _)| *p == w.source()) + .unwrap(); + v.try_into_value(&typ) + } +} + /// Tells whether a loop iterates (never, always, sometimes) #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub enum TailLoopTermination { diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index d3010ae45..5b5695dd0 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -2,7 +2,6 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; use hugr_core::ops::Value; use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; -use hugr_core::{HugrView, Wire}; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -327,34 +326,6 @@ impl TryFrom> for Value { } } -impl PartialValue -where - Value: From, -{ - /// Turns this instance into a [Value], if it is either a single [Value](Self::Value) or - /// a [Sum](PartialValue::PartialSum) with a single known tag, extracting the desired type - /// from a HugrView and Wire. - /// - /// # Errors - /// `None` if the analysis did not result in a single value on that wire - /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] - /// - /// # Panics - /// - /// If a [Type] for the specified wire could not be extracted from the Hugr - pub fn try_into_wire_value( - self, - hugr: &impl HugrView, - w: Wire, - ) -> Result> { - let (_, typ) = hugr - .out_value_types(w.node()) - .find(|(p, _)| *p == w.source()) - .unwrap(); - self.try_into_value(&typ) - } -} - impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e79a0024b..74d504b74 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -99,11 +99,7 @@ fn test_make_tuple() { let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let x = results - .read_out_wire(v3) - .unwrap() - .try_into_wire_value(&hugr, v3) - .unwrap(); + let x = results.try_read_wire_value(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); } @@ -119,17 +115,9 @@ fn test_unpack_tuple_const() { let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o1_r = results - .read_out_wire(o1) - .unwrap() - .try_into_wire_value(&hugr, o1) - .unwrap(); + let o1_r = results.try_read_wire_value(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); - let o2_r = results - .read_out_wire(o2) - .unwrap() - .try_into_wire_value(&hugr, o2) - .unwrap(); + let o2_r = results.try_read_wire_value(o2).unwrap(); assert_eq!(o2_r, Value::true_val()); } @@ -150,11 +138,7 @@ fn test_tail_loop_never_iterates() { let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); - let o_r = results - .read_out_wire(tl_o) - .unwrap() - .try_into_wire_value(&hugr, tl_o) - .unwrap(); + let o_r = results.try_read_wire_value(tl_o).unwrap(); assert_eq!(o_r, r_v); assert_eq!( Some(TailLoopTermination::NeverContinues), @@ -339,17 +323,9 @@ fn test_conditional() { )); let results = Machine::default().run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); - let cond_r1 = results - .read_out_wire(cond_o1) - .unwrap() - .try_into_wire_value(&hugr, cond_o1) - .unwrap(); + let cond_r1 = results.try_read_wire_value(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results - .read_out_wire(cond_o2) - .unwrap() - .try_into_wire_value(&hugr, cond_o2) - .is_err()); + assert!(results.try_read_wire_value(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); From 015707f98a21d99cb0729e4be776604d2b16b4cf Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 10:16:14 +0100 Subject: [PATCH 135/203] doc fixes and fix comment --- hugr-passes/src/dataflow/machine.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index df2320ff1..4ddde626e 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -14,7 +14,7 @@ pub struct Machine>(AscentProgram); /// Results of a dataflow analysis. pub struct AnalysisResults>( - AscentProgram, // Already run - kept for tests/debug + AscentProgram, // Already run HashMap>, ); @@ -138,8 +138,8 @@ where Value: From, { /// Reads a [Value] from an output wire, if the lattice value computed for it can be turned - /// into one. (The lattice value must be either a single [Value](Self::Value) or - /// a [Sum](PartialValue::PartialSum with a single known tag.) + /// into one. (The lattice value must be either a single [Value](PartialValue::Value) or + /// a [Sum](PartialValue::PartialSum) with a single known tag.) /// /// # Errors /// `None` if the analysis did not result in a single value on that wire @@ -147,7 +147,7 @@ where /// /// # Panics /// - /// If a [Type] for the specified wire could not be extracted from the Hugr + /// If a [Type](hugr_core::types::Type) for the specified wire could not be extracted from the Hugr pub fn try_read_wire_value(&self, w: Wire) -> Result> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self From ff39f7d92831ddcc9babf6bf947414b1851ae15a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 21 Oct 2024 10:21:22 +0100 Subject: [PATCH 136/203] Try to make clippy happy --- hugr-passes/src/dataflow/machine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index 4ddde626e..ccd13e517 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -9,7 +9,7 @@ use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values /// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via -/// [read_out_wire](AnalysisResults::read_out_wire) +/// [read_out_wire](AnalysisResults::read_out_wire) pub struct Machine>(AscentProgram); /// Results of a dataflow analysis. From ada7ee1d1afc0fe627179f09ffe515f6889132b9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 12:02:22 +0100 Subject: [PATCH 137/203] Use ascent_run to drop context from all the relations. Lots cleanup to follow --- hugr-passes/src/dataflow/datalog.rs | 366 +++++++++++++++------------- hugr-passes/src/dataflow/machine.rs | 57 ++--- 2 files changed, 219 insertions(+), 204 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index f664a84d4..5cb9afa32 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -8,7 +8,7 @@ )] use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::zip_eq; +use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::hash::Hash; use std::ops::{Index, IndexMut}; @@ -27,183 +27,197 @@ pub enum IO { Output, } -ascent::ascent! { - pub(super) struct AscentProgram>; - relation context(C); - - relation node(C, Node); - relation in_wire(C, Node, IncomingPort); - relation out_wire(C, Node, OutgoingPort); - relation parent_of_node(C, Node, Node); - relation io_node(C, Node, Node, IO); - lattice out_wire_value(C, Node, OutgoingPort, PV); - lattice in_wire_value(C, Node, IncomingPort, PV); - lattice node_in_value_row(C, Node, ValueRow); - - node(c, n) <-- context(c), for n in c.nodes(); - - in_wire(c, n,p) <-- node(c, n), for (p,_) in c.in_value_types(*n); // Note, gets connected inports only - out_wire(c, n,p) <-- node(c, n), for (p,_) in c.out_value_types(*n); // (and likewise) - - parent_of_node(c, parent, child) <-- - node(c, child), if let Some(parent) = c.get_parent(*child); - - io_node(c, parent, child, io) <-- node(c, parent), - if let Some([i,o]) = c.get_io(*parent), - for (child,io) in [(i,IO::Input),(o,IO::Output)]; - - // Initialize all wires to bottom - out_wire_value(c, n,p, PV::bottom()) <-- out_wire(c, n,p); - - in_wire_value(c, n, ip, v) <-- in_wire(c, n, ip), - if let Some((m,op)) = c.single_linked_output(*n, *ip), - out_wire_value(c, m, op, v); - - // We support prepopulating in_wire_value via in_wire_value_proto. - relation in_wire_value_proto(Node, IncomingPort, PV); - in_wire_value(c, n, p, PV::bottom()) <-- in_wire(c, n,p); - in_wire_value(c, n, p, v) <-- node(c,n), - if let Some(sig) = c.signature(*n), - for p in sig.input_ports(), - in_wire_value_proto(n, p, v); - - node_in_value_row(c, n, ValueRow::new(sig.input_count())) <-- node(c, n), if let Some(sig) = c.signature(*n); - node_in_value_row(c, n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(c, n, p, v); - - out_wire_value(c, n, p, v) <-- - node(c, n), - let op_t = c.get_optype(*n), - if !op_t.is_container(), - if let Some(sig) = op_t.dataflow_signature(), - node_in_value_row(c, n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count()), - for (p,v) in (0..).map(OutgoingPort::from).zip(outs); - - // DFG - relation dfg_node(C, Node); - dfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_dfg(); - - out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), - io_node(c, dfg, i, IO::Input), in_wire_value(c, dfg, p, v); - - out_wire_value(c, dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(c,dfg), - io_node(c,dfg,o, IO::Output), in_wire_value(c, o, p, v); - - - // TailLoop - - // inputs of tail loop propagate to Input node of child region - out_wire_value(c, i, OutgoingPort::from(p.index()), v) <-- node(c, tl), - if c.get_optype(*tl).is_tail_loop(), - io_node(c,tl,i, IO::Input), - in_wire_value(c, tl, p, v); - - // Output node of child region propagate to Input node of child region - out_wire_value(c, in_n, OutgoingPort::from(out_p), v) <-- node(c, tl_n), - if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - io_node(c,tl_n,in_n, IO::Input), - io_node(c,tl_n,out_n, IO::Output), - node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - - if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 - for (out_p, v) in fields.enumerate(); - - // Output node of child region propagate to outputs of tail loop - out_wire_value(c, tl_n, OutgoingPort::from(out_p), v) <-- node(c, tl_n), - if let Some(tailloop) = c.get_optype(*tl_n).as_tail_loop(), - io_node(c,tl_n,out_n, IO::Output), - node_in_value_row(c, out_n, out_in_row), // get the whole input row for the output node - if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 - for (out_p, v) in fields.enumerate(); - - // Conditional - relation conditional_node(C, Node); - relation case_node(C,Node,usize, Node); - - conditional_node (c,n)<-- node(c, n), if c.get_optype(*n).is_conditional(); - case_node(c,cond,i, case) <-- conditional_node(c,cond), - for (i, case) in c.children(*cond).enumerate(), - if c.get_optype(case).is_case(); - - // inputs of conditional propagate into case nodes - out_wire_value(c, i_node, OutgoingPort::from(out_p), v) <-- - case_node(c, cond, case_index, case), - io_node(c, case, i_node, IO::Input), - node_in_value_row(c, cond, in_row), - let conditional = c.get_optype(*cond).as_conditional().unwrap(), - if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), - for (out_p, v) in fields.enumerate(); - - // outputs of case nodes propagate to outputs of conditional *if* case reachable - out_wire_value(c, cond, OutgoingPort::from(o_p.index()), v) <-- - case_node(c, cond, _, case), - case_reachable(c, cond, case), - io_node(c, case, o, IO::Output), - in_wire_value(c, o, o_p, v); - - relation case_reachable(C, Node, Node); - case_reachable(c, cond, case) <-- case_node(c,cond,i,case), - in_wire_value(c, cond, IncomingPort::from(0), v), - if v.supports_tag(*i); - - // CFG - relation cfg_node(C, Node); - relation dfb_block(C, Node, Node); - cfg_node(c,n) <-- node(c, n), if c.get_optype(*n).is_cfg(); - dfb_block(c,cfg,blk) <-- cfg_node(c, cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); - - // Reachability - relation bb_reachable(C, Node, Node); - bb_reachable(c, cfg, entry) <-- cfg_node(c, cfg), if let Some(entry) = c.children(*cfg).next(); - bb_reachable(c, cfg, bb) <-- cfg_node(c, cfg), - bb_reachable(c, cfg, pred), - io_node(c, pred, pred_out, IO::Output), - in_wire_value(c, pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in c.output_neighbours(*pred).enumerate(), - if predicate.supports_tag(tag); - - // Where do the values "fed" along a control-flow edge come out? - relation _cfg_succ_dest(C, Node, Node, Node); - _cfg_succ_dest(c, cfg, blk, inp) <-- dfb_block(c, cfg, blk), io_node(c, blk, inp, IO::Input); - _cfg_succ_dest(c, cfg, exit, cfg) <-- cfg_node(c, cfg), if let Some(exit) = c.children(*cfg).nth(1); - - // Inputs of CFG propagate to entry block - out_wire_value(c, i_node, OutgoingPort::from(p.index()), v) <-- - cfg_node(c, cfg), - if let Some(entry) = c.children(*cfg).next(), - io_node(c, entry, i_node, IO::Input), - in_wire_value(c, cfg, p, v); - - // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself - out_wire_value(c, dest, OutgoingPort::from(out_p), v) <-- - dfb_block(c, cfg, pred), - bb_reachable(c, cfg, pred), - let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), - for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), - io_node(c, pred, out_n, IO::Output), - _cfg_succ_dest(c, cfg, succ, dest), - node_in_value_row(c, out_n, out_in_row), - if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), - for (out_p, v) in fields.enumerate(); - - // Call - relation func_call(C, Node, Node); - func_call(c, call, func_defn) <-- - node(c, call), - if c.get_optype(*call).is_call(), - if let Some(func_defn) = c.static_source(*call); - - out_wire_value(c, inp, OutgoingPort::from(p.index()), v) <-- - func_call(c, call, func), - io_node(c, func, inp, IO::Input), - in_wire_value(c, call, p, v); - - out_wire_value(c, call, OutgoingPort::from(p.index()), v) <-- - func_call(c, call, func), - io_node(c, func, outp, IO::Output), - in_wire_value(c, outp, p, v); +pub(super) struct DatalogResults { + pub in_wire_value: Vec<(Node, IncomingPort, PV)>, + pub out_wire_value: Vec<(Node, OutgoingPort, PV)>, + pub case_reachable: Vec<(Node, Node)>, + pub bb_reachable: Vec<(Node, Node)>, } +pub(super) fn run_datalog>( + in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, + c: &C, +) -> DatalogResults { + let all_results = ascent::ascent_run! { + pub(super) struct AscentProgram; + relation node(Node); + relation in_wire(Node, IncomingPort); + relation out_wire(Node, OutgoingPort); + relation parent_of_node(Node, Node); + relation io_node(Node, Node, IO); + lattice out_wire_value(Node, OutgoingPort, PV); + lattice in_wire_value(Node, IncomingPort, PV); + lattice node_in_value_row(Node, ValueRow); + + node(n) <-- for n in c.nodes(); + + in_wire(n, p) <-- node(n), for (p,_) in c.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in c.out_value_types(*n); // (and likewise) + + parent_of_node(parent, child) <-- + node(child), if let Some(parent) = c.get_parent(*child); + + io_node(parent, child, io) <-- node(parent), + if let Some([i, o]) = c.get_io(*parent), + for (child,io) in [(i,IO::Input),(o,IO::Output)]; + + // Initialize all wires to bottom + out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); + + in_wire_value(n, ip, v) <-- in_wire(n, ip), + if let Some((m, op)) = c.single_linked_output(*n, *ip), + out_wire_value(m, op, v); + + // We support prepopulating in_wire_value via in_wire_value_proto. + in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); + in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), + node(n), + if let Some(sig) = c.signature(*n), + if sig.input_ports().contains(p); + + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = c.signature(*n); + node_in_value_row(n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + + out_wire_value(n, p, v) <-- + node(n), + let op_t = c.get_optype(*n), + if !op_t.is_container(), + if let Some(sig) = op_t.dataflow_signature(), + node_in_value_row(n, vs), + if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count()), + for (p, v) in (0..).map(OutgoingPort::from).zip(outs); + + // DFG + relation dfg_node(Node); + dfg_node(n) <-- node(n), if c.get_optype(*n).is_dfg(); + + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + io_node(dfg, i, IO::Input), in_wire_value(dfg, p, v); + + out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), + io_node(dfg, o, IO::Output), in_wire_value(o, p, v); + + + // TailLoop + + // inputs of tail loop propagate to Input node of child region + out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), + if c.get_optype(*tl).is_tail_loop(), + io_node(tl, i, IO::Input), + in_wire_value(tl, p, v); + + // Output node of child region propagate to Input node of child region + out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = c.get_optype(*tl).as_tail_loop(), + io_node(tl, in_n, IO::Input), + io_node(tl, out_n, IO::Output), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node + + if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 + for (out_p, v) in fields.enumerate(); + + // Output node of child region propagate to outputs of tail loop + out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), + if let Some(tailloop) = c.get_optype(*tl).as_tail_loop(), + io_node(tl, out_n, IO::Output), + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node + if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 + for (out_p, v) in fields.enumerate(); + + // Conditional + relation conditional_node(Node); + relation case_node(Node, usize, Node); + + conditional_node(n)<-- node(n), if c.get_optype(*n).is_conditional(); + case_node(cond, i, case) <-- conditional_node(cond), + for (i, case) in c.children(*cond).enumerate(), + if c.get_optype(case).is_case(); + + // inputs of conditional propagate into case nodes + out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- + case_node(cond, case_index, case), + io_node(case, i_node, IO::Input), + node_in_value_row(cond, in_row), + let conditional = c.get_optype(*cond).as_conditional().unwrap(), + if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + for (out_p, v) in fields.enumerate(); + + // outputs of case nodes propagate to outputs of conditional *if* case reachable + out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- + case_node(cond, _, case), + case_reachable(cond, case), + io_node(case, o, IO::Output), + in_wire_value(o, o_p, v); + + relation case_reachable(Node, Node); + case_reachable(cond, case) <-- case_node(cond, i, case), + in_wire_value(cond, IncomingPort::from(0), v), + if v.supports_tag(*i); + + // CFG + relation cfg_node(Node); + relation dfb_block(Node, Node); + cfg_node(n) <-- node(n), if c.get_optype(*n).is_cfg(); + dfb_block(cfg, blk) <-- cfg_node(cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); + + // Reachability + relation bb_reachable(Node, Node); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = c.children(*cfg).next(); + bb_reachable(cfg, bb) <-- cfg_node(cfg), + bb_reachable(cfg, pred), + io_node(pred, pred_out, IO::Output), + in_wire_value(pred_out, IncomingPort::from(0), predicate), + for (tag, bb) in c.output_neighbours(*pred).enumerate(), + if predicate.supports_tag(tag); + + // Where do the values "fed" along a control-flow edge come out? + relation _cfg_succ_dest(Node, Node, Node); + _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), io_node(blk, inp, IO::Input); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = c.children(*cfg).nth(1); + + // Inputs of CFG propagate to entry block + out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- + cfg_node(cfg), + if let Some(entry) = c.children(*cfg).next(), + io_node(entry, i_node, IO::Input), + in_wire_value(cfg, p, v); + + // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself + out_wire_value(dest, OutgoingPort::from(out_p), v) <-- + dfb_block(cfg, pred), + bb_reachable(cfg, pred), + let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), + for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), + io_node(pred, out_n, IO::Output), + _cfg_succ_dest(cfg, succ, dest), + node_in_value_row(out_n, out_in_row), + if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + for (out_p, v) in fields.enumerate(); + + // Call + relation func_call(Node, Node); + func_call(call, func_defn) <-- + node(call), + if c.get_optype(*call).is_call(), + if let Some(func_defn) = c.static_source(*call); + + out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + io_node(func, inp, IO::Input), + in_wire_value(call, p, v); + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + func_call(call, func), + io_node(func, outp, IO::Output), + in_wire_value(outp, p, v); + }; + DatalogResults { + in_wire_value: all_results.in_wire_value, + out_wire_value: all_results.out_wire_value, + case_reachable: all_results.case_reachable, + bb_reachable: all_results.bb_reachable, + } +} fn propagate_leaf_op( c: &impl DFContext, n: Node, diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index ccd13e517..f93384f7d 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -1,35 +1,36 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use itertools::Itertools; -use super::{datalog::AscentProgram, AbstractValue, DFContext, PartialValue}; +use super::datalog::{run_datalog, DatalogResults}; +use super::{AbstractValue, DFContext, PartialValue}; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values /// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via /// [read_out_wire](AnalysisResults::read_out_wire) -pub struct Machine>(AscentProgram); +pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); -/// Results of a dataflow analysis. -pub struct AnalysisResults>( - AscentProgram, // Already run - HashMap>, -); +/// Results of a dataflow analysis, packaged with context for easy inspection +pub struct AnalysisResults> { + context: C, + results: DatalogResults, + out_wire_values: HashMap>, +} /// derived-Default requires the context to be Defaultable, which is unnecessary -impl> Default for Machine { +impl Default for Machine { fn default() -> Self { Self(Default::default()) } } -impl> Machine { +impl Machine { /// Provide initial values for some wires. // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { - self.0.in_wire_value_proto.extend( + self.0.extend( h.linked_inputs(wire.node(), wire.source()) .map(|(n, inp)| (n, inp, value.clone())), ); @@ -40,36 +41,36 @@ impl> Machine { /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. - pub fn run( + pub fn run>( mut self, context: C, in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = context.root(); self.0 - .in_wire_value_proto .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - self.0.context.push((context,)); - self.0.run(); - let results = self - .0 + let results = run_datalog(self.0, &context); + let out_wire_values = results .out_wire_value .iter() - .map(|(_, n, p, v)| (Wire::new(*n, *p), v.clone())) + .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); - AnalysisResults(self.0, results) + AnalysisResults { + context, + results, + out_wire_values, + } } } impl> AnalysisResults { fn context(&self) -> &C { - let (c,) = self.0.context.iter().exactly_one().ok().unwrap(); - c + &self.context } /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { - self.1.get(&w).cloned() + self.out_wire_values.get(&w).cloned() } /// Tells whether a [TailLoop] node can terminate, i.e. whether @@ -82,10 +83,10 @@ impl> AnalysisResults { hugr.get_optype(node).as_tail_loop()?; let [_, out] = hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( - self.0 + self.results .in_wire_value .iter() - .find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v)) + .find_map(|(n, p, v)| (*n == out && p.index() == 0).then_some(v)) .unwrap(), )) } @@ -103,10 +104,10 @@ impl> AnalysisResults { let cond = hugr.get_parent(case)?; hugr.get_optype(cond).as_conditional()?; Some( - self.0 + self.results .case_reachable .iter() - .any(|(_, cond2, case2)| &cond == cond2 && &case == case2), + .any(|(cond2, case2)| &cond == cond2 && &case == case2), ) } @@ -125,10 +126,10 @@ impl> AnalysisResults { return None; }; Some( - self.0 + self.results .bb_reachable .iter() - .any(|(_, cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), + .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), ) } } From 7c02d41ae6055cd6082f425f36d21270585cf88c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 15:40:50 +0100 Subject: [PATCH 138/203] DFContext does not Deref, pass Hugr separately --- hugr-passes/src/dataflow.rs | 5 +-- hugr-passes/src/dataflow/datalog.rs | 70 +++++++++++++++-------------- hugr-passes/src/dataflow/machine.rs | 46 +++++++++---------- hugr-passes/src/dataflow/test.rs | 69 +++++----------------------- 4 files changed, 70 insertions(+), 120 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index e04b4c2a8..e24960988 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -12,11 +12,10 @@ pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::{Hugr, Node}; -use std::hash::Hash; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { +pub trait DFContext { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] @@ -61,7 +60,7 @@ pub trait DFContext: Clone + Eq + Hash + std::ops::Deref { } fn traverse_value( - s: &impl DFContext, + s: &(impl DFContext + ?Sized), n: Node, fields: &mut Vec, cst: &Value, diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 5cb9afa32..eb00e1c6a 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -34,9 +34,10 @@ pub(super) struct DatalogResults { pub bb_reachable: Vec<(Node, Node)>, } -pub(super) fn run_datalog>( +pub(super) fn run_datalog( in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, - c: &C, + c: &impl DFContext, + hugr: &impl HugrView, ) -> DatalogResults { let all_results = ascent::ascent_run! { pub(super) struct AscentProgram; @@ -49,47 +50,47 @@ pub(super) fn run_datalog>( lattice in_wire_value(Node, IncomingPort, PV); lattice node_in_value_row(Node, ValueRow); - node(n) <-- for n in c.nodes(); + node(n) <-- for n in hugr.nodes(); - in_wire(n, p) <-- node(n), for (p,_) in c.in_value_types(*n); // Note, gets connected inports only - out_wire(n, p) <-- node(n), for (p,_) in c.out_value_types(*n); // (and likewise) + in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n); // (and likewise) parent_of_node(parent, child) <-- - node(child), if let Some(parent) = c.get_parent(*child); + node(child), if let Some(parent) = hugr.get_parent(*child); io_node(parent, child, io) <-- node(parent), - if let Some([i, o]) = c.get_io(*parent), + if let Some([i, o]) = hugr.get_io(*parent), for (child,io) in [(i,IO::Input),(o,IO::Output)]; // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); in_wire_value(n, ip, v) <-- in_wire(n, ip), - if let Some((m, op)) = c.single_linked_output(*n, *ip), + if let Some((m, op)) = hugr.single_linked_output(*n, *ip), out_wire_value(m, op, v); // We support prepopulating in_wire_value via in_wire_value_proto. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), node(n), - if let Some(sig) = c.signature(*n), + if let Some(sig) = hugr.signature(*n), if sig.input_ports().contains(p); - node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = c.signature(*n); - node_in_value_row(n, ValueRow::single_known(c.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n); + node_in_value_row(n, ValueRow::single_known(hugr.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); out_wire_value(n, p, v) <-- node(n), - let op_t = c.get_optype(*n), + let op_t = hugr.get_optype(*n), if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(c, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(c, hugr, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG relation dfg_node(Node); - dfg_node(n) <-- node(n), if c.get_optype(*n).is_dfg(); + dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), io_node(dfg, i, IO::Input), in_wire_value(dfg, p, v); @@ -102,13 +103,13 @@ pub(super) fn run_datalog>( // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), - if c.get_optype(*tl).is_tail_loop(), + if hugr.get_optype(*tl).is_tail_loop(), io_node(tl, i, IO::Input), in_wire_value(tl, p, v); // Output node of child region propagate to Input node of child region out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = c.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), io_node(tl, in_n, IO::Input), io_node(tl, out_n, IO::Output), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node @@ -118,7 +119,7 @@ pub(super) fn run_datalog>( // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = c.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), io_node(tl, out_n, IO::Output), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 @@ -128,17 +129,17 @@ pub(super) fn run_datalog>( relation conditional_node(Node); relation case_node(Node, usize, Node); - conditional_node(n)<-- node(n), if c.get_optype(*n).is_conditional(); + conditional_node(n)<-- node(n), if hugr.get_optype(*n).is_conditional(); case_node(cond, i, case) <-- conditional_node(cond), - for (i, case) in c.children(*cond).enumerate(), - if c.get_optype(case).is_case(); + for (i, case) in hugr.children(*cond).enumerate(), + if hugr.get_optype(case).is_case(); // inputs of conditional propagate into case nodes out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- case_node(cond, case_index, case), io_node(case, i_node, IO::Input), node_in_value_row(cond, in_row), - let conditional = c.get_optype(*cond).as_conditional().unwrap(), + let conditional = hugr.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); @@ -157,28 +158,28 @@ pub(super) fn run_datalog>( // CFG relation cfg_node(Node); relation dfb_block(Node, Node); - cfg_node(n) <-- node(n), if c.get_optype(*n).is_cfg(); - dfb_block(cfg, blk) <-- cfg_node(cfg), for blk in c.children(*cfg), if c.get_optype(blk).is_dataflow_block(); + cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); + dfb_block(cfg, blk) <-- cfg_node(cfg), for blk in hugr.children(*cfg), if hugr.get_optype(blk).is_dataflow_block(); // Reachability relation bb_reachable(Node, Node); - bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = c.children(*cfg).next(); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), bb_reachable(cfg, pred), io_node(pred, pred_out, IO::Output), in_wire_value(pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in c.output_neighbours(*pred).enumerate(), + for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(Node, Node, Node); _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), io_node(blk, inp, IO::Input); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = c.children(*cfg).nth(1); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), - if let Some(entry) = c.children(*cfg).next(), + if let Some(entry) = hugr.children(*cfg).next(), io_node(entry, i_node, IO::Input), in_wire_value(cfg, p, v); @@ -186,8 +187,8 @@ pub(super) fn run_datalog>( out_wire_value(dest, OutgoingPort::from(out_p), v) <-- dfb_block(cfg, pred), bb_reachable(cfg, pred), - let df_block = c.get_optype(*pred).as_dataflow_block().unwrap(), - for (succ_n, succ) in c.output_neighbours(*pred).enumerate(), + let df_block = hugr.get_optype(*pred).as_dataflow_block().unwrap(), + for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), io_node(pred, out_n, IO::Output), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), @@ -198,8 +199,8 @@ pub(super) fn run_datalog>( relation func_call(Node, Node); func_call(call, func_defn) <-- node(call), - if c.get_optype(*call).is_call(), - if let Some(func_defn) = c.static_source(*call); + if hugr.get_optype(*call).is_call(), + if let Some(func_defn) = hugr.static_source(*call); out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), @@ -220,11 +221,12 @@ pub(super) fn run_datalog>( } fn propagate_leaf_op( c: &impl DFContext, + hugr: &impl HugrView, n: Node, ins: &[PV], num_outs: usize, ) -> Option> { - match c.get_optype(n) { + match hugr.get_optype(n) { // Handle basics here. I guess (given the current interface) we could allow // DFContext to handle these but at the least we'd want these impls to be // easily available for reuse. @@ -247,11 +249,11 @@ fn propagate_leaf_op( OpType::Const(_) => None, // handled by LoadConstant: OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant - let const_node = c + let const_node = hugr .single_linked_output(n, load_op.constant_port()) .unwrap() .0; - let const_val = c.get_optype(const_node).as_const().unwrap().value(); + let const_val = hugr.get_optype(const_node).as_const().unwrap().value(); Some(ValueRow::single_known( 1, 0, diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/machine.rs index f93384f7d..f7f97463d 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/machine.rs @@ -13,8 +13,8 @@ use super::{AbstractValue, DFContext, PartialValue}; pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); /// Results of a dataflow analysis, packaged with context for easy inspection -pub struct AnalysisResults> { - context: C, +pub struct AnalysisResults { + hugr: H, results: DatalogResults, out_wire_values: HashMap>, } @@ -41,33 +41,30 @@ impl Machine { /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. - pub fn run>( + pub fn run( mut self, - context: C, + context: &impl DFContext, + hugr: H, in_values: impl IntoIterator)>, - ) -> AnalysisResults { - let root = context.root(); + ) -> AnalysisResults { + let root = hugr.root(); self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let results = run_datalog(self.0, &context); + let results = run_datalog(self.0, context, &hugr); let out_wire_values = results .out_wire_value .iter() .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { - context, + hugr, results, out_wire_values, } } } -impl> AnalysisResults { - fn context(&self) -> &C { - &self.context - } - +impl AnalysisResults { /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() @@ -79,9 +76,8 @@ impl> AnalysisResults { /// /// [TailLoop]: hugr_core::ops::TailLoop pub fn tail_loop_terminates(&self, node: Node) -> Option { - let hugr = self.context(); - hugr.get_optype(node).as_tail_loop()?; - let [_, out] = hugr.get_io(node).unwrap(); + self.hugr.get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( self.results .in_wire_value @@ -99,10 +95,9 @@ impl> AnalysisResults { /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional pub fn case_reachable(&self, case: Node) -> Option { - let hugr = self.context(); - hugr.get_optype(case).as_case()?; - let cond = hugr.get_parent(case)?; - hugr.get_optype(cond).as_conditional()?; + self.hugr.get_optype(case).as_case()?; + let cond = self.hugr.get_parent(case)?; + self.hugr.get_optype(cond).as_conditional()?; Some( self.results .case_reachable @@ -118,10 +113,9 @@ impl> AnalysisResults { /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock pub fn bb_reachable(&self, bb: Node) -> Option { - let hugr = self.context(); - let cfg = hugr.get_parent(bb)?; // Not really required...?? - hugr.get_optype(cfg).as_cfg()?; - let t = hugr.get_optype(bb); + let cfg = self.hugr.get_parent(bb)?; // Not really required...?? + self.hugr.get_optype(cfg).as_cfg()?; + let t = self.hugr.get_optype(bb); if !t.is_dataflow_block() && !t.is_exit_block() { return None; }; @@ -134,7 +128,7 @@ impl> AnalysisResults { } } -impl> AnalysisResults +impl AnalysisResults where Value: From, { @@ -152,7 +146,7 @@ where pub fn try_read_wire_value(&self, w: Wire) -> Result> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self - .context() + .hugr .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 74d504b74..fe4af1c46 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,6 +1,3 @@ -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; @@ -26,49 +23,9 @@ enum Void {} impl AbstractValue for Void {} -struct TestContext(Arc); - -// Deriving Clone requires H:HugrView to implement Clone, -// but we don't need that as we only clone the Arc. -impl Clone for TestContext { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -impl std::ops::Deref for TestContext { - type Target = hugr_core::Hugr; - - fn deref(&self) -> &Self::Target { - self.0.base_hugr() - } -} - -// Any value used in an Ascent program must be hashable. -// However, there should only be one DFContext, so its hash is immaterial. -impl Hash for TestContext { - fn hash(&self, _state: &mut I) {} -} - -impl PartialEq for TestContext { - fn eq(&self, other: &Self) -> bool { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - true - } -} - -impl Eq for TestContext {} - -impl PartialOrd for TestContext { - fn partial_cmp(&self, other: &Self) -> Option { - // Any AscentProgram should have only one DFContext (maybe cloned) - assert!(Arc::ptr_eq(&self.0, &other.0)); - Some(std::cmp::Ordering::Equal) - } -} +struct TestContext; -impl DFContext for TestContext {} +impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) impl From for Value { @@ -97,7 +54,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let x = results.try_read_wire_value(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -113,7 +70,7 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o1_r = results.try_read_wire_value(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -136,7 +93,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r = results.try_read_wire_value(tl_o).unwrap(); assert_eq!(o_r, r_v); @@ -165,7 +122,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -203,7 +160,7 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -269,7 +226,7 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), []); + let results = Machine::default().run(&TestContext, &hugr, []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -321,7 +278,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - let results = Machine::default().run(TestContext(Arc::new(&hugr)), [(0.into(), arg_pv)]); + let results = Machine::default().run(&TestContext, &hugr, [(0.into(), arg_pv)]); let cond_r1 = results.try_read_wire_value(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); @@ -445,7 +402,8 @@ fn test_cfg( ) { let root = xor_and_cfg.root(); let results = Machine::default().run( - TestContext(Arc::new(xor_and_cfg)), + &TestContext, + &xor_and_cfg, [(0.into(), inp0), (1.into(), inp1)], ); @@ -481,10 +439,7 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let results = Machine::default().run( - TestContext(Arc::new(&hugr)), - [(0.into(), inp0), (1.into(), inp1)], - ); + let results = Machine::default().run(&TestContext, &hugr, [(0.into(), inp0), (1.into(), inp1)]); let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: From 3d4f01675013efa7361b325769b501329002194d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 15:44:59 +0100 Subject: [PATCH 139/203] Massively reduce scope of clippy-allow to inside run_datalog --- hugr-passes/src/dataflow/datalog.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index eb00e1c6a..b1c341dc4 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,11 +1,4 @@ //! [ascent] datalog implementation of analysis. -//! Since ascent-(macro-)generated code generates a bunch of warnings, -//! keep code in here to a minimum. -#![allow( - clippy::clone_on_copy, - clippy::unused_enumerate_index, - clippy::collapsible_if -)] use ascent::lattice::{BoundedLattice, Lattice}; use itertools::{zip_eq, Itertools}; @@ -39,6 +32,13 @@ pub(super) fn run_datalog( c: &impl DFContext, hugr: &impl HugrView, ) -> DatalogResults { + // ascent-(macro-)generated code generates a bunch of warnings, + // keep code in here to a minimum. + #![allow( + clippy::clone_on_copy, + clippy::unused_enumerate_index, + clippy::collapsible_if + )] let all_results = ascent::ascent_run! { pub(super) struct AscentProgram; relation node(Node); @@ -145,7 +145,7 @@ pub(super) fn run_datalog( // outputs of case nodes propagate to outputs of conditional *if* case reachable out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- - case_node(cond, _, case), + case_node(cond, _i, case), case_reachable(cond, case), io_node(case, o, IO::Output), in_wire_value(o, o_p, v); @@ -219,6 +219,7 @@ pub(super) fn run_datalog( bb_reachable: all_results.bb_reachable, } } + fn propagate_leaf_op( c: &impl DFContext, hugr: &impl HugrView, From 8bab4d5e099af0159caa0b5f5b1ed46dfa42d95f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 15:51:24 +0100 Subject: [PATCH 140/203] Move ValueRow into own file --- hugr-passes/src/dataflow.rs | 4 +- hugr-passes/src/dataflow/datalog.rs | 95 +----------------------- hugr-passes/src/dataflow/value_row.rs | 101 ++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 94 deletions(-) create mode 100644 hugr-passes/src/dataflow/value_row.rs diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index e24960988..f3e763feb 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -2,15 +2,15 @@ //! Dataflow analysis of Hugrs. mod datalog; +mod value_row; mod machine; -use hugr_core::ops::constant::CustomConst; pub use machine::{AnalysisResults, Machine, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; -use hugr_core::ops::{ExtensionOp, Value}; +use hugr_core::ops::{constant::CustomConst, ExtensionOp, Value}; use hugr_core::{Hugr, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b1c341dc4..143638ae2 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,15 +1,14 @@ //! [ascent] datalog implementation of analysis. -use ascent::lattice::{BoundedLattice, Lattice}; -use itertools::{zip_eq, Itertools}; -use std::cmp::Ordering; +use ascent::lattice::BoundedLattice; +use itertools::Itertools; use std::hash::Hash; -use std::ops::{Index, IndexMut}; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::{OpTrait, OpType}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; +use super::value_row::ValueRow; use super::{AbstractValue, DFContext, PartialValue}; type PV = PartialValue; @@ -273,91 +272,3 @@ fn propagate_leaf_op( o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" } } - -// Wrap a (known-length) row of values into a lattice. -#[derive(PartialEq, Clone, Eq, Hash)] -struct ValueRow(Vec>); - -impl ValueRow { - fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) - } - - fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { - assert!(idx < len); - let mut r = Self::new(len); - r.0[idx] = v; - r - } - - pub fn unpack_first( - &self, - variant: usize, - len: usize, - ) -> Option>> { - let vals = self[0].variant_values(variant, len)?; - Some(vals.into_iter().chain(self.0[1..].to_owned())) - } -} - -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { - Self(iter.into_iter().collect()) - } -} - -impl PartialOrd for ValueRow { - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl Lattice for ValueRow { - fn join_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.join_mut(v2); - } - changed - } - - fn meet_mut(&mut self, other: Self) -> bool { - assert_eq!(self.0.len(), other.0.len()); - let mut changed = false; - for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { - changed |= v1.meet_mut(v2); - } - changed - } -} - -impl IntoIterator for ValueRow { - type Item = PartialValue; - - type IntoIter = > as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -impl Index for ValueRow -where - Vec>: Index, -{ - type Output = > as Index>::Output; - - fn index(&self, index: Idx) -> &Self::Output { - self.0.index(index) - } -} - -impl IndexMut for ValueRow -where - Vec>: IndexMut, -{ - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { - self.0.index_mut(index) - } -} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs new file mode 100644 index 000000000..ebdbf1b75 --- /dev/null +++ b/hugr-passes/src/dataflow/value_row.rs @@ -0,0 +1,101 @@ +// Wrap a (known-length) row of values into a lattice. + +use std::{ + cmp::Ordering, + ops::{Index, IndexMut}, +}; + +use ascent::{lattice::BoundedLattice, Lattice}; +use itertools::zip_eq; + +use super::{AbstractValue, PartialValue}; + +#[derive(PartialEq, Clone, Eq, Hash)] +pub(super) struct ValueRow(Vec>); + +impl ValueRow { + pub fn new(len: usize) -> Self { + Self(vec![PartialValue::bottom(); len]) + } + + pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { + assert!(idx < len); + let mut r = Self::new(len); + r.0[idx] = v; + r + } + + /// The first value in this ValueRow must be a sum; + /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, + /// then appending the rest of the values in this row. + pub fn unpack_first( + &self, + variant: usize, + len: usize, + ) -> Option>> { + let vals = self[0].variant_values(variant, len)?; + Some(vals.into_iter().chain(self.0[1..].to_owned())) + } +} + +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl PartialOrd for ValueRow { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl Lattice for ValueRow { + fn join_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.join_mut(v2); + } + changed + } + + fn meet_mut(&mut self, other: Self) -> bool { + assert_eq!(self.0.len(), other.0.len()); + let mut changed = false; + for (v1, v2) in zip_eq(self.0.iter_mut(), other.0.into_iter()) { + changed |= v1.meet_mut(v2); + } + changed + } +} + +impl IntoIterator for ValueRow { + type Item = PartialValue; + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl Index for ValueRow +where + Vec>: Index, +{ + type Output = > as Index>::Output; + + fn index(&self, index: Idx) -> &Self::Output { + self.0.index(index) + } +} + +impl IndexMut for ValueRow +where + Vec>: IndexMut, +{ + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + self.0.index_mut(index) + } +} From 3eccadfbd733f10a7c1db0edb469964b10bbf0c4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 15:59:17 +0100 Subject: [PATCH 141/203] Move Machine into datalog.rs, pub(super) fields in AnalysisResults, rm DatalogResults --- hugr-passes/src/dataflow.rs | 3 +- hugr-passes/src/dataflow/datalog.rs | 67 +++++++++++++---- .../src/dataflow/{machine.rs => results.rs} | 72 +++---------------- 3 files changed, 65 insertions(+), 77 deletions(-) rename hugr-passes/src/dataflow/{machine.rs => results.rs} (65%) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index f3e763feb..cce6ea97d 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -2,10 +2,11 @@ //! Dataflow analysis of Hugrs. mod datalog; +pub use datalog::Machine; mod value_row; mod machine; -pub use machine::{AnalysisResults, Machine, TailLoopTermination}; +pub use machine::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 143638ae2..0421567d0 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -6,10 +6,10 @@ use std::hash::Hash; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::{OpTrait, OpType}; -use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _}; +use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; -use super::{AbstractValue, DFContext, PartialValue}; +use super::{AbstractValue, AnalysisResults, DFContext, PartialValue}; type PV = PartialValue; @@ -19,18 +19,53 @@ pub enum IO { Output, } -pub(super) struct DatalogResults { - pub in_wire_value: Vec<(Node, IncomingPort, PV)>, - pub out_wire_value: Vec<(Node, OutgoingPort, PV)>, - pub case_reachable: Vec<(Node, Node)>, - pub bb_reachable: Vec<(Node, Node)>, +/// Basic structure for performing an analysis. Usage: +/// 1. Get a new instance via [Self::default()] +/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values +/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via +/// [read_out_wire](AnalysisResults::read_out_wire) +pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); + +/// derived-Default requires the context to be Defaultable, which is unnecessary +impl Default for Machine { + fn default() -> Self { + Self(Default::default()) + } } -pub(super) fn run_datalog( +impl Machine { + /// Provide initial values for some wires. + // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? + pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { + self.0.extend( + h.linked_inputs(wire.node(), wire.source()) + .map(|(n, inp)| (n, inp, value.clone())), + ); + } + + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// given initial values for some of the root node inputs. + /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, + /// but should handle other containers.) + /// The context passed in allows interpretation of leaf operations. + pub fn run( + mut self, + context: &impl DFContext, + hugr: H, + in_values: impl IntoIterator)>, + ) -> AnalysisResults { + let root = hugr.root(); + self.0 + .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + run_datalog(self.0, context, hugr) + } +} + +pub(super) fn run_datalog( in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, c: &impl DFContext, - hugr: &impl HugrView, -) -> DatalogResults { + hugr: H, +) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. #![allow( @@ -84,7 +119,7 @@ pub(super) fn run_datalog( if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(c, hugr, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(c, &hugr, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG @@ -211,9 +246,15 @@ pub(super) fn run_datalog( io_node(func, outp, IO::Output), in_wire_value(outp, p, v); }; - DatalogResults { + let out_wire_values = all_results + .out_wire_value + .iter() + .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) + .collect(); + AnalysisResults { + hugr, + out_wire_values, in_wire_value: all_results.in_wire_value, - out_wire_value: all_results.out_wire_value, case_reachable: all_results.case_reachable, bb_reachable: all_results.bb_reachable, } diff --git a/hugr-passes/src/dataflow/machine.rs b/hugr-passes/src/dataflow/results.rs similarity index 65% rename from hugr-passes/src/dataflow/machine.rs rename to hugr-passes/src/dataflow/results.rs index f7f97463d..0be8072b0 100644 --- a/hugr-passes/src/dataflow/machine.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,66 +2,15 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::datalog::{run_datalog, DatalogResults}; -use super::{AbstractValue, DFContext, PartialValue}; - -/// Basic structure for performing an analysis. Usage: -/// 1. Get a new instance via [Self::default()] -/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via -/// [read_out_wire](AnalysisResults::read_out_wire) -pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); +use super::{AbstractValue, PartialValue}; /// Results of a dataflow analysis, packaged with context for easy inspection pub struct AnalysisResults { - hugr: H, - results: DatalogResults, - out_wire_values: HashMap>, -} - -/// derived-Default requires the context to be Defaultable, which is unnecessary -impl Default for Machine { - fn default() -> Self { - Self(Default::default()) - } -} - -impl Machine { - /// Provide initial values for some wires. - // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? - pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { - self.0.extend( - h.linked_inputs(wire.node(), wire.source()) - .map(|(n, inp)| (n, inp, value.clone())), - ); - } - - /// Run the analysis (iterate until a lattice fixpoint is reached), - /// given initial values for some of the root node inputs. - /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, - /// but should handle other containers.) - /// The context passed in allows interpretation of leaf operations. - pub fn run( - mut self, - context: &impl DFContext, - hugr: H, - in_values: impl IntoIterator)>, - ) -> AnalysisResults { - let root = hugr.root(); - self.0 - .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let results = run_datalog(self.0, context, &hugr); - let out_wire_values = results - .out_wire_value - .iter() - .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) - .collect(); - AnalysisResults { - hugr, - results, - out_wire_values, - } - } + pub(super) hugr: H, + pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, + pub(super) case_reachable: Vec<(Node, Node)>, + pub(super) bb_reachable: Vec<(Node, Node)>, + pub(super) out_wire_values: HashMap>, } impl AnalysisResults { @@ -79,8 +28,7 @@ impl AnalysisResults { self.hugr.get_optype(node).as_tail_loop()?; let [_, out] = self.hugr.get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( - self.results - .in_wire_value + self.in_wire_value .iter() .find_map(|(n, p, v)| (*n == out && p.index() == 0).then_some(v)) .unwrap(), @@ -99,8 +47,7 @@ impl AnalysisResults { let cond = self.hugr.get_parent(case)?; self.hugr.get_optype(cond).as_conditional()?; Some( - self.results - .case_reachable + self.case_reachable .iter() .any(|(cond2, case2)| &cond == cond2 && &case == case2), ) @@ -120,8 +67,7 @@ impl AnalysisResults { return None; }; Some( - self.results - .bb_reachable + self.bb_reachable .iter() .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), ) From 0f4fa522643777fc65cfdfcd3bc0a253e7204272 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:02:03 +0100 Subject: [PATCH 142/203] Move machine.rs to results.rs --- hugr-passes/src/dataflow.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index cce6ea97d..5cd0fa7de 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -5,8 +5,8 @@ mod datalog; pub use datalog::Machine; mod value_row; -mod machine; -pub use machine::{AnalysisResults, TailLoopTermination}; +mod results; +pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; From caa8acaa89bfece8bbf00751a03a3ae9f5140e29 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:08:41 +0100 Subject: [PATCH 143/203] Remove enum IO, replace io_node -> input_child/output_child --- hugr-passes/src/dataflow/datalog.rs | 44 ++++++++++++----------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0421567d0..b8b7f18bf 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -2,7 +2,6 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; -use std::hash::Hash; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; use hugr_core::ops::{OpTrait, OpType}; @@ -13,12 +12,6 @@ use super::{AbstractValue, AnalysisResults, DFContext, PartialValue}; type PV = PartialValue; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum IO { - Input, - Output, -} - /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values @@ -79,7 +72,8 @@ pub(super) fn run_datalog( relation in_wire(Node, IncomingPort); relation out_wire(Node, OutgoingPort); relation parent_of_node(Node, Node); - relation io_node(Node, Node, IO); + relation input_child(Node, Node); + relation output_child(Node, Node); lattice out_wire_value(Node, OutgoingPort, PV); lattice in_wire_value(Node, IncomingPort, PV); lattice node_in_value_row(Node, ValueRow); @@ -92,9 +86,8 @@ pub(super) fn run_datalog( parent_of_node(parent, child) <-- node(child), if let Some(parent) = hugr.get_parent(*child); - io_node(parent, child, io) <-- node(parent), - if let Some([i, o]) = hugr.get_io(*parent), - for (child,io) in [(i,IO::Input),(o,IO::Output)]; + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = hugr.get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = hugr.get_io(*parent); // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); @@ -127,10 +120,10 @@ pub(super) fn run_datalog( dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), - io_node(dfg, i, IO::Input), in_wire_value(dfg, p, v); + input_child(dfg, i), in_wire_value(dfg, p, v); out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), - io_node(dfg, o, IO::Output), in_wire_value(o, p, v); + output_child(dfg, o), in_wire_value(o, p, v); // TailLoop @@ -138,23 +131,22 @@ pub(super) fn run_datalog( // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), if hugr.get_optype(*tl).is_tail_loop(), - io_node(tl, i, IO::Input), + input_child(tl, i), in_wire_value(tl, p, v); // Output node of child region propagate to Input node of child region out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), - io_node(tl, in_n, IO::Input), - io_node(tl, out_n, IO::Output), + input_child(tl, in_n), + output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node - if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), - io_node(tl, out_n, IO::Output), + output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 for (out_p, v) in fields.enumerate(); @@ -171,7 +163,7 @@ pub(super) fn run_datalog( // inputs of conditional propagate into case nodes out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- case_node(cond, case_index, case), - io_node(case, i_node, IO::Input), + input_child(case, i_node), node_in_value_row(cond, in_row), let conditional = hugr.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), @@ -181,7 +173,7 @@ pub(super) fn run_datalog( out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- case_node(cond, _i, case), case_reachable(cond, case), - io_node(case, o, IO::Output), + output_child(case, o), in_wire_value(o, o_p, v); relation case_reachable(Node, Node); @@ -200,21 +192,21 @@ pub(super) fn run_datalog( bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), bb_reachable(cfg, pred), - io_node(pred, pred_out, IO::Output), + output_child(pred, pred_out), in_wire_value(pred_out, IncomingPort::from(0), predicate), for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), io_node(blk, inp, IO::Input); + _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), input_child(blk, inp); _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(), - io_node(entry, i_node, IO::Input), + input_child(entry, i_node), in_wire_value(cfg, p, v); // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself @@ -223,7 +215,7 @@ pub(super) fn run_datalog( bb_reachable(cfg, pred), let df_block = hugr.get_optype(*pred).as_dataflow_block().unwrap(), for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), - io_node(pred, out_n, IO::Output), + output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), @@ -238,12 +230,12 @@ pub(super) fn run_datalog( out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), - io_node(func, inp, IO::Input), + input_child(func, inp), in_wire_value(call, p, v); out_wire_value(call, OutgoingPort::from(p.index()), v) <-- func_call(call, func), - io_node(func, outp, IO::Output), + output_child(func, outp), in_wire_value(outp, p, v); }; let out_wire_values = all_results From e7f61fc9016ff12bd0295d2923850a750bbe9ef7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:10:27 +0100 Subject: [PATCH 144/203] move docs --- hugr-passes/src/dataflow/datalog.rs | 3 +-- hugr-passes/src/dataflow/results.rs | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b8b7f18bf..0d4b6be25 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -15,8 +15,7 @@ type PV = PartialValue; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Call [Self::run] to produce [AnalysisResults] which can be inspected via -/// [read_out_wire](AnalysisResults::read_out_wire) +/// 3. Call [Self::run] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); /// derived-Default requires the context to be Defaultable, which is unnecessary diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 0be8072b0..f457ef68c 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -4,7 +4,8 @@ use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, use super::{AbstractValue, PartialValue}; -/// Results of a dataflow analysis, packaged with context for easy inspection +/// Results of a dataflow analysis, packaged with the Hugr for easy inspection. +/// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { pub(super) hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, From 811802ce94149aace3afa014f9061bcfe7faa40f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:33:39 +0100 Subject: [PATCH 145/203] Remove/inline/dedup dfb_block --- hugr-passes/src/dataflow/datalog.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0d4b6be25..f3f94da3f 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -182,9 +182,7 @@ pub(super) fn run_datalog( // CFG relation cfg_node(Node); - relation dfb_block(Node, Node); cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); - dfb_block(cfg, blk) <-- cfg_node(cfg), for blk in hugr.children(*cfg), if hugr.get_optype(blk).is_dataflow_block(); // Reachability relation bb_reachable(Node, Node); @@ -198,7 +196,10 @@ pub(super) fn run_datalog( // Where do the values "fed" along a control-flow edge come out? relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, blk, inp) <-- dfb_block(cfg, blk), input_child(blk, inp); + _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), + for blk in hugr.children(*cfg), + if hugr.get_optype(blk).is_dataflow_block(), + input_child(blk, inp); _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); // Inputs of CFG propagate to entry block @@ -210,9 +211,8 @@ pub(super) fn run_datalog( // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself out_wire_value(dest, OutgoingPort::from(out_p), v) <-- - dfb_block(cfg, pred), bb_reachable(cfg, pred), - let df_block = hugr.get_optype(*pred).as_dataflow_block().unwrap(), + if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(), for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), From 3713ea7e7d5916bc99b71415ecb73c275d7a8469 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:35:27 +0100 Subject: [PATCH 146/203] relation doc --- hugr-passes/src/dataflow/datalog.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index f3f94da3f..d0620e634 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -194,7 +194,8 @@ pub(super) fn run_datalog( for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); - // Where do the values "fed" along a control-flow edge come out? + // Relation: in `CFG` , values fed along a control-flow edge to + // come out of Value outports of . relation _cfg_succ_dest(Node, Node, Node); _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), for blk in hugr.children(*cfg), From d3178097d76ff1954507f70da17ecd4aa02da146 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:46:18 +0100 Subject: [PATCH 147/203] datalog docs (each relation), move _cfg_succ_dest --- hugr-passes/src/dataflow/datalog.rs | 53 +++++++++++++++-------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index d0620e634..a2f22dbd6 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -67,15 +67,15 @@ pub(super) fn run_datalog( )] let all_results = ascent::ascent_run! { pub(super) struct AscentProgram; - relation node(Node); - relation in_wire(Node, IncomingPort); - relation out_wire(Node, OutgoingPort); - relation parent_of_node(Node, Node); - relation input_child(Node, Node); - relation output_child(Node, Node); - lattice out_wire_value(Node, OutgoingPort, PV); - lattice in_wire_value(Node, IncomingPort, PV); - lattice node_in_value_row(Node, ValueRow); + relation node(Node); // exists in the hugr + relation in_wire(Node, IncomingPort); // has an of `EdgeKind::Value` + relation out_wire(Node, OutgoingPort); // has an of `EdgeKind::Value` + relation parent_of_node(Node, Node); // is parent of + relation input_child(Node, Node); // has 1st child that is its `Input` + relation output_child(Node, Node); // has 2nd child that is its `Output` + lattice out_wire_value(Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(Node, ValueRow); // 's inputs are node(n) <-- for n in hugr.nodes(); @@ -95,13 +95,14 @@ pub(super) fn run_datalog( if let Some((m, op)) = hugr.single_linked_output(*n, *ip), out_wire_value(m, op, v); - // We support prepopulating in_wire_value via in_wire_value_proto. + // Prepopulate in_wire_value from in_wire_value_proto. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), node(n), if let Some(sig) = hugr.signature(*n), if sig.input_ports().contains(p); + // Assemble in_value_row from in_value's node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n); node_in_value_row(n, ValueRow::single_known(hugr.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); @@ -115,7 +116,7 @@ pub(super) fn run_datalog( for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG - relation dfg_node(Node); + relation dfg_node(Node); // is a `DFG` dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), @@ -151,7 +152,8 @@ pub(super) fn run_datalog( for (out_p, v) in fields.enumerate(); // Conditional - relation conditional_node(Node); + relation conditional_node(Node); // is a `Conditional` + // is a `Conditional` and its 'th child (a `Case`) is : relation case_node(Node, usize, Node); conditional_node(n)<-- node(n), if hugr.get_optype(*n).is_conditional(); @@ -175,16 +177,17 @@ pub(super) fn run_datalog( output_child(case, o), in_wire_value(o, o_p, v); + // In `Conditional` , child `Case` is reachable given our knowledge of predicate relation case_reachable(Node, Node); case_reachable(cond, case) <-- case_node(cond, i, case), in_wire_value(cond, IncomingPort::from(0), v), if v.supports_tag(*i); // CFG - relation cfg_node(Node); + relation cfg_node(Node); // is a `CFG` cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); - // Reachability + // In `CFG` , basic block is reachable given our knowledge of predicates relation bb_reachable(Node, Node); bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), @@ -194,15 +197,6 @@ pub(super) fn run_datalog( for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); - // Relation: in `CFG` , values fed along a control-flow edge to - // come out of Value outports of . - relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), - for blk in hugr.children(*cfg), - if hugr.get_optype(blk).is_dataflow_block(), - input_child(blk, inp); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); - // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), @@ -210,7 +204,16 @@ pub(super) fn run_datalog( input_child(entry, i_node), in_wire_value(cfg, p, v); - // Outputs of each reachable block propagated to successor block or (if exit block) then CFG itself + // In `CFG` , values fed along a control-flow edge to + // come out of Value outports of . + relation _cfg_succ_dest(Node, Node, Node); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); + _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), + for blk in hugr.children(*cfg), + if hugr.get_optype(blk).is_dataflow_block(), + input_child(blk, inp); + + // Outputs of each reachable block propagated to successor block or CFG itself out_wire_value(dest, OutgoingPort::from(out_p), v) <-- bb_reachable(cfg, pred), if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(), @@ -222,7 +225,7 @@ pub(super) fn run_datalog( for (out_p, v) in fields.enumerate(); // Call - relation func_call(Node, Node); + relation func_call(Node, Node); // is a `Call` to `FuncDefn` func_call(call, func_defn) <-- node(call), if hugr.get_optype(*call).is_call(), From 69c3270a52a49abc3ab9fe8a4117e4c0ebae9bf5 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Oct 2024 16:46:52 +0100 Subject: [PATCH 148/203] comment, use exactly_one --- hugr-passes/src/dataflow/datalog.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index a2f22dbd6..dc793d872 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -263,16 +263,15 @@ fn propagate_leaf_op( num_outs: usize, ) -> Option> { match hugr.get_optype(n) { - // Handle basics here. I guess (given the current interface) we could allow - // DFContext to handle these but at the least we'd want these impls to be - // easily available for reuse. + // Handle basics here. We could instead leave these to DFContext, + // but at least we'd want these impls to be easily reusable. op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( 0, ins.iter().cloned(), )])), op if op.cast::().is_some() => { let elem_tys = op.cast::().unwrap().0; - let [tup] = ins.iter().collect::>().try_into().unwrap(); + let tup = ins.iter().exactly_one().unwrap(); tup.variant_values(0, elem_tys.len()) .map(ValueRow::from_iter) } From dc56686b305f30e321100efd40fe903ef2766f4d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 12:02:00 +0100 Subject: [PATCH 149/203] Allow to handle LoadFunction --- hugr-passes/src/dataflow.rs | 7 +++++++ hugr-passes/src/dataflow/datalog.rs | 14 ++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 5cd0fa7de..d34c3454b 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -6,6 +6,7 @@ pub use datalog::Machine; mod value_row; mod results; +use hugr_core::types::TypeArg; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; @@ -58,6 +59,12 @@ pub trait DFContext { fn value_from_const_hugr(&self, _node: Node, _fields: &[usize], _h: &Hugr) -> Option { None } + + /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node, if possible. + /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + fn value_from_function(&self, _node: Node, _type_args: &[TypeArg]) -> Option { + None + } } fn traverse_value( diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index dc793d872..da2aefb59 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -295,6 +295,20 @@ fn propagate_leaf_op( c.value_from_const(n, const_val), )) } + OpType::LoadFunction(load_op) => { + assert!(ins.is_empty()); // static edge + let func_node = hugr + .single_linked_output(n, load_op.function_port()) + .unwrap() + .0; + // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself + Some(ValueRow::single_known( + 1, + 0, + c.value_from_function(func_node, &load_op.type_args) + .map_or(PV::Top, PV::Value), + )) + } OpType::ExtensionOp(e) => { // Interpret op. Default is we know nothing about the outputs (they still happen!) let mut outs = vec![PartialValue::Top; num_outs]; From 33a85923ae9f03bb43fdfe1b25ea8ead69ce9e92 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 14:07:05 +0100 Subject: [PATCH 150/203] doc --- hugr-passes/src/dataflow.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index d34c3454b..8510f8f9b 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -60,8 +60,13 @@ pub trait DFContext { None } - /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node, if possible. + /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node (that has been loaded + /// via a [LoadFunction]), if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. + /// + /// [FuncDefn]: hugr_core::ops::FuncDefn + /// [FuncDecl]: hugr_core::ops::FuncDecl + /// [LoadFunction]: hugr_core::ops::LoadFunction fn value_from_function(&self, _node: Node, _type_args: &[TypeArg]) -> Option { None } From b153ada208ee3a47fde5b2ebe3ba296c184e0238 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 14:16:33 +0100 Subject: [PATCH 151/203] Separate out ConstLoader --- hugr-passes/src/dataflow.rs | 9 +++++++-- hugr-passes/src/dataflow/test.rs | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 8510f8f9b..c764a2ea0 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -17,7 +17,7 @@ use hugr_core::{Hugr, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext { +pub trait DFContext: ConstLoader { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] @@ -34,7 +34,12 @@ pub trait DFContext { _outs: &mut [PartialValue], ) { } +} +/// Trait for loading [PartialValue]s from constants in a Hugr. The default +/// traverses [Sum](Value::Sum) constants to their non-Sum leaves but represents +/// each leaf as [PartialValue::Top]. +pub trait ConstLoader { /// Produces an abstract value from a constant. The default impl /// traverses the constant [Value] to its leaves ([Value::Extension] and [Value::Function]), /// converts these using [Self::value_from_custom_const] and [Self::value_from_const_hugr], @@ -73,7 +78,7 @@ pub trait DFContext { } fn traverse_value( - s: &(impl DFContext + ?Sized), + s: &(impl ConstLoader + ?Sized), n: Node, fields: &mut Vec, cst: &Value, diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index fe4af1c46..d6721d8ed 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -15,7 +15,7 @@ use hugr_core::{ use hugr_core::{Hugr, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -25,6 +25,7 @@ impl AbstractValue for Void {} struct TestContext; +impl ConstLoader for TestContext {} impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) From 5052ac0268cffcec9c0924edcf28b095314a67be Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Oct 2024 16:10:05 +0100 Subject: [PATCH 152/203] value_from_(custom_const=>opaque), taking &OpaqueValue --- hugr-passes/src/dataflow.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index c764a2ea0..23db00383 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -6,14 +6,15 @@ pub use datalog::Machine; mod value_row; mod results; -use hugr_core::types::TypeArg; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; -use hugr_core::ops::{constant::CustomConst, ExtensionOp, Value}; use hugr_core::{Hugr, Node}; +use hugr_core::ops::{ExtensionOp, Value}; +use hugr_core::ops::constant::OpaqueValue; +use hugr_core::types::TypeArg; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). @@ -48,13 +49,13 @@ pub trait ConstLoader { traverse_value(self, n, &mut Vec::new(), cst) } - /// Produces an abstract value from a [CustomConst], if possible. + /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_custom_const( + fn value_from_opaque( &self, _node: Node, _fields: &[usize], - _cc: &dyn CustomConst, + _val: &OpaqueValue, ) -> Option { None } @@ -94,7 +95,7 @@ fn traverse_value( PartialValue::new_variant(*tag, elems) } Value::Extension { e } => s - .value_from_custom_const(n, fields, e.value()) + .value_from_opaque(n, fields, e) .map(PartialValue::from) .unwrap_or(PartialValue::Top), Value::Function { hugr } => s From ad0c6f24e31fbe9413c98251244a033025600002 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 14:54:33 +0000 Subject: [PATCH 153/203] Recombine DFContext with Hugr i.e. reinstate Deref constraint --- hugr-passes/src/dataflow.rs | 13 ++-- hugr-passes/src/dataflow/datalog.rs | 99 ++++++++++++++--------------- hugr-passes/src/dataflow/results.rs | 8 +-- hugr-passes/src/dataflow/test.rs | 31 +++++---- 4 files changed, 74 insertions(+), 77 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 23db00383..2ac927552 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -11,14 +11,14 @@ pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; -use hugr_core::{Hugr, Node}; -use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::ops::constant::OpaqueValue; +use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::types::TypeArg; +use hugr_core::{Hugr, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader { +pub trait DFContext: ConstLoader + std::ops::Deref { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] @@ -51,12 +51,7 @@ pub trait ConstLoader { /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_opaque( - &self, - _node: Node, - _fields: &[usize], - _val: &OpaqueValue, - ) -> Option { + fn value_from_opaque(&self, _node: Node, _fields: &[usize], _val: &OpaqueValue) -> Option { None } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index da2aefb59..c684ac9cd 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -40,24 +40,22 @@ impl Machine { /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, /// but should handle other containers.) /// The context passed in allows interpretation of leaf operations. - pub fn run( + pub fn run>( mut self, - context: &impl DFContext, - hugr: H, + context: C, in_values: impl IntoIterator)>, - ) -> AnalysisResults { - let root = hugr.root(); + ) -> AnalysisResults { + let root = context.root(); self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - run_datalog(self.0, context, hugr) + run_datalog(context, self.0) } } -pub(super) fn run_datalog( +pub(super) fn run_datalog>( + ctx: C, in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, - c: &impl DFContext, - hugr: H, -) -> AnalysisResults { +) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. #![allow( @@ -77,47 +75,47 @@ pub(super) fn run_datalog( lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value lattice node_in_value_row(Node, ValueRow); // 's inputs are - node(n) <-- for n in hugr.nodes(); + node(n) <-- for n in ctx.nodes(); - in_wire(n, p) <-- node(n), for (p,_) in hugr.in_value_types(*n); // Note, gets connected inports only - out_wire(n, p) <-- node(n), for (p,_) in hugr.out_value_types(*n); // (and likewise) + in_wire(n, p) <-- node(n), for (p,_) in ctx.in_value_types(*n); // Note, gets connected inports only + out_wire(n, p) <-- node(n), for (p,_) in ctx.out_value_types(*n); // (and likewise) parent_of_node(parent, child) <-- - node(child), if let Some(parent) = hugr.get_parent(*child); + node(child), if let Some(parent) = ctx.get_parent(*child); - input_child(parent, input) <-- node(parent), if let Some([input, _output]) = hugr.get_io(*parent); - output_child(parent, output) <-- node(parent), if let Some([_input, output]) = hugr.get_io(*parent); + input_child(parent, input) <-- node(parent), if let Some([input, _output]) = ctx.get_io(*parent); + output_child(parent, output) <-- node(parent), if let Some([_input, output]) = ctx.get_io(*parent); // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); in_wire_value(n, ip, v) <-- in_wire(n, ip), - if let Some((m, op)) = hugr.single_linked_output(*n, *ip), + if let Some((m, op)) = ctx.single_linked_output(*n, *ip), out_wire_value(m, op, v); // Prepopulate in_wire_value from in_wire_value_proto. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), node(n), - if let Some(sig) = hugr.signature(*n), + if let Some(sig) = ctx.signature(*n), if sig.input_ports().contains(p); // Assemble in_value_row from in_value's - node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = hugr.signature(*n); - node_in_value_row(n, ValueRow::single_known(hugr.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); + node_in_value_row(n, ValueRow::single_known(ctx.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); out_wire_value(n, p, v) <-- node(n), - let op_t = hugr.get_optype(*n), + let op_t = ctx.get_optype(*n), if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), node_in_value_row(n, vs), - if let Some(outs) = propagate_leaf_op(c, &hugr, *n, &vs[..], sig.output_count()), + if let Some(outs) = propagate_leaf_op(&ctx, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); // DFG relation dfg_node(Node); // is a `DFG` - dfg_node(n) <-- node(n), if hugr.get_optype(*n).is_dfg(); + dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), input_child(dfg, i), in_wire_value(dfg, p, v); @@ -130,13 +128,13 @@ pub(super) fn run_datalog( // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), - if hugr.get_optype(*tl).is_tail_loop(), + if ctx.get_optype(*tl).is_tail_loop(), input_child(tl, i), in_wire_value(tl, p, v); // Output node of child region propagate to Input node of child region out_wire_value(in_n, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), input_child(tl, in_n), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node @@ -145,7 +143,7 @@ pub(super) fn run_datalog( // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), - if let Some(tailloop) = hugr.get_optype(*tl).as_tail_loop(), + if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 @@ -156,17 +154,17 @@ pub(super) fn run_datalog( // is a `Conditional` and its 'th child (a `Case`) is : relation case_node(Node, usize, Node); - conditional_node(n)<-- node(n), if hugr.get_optype(*n).is_conditional(); + conditional_node(n)<-- node(n), if ctx.get_optype(*n).is_conditional(); case_node(cond, i, case) <-- conditional_node(cond), - for (i, case) in hugr.children(*cond).enumerate(), - if hugr.get_optype(case).is_case(); + for (i, case) in ctx.children(*cond).enumerate(), + if ctx.get_optype(case).is_case(); // inputs of conditional propagate into case nodes out_wire_value(i_node, OutgoingPort::from(out_p), v) <-- case_node(cond, case_index, case), input_child(case, i_node), node_in_value_row(cond, in_row), - let conditional = hugr.get_optype(*cond).as_conditional().unwrap(), + let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); @@ -185,39 +183,39 @@ pub(super) fn run_datalog( // CFG relation cfg_node(Node); // is a `CFG` - cfg_node(n) <-- node(n), if hugr.get_optype(*n).is_cfg(); + cfg_node(n) <-- node(n), if ctx.get_optype(*n).is_cfg(); // In `CFG` , basic block is reachable given our knowledge of predicates relation bb_reachable(Node, Node); - bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = hugr.children(*cfg).next(); + bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), bb_reachable(cfg, pred), output_child(pred, pred_out), in_wire_value(pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in hugr.output_neighbours(*pred).enumerate(), + for (tag, bb) in ctx.output_neighbours(*pred).enumerate(), if predicate.supports_tag(tag); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- cfg_node(cfg), - if let Some(entry) = hugr.children(*cfg).next(), + if let Some(entry) = ctx.children(*cfg).next(), input_child(entry, i_node), in_wire_value(cfg, p, v); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of . relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = hugr.children(*cfg).nth(1); + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.children(*cfg).nth(1); _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), - for blk in hugr.children(*cfg), - if hugr.get_optype(blk).is_dataflow_block(), + for blk in ctx.children(*cfg), + if ctx.get_optype(blk).is_dataflow_block(), input_child(blk, inp); // Outputs of each reachable block propagated to successor block or CFG itself out_wire_value(dest, OutgoingPort::from(out_p), v) <-- bb_reachable(cfg, pred), - if let Some(df_block) = hugr.get_optype(*pred).as_dataflow_block(), - for (succ_n, succ) in hugr.output_neighbours(*pred).enumerate(), + if let Some(df_block) = ctx.get_optype(*pred).as_dataflow_block(), + for (succ_n, succ) in ctx.output_neighbours(*pred).enumerate(), output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), @@ -228,8 +226,8 @@ pub(super) fn run_datalog( relation func_call(Node, Node); // is a `Call` to `FuncDefn` func_call(call, func_defn) <-- node(call), - if hugr.get_optype(*call).is_call(), - if let Some(func_defn) = hugr.static_source(*call); + if ctx.get_optype(*call).is_call(), + if let Some(func_defn) = ctx.static_source(*call); out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), @@ -247,7 +245,7 @@ pub(super) fn run_datalog( .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { - hugr, + hugr: ctx, out_wire_values, in_wire_value: all_results.in_wire_value, case_reachable: all_results.case_reachable, @@ -256,13 +254,12 @@ pub(super) fn run_datalog( } fn propagate_leaf_op( - c: &impl DFContext, - hugr: &impl HugrView, + ctx: &impl DFContext, n: Node, ins: &[PV], num_outs: usize, ) -> Option> { - match hugr.get_optype(n) { + match ctx.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. op if op.cast::().is_some() => Some(ValueRow::from_iter([PV::new_variant( @@ -284,20 +281,20 @@ fn propagate_leaf_op( OpType::Const(_) => None, // handled by LoadConstant: OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant - let const_node = hugr + let const_node = ctx .single_linked_output(n, load_op.constant_port()) .unwrap() .0; - let const_val = hugr.get_optype(const_node).as_const().unwrap().value(); + let const_val = ctx.get_optype(const_node).as_const().unwrap().value(); Some(ValueRow::single_known( 1, 0, - c.value_from_const(n, const_val), + ctx.value_from_const(n, const_val), )) } OpType::LoadFunction(load_op) => { assert!(ins.is_empty()); // static edge - let func_node = hugr + let func_node = ctx .single_linked_output(n, load_op.function_port()) .unwrap() .0; @@ -305,7 +302,7 @@ fn propagate_leaf_op( Some(ValueRow::single_known( 1, 0, - c.value_from_function(func_node, &load_op.type_args) + ctx.value_from_function(func_node, &load_op.type_args) .map_or(PV::Top, PV::Value), )) } @@ -315,7 +312,7 @@ fn propagate_leaf_op( // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues // are not necessarily convertible to Value. - c.interpret_leaf_op(n, e, ins, &mut outs[..]); + ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); Some(ValueRow::from_iter(outs)) } o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index f457ef68c..3e37e2ce9 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,11 +2,11 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, PartialValue}; +use super::{AbstractValue, DFContext, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults { +pub struct AnalysisResults { pub(super) hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, @@ -14,7 +14,7 @@ pub struct AnalysisResults { pub(super) out_wire_values: HashMap>, } -impl AnalysisResults { +impl> AnalysisResults { /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() @@ -75,7 +75,7 @@ impl AnalysisResults { } } -impl AnalysisResults +impl> AnalysisResults where Value: From, { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index d6721d8ed..e00827fa8 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -23,10 +23,16 @@ enum Void {} impl AbstractValue for Void {} -struct TestContext; +struct TestContext(H); -impl ConstLoader for TestContext {} -impl DFContext for TestContext {} +impl std::ops::Deref for TestContext { + type Target = Hugr; + fn deref(&self) -> &Hugr { + self.0.base_hugr() + } +} +impl ConstLoader for TestContext {} +impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) impl From for Value { @@ -55,7 +61,7 @@ fn test_make_tuple() { let v3 = builder.make_tuple([v1, v2]).unwrap(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(hugr), []); let x = results.try_read_wire_value(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); @@ -71,7 +77,7 @@ fn test_unpack_tuple_const() { .outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(hugr), []); let o1_r = results.try_read_wire_value(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); @@ -94,7 +100,7 @@ fn test_tail_loop_never_iterates() { let [tl_o] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(hugr), []); let o_r = results.try_read_wire_value(tl_o).unwrap(); assert_eq!(o_r, r_v); @@ -123,7 +129,7 @@ fn test_tail_loop_always_iterates() { let [tl_o1, tl_o2] = tail_loop.outputs_arr(); let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(&hugr), []); let o_r1 = results.read_out_wire(tl_o1).unwrap(); assert_eq!(o_r1, PartialValue::bottom()); @@ -161,7 +167,7 @@ fn test_tail_loop_two_iters() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(&hugr), []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true_or_false()); @@ -227,7 +233,7 @@ fn test_tail_loop_containing_conditional() { let hugr = builder.finish_hugr(&EMPTY_REG).unwrap(); let [o_w1, o_w2] = tail_loop.outputs_arr(); - let results = Machine::default().run(&TestContext, &hugr, []); + let results = Machine::default().run(TestContext(&hugr), []); let o_r1 = results.read_out_wire(o_w1).unwrap(); assert_eq!(o_r1, pv_true()); @@ -279,7 +285,7 @@ fn test_conditional() { 2, [PartialValue::new_variant(0, [])], )); - let results = Machine::default().run(&TestContext, &hugr, [(0.into(), arg_pv)]); + let results = Machine::default().run(TestContext(hugr), [(0.into(), arg_pv)]); let cond_r1 = results.try_read_wire_value(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); @@ -403,8 +409,7 @@ fn test_cfg( ) { let root = xor_and_cfg.root(); let results = Machine::default().run( - &TestContext, - &xor_and_cfg, + TestContext(xor_and_cfg), [(0.into(), inp0), (1.into(), inp1)], ); @@ -440,7 +445,7 @@ fn test_call( .finish_hugr_with_outputs([a2, b2], &EMPTY_REG) .unwrap(); - let results = Machine::default().run(&TestContext, &hugr, [(0.into(), inp0), (1.into(), inp1)]); + let results = Machine::default().run(TestContext(&hugr), [(0.into(), inp0), (1.into(), inp1)]); let [res0, res1] = [0, 1].map(|i| results.read_out_wire(Wire::new(hugr.root(), i)).unwrap()); // The two calls alias so both results will be the same: From 16a18f4bfe048350116f0e5e5a003567faa4d072 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 14:59:44 +0000 Subject: [PATCH 154/203] Replace Deref with HugrView, trivially obtainable by implementing AsRef --- hugr-passes/src/dataflow.rs | 4 ++-- hugr-passes/src/dataflow/results.rs | 8 ++++---- hugr-passes/src/dataflow/test.rs | 5 ++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 2ac927552..c960fb7de 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -14,11 +14,11 @@ pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; use hugr_core::types::TypeArg; -use hugr_core::{Hugr, Node}; +use hugr_core::{Hugr, HugrView, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader + std::ops::Deref { +pub trait DFContext: ConstLoader + HugrView { /// Given lattice values for each input, update lattice values for the (dataflow) outputs. /// For extension ops only, excluding [MakeTuple] and [UnpackTuple]. /// `_outs` is an array with one element per dataflow output, each initialized to [PartialValue::Top] diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 3e37e2ce9..f457ef68c 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,11 +2,11 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, DFContext, PartialValue}; +use super::{AbstractValue, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults { +pub struct AnalysisResults { pub(super) hugr: H, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, @@ -14,7 +14,7 @@ pub struct AnalysisResults { pub(super) out_wire_values: HashMap>, } -impl> AnalysisResults { +impl AnalysisResults { /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() @@ -75,7 +75,7 @@ impl> AnalysisResults { } } -impl> AnalysisResults +impl AnalysisResults where Value: From, { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e00827fa8..e73cc8e4f 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -25,9 +25,8 @@ impl AbstractValue for Void {} struct TestContext(H); -impl std::ops::Deref for TestContext { - type Target = Hugr; - fn deref(&self) -> &Hugr { +impl AsRef for TestContext { + fn as_ref(&self) -> &Hugr { self.0.base_hugr() } } From a0f2b2ca1a4418b8b66f2208bdbe09d7adcc9d95 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 28 Oct 2024 17:33:18 +0000 Subject: [PATCH 155/203] ValueRow Debug; ops default to PartialValue::Top less aggressively --- hugr-passes/src/dataflow/datalog.rs | 12 ++++++++++-- hugr-passes/src/dataflow/value_row.rs | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index c684ac9cd..2cc535d89 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -307,8 +307,16 @@ fn propagate_leaf_op( )) } OpType::ExtensionOp(e) => { - // Interpret op. Default is we know nothing about the outputs (they still happen!) - let mut outs = vec![PartialValue::Top; num_outs]; + // Interpret op. + let init = if ins.iter().contains(&PartialValue::Bottom) { + // So far we think one or more inputs can't happen. + // So, don't pollute outputs with Top, and wait for better knowledge of inputs. + PartialValue::Bottom + } else { + // If we can't figure out anything about the outputs, assume nothing (they still happen!) + PartialValue::Top + }; + let mut outs = vec![init; num_outs]; // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, // thus keeping PartialValue hidden, but AbstractValues // are not necessarily convertible to Value. diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index ebdbf1b75..0d8bc15a6 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -10,7 +10,7 @@ use itertools::zip_eq; use super::{AbstractValue, PartialValue}; -#[derive(PartialEq, Clone, Eq, Hash)] +#[derive(PartialEq, Clone, Debug, Eq, Hash)] pub(super) struct ValueRow(Vec>); impl ValueRow { From 87eb700565aba68c29feb26aeef699c48586bbd9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 29 Oct 2024 19:52:40 +0000 Subject: [PATCH 156/203] And back to Deref, should allow using a region view not the whole Hugr --- hugr-passes/src/dataflow.rs | 6 +++++- hugr-passes/src/dataflow/datalog.rs | 4 ++-- hugr-passes/src/dataflow/results.rs | 33 +++++++++++++++++------------ hugr-passes/src/dataflow/test.rs | 11 ++++++---- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index c960fb7de..69942eafc 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -18,7 +18,11 @@ use hugr_core::{Hugr, HugrView, Node}; /// Clients of the dataflow framework (particular analyses, such as constant folding) /// must implement this trait (including providing an appropriate domain type `V`). -pub trait DFContext: ConstLoader + HugrView { +pub trait DFContext: ConstLoader + std::ops::Deref { + /// Type of view contained within this context. (Ideally we'd constrain + /// by `std::ops::Deref>( .map(|(n, p, v)| (Wire::new(*n, *p), v.clone())) .collect(); AnalysisResults { - hugr: ctx, + ctx, out_wire_values, in_wire_value: all_results.in_wire_value, case_reachable: all_results.case_reachable, @@ -308,7 +308,7 @@ fn propagate_leaf_op( } OpType::ExtensionOp(e) => { // Interpret op. - let init = if ins.iter().contains(&PartialValue::Bottom) { + let init = if ins.iter().contains(&PartialValue::Bottom) { // So far we think one or more inputs can't happen. // So, don't pollute outputs with Top, and wait for better knowledge of inputs. PartialValue::Bottom diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index f457ef68c..b18a3c704 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,19 +2,24 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, PartialValue}; +use super::{AbstractValue, DFContext, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). -pub struct AnalysisResults { - pub(super) hugr: H, +pub struct AnalysisResults> { + pub(super) ctx: C, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, pub(super) case_reachable: Vec<(Node, Node)>, pub(super) bb_reachable: Vec<(Node, Node)>, pub(super) out_wire_values: HashMap>, } -impl AnalysisResults { +impl> AnalysisResults { + /// Allows to use the [HugrView] contained within + pub fn hugr(&self) -> &C::View { + &self.ctx + } + /// Gets the lattice value computed for the given wire pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() @@ -26,8 +31,8 @@ impl AnalysisResults { /// /// [TailLoop]: hugr_core::ops::TailLoop pub fn tail_loop_terminates(&self, node: Node) -> Option { - self.hugr.get_optype(node).as_tail_loop()?; - let [_, out] = self.hugr.get_io(node).unwrap(); + self.hugr().get_optype(node).as_tail_loop()?; + let [_, out] = self.hugr().get_io(node).unwrap(); Some(TailLoopTermination::from_control_value( self.in_wire_value .iter() @@ -44,9 +49,9 @@ impl AnalysisResults { /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional pub fn case_reachable(&self, case: Node) -> Option { - self.hugr.get_optype(case).as_case()?; - let cond = self.hugr.get_parent(case)?; - self.hugr.get_optype(cond).as_conditional()?; + self.hugr().get_optype(case).as_case()?; + let cond = self.hugr().get_parent(case)?; + self.hugr().get_optype(cond).as_conditional()?; Some( self.case_reachable .iter() @@ -61,9 +66,9 @@ impl AnalysisResults { /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock pub fn bb_reachable(&self, bb: Node) -> Option { - let cfg = self.hugr.get_parent(bb)?; // Not really required...?? - self.hugr.get_optype(cfg).as_cfg()?; - let t = self.hugr.get_optype(bb); + let cfg = self.hugr().get_parent(bb)?; // Not really required...?? + self.hugr().get_optype(cfg).as_cfg()?; + let t = self.hugr().get_optype(bb); if !t.is_dataflow_block() && !t.is_exit_block() { return None; }; @@ -75,7 +80,7 @@ impl AnalysisResults { } } -impl AnalysisResults +impl> AnalysisResults where Value: From, { @@ -93,7 +98,7 @@ where pub fn try_read_wire_value(&self, w: Wire) -> Result> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self - .hugr + .hugr() .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index e73cc8e4f..c7d4e7b7e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -25,13 +25,16 @@ impl AbstractValue for Void {} struct TestContext(H); -impl AsRef for TestContext { - fn as_ref(&self) -> &Hugr { - self.0.base_hugr() +impl std::ops::Deref for TestContext { + type Target = H; + fn deref(&self) -> &H { + &self.0 } } impl ConstLoader for TestContext {} -impl DFContext for TestContext {} +impl DFContext for TestContext { + type View = H; +} // This allows testing creation of tuple/sum Values (only) impl From for Value { From a49221b3b172ac3bb514a763d072c66d8e1747cd Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 10:49:15 +0000 Subject: [PATCH 157/203] fix doclink --- hugr-passes/src/dataflow.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 69942eafc..0b7bf03ed 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -47,7 +47,7 @@ pub trait DFContext: ConstLoader + std::ops::Deref { pub trait ConstLoader { /// Produces an abstract value from a constant. The default impl /// traverses the constant [Value] to its leaves ([Value::Extension] and [Value::Function]), - /// converts these using [Self::value_from_custom_const] and [Self::value_from_const_hugr], + /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { traverse_value(self, n, &mut Vec::new(), cst) From 71ea55d1e5e7970d75d0ea57c34a41c64557806a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 10:40:21 +0000 Subject: [PATCH 158/203] PartialSum::try_into_value also uses Option<...> as error-type --- hugr-passes/src/dataflow/partial_value.rs | 59 ++++++++++++----------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 5b5695dd0..8cd5c12fb 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -141,41 +141,37 @@ impl PartialSum { self.0.contains_key(&tag) } - /// Turns this instance into a [Sum] if it has exactly one possible tag, - /// otherwise failing and returning itself back unmodified (also if there is another - /// error, e.g. this instance is not described by `typ`). - // ALAN is this too parametric? Should we fix V2 == Value? Is the 'Self' error useful (no?) + /// Turns this instance into a [Sum] of some target value type `V2`, + /// *if* this PartialSum has exactly one possible tag. + /// + /// # Errors + /// `None` if this PartialSum had multiple possible tags; or, if there was a single + /// tag, but `typ` was not a [TypeEnum::Sum] supporting that tag and containing no + /// row variables within that variant and of the correct number of variants + /// `Some(e)` if none of the error conditions above applied, but there was an error + /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] pub fn try_into_value + TryFrom>>( self, typ: &Type, - ) -> Result, Self> { - let Ok((k, v)) = self.0.iter().exactly_one() else { - Err(self)? - }; - + ) -> Result, Option<>>::Error>> { + let (k, v) = self.0.iter().exactly_one().map_err(|_| None)?; let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(self)? + Err(None)? }; let Some(r) = st.get_variant(*k) else { - Err(self)? - }; - let Ok(r) = TypeRow::try_from(r.clone()) else { - Err(self)? + Err(None)? }; + let r: TypeRow = r.clone().try_into().map_err(|_| None)?; if v.len() != r.len() { - return Err(self); - } - match zip_eq(v, r.iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>() - { - Ok(values) => Ok(Sum { - tag: *k, - values, - st: st.clone(), - }), - Err(_) => Err(self), + return Err(None); } + Ok(Sum { + tag: *k, + values: zip_eq(v, r.iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>()?, + st: st.clone(), + }) } } @@ -301,8 +297,13 @@ impl PartialValue { } } - /// Extracts a value (in any representation supporting both leaf values and sums) - // ALAN is this too parametric? Should we fix V2 == Value? Is the error useful (should we have 'Self') or is it a smell? + /// Turns this instance into a target value type `V2` if it is a single value, + /// or a [PartialValue::PartialSum] convertible by [PartialSum::try_into_value]. + /// + /// # Errors + /// + /// `None` if this is [Bottom](PartialValue::Bottom) or [Top](PartialValue::Top), + /// otherwise as per [PartialSum::try_into_value] pub fn try_into_value + TryFrom>>( self, typ: &Type, @@ -310,7 +311,7 @@ impl PartialValue { match self { Self::Value(v) => Ok(V2::from(v.clone())), Self::PartialSum(ps) => { - let v = ps.try_into_value(typ).map_err(|_| None)?; + let v = ps.try_into_value(typ)?; V2::try_from(v).map_err(Some) } _ => Err(None), From ea9db2e964e6e9345a4fab747fd0121d5f6f292a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 11:30:02 +0000 Subject: [PATCH 159/203] Proper errors from try_into_value, Option from try_read_wire_value --- hugr-passes/src/dataflow/partial_value.rs | 66 ++++++++++++++++------- hugr-passes/src/dataflow/results.rs | 13 +++-- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 8cd5c12fb..b03e2d842 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -6,6 +6,7 @@ use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; +use thiserror::Error; /// Trait for an underlying domain of abstract values which can form the *elements* of a /// [PartialValue] and thus be used in dataflow analysis. @@ -150,31 +151,57 @@ impl PartialSum { /// row variables within that variant and of the correct number of variants /// `Some(e)` if none of the error conditions above applied, but there was an error /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] - pub fn try_into_value + TryFrom>>( + pub fn try_into_value + TryFrom, Error=E>>( self, typ: &Type, - ) -> Result, Option<>>::Error>> { - let (k, v) = self.0.iter().exactly_one().map_err(|_| None)?; - let TypeEnum::Sum(st) = typ.as_type_enum() else { - Err(None)? - }; - let Some(r) = st.get_variant(*k) else { - Err(None)? + ) -> Result, ExtractValueError> { + let Ok((k, v)) = self.0.iter().exactly_one() else { + return Err(ExtractValueError::MultipleVariants(self)); }; - let r: TypeRow = r.clone().try_into().map_err(|_| None)?; - if v.len() != r.len() { - return Err(None); + if let TypeEnum::Sum(st) = typ.as_type_enum() { + if let Some(r) = st.get_variant(*k) { + if let Ok(r) = TypeRow::try_from(r.clone()) { + if v.len() == r.len() { + return Ok(Sum { + tag: *k, + values: zip_eq(v, r.iter()) + .map(|(v, t)| v.clone().try_into_value(t)) + .collect::, _>>()?, + st: st.clone(), + }); + } + } + } } - Ok(Sum { + Err(ExtractValueError::BadSumType { + typ: typ.clone(), tag: *k, - values: zip_eq(v, r.iter()) - .map(|(v, t)| v.clone().try_into_value(t)) - .collect::, _>>()?, - st: st.clone(), + num_elements: v.len(), }) } } +#[derive(Clone, Debug, PartialEq, Eq, Error)] +#[allow(missing_docs)] +pub enum ExtractValueError { + #[error("PartialSum value had multiple possible tags: {0}")] + MultipleVariants(PartialSum), + #[error("Value contained `Bottom`")] + ValueIsBottom, + #[error("Value contained `Top`")] + ValueIsTop, + #[error("Could not convert element from abstract value into concrete: {0}")] + CouldNotConvert(V, #[source] E), + #[error("Could not build Sum from concrete element values")] + CouldNotBuildSum(#[source] E), + #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] + BadSumType { + typ: Type, + tag: usize, + num_elements: usize, + }, +} + impl PartialSum { /// If this Sum might have the specified `tag`, get the elements inside that tag. pub fn variant_values(&self, variant: usize) -> Option>> { @@ -307,14 +334,15 @@ impl PartialValue { pub fn try_into_value + TryFrom>>( self, typ: &Type, - ) -> Result>>::Error>> { + ) -> Result>>::Error>> { match self { Self::Value(v) => Ok(V2::from(v.clone())), Self::PartialSum(ps) => { let v = ps.try_into_value(typ)?; - V2::try_from(v).map_err(Some) + V2::try_from(v).map_err(ExtractValueError::CouldNotBuildSum) } - _ => Err(None), + Self::Top => Err(ExtractValueError::ValueIsTop), + Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index b18a3c704..f51c6353e 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{AbstractValue, DFContext, PartialValue}; +use super::{partial_value::ExtractValueError, AbstractValue, DFContext, PartialValue}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). @@ -89,20 +89,23 @@ where /// a [Sum](PartialValue::PartialSum) with a single known tag.) /// /// # Errors - /// `None` if the analysis did not result in a single value on that wire - /// `Some(e)` if conversion to a [Value] produced a [ConstTypeError] + /// `None` if the analysis did not produce a result for that wire + /// `Some(e)` if conversion to a [Value] failed with error `e` /// /// # Panics /// /// If a [Type](hugr_core::types::Type) for the specified wire could not be extracted from the Hugr - pub fn try_read_wire_value(&self, w: Wire) -> Result> { + pub fn try_read_wire_value( + &self, + w: Wire, + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr() .out_value_types(w.node()) .find(|(p, _)| *p == w.source()) .unwrap(); - v.try_into_value(&typ) + v.try_into_value(&typ).map_err(Some) } } From df3152356bc7a5657ba5289c430dde72599e320c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 11:53:41 +0000 Subject: [PATCH 160/203] fmt --- hugr-passes/src/dataflow/partial_value.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index b03e2d842..648034400 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -151,7 +151,7 @@ impl PartialSum { /// row variables within that variant and of the correct number of variants /// `Some(e)` if none of the error conditions above applied, but there was an error /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] - pub fn try_into_value + TryFrom, Error=E>>( + pub fn try_into_value + TryFrom, Error = E>>( self, typ: &Type, ) -> Result, ExtractValueError> { From a490874e98fdc683aa890388de0fb1b49b5a270f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 12:23:01 +0000 Subject: [PATCH 161/203] Add test running on region --- hugr-passes/src/dataflow/test.rs | 55 ++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index c7d4e7b7e..80bd484ef 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,6 +1,8 @@ use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; +use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; +use hugr_core::ops::handle::DfgID; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -454,3 +456,56 @@ fn test_call( assert_eq!(res0, out); assert_eq!(res1, out); } + +#[test] +fn test_region() { + let mut builder = + DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T;2])).unwrap(); + let [in_w] = builder.input_wires_arr(); + let cst_w = builder.add_load_const(Value::false_val()); + let nested = builder + .dfg_builder(Signature::new_endo(type_row![BOOL_T; 2]), [in_w, cst_w]) + .unwrap(); + let nested_ins = nested.input_wires(); + let nested = nested.finish_with_outputs(nested_ins).unwrap(); + let hugr = builder + .finish_prelude_hugr_with_outputs(nested.outputs()) + .unwrap(); + let [nested_input, _] = hugr.get_io(nested.node()).unwrap(); + let whole_hugr_results = Machine::default().run(TestContext(&hugr), [(0.into(), pv_true())]); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(pv_false()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 0)), + Some(pv_true()) + ); + assert_eq!( + whole_hugr_results.read_out_wire(Wire::new(hugr.root(), 1)), + Some(pv_false()) + ); + + let subview = DescendantsGraph::::try_new(&hugr, nested.node()).unwrap(); + // Do not provide a value on the second input (constant false in the whole hugr, above) + let sub_hugr_results = Machine::default().run(TestContext(subview), [(0.into(), pv_true())]); + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), + Some(pv_true()) + ); + // TODO this should really be `Top` - safety says we have to assume it could be anything, not that it can't happen + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(nested_input, 1)), + Some(PartialValue::Bottom) + ); + for w in [0, 1] { + assert_eq!( + sub_hugr_results.read_out_wire(Wire::new(hugr.root(), w)), + None + ); + } +} From 3af39aa0a133f41a7c5bf38e027982722efb5b01 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 12:42:06 +0000 Subject: [PATCH 162/203] fix: provide PartialValue::Top for unspecified Hugr inputs --- hugr-passes/src/dataflow/datalog.rs | 20 ++++++++++++++++++++ hugr-passes/src/dataflow/test.rs | 3 +-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 4fc54b511..facab8595 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -48,6 +48,26 @@ impl Machine { let root = context.root(); self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis + // (Consider: for a conditional that selects *either* the unknown input *or* value V, + // analysis must produce Top == we-know-nothing, not `V` !) + let mut have_inputs = + vec![false; context.signature(root).unwrap_or_default().input_count()]; + self.0.iter().for_each(|(n, p, _)| { + if n == &root { + if let Some(e) = have_inputs.get_mut(p.index()) { + *e = true; + } + } + }); + for (i, b) in have_inputs.into_iter().enumerate() { + if !b { + self.0 + .push((root, IncomingPort::from(i), PartialValue::Top)); + } + } + // Note/TODO, if analysis is running on a subregion then we should do similar + // for any nonlocal edges providing values from outside the region. run_datalog(context, self.0) } } diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 80bd484ef..8ec0f9dee 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -497,10 +497,9 @@ fn test_region() { sub_hugr_results.read_out_wire(Wire::new(nested_input, 0)), Some(pv_true()) ); - // TODO this should really be `Top` - safety says we have to assume it could be anything, not that it can't happen assert_eq!( sub_hugr_results.read_out_wire(Wire::new(nested_input, 1)), - Some(PartialValue::Bottom) + Some(PartialValue::Top) ); for w in [0, 1] { assert_eq!( From dc159993ec4567a21708907f4d369bfb91679d31 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 13:31:11 +0000 Subject: [PATCH 163/203] try_into_value allows TryFrom by giving ExtractValueError *2* errortype params --- hugr-passes/src/dataflow/partial_value.rs | 17 +++++++++-------- hugr-passes/src/dataflow/results.rs | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 648034400..f601af86b 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -151,10 +151,10 @@ impl PartialSum { /// row variables within that variant and of the correct number of variants /// `Some(e)` if none of the error conditions above applied, but there was an error /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] - pub fn try_into_value + TryFrom, Error = E>>( + pub fn try_into_value + TryFrom, Error = SE>>( self, typ: &Type, - ) -> Result, ExtractValueError> { + ) -> Result, ExtractValueError> { let Ok((k, v)) = self.0.iter().exactly_one() else { return Err(ExtractValueError::MultipleVariants(self)); }; @@ -183,7 +183,7 @@ impl PartialSum { #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] -pub enum ExtractValueError { +pub enum ExtractValueError { #[error("PartialSum value had multiple possible tags: {0}")] MultipleVariants(PartialSum), #[error("Value contained `Bottom`")] @@ -191,9 +191,9 @@ pub enum ExtractValueError { #[error("Value contained `Top`")] ValueIsTop, #[error("Could not convert element from abstract value into concrete: {0}")] - CouldNotConvert(V, #[source] E), + CouldNotConvert(V, #[source] VE), #[error("Could not build Sum from concrete element values")] - CouldNotBuildSum(#[source] E), + CouldNotBuildSum(#[source] SE), #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] BadSumType { typ: Type, @@ -331,12 +331,13 @@ impl PartialValue { /// /// `None` if this is [Bottom](PartialValue::Bottom) or [Top](PartialValue::Top), /// otherwise as per [PartialSum::try_into_value] - pub fn try_into_value + TryFrom>>( + pub fn try_into_value + TryFrom, Error = SE>>( self, typ: &Type, - ) -> Result>>::Error>> { + ) -> Result> { match self { - Self::Value(v) => Ok(V2::from(v.clone())), + Self::Value(v) => V2::try_from(v.clone()) + .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), Self::PartialSum(ps) => { let v = ps.try_into_value(typ)?; V2::try_from(v).map_err(ExtractValueError::CouldNotBuildSum) diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index f51c6353e..713900acc 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -98,7 +98,7 @@ where pub fn try_read_wire_value( &self, w: Wire, - ) -> Result>> { + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr() From 2d81264a3e5287230c975dc27ac2602149114ad9 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 13:38:49 +0000 Subject: [PATCH 164/203] improve docs --- hugr-passes/src/dataflow/partial_value.rs | 23 +++++++++++++---------- hugr-passes/src/dataflow/results.rs | 2 +- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f601af86b..12f1733d3 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -142,15 +142,14 @@ impl PartialSum { self.0.contains_key(&tag) } - /// Turns this instance into a [Sum] of some target value type `V2`, + /// Turns this instance into a [Sum] of some "concrete" value type `V2`, /// *if* this PartialSum has exactly one possible tag. /// /// # Errors - /// `None` if this PartialSum had multiple possible tags; or, if there was a single - /// tag, but `typ` was not a [TypeEnum::Sum] supporting that tag and containing no - /// row variables within that variant and of the correct number of variants - /// `Some(e)` if none of the error conditions above applied, but there was an error - /// `e` in converting one of the variant elements into `V2` via [PartialValue::try_into_value] + /// + /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] + /// supporting the single possible tag with the correct number of elements and no row variables; + /// or if converting a child element failed via [PartialValue::try_into_value]. pub fn try_into_value + TryFrom, Error = SE>>( self, typ: &Type, @@ -181,6 +180,8 @@ impl PartialSum { } } +/// An error converting a [PartialValue] or [PartialSum] into a concrete value type +/// via [PartialValue::try_into_value] or [PartialSum::try_into_value] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] pub enum ExtractValueError { @@ -324,13 +325,15 @@ impl PartialValue { } } - /// Turns this instance into a target value type `V2` if it is a single value, - /// or a [PartialValue::PartialSum] convertible by [PartialSum::try_into_value]. + /// Turns this instance into some "concrete" value type `V2`, *if* it is a single value, + /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by + /// [PartialSum::try_into_value]. /// /// # Errors /// - /// `None` if this is [Bottom](PartialValue::Bottom) or [Top](PartialValue::Top), - /// otherwise as per [PartialSum::try_into_value] + /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) + /// that could not be converted into a [Sum] by [PartialSum::try_into_value] (e.g. if `typ` is + /// incorrect), or if that [Sum] could not be converted into a `V2`. pub fn try_into_value + TryFrom, Error = SE>>( self, typ: &Type, diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 713900acc..c1e554154 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -90,7 +90,7 @@ where /// /// # Errors /// `None` if the analysis did not produce a result for that wire - /// `Some(e)` if conversion to a [Value] failed with error `e` + /// `Some(e)` if conversion to a [Value] failed with error `e`, see [PartialValue::try_into_value] /// /// # Panics /// From b61d2520231fd29ce5fb23dd489abf64726c7401 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 14:10:36 +0000 Subject: [PATCH 165/203] Parametrize Machine::try_read_wire_value the same way --- hugr-passes/src/dataflow/results.rs | 24 +++++++++++------------- hugr-passes/src/dataflow/test.rs | 12 ++++++------ 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index c1e554154..6c90e33b3 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; -use hugr_core::{ops::Value, types::ConstTypeError, HugrView, IncomingPort, Node, PortIndex, Wire}; +use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, DFContext, PartialValue}; +use super::{partial_value::ExtractValueError, AbstractValue, DFContext, PartialValue, Sum}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). @@ -78,27 +78,25 @@ impl> AnalysisResults { .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), ) } -} -impl> AnalysisResults -where - Value: From, -{ - /// Reads a [Value] from an output wire, if the lattice value computed for it can be turned - /// into one. (The lattice value must be either a single [Value](PartialValue::Value) or - /// a [Sum](PartialValue::PartialSum) with a single known tag.) + /// Reads a concrete representation of the value on an output wire, if the lattice value + /// computed for the wire can be turned into such. (The lattice value must be either a + /// [PartialValue::Value] or a [PartialValue::PartialSum] with a single possible tag.) /// /// # Errors /// `None` if the analysis did not produce a result for that wire - /// `Some(e)` if conversion to a [Value] failed with error `e`, see [PartialValue::try_into_value] + /// `Some(e)` if conversion to a concrete value failed with error `e`, see [PartialValue::try_into_value] /// /// # Panics /// /// If a [Type](hugr_core::types::Type) for the specified wire could not be extracted from the Hugr - pub fn try_read_wire_value( + pub fn try_read_wire_value( &self, w: Wire, - ) -> Result>> { + ) -> Result>> + where + V2: TryFrom + TryFrom, Error = SE>, + { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr() diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 8ec0f9dee..057242d03 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -67,7 +67,7 @@ fn test_make_tuple() { let results = Machine::default().run(TestContext(hugr), []); - let x = results.try_read_wire_value(v3).unwrap(); + let x: Value = results.try_read_wire_value(v3).unwrap(); assert_eq!(x, Value::tuple([Value::false_val(), Value::true_val()])); } @@ -83,9 +83,9 @@ fn test_unpack_tuple_const() { let results = Machine::default().run(TestContext(hugr), []); - let o1_r = results.try_read_wire_value(o1).unwrap(); + let o1_r: Value = results.try_read_wire_value(o1).unwrap(); assert_eq!(o1_r, Value::false_val()); - let o2_r = results.try_read_wire_value(o2).unwrap(); + let o2_r: Value = results.try_read_wire_value(o2).unwrap(); assert_eq!(o2_r, Value::true_val()); } @@ -106,7 +106,7 @@ fn test_tail_loop_never_iterates() { let results = Machine::default().run(TestContext(hugr), []); - let o_r = results.try_read_wire_value(tl_o).unwrap(); + let o_r: Value = results.try_read_wire_value(tl_o).unwrap(); assert_eq!(o_r, r_v); assert_eq!( Some(TailLoopTermination::NeverContinues), @@ -291,9 +291,9 @@ fn test_conditional() { )); let results = Machine::default().run(TestContext(hugr), [(0.into(), arg_pv)]); - let cond_r1 = results.try_read_wire_value(cond_o1).unwrap(); + let cond_r1: Value = results.try_read_wire_value(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results.try_read_wire_value(cond_o2).is_err()); + assert!(results.try_read_wire_value::(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); From 8cac194e534df144a5ef7cb5a8a7ce5c651c5ce6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 30 Oct 2024 14:45:41 +0000 Subject: [PATCH 166/203] tweaks --- hugr-passes/Cargo.toml | 2 +- hugr-passes/src/dataflow.rs | 18 ++++---- hugr-passes/src/dataflow/datalog.rs | 54 +++++++++++------------ hugr-passes/src/dataflow/partial_value.rs | 1 - hugr-passes/src/dataflow/test.rs | 1 + 5 files changed, 37 insertions(+), 39 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 88d7fc62b..818aa069c 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -30,4 +30,4 @@ extension_inference = ["hugr-core/extension_inference"] rstest = { workspace = true } proptest = { workspace = true } proptest-derive = { workspace = true } -proptest-recurse = { version = "0.5.0" } \ No newline at end of file +proptest-recurse = { version = "0.5.0" } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 0b7bf03ed..dd3e6d2c0 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -20,7 +20,7 @@ use hugr_core::{Hugr, HugrView, Node}; /// must implement this trait (including providing an appropriate domain type `V`). pub trait DFContext: ConstLoader + std::ops::Deref { /// Type of view contained within this context. (Ideally we'd constrain - /// by `std::ops::Deref` but that's not stable yet.) type View: HugrView; /// Given lattice values for each input, update lattice values for the (dataflow) outputs. @@ -41,12 +41,14 @@ pub trait DFContext: ConstLoader + std::ops::Deref { } } -/// Trait for loading [PartialValue]s from constants in a Hugr. The default -/// traverses [Sum](Value::Sum) constants to their non-Sum leaves but represents -/// each leaf as [PartialValue::Top]. +/// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. +/// Implementors will likely want to override some/all of [Self::value_from_opaque], +/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { - /// Produces an abstract value from a constant. The default impl - /// traverses the constant [Value] to its leaves ([Value::Extension] and [Value::Function]), + /// Produces a [PartialValue] from a constant. The default impl (expected + /// to be appropriate in most cases) traverses [Sum](Value::Sum) constants + /// to their leaves ([Value::Extension] and [Value::Function]), /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { @@ -65,8 +67,8 @@ pub trait ConstLoader { None } - /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node (that has been loaded - /// via a [LoadFunction]), if possible. + /// Produces an abstract value from a [FuncDefn] or [FuncDecl] node + /// (that has been loaded via a [LoadFunction]), if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. /// /// [FuncDefn]: hugr_core::ops::FuncDefn diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index facab8595..81f415fd0 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,5 +1,8 @@ //! [ascent] datalog implementation of analysis. +use std::collections::HashSet; +use std::hash::RandomState; + use ascent::lattice::BoundedLattice; use itertools::Itertools; @@ -51,20 +54,16 @@ impl Machine { // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis // (Consider: for a conditional that selects *either* the unknown input *or* value V, // analysis must produce Top == we-know-nothing, not `V` !) - let mut have_inputs = - vec![false; context.signature(root).unwrap_or_default().input_count()]; + let mut need_inputs: HashSet<_, RandomState> = HashSet::from_iter( + (0..context.signature(root).unwrap_or_default().input_count()).map(IncomingPort::from), + ); self.0.iter().for_each(|(n, p, _)| { if n == &root { - if let Some(e) = have_inputs.get_mut(p.index()) { - *e = true; - } + need_inputs.remove(p); } }); - for (i, b) in have_inputs.into_iter().enumerate() { - if !b { - self.0 - .push((root, IncomingPort::from(i), PartialValue::Top)); - } + for p in need_inputs { + self.0.push((root, p, PartialValue::Top)); } // Note/TODO, if analysis is running on a subregion then we should do similar // for any nonlocal edges providing values from outside the region. @@ -109,6 +108,7 @@ pub(super) fn run_datalog>( // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); + // Outputs to inputs in_wire_value(n, ip, v) <-- in_wire(n, ip), if let Some((m, op)) = ctx.single_linked_output(*n, *ip), out_wire_value(m, op, v); @@ -120,10 +120,11 @@ pub(super) fn run_datalog>( if let Some(sig) = ctx.signature(*n), if sig.input_ports().contains(p); - // Assemble in_value_row from in_value's + // Assemble node_in_value_row from in_wire_value's node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); node_in_value_row(n, ValueRow::single_known(ctx.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + // Interpret leaf ops out_wire_value(n, p, v) <-- node(n), let op_t = ctx.get_optype(*n), @@ -133,7 +134,7 @@ pub(super) fn run_datalog>( if let Some(outs) = propagate_leaf_op(&ctx, *n, &vs[..], sig.output_count()), for (p, v) in (0..).map(OutgoingPort::from).zip(outs); - // DFG + // DFG -------------------- relation dfg_node(Node); // is a `DFG` dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); @@ -143,9 +144,7 @@ pub(super) fn run_datalog>( out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), output_child(dfg, o), in_wire_value(o, p, v); - - // TailLoop - + // TailLoop -------------------- // inputs of tail loop propagate to Input node of child region out_wire_value(i, OutgoingPort::from(p.index()), v) <-- node(tl), if ctx.get_optype(*tl).is_tail_loop(), @@ -169,13 +168,11 @@ pub(super) fn run_datalog>( if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 for (out_p, v) in fields.enumerate(); - // Conditional - relation conditional_node(Node); // is a `Conditional` + // Conditional -------------------- // is a `Conditional` and its 'th child (a `Case`) is : relation case_node(Node, usize, Node); - - conditional_node(n)<-- node(n), if ctx.get_optype(*n).is_conditional(); - case_node(cond, i, case) <-- conditional_node(cond), + case_node(cond, i, case) <-- node(cond), + if ctx.get_optype(*cond).is_conditional(), for (i, case) in ctx.children(*cond).enumerate(), if ctx.get_optype(case).is_case(); @@ -195,17 +192,17 @@ pub(super) fn run_datalog>( output_child(case, o), in_wire_value(o, o_p, v); - // In `Conditional` , child `Case` is reachable given our knowledge of predicate + // In `Conditional` , child `Case` is reachable given our knowledge of predicate: relation case_reachable(Node, Node); case_reachable(cond, case) <-- case_node(cond, i, case), in_wire_value(cond, IncomingPort::from(0), v), if v.supports_tag(*i); - // CFG + // CFG -------------------- relation cfg_node(Node); // is a `CFG` cfg_node(n) <-- node(n), if ctx.get_optype(*n).is_cfg(); - // In `CFG` , basic block is reachable given our knowledge of predicates + // In `CFG` , basic block is reachable given our knowledge of predicates: relation bb_reachable(Node, Node); bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(); bb_reachable(cfg, bb) <-- cfg_node(cfg), @@ -223,7 +220,7 @@ pub(super) fn run_datalog>( in_wire_value(cfg, p, v); // In `CFG` , values fed along a control-flow edge to - // come out of Value outports of . + // come out of Value outports of : relation _cfg_succ_dest(Node, Node, Node); _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.children(*cfg).nth(1); _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), @@ -242,7 +239,7 @@ pub(super) fn run_datalog>( if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), for (out_p, v) in fields.enumerate(); - // Call + // Call -------------------- relation func_call(Node, Node); // is a `Call` to `FuncDefn` func_call(call, func_defn) <-- node(call), @@ -327,7 +324,7 @@ fn propagate_leaf_op( )) } OpType::ExtensionOp(e) => { - // Interpret op. + // Interpret op using DFContext let init = if ins.iter().contains(&PartialValue::Bottom) { // So far we think one or more inputs can't happen. // So, don't pollute outputs with Top, and wait for better knowledge of inputs. @@ -337,9 +334,8 @@ fn propagate_leaf_op( PartialValue::Top }; let mut outs = vec![init; num_outs]; - // It'd be nice to convert these to [(IncomingPort, Value)] to pass to the context, - // thus keeping PartialValue hidden, but AbstractValues - // are not necessarily convertible to Value. + // It might be nice to convert these to [(IncomingPort, Value)], or some concrete value, + // for the context, but PV contains more information, and try_into_value may fail. ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); Some(ValueRow::from_iter(outs)) } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 12f1733d3..0086629a1 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -362,7 +362,6 @@ impl TryFrom> for Value { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - // println!("join {self:?}\n{:?}", &other); match (&*self, other) { (Self::Top, _) => false, (_, other @ Self::Top) => { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 057242d03..01b1474dd 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -312,6 +312,7 @@ fn xor_and_cfg() -> Hugr { let mut builder = CFGBuilder::new(Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T; 2])).unwrap(); let false_c = builder.add_constant(Value::false_val()); + // entry (x, y) => if x {A(y, x=true)} else B(y)} let entry_outs = [type_row![BOOL_T;2], type_row![BOOL_T]]; let mut entry = builder From fb3816e4c6e14c2a2e8889b0a6695f1214969600 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 12:41:49 +0000 Subject: [PATCH 167/203] Massively simplify xor_and_cfg, no need for conditionals --- hugr-passes/src/dataflow/test.rs | 94 ++++++++++---------------------- 1 file changed, 30 insertions(+), 64 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 01b1474dd..c8961ea2c 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -304,90 +304,56 @@ fn test_conditional() { // A Hugr being a function on bools: (x, y) => (x XOR y, x AND y) #[fixture] fn xor_and_cfg() -> Hugr { - // Entry - // /0 1\ - // A --1-> B A(x=true, y) => if y then X(false, true) else B(x=true) - // \0 / B(z) => X(z,false) + // Entry branch on first arg, passes arguments on unchanged + // /T F\ + // A --T-> B A(x=true, y) branch on second arg, passing (first arg == true, false) + // \F / B(w,v) => X(v,w) // > X < + // Inputs received: + // Entry A B X + // F,F - F,F F,F + // F,T - F,T T,F + // T,F T,F - T,F + // T,T T,T T,F F,T let mut builder = CFGBuilder::new(Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T; 2])).unwrap(); - let false_c = builder.add_constant(Value::false_val()); - // entry (x, y) => if x {A(y, x=true)} else B(y)} - let entry_outs = [type_row![BOOL_T;2], type_row![BOOL_T]]; - let mut entry = builder - .entry_builder(entry_outs.clone(), type_row![]) + // entry (x, y) => (if x then A else B)(x=true, y) + let entry = builder + .entry_builder(vec![type_row![]; 2], type_row![BOOL_T;2]) .unwrap(); let [in_x, in_y] = entry.input_wires_arr(); - let mut cond = entry - .conditional_builder( - (vec![type_row![]; 2], in_x), - [], - Type::new_sum(entry_outs.clone()).into(), - ) - .unwrap(); - let mut if_x_true = cond.case_builder(1).unwrap(); - let br_to_a = if_x_true - .add_dataflow_op(Tag::new(0, entry_outs.to_vec()), [in_y, in_x]) - .unwrap(); - if_x_true.finish_with_outputs(br_to_a.outputs()).unwrap(); - let mut if_x_false = cond.case_builder(0).unwrap(); - let br_to_b = if_x_false - .add_dataflow_op(Tag::new(1, entry_outs.into()), [in_y]) - .unwrap(); - if_x_false.finish_with_outputs(br_to_b.outputs()).unwrap(); - - let [res] = cond.finish_sub_container().unwrap().outputs_arr(); - let entry = entry.finish_with_outputs(res, []).unwrap(); + let entry = entry.finish_with_outputs(in_x, [in_x, in_y]).unwrap(); - // A(y, z always true) => if y {X(false, z)} else {B(z)} - let a_outs = vec![type_row![BOOL_T], type_row![]]; + // A(x==true, y) => (if y then B else X)(x, false) let mut a = builder .block_builder( type_row![BOOL_T; 2], - a_outs.clone(), - type_row![BOOL_T], // Trailing z common to both branches + vec![type_row![]; 2], + type_row![BOOL_T; 2], ) .unwrap(); - let [in_y, in_z] = a.input_wires_arr(); + let [in_x, in_y] = a.input_wires_arr(); + let false_w1 = a.add_load_value(Value::false_val()); + let a = a.finish_with_outputs(in_y, [in_x, false_w1]).unwrap(); - let mut cond = a - .conditional_builder( - (vec![type_row![]; 2], in_y), - [], - Type::new_sum(a_outs.clone()).into(), - ) - .unwrap(); - let mut if_y_true = cond.case_builder(1).unwrap(); - let false_w1 = if_y_true.load_const(&false_c); - let br_to_x = if_y_true - .add_dataflow_op(Tag::new(0, a_outs.clone()), [false_w1]) - .unwrap(); - if_y_true.finish_with_outputs(br_to_x.outputs()).unwrap(); - let mut if_y_false = cond.case_builder(0).unwrap(); - let br_to_b = if_y_false.add_dataflow_op(Tag::new(1, a_outs), []).unwrap(); - if_y_false.finish_with_outputs(br_to_b.outputs()).unwrap(); - let [res] = cond.finish_sub_container().unwrap().outputs_arr(); - let a = a.finish_with_outputs(res, [in_z]).unwrap(); - - // B(v) => X(v, false) + // B(w, v) => X(v, w) let mut b = builder - .block_builder(type_row![BOOL_T], [type_row![]], type_row![BOOL_T; 2]) + .block_builder(type_row![BOOL_T; 2], [type_row![]], type_row![BOOL_T; 2]) .unwrap(); - let [in_v] = b.input_wires_arr(); - let false_w2 = b.load_const(&false_c); + let [in_w, in_v] = b.input_wires_arr(); let [control] = b .add_dataflow_op(Tag::new(0, vec![type_row![]]), []) .unwrap() .outputs_arr(); - let b = b.finish_with_outputs(control, [in_v, false_w2]).unwrap(); + let b = b.finish_with_outputs(control, [in_v, in_w]).unwrap(); let x = builder.exit_block(); - builder.branch(&entry, 0, &a).unwrap(); - builder.branch(&entry, 1, &b).unwrap(); - builder.branch(&a, 0, &x).unwrap(); - builder.branch(&a, 1, &b).unwrap(); + builder.branch(&entry, 1, &a).unwrap(); // if true + builder.branch(&entry, 0, &b).unwrap(); // if false + builder.branch(&a, 1, &b).unwrap(); // if true + builder.branch(&a, 0, &x).unwrap(); // if false builder.branch(&b, 0, &x).unwrap(); builder.finish_hugr(&EMPTY_REG).unwrap() } @@ -402,9 +368,9 @@ fn xor_and_cfg() -> Hugr { #[case(pv_false(), pv_true_or_false(), pv_true_or_false(), pv_false())] #[case(pv_false(), PartialValue::Top, PartialValue::Top, pv_false())] // if !inp0 then out0=inp1 #[case(pv_true_or_false(), pv_true(), pv_true_or_false(), pv_true_or_false())] -#[case(pv_true_or_false(), pv_false(), pv_true_or_false(), pv_false())] +#[case(pv_true_or_false(), pv_false(), pv_true_or_false(), pv_true_or_false())] #[case(PartialValue::Top, pv_true(), pv_true_or_false(), PartialValue::Top)] -#[case(PartialValue::Top, pv_false(), PartialValue::Top, pv_false())] +#[case(PartialValue::Top, pv_false(), PartialValue::Top, PartialValue::Top)] fn test_cfg( #[case] inp0: PartialValue, #[case] inp1: PartialValue, From 19571f6a9d2bcd94c1afe49b9b136f7505a4d493 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 4 Nov 2024 13:12:19 +0000 Subject: [PATCH 168/203] Use tru/fals constants --- hugr-passes/src/dataflow/test.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index c8961ea2c..a300965b1 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -350,10 +350,11 @@ fn xor_and_cfg() -> Hugr { let x = builder.exit_block(); - builder.branch(&entry, 1, &a).unwrap(); // if true - builder.branch(&entry, 0, &b).unwrap(); // if false - builder.branch(&a, 1, &b).unwrap(); // if true - builder.branch(&a, 0, &x).unwrap(); // if false + let [fals, tru]: [usize; 2] = [0, 1]; + builder.branch(&entry, tru, &a).unwrap(); // if true + builder.branch(&entry, fals, &b).unwrap(); // if false + builder.branch(&a, tru, &b).unwrap(); // if true + builder.branch(&a, fals, &x).unwrap(); // if false builder.branch(&b, 0, &x).unwrap(); builder.finish_hugr(&EMPTY_REG).unwrap() } From 69d0f5e59a5a787a6e222daaca622ca03114332a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 10:00:15 +0000 Subject: [PATCH 169/203] try_into_value: reorder type params, separate out where clause --- hugr-passes/src/dataflow/partial_value.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 0086629a1..992a72444 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -150,10 +150,13 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_value]. - pub fn try_into_value + TryFrom, Error = SE>>( + pub fn try_into_value( self, typ: &Type, - ) -> Result, ExtractValueError> { + ) -> Result, ExtractValueError> + where + V2: TryFrom + TryFrom, Error = SE>, + { let Ok((k, v)) = self.0.iter().exactly_one() else { return Err(ExtractValueError::MultipleVariants(self)); }; @@ -334,10 +337,10 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_value] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_value + TryFrom, Error = SE>>( - self, - typ: &Type, - ) -> Result> { + pub fn try_into_value(self, typ: &Type) -> Result> + where + V2: TryFrom + TryFrom, Error = SE>, + { match self { Self::Value(v) => V2::try_from(v.clone()) .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), From 2624ee85f7607726b1b82d00bb918f70942717af Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 10:02:25 +0000 Subject: [PATCH 170/203] We don't actually use portgraph, nor downcast-rs --- hugr-passes/Cargo.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index caa7e0af0..311fe781c 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -14,9 +14,7 @@ categories = ["compilers"] [dependencies] hugr-core = { path = "../hugr-core", version = "0.13.3" } -portgraph = { workspace = true } ascent = { version = "0.7.0" } -downcast-rs = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } paste = { workspace = true } From ec526e83d4c1cea0c2d45f2686672d6578b44d66 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 10:23:13 +0000 Subject: [PATCH 171/203] Import RandomState from std::collections::hash_map for rust 1.75 --- hugr-passes/src/dataflow/datalog.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 81f415fd0..dbca253da 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,7 +1,7 @@ //! [ascent] datalog implementation of analysis. -use std::collections::HashSet; -use std::hash::RandomState; +use std::collections::hash_map::RandomState; +use std::collections::HashSet; // Moves to std::hash in Rust 1.76 use ascent::lattice::BoundedLattice; use itertools::Itertools; From 5650ee40b2f5aed54657a221274ecc97b01d771b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 13:48:36 +0000 Subject: [PATCH 172/203] Use BREAK_TAG/CONTINUE_TAG --- hugr-passes/src/dataflow/datalog.rs | 12 +++++++----- hugr-passes/src/dataflow/test.rs | 26 +++++++++++++++++++++----- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index dbca253da..33ab4524d 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType}; +use hugr_core::ops::{OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -156,16 +156,18 @@ pub(super) fn run_datalog>( if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), input_child(tl, in_n), output_child(tl, out_n), - node_in_value_row(out_n, out_in_row), // get the whole input row for the output node - if let Some(fields) = out_in_row.unpack_first(0, tailloop.just_inputs.len()), // if it is possible for tag to be 0 + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ...and select just what's possible for CONTINUE_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop out_wire_value(tl, OutgoingPort::from(out_p), v) <-- node(tl), if let Some(tailloop) = ctx.get_optype(*tl).as_tail_loop(), output_child(tl, out_n), - node_in_value_row(out_n, out_in_row), // get the whole input row for the output node - if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1 + node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... + // ... and select just what's possible for BREAK_TAG, if anything + if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), for (out_p, v) in fields.enumerate(); // Conditional -------------------- diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index a300965b1..44446760e 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -3,6 +3,7 @@ use ascent::{lattice::BoundedLattice, Lattice}; use hugr_core::builder::{CFGBuilder, Container, DataflowHugr}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; +use hugr_core::ops::TailLoop; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -94,7 +95,10 @@ fn test_tail_loop_never_iterates() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); let r_v = Value::unit_sum(3, 6).unwrap(); let r_w = builder.add_load_value(r_v.clone()); - let tag = Tag::new(1, vec![type_row![], r_v.get_type().into()]); + let tag = Tag::new( + TailLoop::BREAK_TAG, + vec![type_row![], r_v.get_type().into()], + ); let tagged = builder.add_dataflow_op(tag, [r_w]).unwrap(); let tlb = builder @@ -117,8 +121,14 @@ fn test_tail_loop_never_iterates() { #[test] fn test_tail_loop_always_iterates() { let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap(); - let r_w = builder - .add_load_value(Value::sum(0, [], SumType::new([type_row![], BOOL_T.into()])).unwrap()); + let r_w = builder.add_load_value( + Value::sum( + TailLoop::CONTINUE_TAG, + [], + SumType::new([type_row![], BOOL_T.into()]), + ) + .unwrap(), + ); let true_w = builder.add_load_value(Value::true_val()); let tlb = builder @@ -221,13 +231,19 @@ fn test_tail_loop_containing_conditional() { .unwrap() .outputs_arr(); let cont = case0_b - .add_dataflow_op(Tag::new(0, body_out_variants.clone()), [next_input]) + .add_dataflow_op( + Tag::new(TailLoop::CONTINUE_TAG, body_out_variants.clone()), + [next_input], + ) .unwrap(); case0_b.finish_with_outputs(cont.outputs()).unwrap(); // Second iter 1(true, false) => exit with (true, false) let mut case1_b = cond.case_builder(1).unwrap(); let loop_res = case1_b - .add_dataflow_op(Tag::new(1, body_out_variants), case1_b.input_wires()) + .add_dataflow_op( + Tag::new(TailLoop::BREAK_TAG, body_out_variants), + case1_b.input_wires(), + ) .unwrap(); case1_b.finish_with_outputs(loop_res.outputs()).unwrap(); let [r] = cond.finish_sub_container().unwrap().outputs_arr(); From da2981cf3d8e840d9e833b602a317ae5e0213623 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 8 Nov 2024 13:52:10 +0000 Subject: [PATCH 173/203] No, use make_break/make_continue for easy cases --- hugr-passes/src/dataflow/test.rs | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 44446760e..c0fbf395a 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -213,6 +213,7 @@ fn test_tail_loop_containing_conditional() { let mut tlb = builder .tail_loop_builder([(control_t, init)], [], type_row![BOOL_T; 2]) .unwrap(); + let tl = tlb.loop_signature().unwrap().clone(); let [in_w] = tlb.input_wires_arr(); // Branch on in_wire, so first iter 0(false, true)... @@ -230,22 +231,12 @@ fn test_tail_loop_containing_conditional() { .add_dataflow_op(Tag::new(1, control_variants), [b, a]) .unwrap() .outputs_arr(); - let cont = case0_b - .add_dataflow_op( - Tag::new(TailLoop::CONTINUE_TAG, body_out_variants.clone()), - [next_input], - ) - .unwrap(); - case0_b.finish_with_outputs(cont.outputs()).unwrap(); + let cont = case0_b.make_continue(tl.clone(), [next_input]).unwrap(); + case0_b.finish_with_outputs([cont]).unwrap(); // Second iter 1(true, false) => exit with (true, false) let mut case1_b = cond.case_builder(1).unwrap(); - let loop_res = case1_b - .add_dataflow_op( - Tag::new(TailLoop::BREAK_TAG, body_out_variants), - case1_b.input_wires(), - ) - .unwrap(); - case1_b.finish_with_outputs(loop_res.outputs()).unwrap(); + let loop_res = case1_b.make_break(tl, case1_b.input_wires()).unwrap(); + case1_b.finish_with_outputs([loop_res]).unwrap(); let [r] = cond.finish_sub_container().unwrap().outputs_arr(); let tail_loop = tlb.finish_with_outputs(r, []).unwrap(); From 0b684549442525d868ebba0f48a98f7927d60c2a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 13:40:56 +0000 Subject: [PATCH 174/203] Refactor bb_reachable using then --- hugr-passes/src/dataflow/results.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 6c90e33b3..21d6b13c0 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -69,14 +69,11 @@ impl> AnalysisResults { let cfg = self.hugr().get_parent(bb)?; // Not really required...?? self.hugr().get_optype(cfg).as_cfg()?; let t = self.hugr().get_optype(bb); - if !t.is_dataflow_block() && !t.is_exit_block() { - return None; - }; - Some( + (t.is_dataflow_block() || t.is_exit_block()).then(|| { self.bb_reachable .iter() - .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb), - ) + .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb) + }) } /// Reads a concrete representation of the value on an output wire, if the lattice value From 4916e9d6b7a099a6a1db871f1cddf3c677b3b51e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 14:24:35 +0000 Subject: [PATCH 175/203] ConstLocation with Box - a lot of cloning --- hugr-passes/src/dataflow.rs | 58 ++++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index dd3e6d2c0..a1e84bad1 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -41,6 +41,42 @@ pub trait DFContext: ConstLoader + std::ops::Deref { } } +/// A location where a [Value] could be find in a Hugr. That is, +/// (perhaps deeply nested within [Value::Sum]s) within a [Node] +/// that is a [Const](hugr_core::ops::Const). +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum ConstLocation>> { + /// The specified-index'th field of the [Value::Sum] constant identified by the RHS + Field(usize, C), + /// The entire ([Const::value](hugr_core::ops::Const::value)) of the node. + Node(Node), +} + +struct SharedConstLocation<'a>(ConstLocation<&'a SharedConstLocation<'a>>); + +impl<'a> AsRef>> for SharedConstLocation<'a> { + fn as_ref(&self) -> &ConstLocation<&'a SharedConstLocation<'a>> { + &self.0 + } +} + +struct BoxedConstLocation(Box>); + +impl AsRef> for BoxedConstLocation { + fn as_ref(&self) -> &ConstLocation { + &self.0 + } +} + +impl<'a> From<&SharedConstLocation<'a>> for BoxedConstLocation { + fn from(value: &SharedConstLocation<'a>) -> Self { + BoxedConstLocation(Box::new(match value.0 { + ConstLocation::Node(n) => ConstLocation::Node(n), + ConstLocation::Field(idx, elem) => ConstLocation::Field(idx, elem.into()), + })) + } +} + /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. /// Implementors will likely want to override some/all of [Self::value_from_opaque], /// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults @@ -52,18 +88,18 @@ pub trait ConstLoader { /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { - traverse_value(self, n, &mut Vec::new(), cst) + traverse_value(self, SharedConstLocation(ConstLocation::Node(n)), cst) } /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_opaque(&self, _node: Node, _fields: &[usize], _val: &OpaqueValue) -> Option { + fn value_from_opaque(&self, _loc: SharedConstLocation, _val: &OpaqueValue) -> Option { None } /// Produces an abstract value from a Hugr in a [Value::Function], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_const_hugr(&self, _node: Node, _fields: &[usize], _h: &Hugr) -> Option { + fn value_from_const_hugr(&self, _loc: SharedConstLocation, _h: &Hugr) -> Option { None } @@ -81,26 +117,22 @@ pub trait ConstLoader { fn traverse_value( s: &(impl ConstLoader + ?Sized), - n: Node, - fields: &mut Vec, + loc: SharedConstLocation, cst: &Value, ) -> PartialValue { match cst { Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { - let elems = values.iter().enumerate().map(|(idx, elem)| { - fields.push(idx); - let r = traverse_value(s, n, fields, elem); - fields.pop(); - r - }); + let elems = values.iter().enumerate().map(|(idx, elem)| + traverse_value(s, SharedConstLocation(ConstLocation::Field(idx, &loc)), elem) + ); PartialValue::new_variant(*tag, elems) } Value::Extension { e } => s - .value_from_opaque(n, fields, e) + .value_from_opaque(loc, e) .map(PartialValue::from) .unwrap_or(PartialValue::Top), Value::Function { hugr } => s - .value_from_const_hugr(n, fields, hugr) + .value_from_const_hugr(loc, hugr) .map(PartialValue::from) .unwrap_or(PartialValue::Top), } From a4a64c011a7e669256d869d9f9bc9531d6e8f8d0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 15:18:35 +0000 Subject: [PATCH 176/203] No - Revert - just make ConstLocation store a reference --- hugr-passes/src/dataflow.rs | 39 +++++++------------------------------ 1 file changed, 7 insertions(+), 32 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index a1e84bad1..802f00022 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -45,38 +45,13 @@ pub trait DFContext: ConstLoader + std::ops::Deref { /// (perhaps deeply nested within [Value::Sum]s) within a [Node] /// that is a [Const](hugr_core::ops::Const). #[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum ConstLocation>> { +pub enum ConstLocation<'a> { /// The specified-index'th field of the [Value::Sum] constant identified by the RHS - Field(usize, C), + Field(usize, &'a ConstLocation<'a>), /// The entire ([Const::value](hugr_core::ops::Const::value)) of the node. Node(Node), } -struct SharedConstLocation<'a>(ConstLocation<&'a SharedConstLocation<'a>>); - -impl<'a> AsRef>> for SharedConstLocation<'a> { - fn as_ref(&self) -> &ConstLocation<&'a SharedConstLocation<'a>> { - &self.0 - } -} - -struct BoxedConstLocation(Box>); - -impl AsRef> for BoxedConstLocation { - fn as_ref(&self) -> &ConstLocation { - &self.0 - } -} - -impl<'a> From<&SharedConstLocation<'a>> for BoxedConstLocation { - fn from(value: &SharedConstLocation<'a>) -> Self { - BoxedConstLocation(Box::new(match value.0 { - ConstLocation::Node(n) => ConstLocation::Node(n), - ConstLocation::Field(idx, elem) => ConstLocation::Field(idx, elem.into()), - })) - } -} - /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. /// Implementors will likely want to override some/all of [Self::value_from_opaque], /// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults @@ -88,18 +63,18 @@ pub trait ConstLoader { /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { - traverse_value(self, SharedConstLocation(ConstLocation::Node(n)), cst) + traverse_value(self, ConstLocation::Node(n), cst) } /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_opaque(&self, _loc: SharedConstLocation, _val: &OpaqueValue) -> Option { + fn value_from_opaque(&self, _loc: ConstLocation, _val: &OpaqueValue) -> Option { None } /// Produces an abstract value from a Hugr in a [Value::Function], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. - fn value_from_const_hugr(&self, _loc: SharedConstLocation, _h: &Hugr) -> Option { + fn value_from_const_hugr(&self, _loc: ConstLocation, _h: &Hugr) -> Option { None } @@ -117,13 +92,13 @@ pub trait ConstLoader { fn traverse_value( s: &(impl ConstLoader + ?Sized), - loc: SharedConstLocation, + loc: ConstLocation, cst: &Value, ) -> PartialValue { match cst { Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { let elems = values.iter().enumerate().map(|(idx, elem)| - traverse_value(s, SharedConstLocation(ConstLocation::Field(idx, &loc)), elem) + traverse_value(s, ConstLocation::Field(idx, &loc), elem) ); PartialValue::new_variant(*tag, elems) } From bc39b76e9b3c6840db59071ecad352456beca48d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 15:19:48 +0000 Subject: [PATCH 177/203] {value=>partial}_from_const, takes ConstLoc, inline traverse_value --- hugr-passes/src/dataflow.rs | 42 ++++++++++++----------------- hugr-passes/src/dataflow/datalog.rs | 4 +-- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 802f00022..dca97fb77 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -62,8 +62,23 @@ pub trait ConstLoader { /// to their leaves ([Value::Extension] and [Value::Function]), /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], /// and builds nested [PartialValue::new_variant] to represent the structure. - fn value_from_const(&self, n: Node, cst: &Value) -> PartialValue { - traverse_value(self, ConstLocation::Node(n), cst) + fn partial_from_const(&self, loc: ConstLocation, cst: &Value) -> PartialValue { + match cst { + Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { + let elems = values.iter().enumerate().map(|(idx, elem)| { + self.partial_from_const(ConstLocation::Field(idx, &loc), elem) + }); + PartialValue::new_variant(*tag, elems) + } + Value::Extension { e } => self + .value_from_opaque(loc, e) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + Value::Function { hugr } => self + .value_from_const_hugr(loc, hugr) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + } } /// Produces an abstract value from an [OpaqueValue], if possible. @@ -90,28 +105,5 @@ pub trait ConstLoader { } } -fn traverse_value( - s: &(impl ConstLoader + ?Sized), - loc: ConstLocation, - cst: &Value, -) -> PartialValue { - match cst { - Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { - let elems = values.iter().enumerate().map(|(idx, elem)| - traverse_value(s, ConstLocation::Field(idx, &loc), elem) - ); - PartialValue::new_variant(*tag, elems) - } - Value::Extension { e } => s - .value_from_opaque(loc, e) - .map(PartialValue::from) - .unwrap_or(PartialValue::Top), - Value::Function { hugr } => s - .value_from_const_hugr(loc, hugr) - .map(PartialValue::from) - .unwrap_or(PartialValue::Top), - } -} - #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 33ab4524d..3d0acf20b 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -11,7 +11,7 @@ use hugr_core::ops::{OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; -use super::{AbstractValue, AnalysisResults, DFContext, PartialValue}; +use super::{AbstractValue, AnalysisResults, ConstLocation, DFContext, PartialValue}; type PV = PartialValue; @@ -308,7 +308,7 @@ fn propagate_leaf_op( Some(ValueRow::single_known( 1, 0, - ctx.value_from_const(n, const_val), + ctx.partial_from_const(ConstLocation::Node(n), const_val), )) } OpType::LoadFunction(load_op) => { From 92669db2c3576bea591d4daa4718625d1f35af9d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 15:25:48 +0000 Subject: [PATCH 178/203] Make ConstLocation Copy --- hugr-passes/src/dataflow.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index dca97fb77..5ebcc1afb 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -44,7 +44,7 @@ pub trait DFContext: ConstLoader + std::ops::Deref { /// A location where a [Value] could be find in a Hugr. That is, /// (perhaps deeply nested within [Value::Sum]s) within a [Node] /// that is a [Const](hugr_core::ops::Const). -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum ConstLocation<'a> { /// The specified-index'th field of the [Value::Sum] constant identified by the RHS Field(usize, &'a ConstLocation<'a>), From 33c860721c2b914ac74382f88e664759755958b4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 14 Nov 2024 15:38:22 +0000 Subject: [PATCH 179/203] ConstLocation is From; move partial_from_const out to toplev, no value_from_const --- hugr-passes/src/dataflow.rs | 59 +++++++++++++++++------------ hugr-passes/src/dataflow/datalog.rs | 4 +- 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 5ebcc1afb..3769eaced 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -52,35 +52,17 @@ pub enum ConstLocation<'a> { Node(Node), } +impl<'a> From for ConstLocation<'a> { + fn from(value: Node) -> Self { + ConstLocation::Node(value) + } +} + /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. /// Implementors will likely want to override some/all of [Self::value_from_opaque], /// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults /// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { - /// Produces a [PartialValue] from a constant. The default impl (expected - /// to be appropriate in most cases) traverses [Sum](Value::Sum) constants - /// to their leaves ([Value::Extension] and [Value::Function]), - /// converts these using [Self::value_from_opaque] and [Self::value_from_const_hugr], - /// and builds nested [PartialValue::new_variant] to represent the structure. - fn partial_from_const(&self, loc: ConstLocation, cst: &Value) -> PartialValue { - match cst { - Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { - let elems = values.iter().enumerate().map(|(idx, elem)| { - self.partial_from_const(ConstLocation::Field(idx, &loc), elem) - }); - PartialValue::new_variant(*tag, elems) - } - Value::Extension { e } => self - .value_from_opaque(loc, e) - .map(PartialValue::from) - .unwrap_or(PartialValue::Top), - Value::Function { hugr } => self - .value_from_const_hugr(loc, hugr) - .map(PartialValue::from) - .unwrap_or(PartialValue::Top), - } - } - /// Produces an abstract value from an [OpaqueValue], if possible. /// The default just returns `None`, which will be interpreted as [PartialValue::Top]. fn value_from_opaque(&self, _loc: ConstLocation, _val: &OpaqueValue) -> Option { @@ -105,5 +87,34 @@ pub trait ConstLoader { } } +/// Produces a [PartialValue] from a constant. Traverses [Sum](Value::Sum) constants +/// to their leaves ([Value::Extension] and [Value::Function]), +/// converts these using [ConstLoader::value_from_opaque] and [ConstLoader::value_from_const_hugr], +/// and builds nested [PartialValue::new_variant] to represent the structure. +fn partial_from_const<'a, V>( + cl: &impl ConstLoader, + loc: impl Into>, + cst: &Value, +) -> PartialValue { + let loc = loc.into(); + match cst { + Value::Sum(hugr_core::ops::constant::Sum { tag, values, .. }) => { + let elems = values + .iter() + .enumerate() + .map(|(idx, elem)| partial_from_const(cl, ConstLocation::Field(idx, &loc), elem)); + PartialValue::new_variant(*tag, elems) + } + Value::Extension { e } => cl + .value_from_opaque(loc, e) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + Value::Function { hugr } => cl + .value_from_const_hugr(loc, hugr) + .map(PartialValue::from) + .unwrap_or(PartialValue::Top), + } +} + #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 3d0acf20b..303b96acf 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -11,7 +11,7 @@ use hugr_core::ops::{OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; -use super::{AbstractValue, AnalysisResults, ConstLocation, DFContext, PartialValue}; +use super::{partial_from_const, AbstractValue, AnalysisResults, DFContext, PartialValue}; type PV = PartialValue; @@ -308,7 +308,7 @@ fn propagate_leaf_op( Some(ValueRow::single_known( 1, 0, - ctx.partial_from_const(ConstLocation::Node(n), const_val), + partial_from_const(ctx, n, const_val), )) } OpType::LoadFunction(load_op) => { From 8b76135795810faabaad8429a78b26aa32f84df0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 18 Nov 2024 21:53:02 +0000 Subject: [PATCH 180/203] Generalize run to deal with Module(use main), and others; add run_lib --- hugr-passes/src/dataflow/datalog.rs | 93 ++++++++++++++++++++++------- 1 file changed, 73 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 303b96acf..4a1448cc5 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowParent, NamedOp, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -29,46 +29,99 @@ impl Default for Machine { } impl Machine { - /// Provide initial values for some wires. - // Likely for test purposes only - should we make non-pub or #[cfg(test)] ? - pub fn prepopulate_wire(&mut self, h: &impl HugrView, wire: Wire, value: PartialValue) { + // Provide initial values for a wire - these will be `join`d with any computed. + // pub(crate) so can be used for tests. + pub(crate) fn prepopulate_wire(&mut self, h: &impl HugrView, w: Wire, v: PartialValue) { self.0.extend( - h.linked_inputs(wire.node(), wire.source()) - .map(|(n, inp)| (n, inp, value.clone())), + h.linked_inputs(w.node(), w.source()) + .map(|(n, inp)| (n, inp, v.clone())), ); } /// Run the analysis (iterate until a lattice fixpoint is reached), - /// given initial values for some of the root node inputs. - /// (Note that `in_values` will not be useful for `Case` or `DFB`-rooted Hugrs, - /// but should handle other containers.) + /// given initial values for some of the root node inputs. For a + /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. /// The context passed in allows interpretation of leaf operations. + /// + /// [Module]: OpType::Module pub fn run>( mut self, context: C, in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = context.root(); - self.0 - .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + // Some nodes do not accept values as dataflow inputs - for these + // we must find the corresponding Output node. + let out_node_parent = match context.get_optype(root) { + OpType::Module(_) => Some( + context + .children(root) + .find(|n| { + context + .get_optype(*n) + .as_func_defn() + .is_some_and(|f| f.name() == "main") + }) + .expect("Module must contain a 'main' function to be analysed"), + ), + OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => Some(root), + // Could also do Dfg above, but ok here too: + _ => None, // Just feed into node inputs + }; + // Now write values onto Input node out-wires or Outputs. // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis // (Consider: for a conditional that selects *either* the unknown input *or* value V, // analysis must produce Top == we-know-nothing, not `V` !) - let mut need_inputs: HashSet<_, RandomState> = HashSet::from_iter( - (0..context.signature(root).unwrap_or_default().input_count()).map(IncomingPort::from), - ); - self.0.iter().for_each(|(n, p, _)| { - if n == &root { - need_inputs.remove(p); + if let Some(p) = out_node_parent { + let [inp, _] = context.get_io(p).unwrap(); + let mut vals = + vec![PartialValue::Top; context.signature(inp).unwrap().output_types().len()]; + for (ip, v) in in_values { + vals[ip.index()] = v; + } + for (i, v) in vals.into_iter().enumerate() { + self.prepopulate_wire(&*context, Wire::new(inp, i), v); + } + } else { + self.0 + .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + let mut need_inputs: HashSet<_, RandomState> = HashSet::from_iter( + (0..context.signature(root).unwrap_or_default().input_count()) + .map(IncomingPort::from), + ); + self.0.iter().for_each(|(n, p, _)| { + if n == &root { + need_inputs.remove(p); + } + }); + for p in need_inputs { + self.0.push((root, p, PartialValue::Top)); } - }); - for p in need_inputs { - self.0.push((root, p, PartialValue::Top)); } // Note/TODO, if analysis is running on a subregion then we should do similar // for any nonlocal edges providing values from outside the region. run_datalog(context, self.0) } + + /// Run the analysis (iterate until a lattice fixpoint is reached), + /// for a [Module]-rooted Hugr where all functions are assumed callable + /// (from a client) with any arguments. + /// The context passed in allows interpretation of leaf operations. + pub fn run_lib>(mut self, context: C) -> AnalysisResults { + let root = context.root(); + if !context.get_optype(root).is_module() { + panic!("Hugr not Module-rooted") + } + for n in context.children(root) { + if let Some(fd) = context.get_optype(n).as_func_defn() { + let [inp, _] = context.get_io(n).unwrap(); + for p in 0..fd.inner_signature().input_count() { + self.prepopulate_wire(&*context, Wire::new(inp, p), PartialValue::Top); + } + } + } + run_datalog(context, self.0) + } } pub(super) fn run_datalog>( From c18cbea26b14d248a5a9a130d7fe82aff63b0c6f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 18 Nov 2024 22:03:15 +0000 Subject: [PATCH 181/203] Shorten the got-all-required-inputs check (build got_inputs) --- hugr-passes/src/dataflow/datalog.rs | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 4a1448cc5..9efbdc041 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -85,17 +85,15 @@ impl Machine { } else { self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let mut need_inputs: HashSet<_, RandomState> = HashSet::from_iter( - (0..context.signature(root).unwrap_or_default().input_count()) - .map(IncomingPort::from), - ); - self.0.iter().for_each(|(n, p, _)| { - if n == &root { - need_inputs.remove(p); + let got_inputs: HashSet<_, RandomState> = self + .0 + .iter() + .filter_map(|(n, p, _)| (n == &root).then_some(*p)) + .collect(); + for p in context.signature(root).unwrap_or_default().input_ports() { + if !got_inputs.contains(&p) { + self.0.push((root, p, PartialValue::Top)); } - }); - for p in need_inputs { - self.0.push((root, p, PartialValue::Top)); } } // Note/TODO, if analysis is running on a subregion then we should do similar From 1b64b4bbd598825d1313b175cb768094e99f830e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 18 Nov 2024 22:19:16 +0000 Subject: [PATCH 182/203] Shorten further...not as easy to follow --- hugr-passes/src/dataflow/datalog.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 9efbdc041..0b0b57ce2 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,8 +1,5 @@ //! [ascent] datalog implementation of analysis. -use std::collections::hash_map::RandomState; -use std::collections::HashSet; // Moves to std::hash in Rust 1.76 - use ascent::lattice::BoundedLattice; use itertools::Itertools; @@ -85,15 +82,13 @@ impl Machine { } else { self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let got_inputs: HashSet<_, RandomState> = self - .0 - .iter() - .filter_map(|(n, p, _)| (n == &root).then_some(*p)) - .collect(); - for p in context.signature(root).unwrap_or_default().input_ports() { - if !got_inputs.contains(&p) { - self.0.push((root, p, PartialValue::Top)); - } + let mut need_inputs = + vec![true; context.signature(root).unwrap_or_default().input_count()]; + for (_, p, _) in self.0.iter().filter(|(n, _, _)| n == &root) { + need_inputs[p.index()] = false; + } + for (i, _) in need_inputs.into_iter().enumerate().filter(|(_, b)| *b) { + self.0.push((root, i.into(), PartialValue::Top)); } } // Note/TODO, if analysis is running on a subregion then we should do similar From 39b8df16e6c05297ad403b593095fcf3144afb72 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 09:14:26 +0000 Subject: [PATCH 183/203] Revert "Shorten further...not as easy to follow" This reverts commit 1b64b4bbd598825d1313b175cb768094e99f830e. --- hugr-passes/src/dataflow/datalog.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 0b0b57ce2..9efbdc041 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -1,5 +1,8 @@ //! [ascent] datalog implementation of analysis. +use std::collections::hash_map::RandomState; +use std::collections::HashSet; // Moves to std::hash in Rust 1.76 + use ascent::lattice::BoundedLattice; use itertools::Itertools; @@ -82,13 +85,15 @@ impl Machine { } else { self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); - let mut need_inputs = - vec![true; context.signature(root).unwrap_or_default().input_count()]; - for (_, p, _) in self.0.iter().filter(|(n, _, _)| n == &root) { - need_inputs[p.index()] = false; - } - for (i, _) in need_inputs.into_iter().enumerate().filter(|(_, b)| *b) { - self.0.push((root, i.into(), PartialValue::Top)); + let got_inputs: HashSet<_, RandomState> = self + .0 + .iter() + .filter_map(|(n, p, _)| (n == &root).then_some(*p)) + .collect(); + for p in context.signature(root).unwrap_or_default().input_ports() { + if !got_inputs.contains(&p) { + self.0.push((root, p, PartialValue::Top)); + } } } // Note/TODO, if analysis is running on a subregion then we should do similar From a5d987c9b42f8f452ff69a68ec4627b10076e182 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 11:50:18 +0000 Subject: [PATCH 184/203] doc fixes, rename to run_library --- hugr-passes/src/dataflow/datalog.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 9efbdc041..74133a5bd 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -15,10 +15,11 @@ use super::{partial_from_const, AbstractValue, AnalysisResults, DFContext, Parti type PV = PartialValue; +#[allow(rustdoc::private_intra_doc_links)] /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] /// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Call [Self::run] to produce [AnalysisResults] +/// 3. Call [Self::run] or [Self::run_library] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); /// derived-Default requires the context to be Defaultable, which is unnecessary @@ -102,10 +103,10 @@ impl Machine { } /// Run the analysis (iterate until a lattice fixpoint is reached), - /// for a [Module]-rooted Hugr where all functions are assumed callable + /// for a [Module](OpType::Module)-rooted Hugr where all functions are assumed callable /// (from a client) with any arguments. /// The context passed in allows interpretation of leaf operations. - pub fn run_lib>(mut self, context: C) -> AnalysisResults { + pub fn run_library>(mut self, context: C) -> AnalysisResults { let root = context.root(); if !context.get_optype(root).is_module() { panic!("Hugr not Module-rooted") From 3e718fdcc95ee1865d8a6e7ed22ed731ac2beace Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 11:52:46 +0000 Subject: [PATCH 185/203] Add PartialValue::contains_bottom, also row_contains_bottom --- hugr-passes/src/dataflow.rs | 8 ++++++++ hugr-passes/src/dataflow/partial_value.rs | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 3769eaced..f6e710f66 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -116,5 +116,13 @@ fn partial_from_const<'a, V>( } } +/// A row of inputs to a node contains bottom (can't happen, the node +/// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). +pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( + elements: impl IntoIterator>, +) -> bool { + elements.into_iter().any(PartialValue::contains_bottom) +} + #[cfg(test)] mod test; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 992a72444..cd0b1fb29 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -8,6 +8,8 @@ use std::collections::HashMap; use std::hash::{Hash, Hasher}; use thiserror::Error; +use super::row_contains_bottom; + /// Trait for an underlying domain of abstract values which can form the *elements* of a /// [PartialValue] and thus be used in dataflow analysis. pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { @@ -181,6 +183,13 @@ impl PartialSum { num_elements: v.len(), }) } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } } /// An error converting a [PartialValue] or [PartialSum] into a concrete value type @@ -352,6 +361,18 @@ impl PartialValue { Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } + + /// A value contains bottom means that it cannot occur during execution + /// - it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } } impl TryFrom> for Value { From 497686ae927a5ded7b4401fd36cc8a7081392e57 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 12:02:48 +0000 Subject: [PATCH 186/203] Don't call interpret_leaf_op if row_contains_bottom --- hugr-passes/src/dataflow/datalog.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 74133a5bd..e33494ac4 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -11,7 +11,10 @@ use hugr_core::ops::{DataflowParent, NamedOp, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; -use super::{partial_from_const, AbstractValue, AnalysisResults, DFContext, PartialValue}; +use super::{ + partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, + PartialValue, +}; type PV = PartialValue; @@ -378,20 +381,19 @@ fn propagate_leaf_op( )) } OpType::ExtensionOp(e) => { - // Interpret op using DFContext - let init = if ins.iter().contains(&PartialValue::Bottom) { + Some(ValueRow::from_iter(if row_contains_bottom(ins) { // So far we think one or more inputs can't happen. // So, don't pollute outputs with Top, and wait for better knowledge of inputs. - PartialValue::Bottom + vec![PartialValue::Bottom; num_outs] } else { - // If we can't figure out anything about the outputs, assume nothing (they still happen!) - PartialValue::Top - }; - let mut outs = vec![init; num_outs]; - // It might be nice to convert these to [(IncomingPort, Value)], or some concrete value, - // for the context, but PV contains more information, and try_into_value may fail. - ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); - Some(ValueRow::from_iter(outs)) + // Interpret op using DFContext + // Default to Top i.e. can't figure out anything about the outputs + let mut outs = vec![PartialValue::Top; num_outs]; + // It might be nice to convert `ins`` to [(IncomingPort, Value)], or some concrete value, + // for the context, but PV contains more information, and try_into_value may fail. + ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); + outs + })) } o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" } From e34c7bef059295be4bc439ec5db35cd4b92eeded Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 12:25:44 +0000 Subject: [PATCH 187/203] Use row_contains_bottom for CFG+DFG, and augment unpack_first(=>_no_bottom) --- hugr-passes/src/dataflow/datalog.rs | 17 +++++++++++------ hugr-passes/src/dataflow/value_row.rs | 15 +++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index e33494ac4..9a0426ae6 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -194,7 +194,10 @@ pub(super) fn run_datalog>( dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), - input_child(dfg, i), in_wire_value(dfg, p, v); + input_child(dfg, i), + node_in_value_row(dfg, row), + if !row_contains_bottom(&row[..]), // Treat the DFG as a scheduling barrier + for (p, v) in row[..].iter().enumerate(); out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), output_child(dfg, o), in_wire_value(o, p, v); @@ -213,7 +216,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ...and select just what's possible for CONTINUE_TAG, if anything - if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), + if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop @@ -222,7 +225,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ... and select just what's possible for BREAK_TAG, if anything - if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), + if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), for (out_p, v) in fields.enumerate(); // Conditional -------------------- @@ -239,7 +242,7 @@ pub(super) fn run_datalog>( input_child(case, i_node), node_in_value_row(cond, in_row), let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), - if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + if let Some(fields) = in_row.unpack_first_no_bottom(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); // outputs of case nodes propagate to outputs of conditional *if* case reachable @@ -274,7 +277,9 @@ pub(super) fn run_datalog>( cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(), input_child(entry, i_node), - in_wire_value(cfg, p, v); + node_in_value_row(cfg, row), + if !row_contains_bottom(&row[..]), + for (p, v) in row[..].iter().enumerate(); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of : @@ -293,7 +298,7 @@ pub(super) fn run_datalog>( output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), - if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + if let Some(fields) = out_in_row.unpack_first_no_bottom(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), for (out_p, v) in fields.enumerate(); // Call -------------------- diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 0d8bc15a6..fc1d66818 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -8,7 +8,7 @@ use std::{ use ascent::{lattice::BoundedLattice, Lattice}; use itertools::zip_eq; -use super::{AbstractValue, PartialValue}; +use super::{row_contains_bottom, AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] pub(super) struct ValueRow(Vec>); @@ -25,16 +25,19 @@ impl ValueRow { r } - /// The first value in this ValueRow must be a sum; - /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, - /// then appending the rest of the values in this row. - pub fn unpack_first( + /// If the first value in this ValueRow is a sum, that might contain + /// the specified tag, then unpack the elements of that tag, append the rest + /// of this ValueRow, and if none of the elements of that row [contain bottom](PartialValue::contains_bottom), + /// return it. + /// Otherwise (if no such tag, or values contain bottom), return None. + pub fn unpack_first_no_bottom( &self, variant: usize, len: usize, ) -> Option>> { let vals = self[0].variant_values(variant, len)?; - Some(vals.into_iter().chain(self.0[1..].to_owned())) + (!row_contains_bottom(vals.iter().chain(self.0[1..].iter()))) + .then(|| vals.into_iter().chain(self.0[1..].to_owned())) } } From 69a69f3be88ebbfc947a5bbd4e13107fdb5da59c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 13:30:36 +0000 Subject: [PATCH 188/203] run_library => publish_function --- hugr-passes/src/dataflow/datalog.rs | 75 ++++++++++++++--------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 9a0426ae6..7a1d866b6 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{DataflowParent, NamedOp, OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowParent, FuncDefn, NamedOp, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -18,11 +18,12 @@ use super::{ type PV = PartialValue; -#[allow(rustdoc::private_intra_doc_links)] /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] -/// 2. (Optionally / for tests) zero or more [Self::prepopulate_wire] with initial values -/// 3. Call [Self::run] or [Self::run_library] to produce [AnalysisResults] +/// 2. (Optionally) For [Module](OpType::Module)-rooted Hugrs, zero or more calls +/// to [Self::publish_function] +// or [Self::prepopulate_wire] with initial values +/// 3. Call [Self::run] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); /// derived-Default requires the context to be Defaultable, which is unnecessary @@ -44,39 +45,35 @@ impl Machine { /// Run the analysis (iterate until a lattice fixpoint is reached), /// given initial values for some of the root node inputs. For a - /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. + /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"` + /// (it is an error if inputs are provided and there is no `"main"``). /// The context passed in allows interpretation of leaf operations. - /// - /// [Module]: OpType::Module pub fn run>( mut self, context: C, in_values: impl IntoIterator)>, ) -> AnalysisResults { + let mut in_values = in_values.into_iter(); let root = context.root(); // Some nodes do not accept values as dataflow inputs - for these // we must find the corresponding Output node. - let out_node_parent = match context.get_optype(root) { - OpType::Module(_) => Some( - context - .children(root) - .find(|n| { - context - .get_optype(*n) - .as_func_defn() - .is_some_and(|f| f.name() == "main") - }) - .expect("Module must contain a 'main' function to be analysed"), - ), + let input_node_parent = match context.get_optype(root) { + OpType::Module(_) => { + let main = find_func(&*context, "main"); + if main.is_none() && in_values.next().is_some() { + panic!("Cannot give inputs to module with no 'main'"); + } + main.map(|(n, _)| n) + } OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => Some(root), // Could also do Dfg above, but ok here too: _ => None, // Just feed into node inputs }; - // Now write values onto Input node out-wires or Outputs. // Any inputs we don't have values for, we must assume `Top` to ensure safety of analysis // (Consider: for a conditional that selects *either* the unknown input *or* value V, // analysis must produce Top == we-know-nothing, not `V` !) - if let Some(p) = out_node_parent { + if let Some(p) = input_node_parent { + // Put values onto out-wires of Input node let [inp, _] = context.get_io(p).unwrap(); let mut vals = vec![PartialValue::Top; context.signature(inp).unwrap().output_types().len()]; @@ -87,6 +84,7 @@ impl Machine { self.prepopulate_wire(&*context, Wire::new(inp, i), v); } } else { + // Put values onto in-wires of root node self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); let got_inputs: HashSet<_, RandomState> = self @@ -105,27 +103,28 @@ impl Machine { run_datalog(context, self.0) } - /// Run the analysis (iterate until a lattice fixpoint is reached), - /// for a [Module](OpType::Module)-rooted Hugr where all functions are assumed callable - /// (from a client) with any arguments. - /// The context passed in allows interpretation of leaf operations. - pub fn run_library>(mut self, context: C) -> AnalysisResults { - let root = context.root(); - if !context.get_optype(root).is_module() { - panic!("Hugr not Module-rooted") + /// For [Module](OpType::Module)-rooted Hugrs, mark a FuncDefn that is a child + /// of the root node as externally callable, i.e. with any arguments. + pub fn publish_function>(&mut self, context: C, name: &str) { + let (n, fd) = find_func(&*context, name).unwrap(); + let [inp, _] = context.get_io(n).unwrap(); + for p in 0..fd.inner_signature().input_count() { + self.prepopulate_wire(&*context, Wire::new(inp, p), PartialValue::Top); } - for n in context.children(root) { - if let Some(fd) = context.get_optype(n).as_func_defn() { - let [inp, _] = context.get_io(n).unwrap(); - for p in 0..fd.inner_signature().input_count() { - self.prepopulate_wire(&*context, Wire::new(inp, p), PartialValue::Top); - } - } - } - run_datalog(context, self.0) } } +fn find_func<'a>(h: &'a impl HugrView, name: &str) -> Option<(Node, &'a FuncDefn)> { + assert!(h.get_optype(h.root()).is_module()); + h.children(h.root()) + .filter_map(|n| { + h.get_optype(n) + .as_func_defn() + .and_then(|f| (f.name() == name).then_some((n, f))) + }) + .next() +} + pub(super) fn run_datalog>( ctx: C, in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, From 57ac432c1e35a1f6c57844fc21ae5422c7adc564 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 14:58:47 +0000 Subject: [PATCH 189/203] Drop publish_function, pub prepopulate_wire --- hugr-passes/src/dataflow/datalog.rs | 55 +++++++++++------------------ 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 7a1d866b6..b29185827 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{DataflowParent, FuncDefn, NamedOp, OpTrait, OpType, TailLoop}; +use hugr_core::ops::{NamedOp, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -20,9 +20,10 @@ type PV = PartialValue; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] -/// 2. (Optionally) For [Module](OpType::Module)-rooted Hugrs, zero or more calls -/// to [Self::publish_function] -// or [Self::prepopulate_wire] with initial values +/// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] with initial values. +/// For example, for a [Module](OpType::Module)-rooted Hugr, each externally-callable +/// [FuncDefn](OpType::FuncDefn) should have the out-wires from its [Input](OpType::Input) +/// node prepopulated with [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); @@ -34,9 +35,8 @@ impl Default for Machine { } impl Machine { - // Provide initial values for a wire - these will be `join`d with any computed. - // pub(crate) so can be used for tests. - pub(crate) fn prepopulate_wire(&mut self, h: &impl HugrView, w: Wire, v: PartialValue) { + /// Provide initial values for a wire - these will be `join`d with any computed. + pub fn prepopulate_wire(&mut self, h: &impl HugrView, w: Wire, v: PartialValue) { self.0.extend( h.linked_inputs(w.node(), w.source()) .map(|(n, inp)| (n, inp, v.clone())), @@ -45,9 +45,12 @@ impl Machine { /// Run the analysis (iterate until a lattice fixpoint is reached), /// given initial values for some of the root node inputs. For a - /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"` - /// (it is an error if inputs are provided and there is no `"main"``). + /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. /// The context passed in allows interpretation of leaf operations. + /// + /// # Panics + /// May panic in various ways if the Hugr is invalid; + /// or if any `in_values` are provided for a module-rooted Hugr without a function `"main"`. pub fn run>( mut self, context: C, @@ -56,14 +59,19 @@ impl Machine { let mut in_values = in_values.into_iter(); let root = context.root(); // Some nodes do not accept values as dataflow inputs - for these - // we must find the corresponding Output node. + // we must find the corresponding Input node. let input_node_parent = match context.get_optype(root) { OpType::Module(_) => { - let main = find_func(&*context, "main"); + let main = context.children(root).find(|n| { + context + .get_optype(*n) + .as_func_defn() + .is_some_and(|f| f.name() == "main") + }); if main.is_none() && in_values.next().is_some() { panic!("Cannot give inputs to module with no 'main'"); } - main.map(|(n, _)| n) + main } OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => Some(root), // Could also do Dfg above, but ok here too: @@ -84,7 +92,7 @@ impl Machine { self.prepopulate_wire(&*context, Wire::new(inp, i), v); } } else { - // Put values onto in-wires of root node + // Put values onto in-wires of root node, datalog will do the rest self.0 .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); let got_inputs: HashSet<_, RandomState> = self @@ -102,27 +110,6 @@ impl Machine { // for any nonlocal edges providing values from outside the region. run_datalog(context, self.0) } - - /// For [Module](OpType::Module)-rooted Hugrs, mark a FuncDefn that is a child - /// of the root node as externally callable, i.e. with any arguments. - pub fn publish_function>(&mut self, context: C, name: &str) { - let (n, fd) = find_func(&*context, name).unwrap(); - let [inp, _] = context.get_io(n).unwrap(); - for p in 0..fd.inner_signature().input_count() { - self.prepopulate_wire(&*context, Wire::new(inp, p), PartialValue::Top); - } - } -} - -fn find_func<'a>(h: &'a impl HugrView, name: &str) -> Option<(Node, &'a FuncDefn)> { - assert!(h.get_optype(h.root()).is_module()); - h.children(h.root()) - .filter_map(|n| { - h.get_optype(n) - .as_func_defn() - .and_then(|f| (f.name() == name).then_some((n, f))) - }) - .next() } pub(super) fn run_datalog>( From f9a9f2446bd7fac321c02a707e452d5a06781edb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 15:10:24 +0000 Subject: [PATCH 190/203] ValueRow::single_known => singleton, set --- hugr-passes/src/dataflow/datalog.rs | 12 +++--------- hugr-passes/src/dataflow/value_row.rs | 12 +++++++----- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b29185827..69e26f029 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -163,7 +163,7 @@ pub(super) fn run_datalog>( // Assemble node_in_value_row from in_wire_value's node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); - node_in_value_row(n, ValueRow::single_known(ctx.signature(*n).unwrap().input_count(), p.index(), v.clone())) <-- in_wire_value(n, p, v); + node_in_value_row(n, ValueRow::new(ctx.signature(*n).unwrap().input_count()).set(p.index(), v.clone())) <-- in_wire_value(n, p, v); // Interpret leaf ops out_wire_value(n, p, v) <-- @@ -351,11 +351,7 @@ fn propagate_leaf_op( .unwrap() .0; let const_val = ctx.get_optype(const_node).as_const().unwrap().value(); - Some(ValueRow::single_known( - 1, - 0, - partial_from_const(ctx, n, const_val), - )) + Some(ValueRow::singleton(partial_from_const(ctx, n, const_val))) } OpType::LoadFunction(load_op) => { assert!(ins.is_empty()); // static edge @@ -364,9 +360,7 @@ fn propagate_leaf_op( .unwrap() .0; // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself - Some(ValueRow::single_known( - 1, - 0, + Some(ValueRow::singleton( ctx.value_from_function(func_node, &load_op.type_args) .map_or(PV::Top, PV::Value), )) diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index fc1d66818..9360f36e3 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -18,11 +18,13 @@ impl ValueRow { Self(vec![PartialValue::bottom(); len]) } - pub fn single_known(len: usize, idx: usize, v: PartialValue) -> Self { - assert!(idx < len); - let mut r = Self::new(len); - r.0[idx] = v; - r + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + *self.0.get_mut(idx).unwrap() = v; + self + } + + pub fn singleton(v: PartialValue) -> Self { + Self(vec![v]) } /// If the first value in this ValueRow is a sum, that might contain From 9cc368df844a1c4e7c65aee76fc9eb3906ea8044 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 15:23:30 +0000 Subject: [PATCH 191/203] try_join / try_meet return extra bool --- hugr-passes/src/dataflow/partial_value.rs | 29 ++++++++++++----------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index cd0b1fb29..8e394ec43 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -14,23 +14,26 @@ use super::row_contains_bottom; /// [PartialValue] and thus be used in dataflow analysis. pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash { /// Computes the join of two values (i.e. towards `Top``), if this is representable - /// within the underlying domain. - /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Top]). + /// within the underlying domain. Return the new value, and whether this is different from + /// the old `self`. /// - /// The default checks equality between `self` and `other` and returns `self` if + /// If the join is not representable, return `None` - i.e., we should use [PartialValue::Top]. + /// + /// The default checks equality between `self` and `other` and returns `(self,false)` if /// the two are identical, otherwise `None`. - fn try_join(self, other: Self) -> Option { - (self == other).then_some(self) + fn try_join(self, other: Self) -> Option<(Self, bool)> { + (self == other).then_some((self, false)) } /// Computes the meet of two values (i.e. towards `Bottom`), if this is representable - /// within the underlying domain. - /// Otherwise return `None` (i.e. an instruction to use [PartialValue::Bottom]). + /// within the underlying domain. Return the new value, and whether this is different from + /// the old `self`. + /// If the meet is not representable, return `None` - i.e., we should use [PartialValue::Bottom]. /// - /// The default checks equality between `self` and `other` and returns `self` if + /// The default checks equality between `self` and `other` and returns `(self, false)` if /// the two are identical, otherwise `None`. - fn try_meet(self, other: Self) -> Option { - (self == other).then_some(self) + fn try_meet(self, other: Self) -> Option<(Self, bool)> { + (self == other).then_some((self, false)) } } @@ -398,8 +401,7 @@ impl Lattice for PartialValue { true } (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { - Some(h3) => { - let ch = h3 != *h1; + Some((h3, ch)) => { *self = Self::Value(h3); ch } @@ -441,8 +443,7 @@ impl Lattice for PartialValue { true } (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_meet(h2) { - Some(h3) => { - let ch = h3 != *h1; + Some((h3, ch)) => { *self = Self::Value(h3); ch } From a61fbdb73f8895856fd31724c90018e27c794e5f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 15:31:47 +0000 Subject: [PATCH 192/203] shorten/common-up meet_mut + join_mut --- hugr-passes/src/dataflow/partial_value.rs | 36 ++++++++++------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 8e394ec43..e6142a9f2 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -400,16 +400,14 @@ impl Lattice for PartialValue { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { - Some((h3, ch)) => { - *self = Self::Value(h3); - ch - } - None => { - *self = Self::Top; - true - } - }, + (Self::Value(h1), Self::Value(h2)) => { + let (nv, ch) = match h1.clone().try_join(h2) { + Some((h3, b)) => (Self::Value(h3), b), + None => (Self::Top, true), + }; + *self = nv; + ch + } (Self::PartialSum(_), Self::PartialSum(ps2)) => { let Self::PartialSum(ps1) = self else { unreachable!() @@ -442,16 +440,14 @@ impl Lattice for PartialValue { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_meet(h2) { - Some((h3, ch)) => { - *self = Self::Value(h3); - ch - } - None => { - *self = Self::Bottom; - true - } - }, + (Self::Value(h1), Self::Value(h2)) => { + let (h3, ch) = match h1.clone().try_meet(h2) { + Some((h3, ch)) => (Self::Value(h3), ch), + None => (Self::Bottom, true), + }; + *self = h3; + ch + } (Self::PartialSum(_), Self::PartialSum(ps2)) => { let Self::PartialSum(ps1) = self else { unreachable!() From 24cce0e01a656f8d1d49a086393bf30612070bb7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 16:44:21 +0000 Subject: [PATCH 193/203] try_into_value: change bounds TryFrom -> TryInto; rename =>try_into_sum --- hugr-passes/src/dataflow/partial_value.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index e6142a9f2..21c36668d 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -155,12 +155,13 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_value]. - pub fn try_into_value( + pub fn try_into_sum( self, typ: &Type, ) -> Result, ExtractValueError> where - V2: TryFrom + TryFrom, Error = SE>, + V: TryInto, + Sum: TryInto, { let Ok((k, v)) = self.0.iter().exactly_one() else { return Err(ExtractValueError::MultipleVariants(self)); @@ -351,15 +352,18 @@ impl PartialValue { /// incorrect), or if that [Sum] could not be converted into a `V2`. pub fn try_into_value(self, typ: &Type) -> Result> where - V2: TryFrom + TryFrom, Error = SE>, + V: TryInto, + Sum: TryInto, { match self { - Self::Value(v) => V2::try_from(v.clone()) + Self::Value(v) => v + .clone() + .try_into() .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), - Self::PartialSum(ps) => { - let v = ps.try_into_value(typ)?; - V2::try_from(v).map_err(ExtractValueError::CouldNotBuildSum) - } + Self::PartialSum(ps) => ps + .try_into_sum(typ)? + .try_into() + .map_err(ExtractValueError::CouldNotBuildSum), Self::Top => Err(ExtractValueError::ValueIsTop), Self::Bottom => Err(ExtractValueError::ValueIsBottom), } From a59076629e08ed71127d010debd27191e661042d Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 16:56:12 +0000 Subject: [PATCH 194/203] Avoid a clone in try_into_sum --- hugr-passes/src/dataflow/partial_value.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 21c36668d..989d40640 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -163,17 +163,18 @@ impl PartialSum { V: TryInto, Sum: TryInto, { - let Ok((k, v)) = self.0.iter().exactly_one() else { + if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); - }; + } + let (tag, v) = self.0.into_iter().exactly_one().unwrap(); if let TypeEnum::Sum(st) = typ.as_type_enum() { - if let Some(r) = st.get_variant(*k) { + if let Some(r) = st.get_variant(tag) { if let Ok(r) = TypeRow::try_from(r.clone()) { if v.len() == r.len() { return Ok(Sum { - tag: *k, + tag, values: zip_eq(v, r.iter()) - .map(|(v, t)| v.clone().try_into_value(t)) + .map(|(v, t)| v.try_into_value(t)) .collect::, _>>()?, st: st.clone(), }); @@ -183,7 +184,7 @@ impl PartialSum { } Err(ExtractValueError::BadSumType { typ: typ.clone(), - tag: *k, + tag, num_elements: v.len(), }) } From 2b2c461397cdea43e1f8b14395aef5a8e648b512 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 15:45:47 +0000 Subject: [PATCH 195/203] Optimize+shorten join_mut / meet_mut via std::mem::swap --- hugr-passes/src/dataflow/partial_value.rs | 102 ++++++++++------------ 1 file changed, 44 insertions(+), 58 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 989d40640..a9933e586 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -394,81 +394,67 @@ impl TryFrom> for Value { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - match (&*self, other) { - (Self::Top, _) => false, - (_, other @ Self::Top) => { - *self = other; - true + let mut old_self = Self::Top; // Good default result + std::mem::swap(self, &mut old_self); + match (old_self, other) { + (Self::Top, _) => false, // result is Top + (_, Self::Top) => true, // result is Top + (old, Self::Bottom) => { + *self = old; // reinstate + false } - (_, Self::Bottom) => false, (Self::Bottom, other) => { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => { - let (nv, ch) = match h1.clone().try_join(h2) { - Some((h3, b)) => (Self::Value(h3), b), - None => (Self::Top, true), - }; - *self = nv; - ch - } - (Self::PartialSum(_), Self::PartialSum(ps2)) => { - let Self::PartialSum(ps1) = self else { - unreachable!() - }; - match ps1.try_join_mut(ps2) { - Ok(ch) => ch, - Err(_) => { - *self = Self::Top; - true - } + (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { + Some((h3, b)) => { + *self = Self::Value(h3); + b } - } - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - *self = Self::Top; - true - } + None => true, // result is Top + }, + (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { + Ok(ch) => { + *self = Self::PartialSum(ps1); + ch + } + Err(_) => true, // result is Top + }, + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => true, // result is Top } } fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - match (&*self, other) { - (Self::Bottom, _) => false, - (_, other @ Self::Bottom) => { - *self = other; - true + let mut old_self = Self::Bottom; // Good default result + std::mem::swap(self, &mut old_self); + match (old_self, other) { + (Self::Bottom, _) => false, // result is Bottom + (_, Self::Bottom) => true, // result is Bottom + (old, Self::Top) => { + *self = old; //reinstate + false } - (_, Self::Top) => false, (Self::Top, other) => { *self = other; true } - (Self::Value(h1), Self::Value(h2)) => { - let (h3, ch) = match h1.clone().try_meet(h2) { - Some((h3, ch)) => (Self::Value(h3), ch), - None => (Self::Bottom, true), - }; - *self = h3; - ch - } - (Self::PartialSum(_), Self::PartialSum(ps2)) => { - let Self::PartialSum(ps1) = self else { - unreachable!() - }; - match ps1.try_meet_mut(ps2) { - Ok(ch) => ch, - Err(_) => { - *self = Self::Bottom; - true - } + (Self::Value(h1), Self::Value(h2)) => match h1.try_meet(h2) { + Some((h3, ch)) => { + *self = Self::Value(h3); + ch } - } - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - *self = Self::Bottom; - true - } + None => true, //result is Bottom + }, + (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { + Ok(ch) => { + *self = Self::PartialSum(ps1); + ch + } + Err(_) => true, + }, + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => true, // result is Bottom } } } From 124718d08758b86ac0106a41cceb04e3917468a4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 17:00:50 +0000 Subject: [PATCH 196/203] refactor join_mut / meet_mut again, common-up assignment --- hugr-passes/src/dataflow/partial_value.rs | 76 +++++++++-------------- 1 file changed, 28 insertions(+), 48 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index a9933e586..4bf5e927f 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -394,68 +394,48 @@ impl TryFrom> for Value { impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - let mut old_self = Self::Top; // Good default result + let mut old_self = Self::Top; std::mem::swap(self, &mut old_self); - match (old_self, other) { - (Self::Top, _) => false, // result is Top - (_, Self::Top) => true, // result is Top - (old, Self::Bottom) => { - *self = old; // reinstate - false - } - (Self::Bottom, other) => { - *self = other; - true - } + let (res, ch) = match (old_self, other) { + (old @ Self::Top, _) | (old, Self::Bottom) => (old, false), + (_, other @ Self::Top) | (Self::Bottom, other) => (other, true), (Self::Value(h1), Self::Value(h2)) => match h1.clone().try_join(h2) { - Some((h3, b)) => { - *self = Self::Value(h3); - b - } - None => true, // result is Top + Some((h3, b)) => (Self::Value(h3), b), + None => (Self::Top, true), }, (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { - Ok(ch) => { - *self = Self::PartialSum(ps1); - ch - } - Err(_) => true, // result is Top + Ok(ch) => (Self::PartialSum(ps1), ch), + Err(_) => (Self::Top, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => true, // result is Top - } + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + (Self::Top, true) + } + }; + *self = res; + ch } fn meet_mut(&mut self, other: Self) -> bool { self.assert_invariants(); - let mut old_self = Self::Bottom; // Good default result + let mut old_self = Self::Bottom; std::mem::swap(self, &mut old_self); - match (old_self, other) { - (Self::Bottom, _) => false, // result is Bottom - (_, Self::Bottom) => true, // result is Bottom - (old, Self::Top) => { - *self = old; //reinstate - false - } - (Self::Top, other) => { - *self = other; - true - } + let (res, ch) = match (old_self, other) { + (old @ Self::Bottom, _) | (old, Self::Top) => (old, false), + (_, other @ Self::Bottom) | (Self::Top, other) => (other, true), (Self::Value(h1), Self::Value(h2)) => match h1.try_meet(h2) { - Some((h3, ch)) => { - *self = Self::Value(h3); - ch - } - None => true, //result is Bottom + Some((h3, ch)) => (Self::Value(h3), ch), + None => (Self::Bottom, true), }, (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { - Ok(ch) => { - *self = Self::PartialSum(ps1); - ch - } - Err(_) => true, + Ok(ch) => (Self::PartialSum(ps1), ch), + Err(_) => (Self::Bottom, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => true, // result is Bottom - } + (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { + (Self::Bottom, true) + } + }; + *self = res; + ch } } From 93b1f4d64501f3182b40415939571e1662cbfbde Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 17:19:21 +0000 Subject: [PATCH 197/203] clippy --- hugr-passes/src/dataflow/datalog.rs | 3 +-- hugr-passes/src/dataflow/partial_value.rs | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 69e26f029..ab8c7c6f3 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -93,8 +93,7 @@ impl Machine { } } else { // Put values onto in-wires of root node, datalog will do the rest - self.0 - .extend(in_values.into_iter().map(|(p, v)| (root, p, v))); + self.0.extend(in_values.map(|(p, v)| (root, p, v))); let got_inputs: HashSet<_, RandomState> = self .0 .iter() diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 4bf5e927f..2a3507144 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -370,8 +370,8 @@ impl PartialValue { } } - /// A value contains bottom means that it cannot occur during execution - /// - it may be an artefact during bootstrapping of the analysis, or else + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else /// the value depends upon a `panic` or a loop that /// [never terminates](super::TailLoopTermination::NeverBreaks). pub fn contains_bottom(&self) -> bool { From 731a3b05bd7967af7bd5f06782c9e13b0d326588 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 17:28:39 +0000 Subject: [PATCH 198/203] doclinks --- hugr-passes/src/dataflow/partial_value.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 2a3507144..60a3ae514 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -344,12 +344,12 @@ impl PartialValue { /// Turns this instance into some "concrete" value type `V2`, *if* it is a single value, /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by - /// [PartialSum::try_into_value]. + /// [PartialSum::try_into_sum]. /// /// # Errors /// /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) - /// that could not be converted into a [Sum] by [PartialSum::try_into_value] (e.g. if `typ` is + /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. pub fn try_into_value(self, typ: &Type) -> Result> where From 7040e83d039028fd424e46eb4cd5c7e1f0a7bdb3 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 19 Nov 2024 21:46:02 +0000 Subject: [PATCH 199/203] prepopulate_df_inputs --- hugr-passes/src/dataflow/datalog.rs | 44 ++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index ab8c7c6f3..e71e34896 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -20,10 +20,11 @@ type PV = PartialValue; /// Basic structure for performing an analysis. Usage: /// 1. Get a new instance via [Self::default()] -/// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] with initial values. -/// For example, for a [Module](OpType::Module)-rooted Hugr, each externally-callable -/// [FuncDefn](OpType::FuncDefn) should have the out-wires from its [Input](OpType::Input) -/// node prepopulated with [PartialValue::Top]. +/// 2. (Optionally) zero or more calls to [Self::prepopulate_wire] and/or +/// [Self::prepopulate_df_inputs] with initial values. +/// For example, to analyse a [Module](OpType::Module)-rooted Hugr as a library, +/// [Self::prepopulate_df_inputs] can be used on each externally-callable +/// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] pub struct Machine(Vec<(Node, IncomingPort, PartialValue)>); @@ -43,6 +44,26 @@ impl Machine { ); } + /// Provide initial values for the inputs to a [DataflowParent](hugr_core::ops::OpTag::DataflowParent) + /// (that is, values on the wires leaving the [Input](OpType::Input) child thereof). + /// Any out-ports of said same `Input` node, not given values by `in_values`, are set to [PartialValue::Top]. + pub fn prepopulate_df_inputs( + &mut self, + h: &impl HugrView, + parent: Node, + in_values: impl IntoIterator)>, + ) { + // Put values onto out-wires of Input node + let [inp, _] = h.get_io(parent).unwrap(); + let mut vals = vec![PartialValue::Top; h.signature(inp).unwrap().output_types().len()]; + for (ip, v) in in_values { + vals[ip.index()] = v; + } + for (i, v) in vals.into_iter().enumerate() { + self.prepopulate_wire(h, Wire::new(inp, i), v); + } + } + /// Run the analysis (iterate until a lattice fixpoint is reached), /// given initial values for some of the root node inputs. For a /// [Module](OpType::Module)-rooted Hugr, these are input to the function `"main"`. @@ -81,16 +102,11 @@ impl Machine { // (Consider: for a conditional that selects *either* the unknown input *or* value V, // analysis must produce Top == we-know-nothing, not `V` !) if let Some(p) = input_node_parent { - // Put values onto out-wires of Input node - let [inp, _] = context.get_io(p).unwrap(); - let mut vals = - vec![PartialValue::Top; context.signature(inp).unwrap().output_types().len()]; - for (ip, v) in in_values { - vals[ip.index()] = v; - } - for (i, v) in vals.into_iter().enumerate() { - self.prepopulate_wire(&*context, Wire::new(inp, i), v); - } + self.prepopulate_df_inputs( + &*context, + p, + in_values.map(|(p, v)| (OutgoingPort::from(p.index()), v)), + ); } else { // Put values onto in-wires of root node, datalog will do the rest self.0.extend(in_values.map(|(p, v)| (root, p, v))); From 584327f36bfacf6e8e7c455aa84c84fb14a09f01 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 20 Nov 2024 13:52:42 +0000 Subject: [PATCH 200/203] Revert "Use row_contains_bottom for CFG+DFG, and augment unpack_first(=>_no_bottom)" This reverts commit e34c7bef059295be4bc439ec5db35cd4b92eeded. --- hugr-passes/src/dataflow/datalog.rs | 17 ++++++----------- hugr-passes/src/dataflow/value_row.rs | 15 ++++++--------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index e71e34896..b814b6440 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -195,10 +195,7 @@ pub(super) fn run_datalog>( dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), - input_child(dfg, i), - node_in_value_row(dfg, row), - if !row_contains_bottom(&row[..]), // Treat the DFG as a scheduling barrier - for (p, v) in row[..].iter().enumerate(); + input_child(dfg, i), in_wire_value(dfg, p, v); out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), output_child(dfg, o), in_wire_value(o, p, v); @@ -217,7 +214,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ...and select just what's possible for CONTINUE_TAG, if anything - if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), + if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop @@ -226,7 +223,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ... and select just what's possible for BREAK_TAG, if anything - if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), + if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), for (out_p, v) in fields.enumerate(); // Conditional -------------------- @@ -243,7 +240,7 @@ pub(super) fn run_datalog>( input_child(case, i_node), node_in_value_row(cond, in_row), let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), - if let Some(fields) = in_row.unpack_first_no_bottom(*case_index, conditional.sum_rows[*case_index].len()), + if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); // outputs of case nodes propagate to outputs of conditional *if* case reachable @@ -278,9 +275,7 @@ pub(super) fn run_datalog>( cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(), input_child(entry, i_node), - node_in_value_row(cfg, row), - if !row_contains_bottom(&row[..]), - for (p, v) in row[..].iter().enumerate(); + in_wire_value(cfg, p, v); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of : @@ -299,7 +294,7 @@ pub(super) fn run_datalog>( output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), - if let Some(fields) = out_in_row.unpack_first_no_bottom(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), for (out_p, v) in fields.enumerate(); // Call -------------------- diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 9360f36e3..50cf10318 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -8,7 +8,7 @@ use std::{ use ascent::{lattice::BoundedLattice, Lattice}; use itertools::zip_eq; -use super::{row_contains_bottom, AbstractValue, PartialValue}; +use super::{AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] pub(super) struct ValueRow(Vec>); @@ -27,19 +27,16 @@ impl ValueRow { Self(vec![v]) } - /// If the first value in this ValueRow is a sum, that might contain - /// the specified tag, then unpack the elements of that tag, append the rest - /// of this ValueRow, and if none of the elements of that row [contain bottom](PartialValue::contains_bottom), - /// return it. - /// Otherwise (if no such tag, or values contain bottom), return None. - pub fn unpack_first_no_bottom( + /// The first value in this ValueRow must be a sum; + /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, + /// then appending the rest of the values in this row. + pub fn unpack_first( &self, variant: usize, len: usize, ) -> Option>> { let vals = self[0].variant_values(variant, len)?; - (!row_contains_bottom(vals.iter().chain(self.0[1..].iter()))) - .then(|| vals.into_iter().chain(self.0[1..].to_owned())) + Some(vals.into_iter().chain(self.0[1..].to_owned())) } } From 5a9d8d65a13b4a7e042f5c2b70f33d20880cfad1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 20 Nov 2024 13:54:36 +0000 Subject: [PATCH 201/203] Redo: use row_contains_bottom for CFG+DFG, and augment unpack_first(=>_no_bottom) This reverts commit 584327f36bfacf6e8e7c455aa84c84fb14a09f01. --- hugr-passes/src/dataflow/datalog.rs | 17 +++++++++++------ hugr-passes/src/dataflow/value_row.rs | 15 +++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index b814b6440..e71e34896 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -195,7 +195,10 @@ pub(super) fn run_datalog>( dfg_node(n) <-- node(n), if ctx.get_optype(*n).is_dfg(); out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), - input_child(dfg, i), in_wire_value(dfg, p, v); + input_child(dfg, i), + node_in_value_row(dfg, row), + if !row_contains_bottom(&row[..]), // Treat the DFG as a scheduling barrier + for (p, v) in row[..].iter().enumerate(); out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), output_child(dfg, o), in_wire_value(o, p, v); @@ -214,7 +217,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ...and select just what's possible for CONTINUE_TAG, if anything - if let Some(fields) = out_in_row.unpack_first(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), + if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::CONTINUE_TAG, tailloop.just_inputs.len()), for (out_p, v) in fields.enumerate(); // Output node of child region propagate to outputs of tail loop @@ -223,7 +226,7 @@ pub(super) fn run_datalog>( output_child(tl, out_n), node_in_value_row(out_n, out_in_row), // get the whole input row for the output node... // ... and select just what's possible for BREAK_TAG, if anything - if let Some(fields) = out_in_row.unpack_first(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), + if let Some(fields) = out_in_row.unpack_first_no_bottom(TailLoop::BREAK_TAG, tailloop.just_outputs.len()), for (out_p, v) in fields.enumerate(); // Conditional -------------------- @@ -240,7 +243,7 @@ pub(super) fn run_datalog>( input_child(case, i_node), node_in_value_row(cond, in_row), let conditional = ctx.get_optype(*cond).as_conditional().unwrap(), - if let Some(fields) = in_row.unpack_first(*case_index, conditional.sum_rows[*case_index].len()), + if let Some(fields) = in_row.unpack_first_no_bottom(*case_index, conditional.sum_rows[*case_index].len()), for (out_p, v) in fields.enumerate(); // outputs of case nodes propagate to outputs of conditional *if* case reachable @@ -275,7 +278,9 @@ pub(super) fn run_datalog>( cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(), input_child(entry, i_node), - in_wire_value(cfg, p, v); + node_in_value_row(cfg, row), + if !row_contains_bottom(&row[..]), + for (p, v) in row[..].iter().enumerate(); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of : @@ -294,7 +299,7 @@ pub(super) fn run_datalog>( output_child(pred, out_n), _cfg_succ_dest(cfg, succ, dest), node_in_value_row(out_n, out_in_row), - if let Some(fields) = out_in_row.unpack_first(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), + if let Some(fields) = out_in_row.unpack_first_no_bottom(succ_n, df_block.sum_rows.get(succ_n).unwrap().len()), for (out_p, v) in fields.enumerate(); // Call -------------------- diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 50cf10318..9360f36e3 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -8,7 +8,7 @@ use std::{ use ascent::{lattice::BoundedLattice, Lattice}; use itertools::zip_eq; -use super::{AbstractValue, PartialValue}; +use super::{row_contains_bottom, AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] pub(super) struct ValueRow(Vec>); @@ -27,16 +27,19 @@ impl ValueRow { Self(vec![v]) } - /// The first value in this ValueRow must be a sum; - /// returns a new ValueRow given by unpacking the elements of the specified variant of said first value, - /// then appending the rest of the values in this row. - pub fn unpack_first( + /// If the first value in this ValueRow is a sum, that might contain + /// the specified tag, then unpack the elements of that tag, append the rest + /// of this ValueRow, and if none of the elements of that row [contain bottom](PartialValue::contains_bottom), + /// return it. + /// Otherwise (if no such tag, or values contain bottom), return None. + pub fn unpack_first_no_bottom( &self, variant: usize, len: usize, ) -> Option>> { let vals = self[0].variant_values(variant, len)?; - Some(vals.into_iter().chain(self.0[1..].to_owned())) + (!row_contains_bottom(vals.iter().chain(self.0[1..].iter()))) + .then(|| vals.into_iter().chain(self.0[1..].to_owned())) } } From d196e4d1633a48d39e11c4ed0904b64dbf732ae1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 20 Nov 2024 10:02:03 +0000 Subject: [PATCH 202/203] Rename in_wire_value_proto => in_wire_values_given --- hugr-passes/src/dataflow/datalog.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index e71e34896..252e2b0c4 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -129,7 +129,7 @@ impl Machine { pub(super) fn run_datalog>( ctx: C, - in_wire_value_proto: Vec<(Node, IncomingPort, PV)>, + in_wire_values_given: Vec<(Node, IncomingPort, PV)>, ) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. @@ -169,9 +169,9 @@ pub(super) fn run_datalog>( if let Some((m, op)) = ctx.single_linked_output(*n, *ip), out_wire_value(m, op, v); - // Prepopulate in_wire_value from in_wire_value_proto. + // Prepopulate in_wire_value from in_wire_values_given. in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); - in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_value_proto.iter(), + in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_values_given.iter(), node(n), if let Some(sig) = ctx.signature(*n), if sig.input_ports().contains(p); From 697f6ed7ef44ce93ba8f1969e29a19f824c940e1 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 20 Nov 2024 13:49:42 +0000 Subject: [PATCH 203/203] reachability v2 --- hugr-passes/src/dataflow/datalog.rs | 119 +++++++++++++++------------- hugr-passes/src/dataflow/results.rs | 40 +++------- hugr-passes/src/dataflow/test.rs | 8 +- 3 files changed, 79 insertions(+), 88 deletions(-) diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 252e2b0c4..302034486 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -7,7 +7,7 @@ use ascent::lattice::BoundedLattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{NamedOp, OpTrait, OpType, TailLoop}; +use hugr_core::ops::{NamedOp, OpTag, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; @@ -141,15 +141,21 @@ pub(super) fn run_datalog>( let all_results = ascent::ascent_run! { pub(super) struct AscentProgram; relation node(Node); // exists in the hugr + relation reachable(Node); // exists and is reachable relation in_wire(Node, IncomingPort); // has an of `EdgeKind::Value` relation out_wire(Node, OutgoingPort); // has an of `EdgeKind::Value` relation parent_of_node(Node, Node); // is parent of relation input_child(Node, Node); // has 1st child that is its `Input` relation output_child(Node, Node); // has 2nd child that is its `Output` - lattice out_wire_value(Node, OutgoingPort, PV); // produces, on , the value - lattice in_wire_value(Node, IncomingPort, PV); // receives, on , the value - lattice node_in_value_row(Node, ValueRow); // 's inputs are - + // produces, on , the value : + lattice out_wire_value(Node, OutgoingPort, PV); + // 's inputs would be , ignoring reachability: + lattice node_in_value_row_proto(Node, ValueRow); + // 's receives inputs, taking account of reachability: + lattice node_in_value_row(Node, ValueRow); + // receives, on , the value , accounting for reachability: + lattice in_wire_value(Node, IncomingPort, PV); + node(n) <-- for n in ctx.nodes(); in_wire(n, p) <-- node(n), for (p,_) in ctx.in_value_types(*n); // Note, gets connected inports only @@ -164,25 +170,36 @@ pub(super) fn run_datalog>( // Initialize all wires to bottom out_wire_value(n, p, PV::bottom()) <-- out_wire(n, p); - // Outputs to inputs - in_wire_value(n, ip, v) <-- in_wire(n, ip), - if let Some((m, op)) = ctx.single_linked_output(*n, *ip), + // out_wires -> node_in_value_row_proto, elements will be combined via join + node_in_value_row_proto(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); + node_in_value_row_proto(n, ValueRow::new(ctx.signature(*n).unwrap().input_count()).set(p.index(), v.clone())) <-- in_wire(n, p), + if let Some((m, op)) = ctx.single_linked_output(*n, *p), out_wire_value(m, op, v); - // Prepopulate in_wire_value from in_wire_values_given. - in_wire_value(n, p, PV::bottom()) <-- in_wire(n, p); - in_wire_value(n, p, v) <-- for (n, p, v) in in_wire_values_given.iter(), + // Also prepopulate in_wire_value_proto from in_wire_values_given. + node_in_value_row_proto(n, ValueRow::new(sig.input_count()).set(p.index(), v.clone())) <-- + for (n, p, v) in in_wire_values_given.iter(), node(n), if let Some(sig) = ctx.signature(*n), if sig.input_ports().contains(p); - // Assemble node_in_value_row from in_wire_value's - node_in_value_row(n, ValueRow::new(sig.input_count())) <-- node(n), if let Some(sig) = ctx.signature(*n); - node_in_value_row(n, ValueRow::new(ctx.signature(*n).unwrap().input_count()).set(p.index(), v.clone())) <-- in_wire_value(n, p, v); + // node_in_value_row from node_in_value_row_proto if parent reachable + node_in_value_row(n, r) <-- node_in_value_row_proto(n,r), reachable(n); + + // in_wire_value by decomposing node_in_value_row + in_wire_value(n, IncomingPort::from(p), v) <-- node_in_value_row(n, r), for (p,v) in r[..].iter().enumerate(); + + // Reachability for dataflow regions (Conditional->Case and CFG->BB handled separately) + reachable(ctx.root()); + reachable(n) <-- parent_of_node(p, n), + if OpTag::DataflowParent.is_superset(ctx.get_optype(*p).tag()), + reachable(p), + node_in_value_row_proto(n, r), + if !row_contains_bottom(&r[..]); // Interpret leaf ops out_wire_value(n, p, v) <-- - node(n), + node(n), reachable(n), let op_t = ctx.get_optype(*n), if !op_t.is_container(), if let Some(sig) = op_t.dataflow_signature(), @@ -197,7 +214,6 @@ pub(super) fn run_datalog>( out_wire_value(i, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), input_child(dfg, i), node_in_value_row(dfg, row), - if !row_contains_bottom(&row[..]), // Treat the DFG as a scheduling barrier for (p, v) in row[..].iter().enumerate(); out_wire_value(dfg, OutgoingPort::from(p.index()), v) <-- dfg_node(dfg), @@ -249,51 +265,49 @@ pub(super) fn run_datalog>( // outputs of case nodes propagate to outputs of conditional *if* case reachable out_wire_value(cond, OutgoingPort::from(o_p.index()), v) <-- case_node(cond, _i, case), - case_reachable(cond, case), + reachable(case), output_child(case, o), in_wire_value(o, o_p, v); // In `Conditional` , child `Case` is reachable given our knowledge of predicate: - relation case_reachable(Node, Node); - case_reachable(cond, case) <-- case_node(cond, i, case), + reachable(case) <-- case_node(cond, i, case), in_wire_value(cond, IncomingPort::from(0), v), - if v.supports_tag(*i); + if v.supports_tag(*i); // TODO better to check no Bottom within variant (or whole row) // CFG -------------------- - relation cfg_node(Node); // is a `CFG` - cfg_node(n) <-- node(n), if ctx.get_optype(*n).is_cfg(); + relation cfg_parent(Node, Node); // is a `CFG` and parent of + cfg_parent(p, n) <-- node(p), if ctx.get_optype(*p).is_cfg(), for n in ctx.children(*p); // In `CFG` , basic block is reachable given our knowledge of predicates: - relation bb_reachable(Node, Node); - bb_reachable(cfg, entry) <-- cfg_node(cfg), if let Some(entry) = ctx.children(*cfg).next(); - bb_reachable(cfg, bb) <-- cfg_node(cfg), - bb_reachable(cfg, pred), - output_child(pred, pred_out), - in_wire_value(pred_out, IncomingPort::from(0), predicate), - for (tag, bb) in ctx.output_neighbours(*pred).enumerate(), - if predicate.supports_tag(tag); - + reachable(entry) <-- cfg_parent(cfg, entry), if Some(*entry) == ctx.children(*cfg).next(); // Inputs of CFG propagate to entry block out_wire_value(i_node, OutgoingPort::from(p.index()), v) <-- - cfg_node(cfg), - if let Some(entry) = ctx.children(*cfg).next(), + cfg_parent(cfg, entry), + if Some(*entry) == ctx.children(*cfg).next(), input_child(entry, i_node), node_in_value_row(cfg, row), - if !row_contains_bottom(&row[..]), for (p, v) in row[..].iter().enumerate(); // In `CFG` , values fed along a control-flow edge to // come out of Value outports of : relation _cfg_succ_dest(Node, Node, Node); - _cfg_succ_dest(cfg, exit, cfg) <-- cfg_node(cfg), if let Some(exit) = ctx.children(*cfg).nth(1); - _cfg_succ_dest(cfg, blk, inp) <-- cfg_node(cfg), - for blk in ctx.children(*cfg), - if ctx.get_optype(blk).is_dataflow_block(), + _cfg_succ_dest(cfg, exit, cfg) <-- cfg_parent(cfg, exit), if Some(*exit) == ctx.children(*cfg).nth(1); + _cfg_succ_dest(cfg, blk, inp) <-- cfg_parent(cfg, blk), + if ctx.get_optype(*blk).is_dataflow_block(), input_child(blk, inp); // Outputs of each reachable block propagated to successor block or CFG itself + reachable(bb) <-- + cfg_parent(cfg, pred), + //reachable(pred), // Output will only get values if pred reachable + output_child(pred, pred_out), + in_wire_value(pred_out, IncomingPort::from(0), predicate), + for (tag, bb) in ctx.output_neighbours(*pred).enumerate(), + if predicate.supports_tag(tag); // TODO Better to check no Bottom in row or at least predicate-variant + out_wire_value(dest, OutgoingPort::from(out_p), v) <-- - bb_reachable(cfg, pred), + cfg_parent(cfg, pred), + //reachable(pred), // Output will only get values if pred reachable if let Some(df_block) = ctx.get_optype(*pred).as_dataflow_block(), for (succ_n, succ) in ctx.output_neighbours(*pred).enumerate(), output_child(pred, out_n), @@ -308,6 +322,7 @@ pub(super) fn run_datalog>( node(call), if ctx.get_optype(*call).is_call(), if let Some(func_defn) = ctx.static_source(*call); + reachable(func_defn) <-- func_call(call, func_defn), reachable(call); out_wire_value(inp, OutgoingPort::from(p.index()), v) <-- func_call(call, func), @@ -328,8 +343,7 @@ pub(super) fn run_datalog>( ctx, out_wire_values, in_wire_value: all_results.in_wire_value, - case_reachable: all_results.case_reachable, - bb_reachable: all_results.bb_reachable, + reachable: all_results.reachable.into_iter().map(|(x,)|x).collect() } } @@ -381,19 +395,16 @@ fn propagate_leaf_op( )) } OpType::ExtensionOp(e) => { - Some(ValueRow::from_iter(if row_contains_bottom(ins) { - // So far we think one or more inputs can't happen. - // So, don't pollute outputs with Top, and wait for better knowledge of inputs. - vec![PartialValue::Bottom; num_outs] - } else { - // Interpret op using DFContext - // Default to Top i.e. can't figure out anything about the outputs - let mut outs = vec![PartialValue::Top; num_outs]; - // It might be nice to convert `ins`` to [(IncomingPort, Value)], or some concrete value, - // for the context, but PV contains more information, and try_into_value may fail. - ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); - outs - })) + // If the inputs contain bottom (can't happen), the node should be unreachable. + assert!(!row_contains_bottom(ins)); + + // Interpret op using DFContext + // Default to Top i.e. can't figure out anything about the outputs + let mut outs = vec![PartialValue::Top; num_outs]; + // It might be nice to convert `ins`` to [(IncomingPort, Value)], or some concrete value, + // for the context, but PV contains more information, and try_into_value may fail. + ctx.interpret_leaf_op(n, e, ins, &mut outs[..]); + Some(ValueRow::from_iter(outs)) } o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index 21d6b13c0..0a615ad6f 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use hugr_core::{HugrView, IncomingPort, Node, PortIndex, Wire}; @@ -9,8 +9,7 @@ use super::{partial_value::ExtractValueError, AbstractValue, DFContext, PartialV pub struct AnalysisResults> { pub(super) ctx: C, pub(super) in_wire_value: Vec<(Node, IncomingPort, PartialValue)>, - pub(super) case_reachable: Vec<(Node, Node)>, - pub(super) bb_reachable: Vec<(Node, Node)>, + pub(super) reachable: HashSet, pub(super) out_wire_values: HashMap>, } @@ -41,39 +40,20 @@ impl> AnalysisResults { )) } - /// Tells whether a [Case] node is reachable, i.e. whether the predicate - /// to its parent [Conditional] may possibly have the tag corresponding to the [Case]. - /// Returns `None` if the specified `case` is not a [Case], or is not within a [Conditional] - /// (e.g. a [Case]-rooted Hugr). + /// Tells whether a node is reachable, i.e. can actually be evaluated if the Hugr was. + /// This includes + /// - Any dataflow node is only reachable if its parent is reachable *and* all the node's inputs are non-[PartialValue::Bottom] + /// - [Case] nodes being reachable only if their parent [Conditional] might possibly receive the corresponding tag + /// - [DataflowBlock]s and [ExitBlock]s only being reachable if some [CFG]-predecessor is reachable + /// and might (according to predicate) pass control flow to the block. /// /// [Case]: hugr_core::ops::Case /// [Conditional]: hugr_core::ops::Conditional - pub fn case_reachable(&self, case: Node) -> Option { - self.hugr().get_optype(case).as_case()?; - let cond = self.hugr().get_parent(case)?; - self.hugr().get_optype(cond).as_conditional()?; - Some( - self.case_reachable - .iter() - .any(|(cond2, case2)| &cond == cond2 && &case == case2), - ) - } - - /// Tells us if a block ([DataflowBlock] or [ExitBlock]) in a [CFG] is known - /// to be reachable. (Returns `None` if argument is not a child of a CFG.) - /// /// [CFG]: hugr_core::ops::CFG /// [DataflowBlock]: hugr_core::ops::DataflowBlock /// [ExitBlock]: hugr_core::ops::ExitBlock - pub fn bb_reachable(&self, bb: Node) -> Option { - let cfg = self.hugr().get_parent(bb)?; // Not really required...?? - self.hugr().get_optype(cfg).as_cfg()?; - let t = self.hugr().get_optype(bb); - (t.is_dataflow_block() || t.is_exit_block()).then(|| { - self.bb_reachable - .iter() - .any(|(cfg2, bb2)| *cfg2 == cfg && *bb2 == bb) - }) + pub fn reachable(&self, case: Node) -> bool { + self.reachable.contains(&case) } /// Reads a concrete representation of the value on an output wire, if the lattice value diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index c0fbf395a..226af8efa 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -302,10 +302,10 @@ fn test_conditional() { assert_eq!(cond_r1, Value::false_val()); assert!(results.try_read_wire_value::(cond_o2).is_err()); - assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only - assert_eq!(results.case_reachable(case2.node()), Some(true)); - assert_eq!(results.case_reachable(case3.node()), Some(true)); - assert_eq!(results.case_reachable(cond.node()), None); + assert_eq!(results.reachable(case1.node()), false); // arg_pv is variant 1 or 2 only + assert!(results.reachable(case2.node())); + assert!(results.reachable(case3.node())); + assert!(results.reachable(cond.node())); } // A Hugr being a function on bools: (x, y) => (x XOR y, x AND y)