From a1baea161e0114374de969618cec803a56e42ea0 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 12 Feb 2025 17:08:36 -0600
Subject: [PATCH 1/4] Loop unroll skeleton

---
 hercules_opt/src/lib.rs       |  2 ++
 hercules_opt/src/unroll.rs    | 18 ++++++++++++++++++
 juno_scheduler/src/compile.rs |  1 +
 juno_scheduler/src/ir.rs      |  1 +
 juno_scheduler/src/pm.rs      | 18 ++++++++++++++++++
 5 files changed, 40 insertions(+)
 create mode 100644 hercules_opt/src/unroll.rs

diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index 7187508a..a810dfbf 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -23,6 +23,7 @@ pub mod simplify_cfg;
 pub mod slf;
 pub mod sroa;
 pub mod unforkify;
+pub mod unroll;
 pub mod utils;
 
 pub use crate::ccp::*;
@@ -48,4 +49,5 @@ pub use crate::simplify_cfg::*;
 pub use crate::slf::*;
 pub use crate::sroa::*;
 pub use crate::unforkify::*;
+pub use crate::unroll::*;
 pub use crate::utils::*;
diff --git a/hercules_opt/src/unroll.rs b/hercules_opt/src/unroll.rs
new file mode 100644
index 00000000..f3c795ca
--- /dev/null
+++ b/hercules_opt/src/unroll.rs
@@ -0,0 +1,18 @@
+use bitvec::prelude::*;
+
+use hercules_ir::*;
+
+use crate::*;
+
+/*
+ * Run loop unrolling on all loops that are mutable in an editor.
+ */
+pub fn loop_unroll_all_loops(editor: &mut FunctionEditor, loops: &LoopTree) {
+    for (header, contents) in loops.bottom_up_loops() {
+        if editor.is_mutable(header) {
+            loop_unroll(editor, header, contents);
+        }
+    }
+}
+
+pub fn loop_unroll(editor: &mut FunctionEditor, header: NodeID, contents: &BitVec<u8, Lsb0>) {}
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 1aaa10cd..9d5a86cc 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -114,6 +114,7 @@ impl FromStr for Appliable {
             "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
+            "loop-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::LoopUnroll)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
             "predication" => Ok(Appliable::Pass(ir::Pass::Predication)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 0ecac39a..1bb6cf13 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -23,6 +23,7 @@ pub enum Pass {
     Inline,
     InterproceduralSROA,
     LiftDCMath,
+    LoopUnroll,
     Outline,
     PhiElim,
     Predication,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 9c7391ac..5c6aec5e 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1665,6 +1665,24 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::LoopUnroll => {
+            assert_eq!(args.len(), 0);
+
+            pm.make_loops();
+            let loops = pm.loops.take().unwrap();
+            for (func, loops) in build_selection(pm, selection, false)
+                .into_iter()
+                .zip(loops.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                loop_unroll_all_loops(&mut func, loops);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::Forkify => {
             assert!(args.is_empty());
             loop {
-- 
GitLab


From a8a2fc3b61e010278a1ca622cd6aaab6f77ff140 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 12 Feb 2025 17:13:04 -0600
Subject: [PATCH 2/4] Just do fork unrolling

---
 hercules_opt/src/fork_transforms.rs | 15 +++++++++++++++
 hercules_opt/src/lib.rs             |  2 --
 hercules_opt/src/unroll.rs          | 18 ------------------
 juno_scheduler/src/compile.rs       |  2 +-
 juno_scheduler/src/ir.rs            |  2 +-
 juno_scheduler/src/pm.rs            | 12 ++++++------
 6 files changed, 23 insertions(+), 28 deletions(-)
 delete mode 100644 hercules_opt/src/unroll.rs

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index fd6747d7..539b7fd1 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1164,3 +1164,18 @@ fn fork_interchange(
         edit.delete_node(fork)
     });
 }
+
+/*
+ * Run fork unrolling on all fork-joins that are mutable in an editor.
+ */
+pub fn fork_unroll_all_forks(editor: &mut FunctionEditor, fork_joins: &HashMap<NodeID, NodeID>) {
+    for (fork, join) in fork_joins {
+        if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join) {
+            break;
+        }
+    }
+}
+
+pub fn fork_unroll(editor: &mut FunctionEditor, fork: NodeID, join: NodeID) -> bool {
+    false
+}
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index a810dfbf..7187508a 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -23,7 +23,6 @@ pub mod simplify_cfg;
 pub mod slf;
 pub mod sroa;
 pub mod unforkify;
-pub mod unroll;
 pub mod utils;
 
 pub use crate::ccp::*;
@@ -49,5 +48,4 @@ pub use crate::simplify_cfg::*;
 pub use crate::slf::*;
 pub use crate::sroa::*;
 pub use crate::unforkify::*;
-pub use crate::unroll::*;
 pub use crate::utils::*;
diff --git a/hercules_opt/src/unroll.rs b/hercules_opt/src/unroll.rs
deleted file mode 100644
index f3c795ca..00000000
--- a/hercules_opt/src/unroll.rs
+++ /dev/null
@@ -1,18 +0,0 @@
-use bitvec::prelude::*;
-
-use hercules_ir::*;
-
-use crate::*;
-
-/*
- * Run loop unrolling on all loops that are mutable in an editor.
- */
-pub fn loop_unroll_all_loops(editor: &mut FunctionEditor, loops: &LoopTree) {
-    for (header, contents) in loops.bottom_up_loops() {
-        if editor.is_mutable(header) {
-            loop_unroll(editor, header, contents);
-        }
-    }
-}
-
-pub fn loop_unroll(editor: &mut FunctionEditor, header: NodeID, contents: &BitVec<u8, Lsb0>) {}
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 9d5a86cc..6b40001c 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -113,8 +113,8 @@ impl FromStr for Appliable {
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
             "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
+            "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
-            "loop-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::LoopUnroll)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
             "predication" => Ok(Appliable::Pass(ir::Pass::Predication)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 1bb6cf13..840f25a6 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -16,6 +16,7 @@ pub enum Pass {
     ForkGuardElim,
     ForkInterchange,
     ForkSplit,
+    ForkUnroll,
     Forkify,
     GCM,
     GVN,
@@ -23,7 +24,6 @@ pub enum Pass {
     Inline,
     InterproceduralSROA,
     LiftDCMath,
-    LoopUnroll,
     Outline,
     PhiElim,
     Predication,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 5c6aec5e..951ba51d 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1665,19 +1665,19 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
-        Pass::LoopUnroll => {
+        Pass::ForkUnroll => {
             assert_eq!(args.len(), 0);
 
-            pm.make_loops();
-            let loops = pm.loops.take().unwrap();
-            for (func, loops) in build_selection(pm, selection, false)
+            pm.make_fork_join_maps();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            for (func, fork_join_map) in build_selection(pm, selection, false)
                 .into_iter()
-                .zip(loops.iter())
+                .zip(fork_join_maps.iter())
             {
                 let Some(mut func) = func else {
                     continue;
                 };
-                loop_unroll_all_loops(&mut func, loops);
+                fork_unroll_all_forks(&mut func, fork_join_map);
                 changed |= func.modified();
             }
             pm.delete_gravestones();
-- 
GitLab


From 9bc5101eeac8a2ac2393cf0aedd7ff5aa9bcc74f Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 12 Feb 2025 18:27:36 -0600
Subject: [PATCH 3/4] Get unrollable fork-joins

---
 hercules_opt/src/fork_transforms.rs      | 38 +++++++++++++++++++++---
 juno_samples/fork_join_tests/src/cpu.sch |  6 ++--
 juno_scheduler/src/pm.rs                 | 12 +++++---
 3 files changed, 46 insertions(+), 10 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 539b7fd1..94898b0d 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1168,14 +1168,44 @@ fn fork_interchange(
 /*
  * Run fork unrolling on all fork-joins that are mutable in an editor.
  */
-pub fn fork_unroll_all_forks(editor: &mut FunctionEditor, fork_joins: &HashMap<NodeID, NodeID>) {
+pub fn fork_unroll_all_forks(
+    editor: &mut FunctionEditor,
+    fork_joins: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) {
     for (fork, join) in fork_joins {
-        if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join) {
+        if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) {
             break;
         }
     }
 }
 
-pub fn fork_unroll(editor: &mut FunctionEditor, fork: NodeID, join: NodeID) -> bool {
-    false
+pub fn fork_unroll(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    join: NodeID,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) -> bool {
+    // We can only unroll forks with a compile time known factor list.
+    let nodes = &editor.func().nodes;
+    let Node::Fork {
+        control,
+        ref factors,
+    } = nodes[fork.idx()]
+    else {
+        panic!()
+    };
+    let mut cons_factors = vec![];
+    for factor in factors {
+        let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(*factor) else {
+            return false;
+        };
+        cons_factors.push(cons);
+    }
+    println!("{}: {:?}", editor.func().name, cons_factors);
+
+    editor.edit(|mut edit| {
+        ();
+        Ok(edit)
+    })
 }
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index fe0a8802..2c832d66 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -39,12 +39,14 @@ dce(*);
 fixpoint panic after 20 {
   infer-schedules(*);
 }
+unroll(auto.test1);
+xdot[true](*);
 
-fork-split(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5);
+fork-split(auto.test2, auto.test3, auto.test4, auto.test5);
 gvn(*);
 phi-elim(*);
 dce(*);
-unforkify(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5);
+unforkify(auto.test2, auto.test3, auto.test4, auto.test5);
 ccp(*);
 gvn(*);
 phi-elim(*);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 951ba51d..f59834ed 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1669,15 +1669,19 @@ fn run_pass(
             assert_eq!(args.len(), 0);
 
             pm.make_fork_join_maps();
+            pm.make_nodes_in_fork_joins();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in build_selection(pm, selection, false)
-                .into_iter()
-                .zip(fork_join_maps.iter())
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            for ((func, fork_join_map), nodes_in_fork_joins) in
+                build_selection(pm, selection, false)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(nodes_in_fork_joins.iter())
             {
                 let Some(mut func) = func else {
                     continue;
                 };
-                fork_unroll_all_forks(&mut func, fork_join_map);
+                fork_unroll_all_forks(&mut func, fork_join_map, nodes_in_fork_joins);
                 changed |= func.modified();
             }
             pm.delete_gravestones();
-- 
GitLab


From 23103bc2d1e31fae8880cd9063d4948a1be89a92 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 12 Feb 2025 22:02:51 -0600
Subject: [PATCH 4/4] holy shit that just worked

---
 hercules_opt/src/fork_transforms.rs      | 81 +++++++++++++++++++++---
 juno_samples/fork_join_tests/src/cpu.sch |  6 +-
 2 files changed, 76 insertions(+), 11 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 94898b0d..2f7a91fa 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1186,7 +1186,8 @@ pub fn fork_unroll(
     join: NodeID,
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
 ) -> bool {
-    // We can only unroll forks with a compile time known factor list.
+    // We can only unroll fork-joins with a compile time known factor list. For
+    // simplicity, just unroll fork-joins that have a single dimension.
     let nodes = &editor.func().nodes;
     let Node::Fork {
         control,
@@ -1195,17 +1196,79 @@ pub fn fork_unroll(
     else {
         panic!()
     };
-    let mut cons_factors = vec![];
-    for factor in factors {
-        let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(*factor) else {
-            return false;
-        };
-        cons_factors.push(cons);
+    if factors.len() != 1 || editor.get_users(fork).count() != 2 {
+        return false;
     }
-    println!("{}: {:?}", editor.func().name, cons_factors);
+    let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else {
+        return false;
+    };
+    let tid = editor
+        .get_users(fork)
+        .filter(|id| nodes[id.idx()].is_thread_id())
+        .next()
+        .unwrap();
+    let inits: HashMap<NodeID, NodeID> = editor
+        .get_users(join)
+        .filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init)))
+        .collect();
 
     editor.edit(|mut edit| {
-        ();
+        // Create a copy of the nodes in the fork join per unrolled iteration,
+        // excluding the fork itself, the join itself, the thread IDs of the fork,
+        // and the reduces on the join. Keep a running tally of the top control
+        // node and the current reduction value.
+        let mut top_control = control;
+        let mut current_reduces = inits;
+        for iter in 0..cons {
+            let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64));
+            let iter_tid = edit.add_node(Node::Constant { id: iter_cons });
+
+            // First, add a copy of each node in the fork join unmodified.
+            // Record the mapping from old ID to new ID.
+            let mut added_ids = HashSet::new();
+            let mut old_to_new_ids = HashMap::new();
+            let mut new_control = None;
+            let mut new_reduces = HashMap::new();
+            for node in nodes_in_fork_joins[&fork].iter() {
+                if *node == fork {
+                    old_to_new_ids.insert(*node, top_control);
+                } else if *node == join {
+                    new_control = Some(edit.get_node(*node).try_join().unwrap());
+                } else if *node == tid {
+                    old_to_new_ids.insert(*node, iter_tid);
+                } else if let Some(current) = current_reduces.get(node) {
+                    old_to_new_ids.insert(*node, *current);
+                    new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2);
+                } else {
+                    let new_node = edit.add_node(edit.get_node(*node).clone());
+                    old_to_new_ids.insert(*node, new_node);
+                    added_ids.insert(new_node);
+                }
+            }
+
+            // Second, replace all the uses in the just added nodes.
+            if let Some(new_control) = new_control {
+                top_control = old_to_new_ids[&new_control];
+            }
+            for (reduce, reduct) in new_reduces {
+                current_reduces.insert(reduce, old_to_new_ids[&reduct]);
+            }
+            for (old, new) in old_to_new_ids {
+                edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?;
+            }
+        }
+
+        // Hook up the control and reduce outputs to the rest of the function.
+        edit = edit.replace_all_uses(join, top_control)?;
+        for (reduce, reduct) in current_reduces {
+            edit = edit.replace_all_uses(reduce, reduct)?;
+        }
+
+        // Delete the old fork-join.
+        for node in nodes_in_fork_joins[&fork].iter() {
+            edit = edit.delete_node(*node)?;
+        }
+
         Ok(edit)
     })
 }
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 2c832d66..9e87d26a 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -39,8 +39,10 @@ dce(*);
 fixpoint panic after 20 {
   infer-schedules(*);
 }
-unroll(auto.test1);
-xdot[true](*);
+fork-split(auto.test1);
+fixpoint panic after 20 {
+  unroll(auto.test1);
+}
 
 fork-split(auto.test2, auto.test3, auto.test4, auto.test5);
 gvn(*);
-- 
GitLab