From f21b3da304fe6083b24f75049e2704b95e541e18 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Thu, 13 Feb 2025 13:26:18 -0600
Subject: [PATCH 1/2] New product test that fails

---
 Cargo.lock                                                 | 2 +-
 Cargo.toml                                                 | 2 +-
 juno_samples/{product_read => products}/Cargo.toml         | 4 ++--
 juno_samples/{product_read => products}/build.rs           | 4 ++--
 juno_samples/{product_read => products}/src/gpu.sch        | 0
 juno_samples/{product_read => products}/src/main.rs        | 7 ++++++-
 .../src/product_read.jn => products/src/products.jn}       | 5 +++++
 7 files changed, 17 insertions(+), 7 deletions(-)
 rename juno_samples/{product_read => products}/Cargo.toml (88%)
 rename juno_samples/{product_read => products}/build.rs (81%)
 rename juno_samples/{product_read => products}/src/gpu.sch (100%)
 rename juno_samples/{product_read => products}/src/main.rs (63%)
 rename juno_samples/{product_read/src/product_read.jn => products/src/products.jn} (72%)

diff --git a/Cargo.lock b/Cargo.lock
index ffb61f4d..f6ffbed9 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1210,7 +1210,7 @@ dependencies = [
 ]
 
 [[package]]
