From 23103bc2d1e31fae8880cd9063d4948a1be89a92 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 12 Feb 2025 22:02:51 -0600
Subject: [PATCH] holy shit that just worked

---
 hercules_opt/src/fork_transforms.rs      | 81 +++++++++++++++++++++---
 juno_samples/fork_join_tests/src/cpu.sch |  6 +-
 2 files changed, 76 insertions(+), 11 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 94898b0d..2f7a91fa 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1186,7 +1186,8 @@ pub fn fork_unroll(
     join: NodeID,
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
 ) -> bool {
-    // We can only unroll forks with a compile time known factor list.
+    // We can only unroll fork-joins with a compile time known factor list. For
+    // simplicity, just unroll fork-joins that have a single dimension.
     let nodes = &editor.func().nodes;
     let Node::Fork {
         control,
@@ -1195,17 +1196,79 @@ pub fn fork_unroll(
     else {
         panic!()
     };
-    let mut cons_factors = vec![];
-    for factor in factors {
-        let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(*factor) else {
-            return false;
-        };
-        cons_factors.push(cons);
+    if factors.len() != 1 || editor.get_users(fork).count() != 2 {
+        return false;
     }
-    println!("{}: {:?}", editor.func().name, cons_factors);
+    let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else {
+        return false;
+    };
+    let tid = editor
+        .get_users(fork)
+        .filter(|id| nodes[id.idx()].is_thread_id())
+        .next()
+        .unwrap();
+    let inits: HashMap<NodeID, NodeID> = editor
+        .get_users(join)
+        .filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init)))
+        .collect();
 
     editor.edit(|mut edit| {
-        ();
+        // Create a copy of the nodes in the fork join per unrolled iteration,
+        // excluding the fork itself, the join itself, the thread IDs of the fork,
+        // and the reduces on the join. Keep a running tally of the top control
+        // node and the current reduction value.
+        let mut top_control = control;
+        let mut current_reduces = inits;
+        for iter in 0..cons {
+            let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64));
+            let iter_tid = edit.add_node(Node::Constant { id: iter_cons });
+
+            // First, add a copy of each node in the fork join unmodified.
+            // Record the mapping from old ID to new ID.
+            let mut added_ids = HashSet::new();
+            let mut old_to_new_ids = HashMap::new();
+            let mut new_control = None;
+            let mut new_reduces = HashMap::new();
+            for node in nodes_in_fork_joins[&fork].iter() {
+                if *node == fork {
+                    old_to_new_ids.insert(*node, top_control);
+                } else if *node == join {
+                    new_control = Some(edit.get_node(*node).try_join().unwrap());
+                } else if *node == tid {
+                    old_to_new_ids.insert(*node, iter_tid);
+                } else if let Some(current) = current_reduces.get(node) {
+                    old_to_new_ids.insert(*node, *current);
+                    new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2);
+                } else {
+                    let new_node = edit.add_node(edit.get_node(*node).clone());
+                    old_to_new_ids.insert(*node, new_node);
+                    added_ids.insert(new_node);
+                }
+            }
+
+            // Second, replace all the uses in the just added nodes.
+            if let Some(new_control) = new_control {
+                top_control = old_to_new_ids[&new_control];
+            }
+            for (reduce, reduct) in new_reduces {
+                current_reduces.insert(reduce, old_to_new_ids[&reduct]);
+            }
+            for (old, new) in old_to_new_ids {
+                edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?;
+            }
+        }
+
+        // Hook up the control and reduce outputs to the rest of the function.
+        edit = edit.replace_all_uses(join, top_control)?;
+        for (reduce, reduct) in current_reduces {
+            edit = edit.replace_all_uses(reduce, reduct)?;
+        }
+
+        // Delete the old fork-join.
+        for node in nodes_in_fork_joins[&fork].iter() {
+            edit = edit.delete_node(*node)?;
+        }
+
         Ok(edit)
     })
 }
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 2c832d66..9e87d26a 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -39,8 +39,10 @@ dce(*);
 fixpoint panic after 20 {
   infer-schedules(*);
 }
-unroll(auto.test1);
-xdot[true](*);
+fork-split(auto.test1);
+fixpoint panic after 20 {
+  unroll(auto.test1);
+}
 
 fork-split(auto.test2, auto.test3, auto.test4, auto.test5);
 gvn(*);
-- 
GitLab