From 6b87d31c183f4eebae9bd50d93e516383ba17d57 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Thu, 16 Jan 2025 14:51:05 -0600
Subject: [PATCH] Simple Store Load Forwarding

---
 hercules_ir/src/ir.rs     |   2 +-
 hercules_opt/src/lib.rs   |   2 +
 hercules_opt/src/pass.rs  |  33 ++++++++
 hercules_opt/src/pred.rs  |  24 ------
 hercules_opt/src/slf.rs   | 160 ++++++++++++++++++++++++++++++++++++++
 hercules_opt/src/utils.rs |  56 ++++++++++++-
 juno_frontend/src/lib.rs  |   2 +
 7 files changed, 251 insertions(+), 28 deletions(-)
 create mode 100644 hercules_opt/src/slf.rs

diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index ffa338b5..cef94a2d 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -133,7 +133,7 @@ pub enum DynamicConstant {
  * operate on an index list, composing indices at different levels in a type
  * tree. Each type that can be indexed has a unique variant in the index enum.
  */
-#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 pub enum Index {
     Field(usize),
     Variant(usize),
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index 08d183a7..9935703e 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -17,6 +17,7 @@ pub mod pass;
 pub mod phi_elim;
 pub mod pred;
 pub mod schedule;
+pub mod slf;
 pub mod sroa;
 pub mod unforkify;
 pub mod utils;
@@ -38,6 +39,7 @@ pub use crate::pass::*;
 pub use crate::phi_elim::*;
 pub use crate::pred::*;
 pub use crate::schedule::*;
+pub use crate::slf::*;
 pub use crate::sroa::*;
 pub use crate::unforkify::*;
 pub use crate::utils::*;
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index d24b6563..1e7104ce 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -25,6 +25,7 @@ pub enum Pass {
     PhiElim,
     Forkify,
     ForkGuardElim,
+    SLF,
     WritePredication,
     Predication,
     SROA,
@@ -470,6 +471,38 @@ impl PassManager {
                     }
                     self.clear_analyses();
                 }
+                Pass::SLF => {
+                    self.make_def_uses();
+                    self.make_reverse_postorders();
+                    self.make_typing();
+                    let def_uses = self.def_uses.as_ref().unwrap();
+                    let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
+                    let typing = self.typing.as_ref().unwrap();
+                    for idx in 0..self.module.functions.len() {
+                        let constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.constants));
+                        let dynamic_constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.dynamic_constants));
+                        let types_ref = RefCell::new(std::mem::take(&mut self.module.types));
+                        let mut editor = FunctionEditor::new(
+                            &mut self.module.functions[idx],
+                            FunctionID::new(idx),
+                            &constants_ref,
+                            &dynamic_constants_ref,
+                            &types_ref,
+                            &def_uses[idx],
+                        );
+                        slf(&mut editor, &reverse_postorders[idx], &typing[idx]);
+
+                        self.module.constants = constants_ref.take();
+                        self.module.dynamic_constants = dynamic_constants_ref.take();
+                        self.module.types = types_ref.take();
+
+                        println!("{}", self.module.functions[idx].name);
+                        self.module.functions[idx].delete_gravestones();
+                    }
+                    self.clear_analyses();
+                }
                 Pass::WritePredication => {
                     self.make_def_uses();
                     let def_uses = self.def_uses.as_ref().unwrap();
diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs
index cfad7d1c..644c69d0 100644
--- a/hercules_opt/src/pred.rs
+++ b/hercules_opt/src/pred.rs
@@ -1,6 +1,5 @@
 use std::cmp::{max, min};
 use std::collections::{BTreeMap, BTreeSet};
-use std::iter::zip;
 
 use itertools::Itertools;
 
@@ -262,26 +261,3 @@ pub fn write_predication(editor: &mut FunctionEditor) {
         }
     }
 }
