From 82ab7c076640225cf0ff4de2b8443a46dcc9b95f Mon Sep 17 00:00:00 2001
From: Xavier Routh <xrouth2@illinois.edu>
Date: Sun, 2 Feb 2025 22:39:36 -0600
Subject: [PATCH] fork dim merge

---
 hercules_opt/src/fork_transforms.rs | 161 ++++++++++++++++++++++++++++
 juno_scheduler/src/compile.rs       |   3 +-
 juno_scheduler/src/ir.rs            |   2 +
 juno_scheduler/src/lib.rs           |   3 +-
 juno_scheduler/src/pm.rs            |  30 ++++--
 5 files changed, 189 insertions(+), 10 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index e23f586f..58ace775 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1,5 +1,6 @@
 use std::collections::{HashMap, HashSet};
 use std::iter::zip;
+use std::thread::ThreadId;
 
 use bimap::BiMap;
 use itertools::Itertools;
@@ -693,3 +694,163 @@ pub(crate) fn split_fork(
         None
     }
 }
+
+// Splits a dimension of a single fork join into multiple. 
+// Iterates an outer loop original_dim / tile_size times 
+// adds a tile_size loop as the inner loop 
+// Assumes that tile size divides original dim evenly.
+pub fn chunk_fork_unguarded(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    dim_idx: usize,
+    tile_size: DynamicConstantID,
+) -> () {
+    // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
+
+    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return};
+
+    let mut new_factors: Vec<_> = old_factors.to_vec();
+
+    let fork_users: Vec<_> = editor.get_users(fork).collect();
+
+
+    for tid in fork_users {
+        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
+        editor.edit(|mut edit| {
+            if tid_dim > dim_idx {
+                let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let new_tid = edit.add_node(new_tid);
+                edit.replace_all_uses(tid, new_tid)
+            } else if tid_dim == dim_idx {
+                let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 };
+                let tile_tid = edit.add_node(tile_tid);
+                
+                let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
+                let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul });
+                let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add });
+                edit.replace_all_uses_where(tid, add, |usee| *usee != mul )
+            } else {
+                Ok(edit)
+            }
+        });
+    }
+
+    editor.edit(|mut edit| {
+        let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
+        new_factors.insert(dim_idx + 1, tile_size);
+        new_factors[dim_idx] = edit.add_dynamic_constant(outer);
+
+        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = edit.add_node(new_fork);
+
+        edit.replace_all_uses(fork, new_fork)
+    });
+}
+
+
+pub fn merge_all_fork_dims(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+) {
+    for (fork, _) in fork_join_map {
+        let Node::Fork { control: _, factors: dims } = editor.node(fork) else {
+            unreachable!();
+        };
+
+        let mut fork = *fork;
+        for _ in 0..dims.len() - 1 {
+            let outer = 0;
+            let inner = 1;
+            fork = fork_dim_merge(editor, fork, outer, inner);
+        }
+    }
+}
+
+// Splits a dimension of a single fork join into multiple. 
+// Iterates an outer loop original_dim / tile_size times 
+// adds a tile_size loop as the inner loop 
+// Assumes that tile size divides original dim evenly.
+pub fn fork_dim_merge(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    dim_idx1: usize,
+    dim_idx2: usize,
+) -> NodeID {
+    // tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2)
+    // tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2)
+    assert_ne!(dim_idx1, dim_idx2);
+
+    // Outer is smaller, and also closer to the left of the factors array.
+    let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 {
+        (dim_idx2, dim_idx1)
+    } else {
+        (dim_idx1, dim_idx2)
+    };
+
+    let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork};
+
+    let mut new_factors: Vec<_> = old_factors.to_vec();
+
+    
+
+    let fork_users: Vec<_> = editor.get_users(fork).collect();
+
+    let mut new_nodes = vec![];
+
+    let outer_dc_id = new_factors[outer_idx];
+    let inner_dc_id = new_factors[inner_idx];
+
+    let mut new_fork_id = NodeID::new(0);
+    editor.edit(|mut edit| {
+        new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx]));
+        new_factors.remove(inner_idx);
+
+        let new_fork = Node::Fork { control: old_control, factors: new_factors.into() };
+        let new_fork = edit.add_node(new_fork);
+        new_fork_id = new_fork;
+
+
+        edit = edit.replace_all_uses(fork, new_fork)?;
+        edit.delete_node(fork)
+    });
+
+    
+
+    for tid in fork_users {
+        let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue };
+
+        println!("tid: {:?}", tid);
+        editor.edit(|mut edit| {
+            if tid_dim > inner_idx {
+                let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 };
+                let new_tid = edit.add_node(new_tid);
+                edit.replace_all_uses(tid, new_tid)
+            } else if tid_dim == outer_idx {
+                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = edit.add_node(outer_tid);
+
+                let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
+                new_nodes.push(outer_tid);
+
+                // inner_idx % dim(outer_idx)
+                let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem});
+                
+                edit.replace_all_uses(tid, rem)
+            } else if tid_dim == inner_idx {
+                let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx };
+                let outer_tid = edit.add_node(outer_tid);
+
+                let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
+                // inner_idx / dim(outer_idx)
+                let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div});
+                
+                edit.replace_all_uses(tid, div)
+            } else {
+                Ok(edit)
+            }
+        });
+    };
+
+    return new_fork_id;
+
+}
\ No newline at end of file
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 11a8ec53..07ad5e7a 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -108,7 +108,8 @@ impl FromStr for Appliable {
             "inline" => Ok(Appliable::Pass(ir::Pass::Inline)),
             "ip-sroa" | "interprocedural-sroa" => {
                 Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
-            }
+            },
+            "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index d6a41baf..939ef925 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -12,6 +12,8 @@ pub enum Pass {
     ForkSplit,
     ForkCoalesce,
     Forkify,
