From 86b0e53182e7d1d5fe8daa9afa18e020c1043b77 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Tue, 28 Jan 2025 14:08:02 -0600
Subject: [PATCH] Possible fix for reduce cycles

---
 hercules_ir/src/loops.rs | 17 +++++++++++++++--
 juno_scheduler/src/pm.rs | 14 ++++++++++++--
 2 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/hercules_ir/src/loops.rs b/hercules_ir/src/loops.rs
index 1d706c78..f7bd5468 100644
--- a/hercules_ir/src/loops.rs
+++ b/hercules_ir/src/loops.rs
@@ -233,14 +233,21 @@ fn loop_reachability_helper(
 pub fn reduce_cycles(
     function: &Function,
     def_use: &ImmutableDefUseMap,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
 ) -> HashMap<NodeID, HashSet<NodeID>> {
     let reduces = (0..function.nodes.len())
         .filter(|idx| function.nodes[*idx].is_reduce())
         .map(NodeID::new);
     let mut result = HashMap::new();
+    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
+        .into_iter()
+        .map(|(fork, join)| (*join, *fork))
+        .collect();
 
     for reduce in reduces {
-        let (_, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
+        let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
+        let fork = join_fork_map[&join];
 
         // First, find all data nodes that are used by the `reduct` input of the
         // reduce, including the `reduct` itself.
@@ -249,7 +256,13 @@ pub fn reduce_cycles(
         let mut worklist = vec![reduct];
         while let Some(item) = worklist.pop() {
             for u in get_uses(&function.nodes[item.idx()]).as_ref() {
-                if !function.nodes[u.idx()].is_control() && !use_reachable.contains(u) {
+                if !function.nodes[u.idx()].is_control()
+                    && !use_reachable.contains(u)
+                    && function.nodes[u.idx()]
+                        .try_phi()
+                        .map(|(control, _)| fork_join_nest[&fork].contains(&control))
+                        .unwrap_or(true)
+                {
                     use_reachable.insert(*u);
                     worklist.push(*u);
                 }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index aa540064..570aa3f1 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -354,10 +354,20 @@ impl PassManager {
     pub fn make_reduce_cycles(&mut self) {
         if self.reduce_cycles.is_none() {
             self.make_def_uses();
+            self.make_fork_join_maps();
+            self.make_fork_join_nests();
             let def_uses = self.def_uses.as_ref().unwrap().iter();
+            let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
+            let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter();
             self.reduce_cycles = Some(
-                zip(self.functions.iter(), def_uses)
-                    .map(|(function, def_use)| reduce_cycles(function, def_use))
+                self.functions
+                    .iter()
+                    .zip(def_uses)
+                    .zip(fork_join_maps)
+                    .zip(fork_join_nests)
+                    .map(|(((function, def_use), fork_join_map), fork_join_nest)| {
+                        reduce_cycles(function, def_use, fork_join_map, fork_join_nest)
+                    })
                     .collect(),
             );
         }
-- 
GitLab