-name = "juno_product_read"
+name = "juno_products"
 version = "0.1.0"
 dependencies = [
  "async-std",
diff --git a/Cargo.toml b/Cargo.toml
index 3e86bad0..0ed8f64b 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -33,5 +33,5 @@ members = [
 	"juno_samples/edge_detection",
 	"juno_samples/fork_join_tests",
 	"juno_samples/multi_device",
-	"juno_samples/product_read",
+	"juno_samples/products",
 ]
diff --git a/juno_samples/product_read/Cargo.toml b/juno_samples/products/Cargo.toml
similarity index 88%
rename from juno_samples/product_read/Cargo.toml
rename to juno_samples/products/Cargo.toml
index d466f555..34878a07 100644
--- a/juno_samples/product_read/Cargo.toml
+++ b/juno_samples/products/Cargo.toml
@@ -1,11 +1,11 @@
 [package]
-name = "juno_product_read"
+name = "juno_products"
 version = "0.1.0"
 authors = ["Aaron Councilman <aaronjc4@illinois.edu>"]
 edition = "2021"
 
 [[bin]]
-name = "juno_product_read"
+name = "juno_products"
 path = "src/main.rs"
 
 [features]
diff --git a/juno_samples/product_read/build.rs b/juno_samples/products/build.rs
similarity index 81%
rename from juno_samples/product_read/build.rs
rename to juno_samples/products/build.rs
index 2bd5172e..6d621961 100644
--- a/juno_samples/product_read/build.rs
+++ b/juno_samples/products/build.rs
@@ -4,7 +4,7 @@ fn main() {
     #[cfg(not(feature = "cuda"))]
     {
         JunoCompiler::new()
-            .file_in_src("product_read.jn")
+            .file_in_src("products.jn")
             .unwrap()
             .build()
             .unwrap();
@@ -12,7 +12,7 @@ fn main() {
     #[cfg(feature = "cuda")]
     {
         JunoCompiler::new()
-            .file_in_src("product_read.jn")
+            .file_in_src("products.jn")
             .unwrap()
             .schedule_in_src("gpu.sch")
             .unwrap()
diff --git a/juno_samples/product_read/src/gpu.sch b/juno_samples/products/src/gpu.sch
similarity index 100%
rename from juno_samples/product_read/src/gpu.sch
rename to juno_samples/products/src/gpu.sch
diff --git a/juno_samples/product_read/src/main.rs b/juno_samples/products/src/main.rs
similarity index 63%
rename from juno_samples/product_read/src/main.rs
rename to juno_samples/products/src/main.rs
index 5211098c..b8abb59d 100644
--- a/juno_samples/product_read/src/main.rs
+++ b/juno_samples/products/src/main.rs
@@ -2,7 +2,7 @@
 
 use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox};
 
-juno_build::juno!("product_read");
+juno_build::juno!("products");
 
 fn main() {
     async_std::task::block_on(async {
@@ -11,6 +11,11 @@ fn main() {
         let mut r = runner!(product_read);
         let res : Vec<i32> = HerculesMutBox::from(r.run(input.to()).await).as_slice().to_vec();
         assert_eq!(res, vec![0, 1, 2, 3]);
+
+        // Technically this returns a product of two i32s, but we can interpret that as an array
+        let mut r = runner!(product_return);
+        let res : Vec<i32> = HerculesMutBox::from(r.run(42, 17).await).as_slice().to_vec();
+        assert_eq!(res, vec![42, 17]);
     });
 }
 
diff --git a/juno_samples/product_read/src/product_read.jn b/juno_samples/products/src/products.jn
similarity index 72%
rename from juno_samples/product_read/src/product_read.jn
rename to juno_samples/products/src/products.jn
index 7bf74a10..4f56368e 100644
--- a/juno_samples/product_read/src/product_read.jn
+++ b/juno_samples/products/src/products.jn
@@ -7,3 +7,8 @@ fn product_read(input: (i32, i32)[2]) -> i32[4] {
   result[3] = input[1].1;
   return result;
 }
+
+#[entry]
+fn product_return(x: i32, y: i32) -> (i32, i32) {
+  return (x, y);
+}
-- 
GitLab


From e169b4f547ceeadc940f0f9812efdbc6700c56fc Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Thu, 13 Feb 2025 13:31:00 -0600
Subject: [PATCH 2/2] Reuse products optimization

---
 hercules_opt/src/lib.rs            |   2 +
 hercules_opt/src/reuse_products.rs | 212 +++++++++++++++++++++++++++++
 hercules_opt/src/sroa.rs           |  14 +-
 juno_samples/products/src/gpu.sch  |   1 +
 juno_scheduler/src/compile.rs      |   1 +
 juno_scheduler/src/default.rs      |   3 +
 juno_scheduler/src/ir.rs           |   1 +
 juno_scheduler/src/pm.rs           |  21 +++
 8 files changed, 248 insertions(+), 7 deletions(-)
 create mode 100644 hercules_opt/src/reuse_products.rs

diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index 7187508a..17c55bbe 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -18,6 +18,7 @@ pub mod lift_dc_math;
 pub mod outline;
 pub mod phi_elim;
 pub mod pred;
+pub mod reuse_products;
 pub mod schedule;
 pub mod simplify_cfg;
 pub mod slf;
@@ -43,6 +44,7 @@ pub use crate::lift_dc_math::*;
 pub use crate::outline::*;
 pub use crate::phi_elim::*;
 pub use crate::pred::*;
+pub use crate::reuse_products::*;
 pub use crate::schedule::*;
 pub use crate::simplify_cfg::*;
 pub use crate::slf::*;
diff --git a/hercules_opt/src/reuse_products.rs b/hercules_opt/src/reuse_products.rs
new file mode 100644
index 00000000..eb0b4a65
--- /dev/null
+++ b/hercules_opt/src/reuse_products.rs
@@ -0,0 +1,212 @@
+use std::collections::HashMap;
+
+use hercules_ir::ir::*;
+
+use crate::*;
+
+/*
+ * Reuse Products is an optimization pass which identifies when two product
+ * values are identical because each field of the "source" product is read and
+ * then written into the "destination" product and then replaces the destination
+ * product by the source product.
+ *
+ * This pattern can occur in our code because SROA and IP SROA are both
+ * aggressive about breaking products into their fields and reconstructing
+ * products right where needed, so if a function returns a product that is
+ * produced by a call node, these optimizations will produce code that reads the
+ * fields out of the call node and then writes them into the product that is
+ * returned.
+ *
+ * This optimization does not delete any nodes other than the destination nodes,
+ * if other nodes become dead as a result the clean up is left to DCE.
+ *
+ * The analysis for this starts by labeling each product source node (arguments,
+ * constants, and call nodes) with themselves as the source of all of their
+ * fields. Then, these field sources are propagated along read and write nodes.
+ * At the end all nodes with product values are labeled by the source (node and
+ * index) of each of its fields. We then check if any node's fields are exactly
+ * the fields of some other node (i.e. is exactly the same value as some other
+ * node) we replace it with that other node.
+ */
+pub fn reuse_products(
+    editor: &mut FunctionEditor,
+    reverse_postorder: &Vec<NodeID>,
+    types: &Vec<TypeID>,
+) {
+    let mut source_nodes = vec![];
+    let mut read_write_nodes = vec![];
+
+    for node in reverse_postorder {
+        match &editor.node(node) {
+            Node::Parameter { .. } | Node::Constant { .. } | Node::Call { .. }
+                if editor.get_type(types[node.idx()]).is_product() =>
+            {
+                source_nodes.push(*node)
+            }
+            Node::Write { .. } if editor.get_type(types[node.idx()]).is_product() => {
+                read_write_nodes.push(*node)
+            }
+            Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => {
+                read_write_nodes.push(*node)
+            }
+            _ => (),
+        }
+    }
+
+    let mut product_nodes: HashMap<NodeID, IndexTree<(NodeID, Vec<Index>)>> = HashMap::new();
+
+    for source in source_nodes {
+        product_nodes.insert(
+            source,
+            generate_source_info(editor, source, types[source.idx()]),
+        );
+    }
+
+    for node in read_write_nodes {
+        match editor.node(node) {
+            Node::Read { collect, indices } => {
+                let Some(collect) = product_nodes.get(collect) else {
+                    continue;
+                };
+                let result = collect.lookup(indices);
+                product_nodes.insert(node, result.clone());
+            }
+            Node::Write {
+                collect,
+                data,
+                indices,
+            } => {
+                let Some(collect) = product_nodes.get(collect) else {
+                    continue;
+                };
+                let Some(data) = product_nodes.get(data) else {
+                    continue;
+                };
+                let result = collect.clone().replace(indices, data.clone());
+                product_nodes.insert(node, result);
+            }
+            _ => panic!("Non read/write node"),
+        }
+    }
+
+    // Note that we don't have to worry about some node A being equivalent to node B but node B
+    // being equivalent to node C and being replaced first causing an issue when we try to replace
+    // node A with B.
+    // This cannot occur since the only nodes something can be equivalent with are the source nodes
+    // and they are all equivalent to precisely themselves which we ignore.
+    for (node, data) in product_nodes {
+        let Some(replace_with) = is_other_product(editor, types, data) else {
+            continue;
+        };
+
+        if replace_with != node {
+            editor.edit(|edit| {
+                let edit = edit.replace_all_uses(node, replace_with)?;
+                edit.delete_node(node)
+            });
+        }
+    }
+}
+
+fn generate_source_info(
+    editor: &FunctionEditor,
+    source: NodeID,
+    typ: TypeID,
+) -> IndexTree<(NodeID, Vec<Index>)> {
+    generate_source_info_at_index(editor, source, typ, vec![])
+}
+
+fn generate_source_info_at_index(
+    editor: &FunctionEditor,
+    source: NodeID,
+    typ: TypeID,
+    idx: Vec<Index>,
+) -> IndexTree<(NodeID, Vec<Index>)> {
+    let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
+        Some(ts.into())
+    } else {
+        None
+    };
+
+    if let Some(ts) = ts {
+        // Recurse on each field with an extended index and appropriate type
+        let mut fields = vec![];
+        for (i, t) in ts.into_iter().enumerate() {
+            let mut new_idx = idx.clone();
+            new_idx.push(Index::Field(i));
+            fields.push(generate_source_info_at_index(editor, source, t, new_idx));
+        }
+        IndexTree::Node(fields)
+    } else {
+        // We've reached the leaf
+        IndexTree::Leaf((source, idx))
+    }
+}
+
+fn is_other_product(
+    editor: &FunctionEditor,
+    types: &Vec<TypeID>,
+    node: IndexTree<(NodeID, Vec<Index>)>,
+) -> Option<NodeID> {
+    let Some(other_node) = find_only_node(&node) else {
+        return None;
+    };
+
+    if matches_fields_index(editor, types[other_node.idx()], &node, vec![]) {
+        Some(other_node)
+    } else {
+        None
+    }
+}
+
+fn find_only_node(tree: &IndexTree<(NodeID, Vec<Index>)>) -> Option<NodeID> {
+    match tree {
+        IndexTree::Leaf((node, _)) => Some(*node),
+        IndexTree::Node(fields) => fields
+            .iter()
+            .map(|t| find_only_node(t))
+            .reduce(|n, m| match (n, m) {
+                (Some(n), Some(m)) if n == m => Some(n),
+                (_, _) => None,
+            })
+            .flatten(),
+    }
+}
+
+fn matches_fields_index(
+    editor: &FunctionEditor,
+    typ: TypeID,
+    tree: &IndexTree<(NodeID, Vec<Index>)>,
+    index: Vec<Index>,
+) -> bool {
+    match tree {
+        IndexTree::Leaf((_, idx)) => {
+            // If in the original value we still have a product, these can't match
+            if editor.get_type(typ).is_product() {
+                false
+            } else {
+                *idx == index
+            }
+        }
+        IndexTree::Node(fields) => {
+            let ts: Vec<TypeID> = if let Some(ts) = editor.get_type(typ).try_product() {
+                ts.into()
+            } else {
+                return false;
+            };
+
+            if fields.len() != ts.len() {
+                return false;
+            }
+
+            ts.into_iter()
+                .zip(fields.iter())
+                .enumerate()
+                .all(|(i, (ty, field))| {
+                    let mut new_index = index.clone();
+                    new_index.push(Index::Field(i));
+                    matches_fields_index(editor, ty, field, new_index)
+                })
+        }
+    }
+}
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 3210094d..dbb2f8ce 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -700,13 +700,13 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
 
 // An index tree is used to store results at many index lists
 #[derive(Clone, Debug)]
-enum IndexTree<T> {
+pub enum IndexTree<T> {
     Leaf(T),
     Node(Vec<IndexTree<T>>),
 }
 
 impl<T: std::fmt::Debug> IndexTree<T> {
-    fn lookup(&self, idx: &[Index]) -> &IndexTree<T> {
+    pub fn lookup(&self, idx: &[Index]) -> &IndexTree<T> {
         self.lookup_idx(idx, 0)
     }
 
@@ -725,7 +725,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
         }
     }
 
-    fn set(self, idx: &[Index], val: T) -> IndexTree<T> {
+    pub fn set(self, idx: &[Index], val: T) -> IndexTree<T> {
         self.set_idx(idx, val, 0)
     }
 
@@ -756,7 +756,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
         }
     }
 
-    fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> {
+    pub fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> {
         self.replace_idx(idx, val, 0)
     }
 
@@ -787,7 +787,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
         }
     }
 
-    fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> {
+    pub fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> {
         match (self, other) {
             (IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)),
             (IndexTree::Node(t), IndexTree::Node(a)) => {
@@ -801,7 +801,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
         }
     }
 
-    fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> {
+    pub fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> {
         match self {
             IndexTree::Leaf(t) => {
                 let mut res = vec![];
@@ -835,7 +835,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
         }
     }
 
-    fn for_each<F>(&self, mut f: F)
+    pub fn for_each<F>(&self, mut f: F)
     where
         F: FnMut(&Vec<Index>, &T),
     {
diff --git a/juno_samples/products/src/gpu.sch b/juno_samples/products/src/gpu.sch
index 549b4215..5ef4c479 100644
--- a/juno_samples/products/src/gpu.sch
+++ b/juno_samples/products/src/gpu.sch
@@ -7,6 +7,7 @@ gpu(out.product_read);
 
 ip-sroa(*);
 sroa(*);
+reuse-products(*);
 crc(*);
 dce(*);
 gvn(*);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 6b40001c..7887b9b3 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -119,6 +119,7 @@ impl FromStr for Appliable {
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
             "predication" => Ok(Appliable::Pass(ir::Pass::Predication)),
             "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)),
+            "reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)),
             "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)),
             "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)),
             "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)),
diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs
index 3f4af107..0621f8de 100644
--- a/juno_scheduler/src/default.rs
+++ b/juno_scheduler/src/default.rs
@@ -45,6 +45,8 @@ pub fn default_schedule() -> ScheduleStmt {
         SROA,
         PhiElim,
         DCE,
+        ReuseProducts,
+        DCE,
         CCP,
         SimplifyCFG,
         DCE,
@@ -88,6 +90,7 @@ pub fn default_schedule() -> ScheduleStmt {
         AutoOutline,
         InterproceduralSROA,
         SROA,
+        ReuseProducts,
         SimplifyCFG,
         InferSchedules,
         DCE,
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 840f25a6..5bfb4e21 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -28,6 +28,7 @@ pub enum Pass {
     PhiElim,
     Predication,
     ReduceSLF,
+    ReuseProducts,
     SLF,
     SROA,
     Serialize,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index f59834ed..342f875b 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2011,6 +2011,27 @@ fn run_pass(
                 changed |= func.modified();
             }
         }
+        Pass::ReuseProducts => {
+            assert!(args.is_empty());
+            pm.make_reverse_postorders();
+            pm.make_typing();
+            let reverse_postorders = pm.reverse_postorders.take().unwrap();
+            let typing = pm.typing.take().unwrap();
+
+            for ((func, reverse_postorder), types) in build_selection(pm, selection, false)
+                .into_iter()
+                .zip(reverse_postorders.iter())
+                .zip(typing.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                reuse_products(&mut func, reverse_postorder, types);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::SLF => {
             assert!(args.is_empty());
             pm.make_reverse_postorders();
-- 
GitLab