+    ForkDimMerge,
+    ForkChunk,
     GCM,
     GVN,
     InferSchedules,
diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs
index 2479af98..ad9195fb 100644
--- a/juno_scheduler/src/lib.rs
+++ b/juno_scheduler/src/lib.rs
@@ -60,7 +60,7 @@ fn build_schedule(sched_filename: String) -> Result<ScheduleStmt, String> {
     }
 }
 
-fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> {
+pub fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> {
     if let Some(name) = sched_filename {
         build_schedule(name)
     } else {
@@ -146,7 +146,6 @@ pub fn run_schedule_on_hercules(
     .map_err(|e| format!("Scheduling Error: {}", e))
 }
 
-
 pub fn run_schedule_from_file_on_hercules(
     module: Module,
     sched_filename: Option<String>,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 43f355c3..8b71d24e 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1871,14 +1871,11 @@ fn run_pass(
         }
         Pass::Unforkify => {
             assert!(args.is_empty());
-            loop {
-                let mut inner_changed = false;
-
-                pm.make_fork_join_maps();
-                pm.make_loops();
+            pm.make_fork_join_maps();
+            pm.make_loops();
 
-                let fork_join_maps = pm.fork_join_maps.take().unwrap();
-                let loops = pm.loops.take().unwrap();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let loops = pm.loops.take().unwrap();
 
             for ((func, fork_join_map), loop_tree) in build_selection(pm, selection)
                 .into_iter()
@@ -1894,6 +1891,24 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ForkDimMerge => {
+            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            for (func, fork_join_map) in
+                build_selection(pm, selection)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                merge_all_fork_dims(&mut func, fork_join_map);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::ForkCoalesce => {
             assert!(args.is_empty());
             pm.make_fork_join_maps();
@@ -1991,6 +2006,7 @@ fn run_pass(
             // Put BasicBlocks back, since it's needed for Codegen.
             pm.bbs = bbs;
         }
+        Pass::ForkChunk => todo!(),
     }
     println!("Ran Pass: {:?}", pass);
 
-- 
GitLab