From 87445e17a4e5defc180b560468e5ac5879848cea Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 24 Feb 2025 19:58:03 -0600 Subject: [PATCH] Fork extend pass --- hercules_opt/src/fork_transforms.rs | 126 +++++++++++++++++++++++++++- juno_samples/cava/src/cpu.sch | 1 + juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 3 + juno_scheduler/src/pm.rs | 24 ++++++ 5 files changed, 153 insertions(+), 2 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ae3dfe22..0e943973 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1601,6 +1601,128 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { } /* - * Looks for reads in fork-joins that are linear in the thread IDs for the fork- - * join. + * Extends the dimensions of a fork-join to be a multiple of a number and gates + * the execution of the body. */ +pub fn extend_all_forks( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + multiple: usize, +) { + for (fork, join) in fork_join_map { + if editor.is_mutable(*fork) { + extend_fork(editor, *fork, *join, multiple); + } + } +} + +fn extend_fork(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, multiple: usize) { + let nodes = &editor.func().nodes; + let (fork_pred, factors) = nodes[fork.idx()].try_fork().unwrap(); + let factors = factors.to_vec(); + let fork_succ = editor + .get_users(fork) + .filter(|id| nodes[id.idx()].is_control()) + .next() + .unwrap(); + let join_pred = nodes[join.idx()].try_join().unwrap(); + let ctrl_between = fork != join_pred; + let reduces: Vec<_> = editor + .get_users(join) + .filter_map(|id| nodes[id.idx()].try_reduce().map(|x| (id, x))) + .collect(); + + editor.edit(|mut edit| { + // We can round up a dynamic constant A to a multiple of another dynamic + // constant B via the following math: + // ((A + B - 1) / B) * B + let new_factors: Vec<_> = factors + .iter() + .map(|factor| { + let b = edit.add_dynamic_constant(DynamicConstant::Constant(multiple)); + let apb = edit.add_dynamic_constant(DynamicConstant::add(*factor, b)); + let o = edit.add_dynamic_constant(DynamicConstant::Constant(1)); + let apbmo = edit.add_dynamic_constant(DynamicConstant::sub(apb, o)); + let apbmodb = edit.add_dynamic_constant(DynamicConstant::div(apbmo, b)); + edit.add_dynamic_constant(DynamicConstant::mul(apbmodb, b)) + }) + .collect(); + + // Create the new control structure. + let new_fork = edit.add_node(Node::Fork { + control: fork_pred, + factors: new_factors.into_boxed_slice(), + }); + edit = edit.replace_all_uses_where(fork, new_fork, |id| *id != fork_succ)?; + edit.sub_edit(fork, new_fork); + let conds: Vec<_> = factors + .iter() + .enumerate() + .map(|(idx, old_factor)| { + let tid = edit.add_node(Node::ThreadID { + control: new_fork, + dimension: idx, + }); + let old_bound = edit.add_node(Node::DynamicConstant { id: *old_factor }); + edit.add_node(Node::Binary { + op: BinaryOperator::LT, + left: tid, + right: old_bound, + }) + }) + .collect(); + let cond = conds + .into_iter() + .reduce(|left, right| { + edit.add_node(Node::Binary { + op: BinaryOperator::And, + left, + right, + }) + }) + .unwrap(); + let branch = edit.add_node(Node::If { + control: new_fork, + cond, + }); + let false_proj = edit.add_node(Node::ControlProjection { + control: branch, + selection: 0, + }); + let true_proj = edit.add_node(Node::ControlProjection { + control: branch, + selection: 1, + }); + if ctrl_between { + edit = edit.replace_all_uses_where(fork, true_proj, |id| *id == fork_succ)?; + } + let bottom_region = edit.add_node(Node::Region { + preds: Box::new([false_proj, if ctrl_between { join_pred } else { true_proj }]), + }); + let new_join = edit.add_node(Node::Join { + control: bottom_region, + }); + edit = edit.replace_all_uses(join, new_join)?; + edit.sub_edit(join, new_join); + edit = edit.delete_node(fork)?; + edit = edit.delete_node(join)?; + + // Update the reduces to use phis on the region node to gate their execution. + for (reduce, (_, init, reduct)) in reduces { + let phi = edit.add_node(Node::Phi { + control: bottom_region, + data: Box::new([reduce, reduct]), + }); + let new_reduce = edit.add_node(Node::Reduce { + control: new_join, + init, + reduct: phi, + }); + edit = edit.replace_all_uses(reduce, new_reduce)?; + edit.sub_edit(reduce, new_reduce); + edit = edit.delete_node(reduce)?; + } + + Ok(edit) + }); +} diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 3ac2f326..efa7302e 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -49,6 +49,7 @@ simpl!(fuse1); write-predication(fuse1); simpl!(fuse1); parallel-reduce(fuse1@loop); +fork-extend[8](fuse1); inline(fuse2); no-memset(fuse2@res); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 13990ef9..3c288ca7 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -131,6 +131,7 @@ impl FromStr for Appliable { "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), + "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index bbecc6ff..3a087c0d 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -15,6 +15,7 @@ pub enum Pass { ForkChunk, ForkCoalesce, ForkDimMerge, + ForkExtend, ForkFissionBufferize, ForkFusion, ForkGuardElim, @@ -53,6 +54,7 @@ impl Pass { match self { Pass::ArrayToProduct => num == 0 || num == 1, Pass::ForkChunk => num == 4, + Pass::ForkExtend => num == 1, Pass::ForkFissionBufferize => num == 2 || num == 1, Pass::ForkInterchange => num == 2, Pass::Print => num == 1, @@ -68,6 +70,7 @@ impl Pass { match self { Pass::ArrayToProduct => "0 or 1", Pass::ForkChunk => "4", + Pass::ForkExtend => "1", Pass::ForkFissionBufferize => "1 or 2", Pass::ForkInterchange => "2", Pass::Print => "1", diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d5e280b4..4656d841 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2642,6 +2642,30 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkExtend => { + assert_eq!(args.len(), 1); + let Some(Value::Integer { val: multiple }) = args.get(0) else { + return Err(SchedulerError::PassError { + pass: "forkExtend".to_string(), + error: "expected integer argument".to_string(), + }); + }; + + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + extend_all_forks(&mut func, fork_join_map, *multiple); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkFissionBufferize => { assert!(args.len() == 1 || args.len() == 2); let Some(Value::Label { -- GitLab