From 6a0c4a3410d40410dbdac7a92926583e9c6e640a Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sat, 1 Feb 2025 15:04:24 -0600
Subject: [PATCH] control in reduce cycle fixes

---
 Cargo.lock                                    | 11 ++++
 Cargo.toml                                    |  3 +-
 hercules_opt/src/forkify.rs                   | 12 +++++
 hercules_opt/src/unforkify.rs                 |  2 +-
 .../hercules_tests/tests/loop_tests.rs        |  4 +-
 juno_scheduler/src/pm.rs                      | 53 +++++++++++--------
 6 files changed, 59 insertions(+), 26 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 49630436..ad69bc72 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1181,6 +1181,17 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_test"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "rand",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_utils"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index ced011a9..46fc7eaa 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,7 +21,7 @@ members = [
 	"hercules_samples/ccp",
 
 	"juno_samples/simple3",
-  "juno_samples/patterns",
+	"juno_samples/patterns",
 	"juno_samples/matmul",
 	"juno_samples/casts_and_intrinsics",
 	"juno_samples/nested_ccp",
@@ -30,4 +30,5 @@ members = [
   	"juno_samples/cava",
 	"juno_samples/concat",
   	"juno_samples/schedule_test",
+	"juno_samples/test",
 ]
diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index ec4e9fbc..0f06627d 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -152,6 +152,7 @@ pub fn forkify_loop(
         .filter(|id| !l.control[id.idx()])
         .collect();
 
+    // FIXME: @xrouth
     if loop_preds.len() != 1 {
         return false;
     }
@@ -388,6 +389,7 @@ nest! {
             is_associative: bool,
         },
         LoopDependant(NodeID),
+        ControlDependant(NodeID), // This phi is redcutionable, but its cycle might depend on internal control within the loop.
         UsedByDependant(NodeID),
     }
 }
@@ -398,6 +400,7 @@ impl LoopPHI {
             LoopPHI::Reductionable { phi, .. } => *phi,
             LoopPHI::LoopDependant(node_id) => *node_id,
             LoopPHI::UsedByDependant(node_id) => *node_id,
+            LoopPHI::ControlDependant(node_id) => *node_id,
         }
     }
 }
@@ -415,6 +418,9 @@ pub fn analyze_phis<'a>(
     loop_nodes: &'a HashSet<NodeID>,
 ) -> impl Iterator<Item = LoopPHI> + 'a {
 
+    // We are also moving the phi from the top of the loop (the header),
+    // to the very end (the join). If there are uses of the phi somewhere in the loop,
+    // then they may try to use the phi (now a reduce) before it hits the join. 
     // Find data cycles within the loop of this phi, 
     // Start from the phis loop_continue_latch, and walk its uses until we find the original phi. 
 
@@ -509,6 +515,12 @@ pub fn analyze_phis<'a>(
             // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
             // by the time the reduce is triggered (at the end of the loop's internal control).
 
+            // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control.
+            // Which is not allowed.
+            if intersection.iter().any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi) || editor.node(loop_continue_latch).is_phi() {
+                return LoopPHI::ControlDependant(*phi);
+            }
+
             // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch.
             // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce.
             if intersection
diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs
index 85ffd233..7d158d1a 100644
--- a/hercules_opt/src/unforkify.rs
+++ b/hercules_opt/src/unforkify.rs
@@ -133,7 +133,7 @@ pub fn unforkify(
         if factors.len() > 1 {
             // For now, don't convert multi-dimensional fork-joins. Rely on pass
             // that splits fork-joins.
-            continue;
+            break; // Because we have to unforkify top down, we can't unforkify forks that are contained 
         }
         let join_control = nodes[join.idx()].try_join().unwrap();
         let tids: Vec<_> = editor
diff --git a/hercules_test/hercules_tests/tests/loop_tests.rs b/hercules_test/hercules_tests/tests/loop_tests.rs
index 5832a161..192c1366 100644
--- a/hercules_test/hercules_tests/tests/loop_tests.rs
+++ b/hercules_test/hercules_tests/tests/loop_tests.rs
@@ -401,7 +401,7 @@ fn matmul_pipeline() {
     let dyn_consts = [I, J, K];
 
     // FIXME: This path should not leave the crate
-    let mut module = parse_module_from_hbin("../../juno_samples/matmul/out.hbin");
+    let mut module = parse_module_from_hbin("../../juno_samples/test/out.hbin");
     //
     let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
     for i in 0..I {
@@ -425,7 +425,7 @@ fn matmul_pipeline() {
     };
     assert_eq!(correct_c[0], value);
 
-    let schedule = Some(default_schedule![Xdot, ForkSplit, Unforkify, Xdot,]);
+    let schedule = Some(default_schedule![AutoOutline, InterproceduralSROA, SROA, InferSchedules, DCE, Xdot, GCM]);
 
     module = run_schedule_on_hercules(module, schedule).unwrap();
 
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 2371e0f2..d2772c71 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1471,29 +1471,38 @@ fn run_pass(
         }
         Pass::Forkify => {
             assert!(args.is_empty());
-            pm.make_fork_join_maps();
-            pm.make_control_subgraphs();
-            pm.make_loops();
-            let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            let loops = pm.loops.take().unwrap();
-            let control_subgraphs = pm.control_subgraphs.take().unwrap();
-            for (((func, fork_join_map), loop_nest), control_subgraph) in
-                build_selection(pm, selection)
-                    .into_iter()
-                    .zip(fork_join_maps.iter())
-                    .zip(loops.iter())
-                    .zip(control_subgraphs.iter())
-            {
-                let Some(mut func) = func else {
-                    continue;
-                };
-                // TODO: uses direct return from forkify for now instead of
-                // func.modified, see comment on top of `forkify` for why. Fix
-                // this eventually.
-                changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
+            loop {
+                let mut inner_changed = false;
+                pm.make_fork_join_maps();
+                pm.make_control_subgraphs();
+                pm.make_loops();
+                let fork_join_maps = pm.fork_join_maps.take().unwrap();
+                let loops = pm.loops.take().unwrap();
+                let control_subgraphs = pm.control_subgraphs.take().unwrap();
+                for (((func, fork_join_map), loop_nest), control_subgraph) in
+                    build_selection(pm, selection.clone())
+                        .into_iter()
+                        .zip(fork_join_maps.iter())
+                        .zip(loops.iter())
+                        .zip(control_subgraphs.iter())
+                {
+                    let Some(mut func) = func else {
+                        continue;
+                    };
+                    // TODO: uses direct return from forkify for now instead of
+                    // func.modified, see comment on top of `forkify` for why. Fix
+                    // this eventually.
+                    let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
+                    changed |= c;
+                    inner_changed |= c; 
+                }
+                pm.delete_gravestones();
+                pm.clear_analyses();
+
+                if !inner_changed {
+                    break;
+                }
             }
-            pm.delete_gravestones();
-            pm.clear_analyses();
         }
         Pass::GCM => {
             assert!(args.is_empty());
-- 
GitLab