From 555cd43175109a332d966fb124977d3ad00a0334 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Mon, 3 Feb 2025 20:12:25 -0600
Subject: [PATCH] Fix incorrect inference of no-memset

---
 hercules_ir/src/collections.rs                | 16 ++--------
 hercules_ir/src/einsum.rs                     | 29 +++++++++++++++++++
 hercules_samples/matmul/src/gpu.sch           |  2 ++
 .../fork_join_tests/src/fork_join_tests.jn    | 12 ++++----
 juno_samples/fork_join_tests/src/gpu.sch      |  8 +++--
 juno_scheduler/src/pm.rs                      |  2 +-
 6 files changed, 46 insertions(+), 23 deletions(-)

diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index d9d6c8f7..1bc650e9 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -1,4 +1,4 @@
-use std::collections::{BTreeMap, BTreeSet, HashMap};
+use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
 use std::iter::{once, repeat, zip};
 
 use either::Either;
@@ -432,20 +432,8 @@ pub fn no_reset_constant_collections(
                 init: _,
                 reduct: _,
             } => {
-                // If the einsum for this reduce node is a full array
-                // comprehension, then every array element is written to, and
-                // the empty indices set (the whole collection) is considered as
-                // written to.
-                let (env, exprs) = reduce_einsum;
-                if let Some(expr) = exprs.get(&id)
-                    && let MathExpr::Comprehension(_, _) = env[expr.idx()]
-                {
-                    ZeroLattice::top()
-                }
                 // Otherwise, meet the `init` and `reduct` inputs.
-                else {
-                    ZeroLattice::meet(&inputs[0], &inputs[1])
-                }
+                ZeroLattice::meet(&inputs[0], &inputs[1])
             }
             Node::Write {
                 collect: _,
diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index d1c276e5..8d3bec3a 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -400,6 +400,35 @@ impl<'a> EinsumContext<'a> {
     }
 }
 
+pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
+    let mut set = HashSet::new();
+    let mut stack = vec![id];
+    while let Some(id) = stack.pop() {
+        match env[id.idx()] {
+            MathExpr::Zero(_) | MathExpr::One(_) | MathExpr::ThreadID(_) => {}
+            MathExpr::OpaqueNode(id) => {
+                set.insert(id);
+            }
+            MathExpr::SumReduction(id, _) | MathExpr::Comprehension(id, _) => {
+                stack.push(id);
+            }
+            MathExpr::Read(id, ref ids) => {
+                stack.push(id);
+                stack.extend(ids);
+            }
+            MathExpr::Add(left, right)
+            | MathExpr::Sub(left, right)
+            | MathExpr::Mul(left, right)
+            | MathExpr::Div(left, right)
+            | MathExpr::Rem(left, right) => {
+                stack.push(left);
+                stack.push(right);
+            }
+        }
+    }
+    set
+}
+
 fn representable(op: BinaryOperator) -> bool {
     match op {
         BinaryOperator::Add
diff --git a/hercules_samples/matmul/src/gpu.sch b/hercules_samples/matmul/src/gpu.sch
index 4303d376..e5d3bbde 100644
--- a/hercules_samples/matmul/src/gpu.sch
+++ b/hercules_samples/matmul/src/gpu.sch
@@ -1,3 +1,5 @@
+no-memset(matmul@c);
+
 gvn(*);
 phi-elim(*);
 dce(*);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index fdfa51a8..6e5db4cb 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -1,6 +1,6 @@
 #[entry]
 fn test1(input : i32) -> i32[4, 4] {
-  let arr : i32[4, 4];
+  @const let arr : i32[4, 4];
   for i = 0 to 4 {
     for j = 0 to 4 {
       arr[i, j] = input;
@@ -24,19 +24,19 @@ fn test2(input : i32) -> i32[4, 4] {
 
 #[entry]
 fn test3(input : i32) -> i32[3, 3] {
-  let arr1 : i32[3, 3];
+  @const1 let arr1 : i32[3, 3];
   for i = 0 to 3 {
     for j = 0 to 3 {
       arr1[i, j] = (i + j) as i32 + input;
     }
   }
-  let arr2 : i32[3, 3];
+  @const2 let arr2 : i32[3, 3];
   for i = 0 to 3 {
     for j = 0 to 3 {
       arr2[i, j] = arr1[2 - i, 2 - j];
     }
   }
-  let arr3 : i32[3, 3];
+  @const3 let arr3 : i32[3, 3];
   for i = 0 to 3 {
     for j = 0 to 3 {
       arr3[i, j] = arr2[i, j] + 7;
@@ -54,7 +54,7 @@ fn test4(input : i32) -> i32[4, 4] {
       for k = 0 to 7 {
         acc += input;
       }
-      arr[i, j] = acc;
+      @reduce arr[i, j] = acc;
     }
   }
   return arr;
@@ -62,7 +62,7 @@ fn test4(input : i32) -> i32[4, 4] {
 
 #[entry]
 fn test5(input : i32) -> i32[4] {
-  @cons let arr1 : i32[4];
+  let arr1 : i32[4];
   for i = 0 to 4 {
     let red = arr1[i];
     for k = 0 to 3 {
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 0987083e..f096ea50 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -1,5 +1,9 @@
-no-memset(test5@cons);
+parallel-reduce(test4@reduce);
 parallel-reduce(test5@reduce);
+no-memset(test1@const);
+no-memset(test3@const1);
+no-memset(test3@const2);
+no-memset(test3@const3);
 
 gvn(*);
 phi-elim(*);
@@ -33,5 +37,5 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
-float-collections(test2, out.test2, test4, out.test4);
+float-collections(test2, out.test2, test4, out.test4, test5, out.test5);
 gcm(*);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 20a3dba8..9c445519 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1329,7 +1329,7 @@ fn run_pass(
     pm: &mut PassManager,
     pass: Pass,
     args: Vec<Value>,
-    mut selection: Option<Vec<CodeLocation>>,
+    selection: Option<Vec<CodeLocation>>,
 ) -> Result<(Value, bool), SchedulerError> {
     let mut result = Value::Record {
         fields: HashMap::new(),
-- 
GitLab