-
-/*
- * Helper function to tell if two lists of indices have the same structure.
- */
-fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool {
-    if indices1.len() == indices2.len() {
-        let mut equiv = true;
-        for pair in zip(indices1, indices2) {
-            equiv = equiv
-                && match pair {
-                    (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2,
-                    (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2,
-                    (Index::Position(ref pos1), Index::Position(ref pos2)) => {
-                        pos1.len() == pos2.len()
-                    }
-                    _ => false,
-                };
-        }
-        equiv
-    } else {
-        false
-    }
-}
diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs
new file mode 100644
index 00000000..981a0cce
--- /dev/null
+++ b/hercules_opt/src/slf.rs
@@ -0,0 +1,160 @@
+use std::collections::BTreeMap;
+
+use hercules_ir::*;
+
+use crate::*;
+
+/*
+ * The SLF lattice tracks what sub-values of a collection are known. Each sub-
+ * value is a node ID at a set of indices that were written at. A write to a set
+ * of indices that structurally maps a previous sub-value removes the old sub-
+ * value, since that write may overwrite the old known sub-value. The lattice
+ * top corresponds to every value is 0. When the sub-values at a set of indices
+ * are not known, the `subvalues` map stores `None` for the known value. When a
+ * write involves array positions, remove sub-values that are clobbered and
+ * insert an indices set with an empty positions list and a `None` value.
+ */
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub struct SLFLattice {
+    subvalues: BTreeMap<Box<[Index]>, Option<NodeID>>,
+}
+
+impl Semilattice for SLFLattice {
+    fn meet(a: &Self, b: &Self) -> Self {
+        // Merge the two maps. Find equal indices sets between `a` and `b` and
+        // keep their known sub-value if they're equivalent. All other indices
+        // sets in `a` or `b` map to `None`.
+        let mut ret = BTreeMap::new();
+        for (indices, a_subvalue) in &a.subvalues {
+            if let Some(b_subvalue) = b.subvalues.get(indices)
+                && a_subvalue == b_subvalue
+            {
+                // If both maps have the same sub-value for this set of indices,
+                // add it unmolested to the meet lattice value.
+                ret.insert(indices.clone(), *a_subvalue);
+            } else {
+                // If not both maps have a write at the same set of indices or
+                // if the writes don't match, then we don't know what's been
+                // written there.
+                ret.insert(indices.clone(), None);
+            }
+        }
+        for (indices, _) in &b.subvalues {
+            // Any indices sets in `b` that aren't in `ret` are indices sets
+            // that aren't in `a`, so the sub-value isn't known.
+            ret.entry(indices.clone()).or_insert(None);
+        }
+        SLFLattice { subvalues: ret }
+    }
+
+    fn top() -> Self {
+        SLFLattice {
+            subvalues: BTreeMap::new(),
+        }
+    }
+
+    fn bottom() -> Self {
+        let mut subvalues = BTreeMap::new();
+        // The empty indices set overlaps with all possible indices sets.
+        subvalues.insert(Box::new([]) as Box<[Index]>, None);
+        SLFLattice { subvalues }
+    }
+}
+
+/*
+ * Top level function to run store-to-load forwarding on a function. Looks for
+ * known values inside collections and replaces reads of those values with the
+ * values directly.
+ */
+pub fn slf(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, typing: &Vec<TypeID>) {
+    // First, run a dataflow analysis that looks at known values inside
+    // collections. Thanks to the value semantics of Hercules IR, this analysis
+    // is relatively simple and straightforward.
+    let func = editor.func();
+    let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| {
+        match func.nodes[id.idx()] {
+            Node::Phi {
+                control: _,
+                data: _,
+            }
+            | Node::Reduce {
+                control: _,
+                init: _,
+                reduct: _,
+            }
+            | Node::Ternary {
+                op: TernaryOperator::Select,
+                first: _,
+                second: _,
+                third: _,
+            } => inputs.into_iter().fold(SLFLattice::top(), |acc, input| {
+                SLFLattice::meet(&acc, input)
+            }),
+            Node::Write {
+                collect: _,
+                data,
+                ref indices,
+            } => {
+                // Start with the indices of the `collect` input.
+                let mut value = inputs[0].clone();
+
+                // Any indices sets that overlap with `indices` become `None`,
+                // since we no longer know what's stored there.
+                for (other_indices, subvalue) in value.subvalues.iter_mut() {
+                    if indices_may_overlap(other_indices, indices) {
+                        *subvalue = None;
+                    }
+                }
+
+                // Track `data` at `indices`.
+                value.subvalues.insert(indices.clone(), Some(data));
+
+                value
+            }
+            _ => SLFLattice::bottom(),
+        }
+    });
+
+    // Second, look for reads where the indices set either:
+    // 1. Equal the indices of a known sub-value. Then, the read can be replaced
+    //    by the known sub-value.
+    // 2. Otherwise, if the indices set doesn't overlap with any known or
+    //    unknown sub-value, then the read can be replaced by a zero constant.
+    // 3. Otherwise, the read can't be replaced.
+    // Keep track of which nodes we've already replaced, since a sub-value we
+    // knew previously may be the ID of an old node replaced previously.
+    let mut replacements = BTreeMap::new();
+    for id in editor.node_ids() {
+        let Node::Read {
+            collect,
+            ref indices,
+        } = editor.func().nodes[id.idx()]
+        else {
+            continue;
+        };
+        let subvalues = &lattice[collect.idx()].subvalues;
+
+        if let Some(sub_value) = subvalues.get(indices)
+            && let Some(mut known) = *sub_value
+        {
+            while let Some(replacement) = replacements.get(&known) {
+                known = *replacement;
+            }
+            editor.edit(|mut edit| {
+                edit = edit.replace_all_uses(id, known)?;
+                edit.delete_node(id)
+            });
+            replacements.insert(id, known);
+        } else if !subvalues
+            .keys()
+            .any(|other_indices| indices_may_overlap(other_indices, indices))
+        {
+            editor.edit(|mut edit| {
+                let zero = edit.add_zero_constant(typing[id.idx()]);
+                let zero = edit.add_node(Node::Constant { id: zero });
+                edit = edit.replace_all_uses(id, zero)?;
+                edit.delete_node(id)
+            });
+        }
+    }
+}
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index 2a4fd94c..6239a644 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -1,3 +1,5 @@
+use std::iter::zip;
+
 use hercules_ir::def_use::*;
 use hercules_ir::ir::*;
 
@@ -241,7 +243,7 @@ pub(crate) fn substitute_dynamic_constants_in_node(
 /*
  * Top level function to make a function have only a single return.
  */
-pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
+pub(crate) fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
     let returns: Vec<NodeID> = (0..editor.func().nodes.len())
         .filter(|idx| editor.func().nodes[*idx].is_return())
         .map(NodeID::new)
@@ -281,7 +283,7 @@ pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
     new_return
 }
 
-pub fn contains_between_control_flow(func: &Function) -> bool {
+pub(crate) fn contains_between_control_flow(func: &Function) -> bool {
     let num_control = func.nodes.iter().filter(|node| node.is_control()).count();
     assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node.");
     num_control > 2
@@ -291,7 +293,7 @@ pub fn contains_between_control_flow(func: &Function) -> bool {
  * Top level function to ensure a Hercules function contains at least one
  * control node that isn't the start or return nodes.
  */
-pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> {
+pub(crate) fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> {
     if !contains_between_control_flow(editor.func()) {
         let ret = editor
             .node_ids()
@@ -326,3 +328,51 @@ pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID
         )
     }
 }
+
+/*
+ * Helper function to tell if two lists of indices have the same structure.
+ */
+pub(crate) fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool {
+    if indices1.len() == indices2.len() {
+        let mut equiv = true;
+        for pair in zip(indices1, indices2) {
+            equiv = equiv
+                && match pair {
+                    (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2,
+                    (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2,
+                    (Index::Position(ref pos1), Index::Position(ref pos2)) => {
+                        assert_eq!(pos1.len(), pos2.len());
+                        true
+                    }
+                    _ => false,
+                };
+        }
+        equiv
+    } else {
+        false
+    }
+}
+
+/*
+ * Helper function to determine if two lists of indices may overlap.
+ */
+pub(crate) fn indices_may_overlap(indices1: &[Index], indices2: &[Index]) -> bool {
+    for pair in zip(indices1, indices2) {
+        match pair {
+            // Check that the field numbers are the same.
+            (Index::Field(idx1), Index::Field(idx2)) => {
+                if idx1 != idx2 {
+                    return false;
+                }
+            }
+            // Variant indices always may overlap, since it's the same
+            // underlying memory. Position indices always may overlap, since the
+            // indexing nodes may be the same at runtime.
+            (Index::Variant(_), Index::Variant(_)) | (Index::Position(_), Index::Position(_)) => {}
+            _ => panic!(),
+        }
+    }
+    // `zip` will exit as soon as either iterator is done - two sets of indices
+    // may overlap when one indexes a larger sub-value than the other.
+    true
+}
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index cfaf7a26..46d34891 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -182,6 +182,8 @@ pub fn compile_ir(
     add_pass!(pm, verify, WritePredication);
     add_pass!(pm, verify, PhiElim);
     add_pass!(pm, verify, DCE);
+    add_pass!(pm, verify, SLF);
+    add_pass!(pm, verify, DCE);
     add_pass!(pm, verify, Predication);
     add_pass!(pm, verify, DCE);
     add_pass!(pm, verify, CCP);
-- 
GitLab