From 544024f2c6a919ff1122e81a6d7e98aea6c425c4 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Sat, 15 Feb 2025 11:39:57 -0600
Subject: [PATCH] Fix reduce cycles and nodes in fork joins

---
 hercules_ir/src/fork_join_analysis.rs | 41 ++++++++++++++++++-----
 hercules_opt/src/editor.rs            |  6 +++-
 juno_samples/matmul/src/gpu.sch       | 37 ++++++++-------------
 juno_samples/matmul/src/main.rs       | 12 -------
 juno_samples/matmul/src/matmul.jn     | 47 ---------------------------
 juno_scheduler/src/pm.rs              | 24 +++++++-------
 6 files changed, 63 insertions(+), 104 deletions(-)

diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs
index bba6ac42..3e89ae90 100644
--- a/hercules_ir/src/fork_join_analysis.rs
+++ b/hercules_ir/src/fork_join_analysis.rs
@@ -87,7 +87,7 @@ pub fn reduce_cycles(
     function: &Function,
     def_use: &ImmutableDefUseMap,
     fork_join_map: &HashMap<NodeID, NodeID>,
-    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
 ) -> HashMap<NodeID, HashSet<NodeID>> {
     let reduces = (0..function.nodes.len())
         .filter(|idx| function.nodes[*idx].is_reduce())
@@ -101,6 +101,24 @@ pub fn reduce_cycles(
     for reduce in reduces {
         let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
         let fork = join_fork_map[&join];
+        let might_be_in_fork_join = |id| {
+            fork_join_nest
+                .get(&id)
+                .map(|nest| nest.contains(&fork))
+                .unwrap_or(true)
+                && function.nodes[id.idx()]
+                    .try_phi()
+                    .map(|(control, _)| fork_join_nest[&control].contains(&fork))
+                    .unwrap_or(true)
+                && function.nodes[id.idx()]
+                    .try_reduce()
+                    .map(|(control, _, _)| fork_join_nest[&control].contains(&fork))
+                    .unwrap_or(true)
+                && !function.nodes[id.idx()].is_parameter()
+                && !function.nodes[id.idx()].is_constant()
+                && !function.nodes[id.idx()].is_dynamic_constant()
+                && !function.nodes[id.idx()].is_undef()
+        };
 
         // Find nodes in the fork-join that the reduce can reach through uses.
         let mut reachable_uses = HashSet::new();
@@ -109,10 +127,7 @@ pub fn reduce_cycles(
         workset.push(reduct);
         while let Some(pop) = workset.pop() {
             for u in get_uses(&function.nodes[pop.idx()]).as_ref() {
-                if !reachable_uses.contains(u)
-                    && nodes_in_fork_joins[&fork].contains(u)
-                    && *u != reduce
-                {
+                if !reachable_uses.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
                     reachable_uses.insert(*u);
                     workset.push(*u);
                 }
@@ -126,10 +141,7 @@ pub fn reduce_cycles(
         workset.push(reduce);
         while let Some(pop) = workset.pop() {
             for u in def_use.get_users(pop) {
-                if !reachable_users.contains(u)
-                    && nodes_in_fork_joins[&fork].contains(u)
-                    && *u != reduce
-                {
+                if !reachable_users.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
                     reachable_users.insert(*u);
                     workset.push(*u);
                 }
@@ -155,6 +167,7 @@ pub fn nodes_in_fork_joins(
     function: &Function,
     def_use: &ImmutableDefUseMap,
     fork_join_map: &HashMap<NodeID, NodeID>,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
 ) -> HashMap<NodeID, HashSet<NodeID>> {
     let mut result = HashMap::new();
 
@@ -164,6 +177,7 @@ pub fn nodes_in_fork_joins(
         let mut set = HashSet::new();
         set.insert(*fork);
 
+        // Iterate uses of the fork.
         while let Some(item) = worklist.pop() {
             for u in def_use.get_users(item) {
                 let terminate = *u == *join
@@ -177,6 +191,15 @@ pub fn nodes_in_fork_joins(
                 set.insert(*u);
             }
         }
+        assert!(set.contains(join));
+
+        // Add all the nodes in the reduce cycle. Some of these nodes may not
+        // use thread IDs of the fork, so do this explicitly.
+        for u in def_use.get_users(*join) {
+            if let Some(cycle) = reduce_cycles.get(u) {
+                set.extend(cycle);
+            }
+        }
 
         result.insert(*fork, set);
     }
diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index 8c339d72..16e5c326 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -292,7 +292,11 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
             // gravestone.
             for id in deleted_nodeids.iter() {
                 // Check that there are no users of deleted nodes.
-                assert!(editor.mut_def_use[id.idx()].is_empty());
+                assert!(
+                    editor.mut_def_use[id.idx()].is_empty(),
+                    "PANIC: Attempted to delete node {:?}, but there are still users of this node ({:?}).",
+                    id, editor.mut_def_use[id.idx()]
+                );
                 editor.function.nodes[id.idx()] = Node::Start;
             }
 
diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index e6eb3641..edb83d74 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -1,34 +1,23 @@
-gvn(*);
 phi-elim(*);
+
+forkify(*);
+fork-guard-elim(*);
 dce(*);
 
-let out = auto-outline(*);
-gpu(out.matmul, out.tiled_64_matmul);
+fixpoint {
+  reduce-slf(*);
+  slf(*);
+  infer-schedules(*);
+}
+fork-coalesce(*);
+infer-schedules(*);
+dce(*);
 
+let out = auto-outline(*);
+gpu(out.matmul);
 ip-sroa(*);
 sroa(*);
 dce(*);
-gvn(*);
-phi-elim(*);
-dce(*);
 
-fixpoint panic after 20 {
-  forkify(out.matmul);
-  fork-guard-elim(out.matmul);
-}
-gvn(out.matmul);
-phi-elim(out.matmul);
-dce(out.matmul);
-
-fixpoint panic after 20 {
-  reduce-slf(out.matmul);
-  slf(out.matmul);
-  infer-schedules(out.matmul);
-}
-fork-coalesce(out.matmul);
-dce(out.matmul);
-
-gcm(*);
 float-collections(*);
-dce(*);
 gcm(*);
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index 98b6c777..2eb2804b 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -32,11 +32,6 @@ fn main() {
                 .run(I as u64, J as u64, K as u64, a.clone(), b.clone())
                 .await;
             assert_eq!(c.as_slice::<i32>(), &*correct_c);
-            let mut r = runner!(tiled_64_matmul);
-            let tiled_c = r
-                .run(I as u64, J as u64, K as u64, a.clone(), b.clone())
-                .await;
-            assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c);
         }
         #[cfg(feature = "cuda")]
         {
@@ -49,13 +44,6 @@ fn main() {
             let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
             c.to_cpu_ref(&mut c_cpu);
             assert_eq!(&*c_cpu, &*correct_c);
-            let mut r = runner!(tiled_64_matmul);
-            let tiled_c = r
-                .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
-                .await;
-            let mut tiled_c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
-            tiled_c.to_cpu_ref(&mut tiled_c_cpu);
-            assert_eq!(&*tiled_c_cpu, &*correct_c);
         }
     });
 }
diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn
index fb6de5bd..e36d94e2 100644
--- a/juno_samples/matmul/src/matmul.jn
+++ b/juno_samples/matmul/src/matmul.jn
@@ -10,52 +10,5 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[
     }
   }
 
-  @exit
-  return res;
-}
-
-#[entry]
-fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
-  let res : i32[n, l];
-  let atile : i32[64, 64];
-  let btile : i32[64, 64];
-  let ctile : i32[64, 64];
-
-  for bi = 0 to n / 64 {
-    for bk = 0 to l / 64 {
-      for ti = 0 to 64 {
-        for tk = 0 to 64 {
-	  atile[ti, tk] = 0;
-	  btile[ti, tk] = 0;
-	  ctile[ti, tk] = 0;
-	}
-      }
-
-      for tile_idx = 0 to m / 64 {
-        for ti = 0 to 64 {
-	  for tk = 0 to 64 {
-	    atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk];
-	    btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk];
-	  }
-	}
-        for ti = 0 to 64 {
-	  for tk = 0 to 64 {
-	    let c_acc = ctile[ti, tk];
-	    for inner_idx = 0 to 64 {
-	      c_acc += atile[ti, inner_idx] * btile[inner_idx, tk];
-	    }
-	    ctile[ti, tk] = c_acc;
-	  }
-	}
-      }
-
-      for ti = 0 to 64 {
-        for tk = 0 to 64 {
-	  res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk];
-	}
-      }
-    }
-  }
-
   return res;
 }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index a4783a93..8e152cfe 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -401,21 +401,19 @@ impl PassManager {
         if self.reduce_cycles.is_none() {
             self.make_def_uses();
             self.make_fork_join_maps();
-            self.make_nodes_in_fork_joins();
+            self.make_fork_join_nests();
             let def_uses = self.def_uses.as_ref().unwrap().iter();
             let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
-            let nodes_in_fork_joins = self.nodes_in_fork_joins.as_ref().unwrap().iter();
+            let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter();
             self.reduce_cycles = Some(
                 self.functions
                     .iter()
                     .zip(fork_join_maps)
-                    .zip(nodes_in_fork_joins)
+                    .zip(fork_join_nests)
                     .zip(def_uses)
-                    .map(
-                        |(((function, fork_join_map), nodes_in_fork_joins), def_use)| {
-                            reduce_cycles(function, def_use, fork_join_map, nodes_in_fork_joins)
-                        },
-                    )
+                    .map(|(((function, fork_join_map), fork_join_nests), def_use)| {
+                        reduce_cycles(function, def_use, fork_join_map, fork_join_nests)
+                    })
                     .collect(),
             );
         }
@@ -425,16 +423,20 @@ impl PassManager {
         if self.nodes_in_fork_joins.is_none() {
             self.make_def_uses();
             self.make_fork_join_maps();
+            self.make_reduce_cycles();
             self.nodes_in_fork_joins = Some(
                 zip(
                     self.functions.iter(),
                     zip(
                         self.def_uses.as_ref().unwrap().iter(),
-                        self.fork_join_maps.as_ref().unwrap().iter(),
+                        zip(
+                            self.fork_join_maps.as_ref().unwrap().iter(),
+                            self.reduce_cycles.as_ref().unwrap().iter(),
+                        ),
                     ),
                 )
-                .map(|(function, (def_use, fork_join_map))| {
-                    nodes_in_fork_joins(function, def_use, fork_join_map)
+                .map(|(function, (def_use, (fork_join_map, reduce_cycles)))| {
+                    nodes_in_fork_joins(function, def_use, fork_join_map, reduce_cycles)
                 })
                 .collect(),
             );
-- 
GitLab