From e29474df4bb1d23f327b24d858e0dbf6b1e1614b Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Fri, 14 Feb 2025 09:54:43 -0600
Subject: [PATCH] Forkify fixes x2

---
 hercules_opt/src/forkify.rs | 21 ++++++---------------
 hercules_opt/src/ivar.rs    | 30 ++++++++++++++++++++++++------
 juno_scheduler/src/pm.rs    |  6 +++---
 3 files changed, 33 insertions(+), 24 deletions(-)

diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index 2adfddd8..774220df 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -30,7 +30,7 @@ pub fn forkify(
 
     for l in natural_loops {
         // FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses.
-        if forkify_loop(
+        if editor.is_mutable(l.0) && forkify_loop(
             editor,
             control_subgraph,
             fork_join_map,
@@ -166,6 +166,7 @@ pub fn forkify_loop(
         return false;
     }
 
+
     // Get all phis used outside of the loop, they need to be reductionable.
     // For now just assume all phis will be phis used outside of the loop, except for the canonical iv.
     // FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one
@@ -182,7 +183,6 @@ pub fn forkify_loop(
     let reductionable_phis: Vec<_> = analyze_phis(&editor, &l, &candidate_phis, &loop_nodes)
         .into_iter()
         .collect();
-
     // TODO: Handle multiple loop body lasts.
     // If there are multiple candidates for loop body last, return false.
     if editor
@@ -327,7 +327,7 @@ pub fn forkify_loop(
         .collect();
 
     // Start failable edit:
-    editor.edit(|mut edit| {
+    let result = editor.edit(|mut edit| {
         let thread_id = Node::ThreadID {
             control: fork_id,
             dimension: dimension,
@@ -405,7 +405,7 @@ pub fn forkify_loop(
         Ok(edit)
     });
 
-    return true;
+    return result;
 }
 
 nest! {
@@ -457,7 +457,7 @@ pub fn analyze_phis<'a>(
 
                 // External Phi
                 if let Node::Phi { control, data: _ } = data {
-                    if *control != natural_loop.header {
+                    if !natural_loop.control[control.idx()] {
                         return true;
                     }
                 }
@@ -539,16 +539,7 @@ pub fn analyze_phis<'a>(
             // PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need
             // to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
             // by the time the reduce is triggered (at the end of the loop's internal control).
-            // If anything in the intersection is a phi (that isn't this own phi), then the reduction cycle depends on control.
-            // Which is not allowed.
-            if intersection
-                .iter()
-                .any(|cycle_node| editor.node(cycle_node).is_phi() && *cycle_node != *phi)
-                || editor.node(loop_continue_latch).is_phi()
-            {
-                return LoopPHI::ControlDependant(*phi);
-            }
-
+    
             // No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch.
             // If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce.
             if intersection
diff --git a/hercules_opt/src/ivar.rs b/hercules_opt/src/ivar.rs
index f7252d29..edadd717 100644
--- a/hercules_opt/src/ivar.rs
+++ b/hercules_opt/src/ivar.rs
@@ -1,3 +1,4 @@
+use core::panic;
 use std::collections::HashSet;
 
 use bitvec::prelude::*;
@@ -73,8 +74,13 @@ pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> Has
 
             // External Phi
             if let Node::Phi { control, data: _ } = data {
-                if !natural_loop.control[control.idx()] {
-                    return true;
+                match natural_loop.control.get(control.idx()) {
+                    Some(v) => if !*v {
+                        return true;
+                    },
+                    None => {
+                        panic!("unexpceted index: {:?} for loop {:?}", control, natural_loop.header);
+                    },
                 }
             }
             // External Reduce
@@ -84,14 +90,26 @@ pub fn calculate_loop_nodes(editor: &FunctionEditor, natural_loop: &Loop) -> Has
                 reduct: _,
             } = data
             {
-                if !natural_loop.control[control.idx()] {
-                    return true;
+                match natural_loop.control.get(control.idx()) {
+                    Some(v) => if !*v {
+                        return true;
+                    },
+                    None => {
+                        panic!("unexpceted index: {:?} for loop {:?}", control, natural_loop.header);
+                    },
                 }
             }
 
             // External Control
-            if data.is_control() && !natural_loop.control[node.idx()] {
-                return true;
+            if data.is_control() {
+                match natural_loop.control.get(node.idx()) {
+                    Some(v) => if !*v {
+                        return true;
+                    },
+                    None => {
+                        panic!("unexpceted index: {:?} for loop {:?}", node, natural_loop.header);
+                    },
+                }
             }
 
             return false;
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index aa09868f..b496a80d 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1859,9 +1859,9 @@ fn run_pass(
                     let Some(mut func) = func else {
                         continue;
                     };
-                    forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
-                    changed |= func.modified();
-                    inner_changed |= func.modified();
+                    let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
+                    changed |= c;
+                    inner_changed |= c;
                 }
                 pm.delete_gravestones();
                 pm.clear_analyses();
-- 
GitLab