From 82ab7c076640225cf0ff4de2b8443a46dcc9b95f Mon Sep 17 00:00:00 2001 From: Xavier Routh <xrouth2@illinois.edu> Date: Sun, 2 Feb 2025 22:39:36 -0600 Subject: [PATCH] fork dim merge --- hercules_opt/src/fork_transforms.rs | 161 ++++++++++++++++++++++++++++ juno_scheduler/src/compile.rs | 3 +- juno_scheduler/src/ir.rs | 2 + juno_scheduler/src/lib.rs | 3 +- juno_scheduler/src/pm.rs | 30 ++++-- 5 files changed, 189 insertions(+), 10 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index e23f586f..58ace775 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,5 +1,6 @@ use std::collections::{HashMap, HashSet}; use std::iter::zip; +use std::thread::ThreadId; use bimap::BiMap; use itertools::Itertools; @@ -693,3 +694,163 @@ pub(crate) fn split_fork( None } } + +// Splits a dimension of a single fork join into multiple. +// Iterates an outer loop original_dim / tile_size times +// adds a tile_size loop as the inner loop +// Assumes that tile size divides original dim evenly. +pub fn chunk_fork_unguarded( + editor: &mut FunctionEditor, + fork: NodeID, + dim_idx: usize, + tile_size: DynamicConstantID, +) -> () { + // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1) + + let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return}; + + let mut new_factors: Vec<_> = old_factors.to_vec(); + + let fork_users: Vec<_> = editor.get_users(fork).collect(); + + + for tid in fork_users { + let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; + editor.edit(|mut edit| { + if tid_dim > dim_idx { + let new_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let new_tid = edit.add_node(new_tid); + edit.replace_all_uses(tid, new_tid) + } else if tid_dim == dim_idx { + let tile_tid = Node::ThreadID { control: fork, dimension: tid_dim + 1 }; + let tile_tid = edit.add_node(tile_tid); + + let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); + let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul }); + let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add }); + edit.replace_all_uses_where(tid, add, |usee| *usee != mul ) + } else { + Ok(edit) + } + }); + } + + editor.edit(|mut edit| { + let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); + new_factors.insert(dim_idx + 1, tile_size); + new_factors[dim_idx] = edit.add_dynamic_constant(outer); + + let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = edit.add_node(new_fork); + + edit.replace_all_uses(fork, new_fork) + }); +} + + +pub fn merge_all_fork_dims( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, +) { + for (fork, _) in fork_join_map { + let Node::Fork { control: _, factors: dims } = editor.node(fork) else { + unreachable!(); + }; + + let mut fork = *fork; + for _ in 0..dims.len() - 1 { + let outer = 0; + let inner = 1; + fork = fork_dim_merge(editor, fork, outer, inner); + } + } +} + +// Splits a dimension of a single fork join into multiple. +// Iterates an outer loop original_dim / tile_size times +// adds a tile_size loop as the inner loop +// Assumes that tile size divides original dim evenly. +pub fn fork_dim_merge( + editor: &mut FunctionEditor, + fork: NodeID, + dim_idx1: usize, + dim_idx2: usize, +) -> NodeID { + // tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2) + // tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2) + assert_ne!(dim_idx1, dim_idx2); + + // Outer is smaller, and also closer to the left of the factors array. + let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 { + (dim_idx2, dim_idx1) + } else { + (dim_idx1, dim_idx2) + }; + + let Node::Fork { control: old_control, factors: ref old_factors} = *editor.node(fork) else {return fork}; + + let mut new_factors: Vec<_> = old_factors.to_vec(); + + + + let fork_users: Vec<_> = editor.get_users(fork).collect(); + + let mut new_nodes = vec![]; + + let outer_dc_id = new_factors[outer_idx]; + let inner_dc_id = new_factors[inner_idx]; + + let mut new_fork_id = NodeID::new(0); + editor.edit(|mut edit| { + new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(new_factors[outer_idx], new_factors[inner_idx])); + new_factors.remove(inner_idx); + + let new_fork = Node::Fork { control: old_control, factors: new_factors.into() }; + let new_fork = edit.add_node(new_fork); + new_fork_id = new_fork; + + + edit = edit.replace_all_uses(fork, new_fork)?; + edit.delete_node(fork) + }); + + + + for tid in fork_users { + let Node::ThreadID { control: _, dimension: tid_dim } = editor.node(tid).clone() else { continue }; + + println!("tid: {:?}", tid); + editor.edit(|mut edit| { + if tid_dim > inner_idx { + let new_tid = Node::ThreadID { control: new_fork_id, dimension: tid_dim - 1 }; + let new_tid = edit.add_node(new_tid); + edit.replace_all_uses(tid, new_tid) + } else if tid_dim == outer_idx { + let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = edit.add_node(outer_tid); + + let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); + new_nodes.push(outer_tid); + + // inner_idx % dim(outer_idx) + let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem}); + + edit.replace_all_uses(tid, rem) + } else if tid_dim == inner_idx { + let outer_tid = Node::ThreadID { control: new_fork_id, dimension: outer_idx }; + let outer_tid = edit.add_node(outer_tid); + + let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); + // inner_idx / dim(outer_idx) + let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div}); + + edit.replace_all_uses(tid, div) + } else { + Ok(edit) + } + }); + }; + + return new_fork_id; + +} \ No newline at end of file diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 11a8ec53..07ad5e7a 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -108,7 +108,8 @@ impl FromStr for Appliable { "inline" => Ok(Appliable::Pass(ir::Pass::Inline)), "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) - } + }, + "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index d6a41baf..939ef925 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -12,6 +12,8 @@ pub enum Pass { ForkSplit, ForkCoalesce, Forkify, + ForkDimMerge, + ForkChunk, GCM, GVN, InferSchedules, diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index 2479af98..ad9195fb 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -60,7 +60,7 @@ fn build_schedule(sched_filename: String) -> Result<ScheduleStmt, String> { } } -fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> { +pub fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> { if let Some(name) = sched_filename { build_schedule(name) } else { @@ -146,7 +146,6 @@ pub fn run_schedule_on_hercules( .map_err(|e| format!("Scheduling Error: {}", e)) } - pub fn run_schedule_from_file_on_hercules( module: Module, sched_filename: Option<String>, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 43f355c3..8b71d24e 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1871,14 +1871,11 @@ fn run_pass( } Pass::Unforkify => { assert!(args.is_empty()); - loop { - let mut inner_changed = false; - - pm.make_fork_join_maps(); - pm.make_loops(); + pm.make_fork_join_maps(); + pm.make_loops(); - let fork_join_maps = pm.fork_join_maps.take().unwrap(); - let loops = pm.loops.take().unwrap(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); for ((func, fork_join_map), loop_tree) in build_selection(pm, selection) .into_iter() @@ -1894,6 +1891,24 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkDimMerge => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + merge_all_fork_dims(&mut func, fork_join_map); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkCoalesce => { assert!(args.is_empty()); pm.make_fork_join_maps(); @@ -1991,6 +2006,7 @@ fn run_pass( // Put BasicBlocks back, since it's needed for Codegen. pm.bbs = bbs; } + Pass::ForkChunk => todo!(), } println!("Ran Pass: {:?}", pass); -- GitLab