diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index ffa338b5a4169a1646b0ca0a582dba44777ffd12..cef94a2d3e082a012a023cf47fdf6305942d986a 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -133,7 +133,7 @@ pub enum DynamicConstant { * operate on an index list, composing indices at different levels in a type * tree. Each type that can be indexed has a unique variant in the index enum. */ -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub enum Index { Field(usize), Variant(usize), diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 08d183a7f8c2ddfc650bb1ef3ce385761a137def..9935703e839b5a4dd705af03f40a456008fe0f12 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -17,6 +17,7 @@ pub mod pass; pub mod phi_elim; pub mod pred; pub mod schedule; +pub mod slf; pub mod sroa; pub mod unforkify; pub mod utils; @@ -38,6 +39,7 @@ pub use crate::pass::*; pub use crate::phi_elim::*; pub use crate::pred::*; pub use crate::schedule::*; +pub use crate::slf::*; pub use crate::sroa::*; pub use crate::unforkify::*; pub use crate::utils::*; diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index d24b6563f6e330df89a5206349ff1103a04e607b..1e7104ce6461b94d2404fb8f59c29b2b7c75a67e 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -25,6 +25,7 @@ pub enum Pass { PhiElim, Forkify, ForkGuardElim, + SLF, WritePredication, Predication, SROA, @@ -470,6 +471,38 @@ impl PassManager { } self.clear_analyses(); } + Pass::SLF => { + self.make_def_uses(); + self.make_reverse_postorders(); + self.make_typing(); + let def_uses = self.def_uses.as_ref().unwrap(); + let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); + let typing = self.typing.as_ref().unwrap(); + for idx in 0..self.module.functions.len() { + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( + &mut self.module.functions[idx], + FunctionID::new(idx), + &constants_ref, + &dynamic_constants_ref, + &types_ref, + &def_uses[idx], + ); + slf(&mut editor, &reverse_postorders[idx], &typing[idx]); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + + println!("{}", self.module.functions[idx].name); + self.module.functions[idx].delete_gravestones(); + } + self.clear_analyses(); + } Pass::WritePredication => { self.make_def_uses(); let def_uses = self.def_uses.as_ref().unwrap(); diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs index cfad7d1c710e3295bef8f41426a68c24b4d61d10..644c69d0df34d327c2c2e34bf8e0a915ddd68233 100644 --- a/hercules_opt/src/pred.rs +++ b/hercules_opt/src/pred.rs @@ -1,6 +1,5 @@ use std::cmp::{max, min}; use std::collections::{BTreeMap, BTreeSet}; -use std::iter::zip; use itertools::Itertools; @@ -262,26 +261,3 @@ pub fn write_predication(editor: &mut FunctionEditor) { } } } - -/* - * Helper function to tell if two lists of indices have the same structure. - */ -fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool { - if indices1.len() == indices2.len() { - let mut equiv = true; - for pair in zip(indices1, indices2) { - equiv = equiv - && match pair { - (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2, - (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2, - (Index::Position(ref pos1), Index::Position(ref pos2)) => { - pos1.len() == pos2.len() - } - _ => false, - }; - } - equiv - } else { - false - } -} diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs new file mode 100644 index 0000000000000000000000000000000000000000..981a0cce2d4dd1ba72ba3525aa9e31467cd6ecc7 --- /dev/null +++ b/hercules_opt/src/slf.rs @@ -0,0 +1,160 @@ +use std::collections::BTreeMap; + +use hercules_ir::*; + +use crate::*; + +/* + * The SLF lattice tracks what sub-values of a collection are known. Each sub- + * value is a node ID at a set of indices that were written at. A write to a set + * of indices that structurally maps a previous sub-value removes the old sub- + * value, since that write may overwrite the old known sub-value. The lattice + * top corresponds to every value is 0. When the sub-values at a set of indices + * are not known, the `subvalues` map stores `None` for the known value. When a + * write involves array positions, remove sub-values that are clobbered and + * insert an indices set with an empty positions list and a `None` value. + */ +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SLFLattice { + subvalues: BTreeMap<Box<[Index]>, Option<NodeID>>, +} + +impl Semilattice for SLFLattice { + fn meet(a: &Self, b: &Self) -> Self { + // Merge the two maps. Find equal indices sets between `a` and `b` and + // keep their known sub-value if they're equivalent. All other indices + // sets in `a` or `b` map to `None`. + let mut ret = BTreeMap::new(); + for (indices, a_subvalue) in &a.subvalues { + if let Some(b_subvalue) = b.subvalues.get(indices) + && a_subvalue == b_subvalue + { + // If both maps have the same sub-value for this set of indices, + // add it unmolested to the meet lattice value. + ret.insert(indices.clone(), *a_subvalue); + } else { + // If not both maps have a write at the same set of indices or + // if the writes don't match, then we don't know what's been + // written there. + ret.insert(indices.clone(), None); + } + } + for (indices, _) in &b.subvalues { + // Any indices sets in `b` that aren't in `ret` are indices sets + // that aren't in `a`, so the sub-value isn't known. + ret.entry(indices.clone()).or_insert(None); + } + SLFLattice { subvalues: ret } + } + + fn top() -> Self { + SLFLattice { + subvalues: BTreeMap::new(), + } + } + + fn bottom() -> Self { + let mut subvalues = BTreeMap::new(); + // The empty indices set overlaps with all possible indices sets. + subvalues.insert(Box::new([]) as Box<[Index]>, None); + SLFLattice { subvalues } + } +} + +/* + * Top level function to run store-to-load forwarding on a function. Looks for + * known values inside collections and replaces reads of those values with the + * values directly. + */ +pub fn slf(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, typing: &Vec<TypeID>) { + // First, run a dataflow analysis that looks at known values inside + // collections. Thanks to the value semantics of Hercules IR, this analysis + // is relatively simple and straightforward. + let func = editor.func(); + let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| { + match func.nodes[id.idx()] { + Node::Phi { + control: _, + data: _, + } + | Node::Reduce { + control: _, + init: _, + reduct: _, + } + | Node::Ternary { + op: TernaryOperator::Select, + first: _, + second: _, + third: _, + } => inputs.into_iter().fold(SLFLattice::top(), |acc, input| { + SLFLattice::meet(&acc, input) + }), + Node::Write { + collect: _, + data, + ref indices, + } => { + // Start with the indices of the `collect` input. + let mut value = inputs[0].clone(); + + // Any indices sets that overlap with `indices` become `None`, + // since we no longer know what's stored there. + for (other_indices, subvalue) in value.subvalues.iter_mut() { + if indices_may_overlap(other_indices, indices) { + *subvalue = None; + } + } + + // Track `data` at `indices`. + value.subvalues.insert(indices.clone(), Some(data)); + + value + } + _ => SLFLattice::bottom(), + } + }); + + // Second, look for reads where the indices set either: + // 1. Equal the indices of a known sub-value. Then, the read can be replaced + // by the known sub-value. + // 2. Otherwise, if the indices set doesn't overlap with any known or + // unknown sub-value, then the read can be replaced by a zero constant. + // 3. Otherwise, the read can't be replaced. + // Keep track of which nodes we've already replaced, since a sub-value we + // knew previously may be the ID of an old node replaced previously. + let mut replacements = BTreeMap::new(); + for id in editor.node_ids() { + let Node::Read { + collect, + ref indices, + } = editor.func().nodes[id.idx()] + else { + continue; + }; + let subvalues = &lattice[collect.idx()].subvalues; + + if let Some(sub_value) = subvalues.get(indices) + && let Some(mut known) = *sub_value + { + while let Some(replacement) = replacements.get(&known) { + known = *replacement; + } + editor.edit(|mut edit| { + edit = edit.replace_all_uses(id, known)?; + edit.delete_node(id) + }); + replacements.insert(id, known); + } else if !subvalues + .keys() + .any(|other_indices| indices_may_overlap(other_indices, indices)) + { + editor.edit(|mut edit| { + let zero = edit.add_zero_constant(typing[id.idx()]); + let zero = edit.add_node(Node::Constant { id: zero }); + edit = edit.replace_all_uses(id, zero)?; + edit.delete_node(id) + }); + } + } +} diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 2a4fd94c7d527147336553b0ade231967b67d004..6239a644c5f63a37fe2e7b48450b40388288c088 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -1,3 +1,5 @@ +use std::iter::zip; + use hercules_ir::def_use::*; use hercules_ir::ir::*; @@ -241,7 +243,7 @@ pub(crate) fn substitute_dynamic_constants_in_node( /* * Top level function to make a function have only a single return. */ -pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { +pub(crate) fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { let returns: Vec<NodeID> = (0..editor.func().nodes.len()) .filter(|idx| editor.func().nodes[*idx].is_return()) .map(NodeID::new) @@ -281,7 +283,7 @@ pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { new_return } -pub fn contains_between_control_flow(func: &Function) -> bool { +pub(crate) fn contains_between_control_flow(func: &Function) -> bool { let num_control = func.nodes.iter().filter(|node| node.is_control()).count(); assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node."); num_control > 2 @@ -291,7 +293,7 @@ pub fn contains_between_control_flow(func: &Function) -> bool { * Top level function to ensure a Hercules function contains at least one * control node that isn't the start or return nodes. */ -pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> { +pub(crate) fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> { if !contains_between_control_flow(editor.func()) { let ret = editor .node_ids() @@ -326,3 +328,51 @@ pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID ) } } + +/* + * Helper function to tell if two lists of indices have the same structure. + */ +pub(crate) fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool { + if indices1.len() == indices2.len() { + let mut equiv = true; + for pair in zip(indices1, indices2) { + equiv = equiv + && match pair { + (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2, + (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2, + (Index::Position(ref pos1), Index::Position(ref pos2)) => { + assert_eq!(pos1.len(), pos2.len()); + true + } + _ => false, + }; + } + equiv + } else { + false + } +} + +/* + * Helper function to determine if two lists of indices may overlap. + */ +pub(crate) fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool { + for pair in zip(indices1, indices2) { + match pair { + // Check that the field numbers are the same. + (Index::Field(idx1), Index::Field(idx2)) => { + if idx1 != idx2 { + return false; + } + } + // Variant indices always may overlap, since it's the same + // underlying memory. Position indices always may overlap, since the + // indexing nodes may be the same at runtime. + (Index::Variant(_), Index::Variant(_)) | (Index::Position(_), Index::Position(_)) => {} + _ => panic!(), + } + } + // `zip` will exit as soon as either iterator is done - two sets of indices + // may overlap when one indexes a larger sub-value than the other. + true +} diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index cfaf7a2687d062d2c4da3b2504a9097f753ca319..46d3489156d34f5152388d08e6e66a600600b4f8 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -182,6 +182,8 @@ pub fn compile_ir( add_pass!(pm, verify, WritePredication); add_pass!(pm, verify, PhiElim); add_pass!(pm, verify, DCE); + add_pass!(pm, verify, SLF); + add_pass!(pm, verify, DCE); add_pass!(pm, verify, Predication); add_pass!(pm, verify, DCE); add_pass!(pm, verify, CCP);