From 9bc5101eeac8a2ac2393cf0aedd7ff5aa9bcc74f Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 12 Feb 2025 18:27:36 -0600
Subject: [PATCH] Get unrollable fork-joins

---
 hercules_opt/src/fork_transforms.rs      | 38 +++++++++++++++++++++---
 juno_samples/fork_join_tests/src/cpu.sch |  6 ++--
 juno_scheduler/src/pm.rs                 | 12 +++++---
 3 files changed, 46 insertions(+), 10 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 539b7fd1..94898b0d 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1168,14 +1168,44 @@ fn fork_interchange(
 /*
  * Run fork unrolling on all fork-joins that are mutable in an editor.
  */
-pub fn fork_unroll_all_forks(editor: &mut FunctionEditor, fork_joins: &HashMap<NodeID, NodeID>) {
+pub fn fork_unroll_all_forks(
+    editor: &mut FunctionEditor,
+    fork_joins: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) {
     for (fork, join) in fork_joins {
-        if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join) {
+        if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) {
             break;
         }
     }
 }
 
-pub fn fork_unroll(editor: &mut FunctionEditor, fork: NodeID, join: NodeID) -> bool {
-    false
+pub fn fork_unroll(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    join: NodeID,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) -> bool {
+    // We can only unroll forks with a compile time known factor list.
+    let nodes = &editor.func().nodes;
+    let Node::Fork {
+        control,
+        ref factors,
+    } = nodes[fork.idx()]
+    else {
+        panic!()
+    };
+    let mut cons_factors = vec![];
+    for factor in factors {
+        let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(*factor) else {
+            return false;
+        };
+        cons_factors.push(cons);
+    }
+    println!("{}: {:?}", editor.func().name, cons_factors);
+
+    editor.edit(|mut edit| {
+        ();
+        Ok(edit)
+    })
 }
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index fe0a8802..2c832d66 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -39,12 +39,14 @@ dce(*);
 fixpoint panic after 20 {
   infer-schedules(*);
 }
+unroll(auto.test1);
+xdot[true](*);
 
-fork-split(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5);
+fork-split(auto.test2, auto.test3, auto.test4, auto.test5);
 gvn(*);
 phi-elim(*);
 dce(*);
-unforkify(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5);
+unforkify(auto.test2, auto.test3, auto.test4, auto.test5);
 ccp(*);
 gvn(*);
 phi-elim(*);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 951ba51d..f59834ed 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1669,15 +1669,19 @@ fn run_pass(
             assert_eq!(args.len(), 0);
 
             pm.make_fork_join_maps();
+            pm.make_nodes_in_fork_joins();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
-            for (func, fork_join_map) in build_selection(pm, selection, false)
-                .into_iter()
-                .zip(fork_join_maps.iter())
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            for ((func, fork_join_map), nodes_in_fork_joins) in
+                build_selection(pm, selection, false)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(nodes_in_fork_joins.iter())
             {
                 let Some(mut func) = func else {
                     continue;
                 };
-                fork_unroll_all_forks(&mut func, fork_join_map);
+                fork_unroll_all_forks(&mut func, fork_join_map, nodes_in_fork_joins);
                 changed |= func.modified();
             }
             pm.delete_gravestones();
-- 
GitLab