From 595b9fbfd028045db3090e279dd2dbe27a883907 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 11 Feb 2025 13:52:43 -0600 Subject: [PATCH 1/7] add interchange pass --- hercules_opt/src/fork_transforms.rs | 56 ++++++++++++++++++- juno_samples/fork_join_tests/src/cpu.sch | 2 +- .../fork_join_tests/src/fork_join_tests.jn | 8 +-- juno_samples/fork_join_tests/src/gpu.sch | 4 +- juno_samples/fork_join_tests/src/main.rs | 2 +- juno_scheduler/src/ir.rs | 15 ++--- juno_scheduler/src/pm.rs | 34 ++++++++++- 7 files changed, 103 insertions(+), 18 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 7c423892..33c4e68c 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,5 +1,4 @@ use std::collections::{HashMap, HashSet}; -use std::hash::Hash; use std::iter::zip; use bimap::BiMap; @@ -1076,3 +1075,58 @@ pub fn fork_dim_merge( new_fork } + +/* + * Run fork interchange on all fork-joins that are mutable in an editor. + */ +pub fn fork_interchange_all_forks( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + first_dim: usize, + second_dim: usize, +) { + for (fork, join) in fork_join_map { + if editor.is_mutable(*fork) { + fork_interchange(editor, *fork, *join, first_dim, second_dim); + } + } +} + +fn fork_interchange( + editor: &mut FunctionEditor, + fork: NodeID, + join: NodeID, + first_dim: usize, + second_dim: usize, +) { + // Check that every reduce on the join is parallel or tight associative. + let nodes = &editor.func().nodes; + let schedules = &editor.func().schedules; + if !editor + .get_users(join) + .filter(|id| nodes[id.idx()].is_reduce()) + .all(|id| { + schedules[id.idx()].contains(&Schedule::ParallelReduce) + || schedules[id.idx()].contains(&Schedule::TightAssociative) + }) + { + // If not, we can't necessarily do interchange. + return; + } + + let Node::Fork { + control, + ref factors, + } = nodes[fork.idx()] + else { + panic!() + }; + let mut factors = factors.clone(); + factors.swap(first_dim, second_dim); + let new_fork = Node::Fork { control, factors }; + editor.edit(|mut edit| { + let new_fork = edit.add_node(new_fork); + edit = edit.replace_all_uses(fork, new_fork)?; + edit.delete_node(fork) + }); +} diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 38a38c2b..fe0a8802 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -50,7 +50,7 @@ gvn(*); phi-elim(*); dce(*); -fork-tile[32, 0, true](test6@loop); +fork-tile[32, 0, false](test6@loop); let out = fork-split(test6@loop); let out = outline(out.test6.fj1); cpu(out); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 886ab13b..31752ec3 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -10,11 +10,11 @@ fn test1(input : i32) -> i32[4, 4] { } #[entry] -fn test2(input : i32) -> i32[4, 4] { - let arr : i32[4, 4]; +fn test2(input : i32) -> i32[4, 5] { + let arr : i32[4, 5]; for i = 0 to 8 { - for j = 0 to 4 { - for k = 0 to 4 { + for k = 0 to 5 { + for j = 0 to 4 { arr[j, k] += input; } } diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 117cf37e..8e1cd687 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -42,7 +42,7 @@ fixpoint panic after 20 { infer-schedules(*); } -fork-tile[32, 0, true](test6@loop); +fork-tile[32, 0, false](test6@loop); let out = fork-split(test6@loop); let out = auto-outline(test6); gpu(out.test6); @@ -69,5 +69,7 @@ phi-elim(*); dce(*); gcm(*); +xdot[true](*); + float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5); gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 1013a1f0..21ccd7c4 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -25,7 +25,7 @@ fn main() { let mut r = runner!(test2); let output = r.run(3).await; - let correct = vec![24i32; 16]; + let correct = vec![24i32; 20]; assert(&correct, output); let mut r = runner!(test3); diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 0ec6f52c..881c63a3 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -9,13 +9,14 @@ pub enum Pass { DCE, DeleteUncalled, FloatCollections, + ForkChunk, + ForkCoalesce, + ForkDimMerge, + ForkFissionBufferize, ForkGuardElim, + ForkInterchange, ForkSplit, - ForkCoalesce, Forkify, - ForkFissionBufferize, - ForkDimMerge, - ForkChunk, GCM, GVN, InferSchedules, @@ -25,14 +26,14 @@ pub enum Pass { Outline, PhiElim, Predication, - SimplifyCFG, SLF, SROA, + Serialize, + SimplifyCFG, Unforkify, - WritePredication, Verify, + WritePredication, Xdot, - Serialize, } impl Pass { diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 20825c54..8581ee23 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1635,6 +1635,36 @@ fn run_pass( fields: new_fork_joins, }; } + Pass::ForkInterchange => { + assert_eq!(args.len(), 2); + let Some(Value::Integer { val: first_dim }) = args.get(0) else { + return Err(SchedulerError::PassError { + pass: "forkInterchange".to_string(), + error: "expected integer argument".to_string(), + }); + }; + let Some(Value::Integer { val: second_dim }) = args.get(1) else { + return Err(SchedulerError::PassError { + pass: "forkInterchange".to_string(), + error: "expected integer argument".to_string(), + }); + }; + + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + fork_interchange_all_forks(&mut func, fork_join_map, *first_dim, *second_dim); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::Forkify => { assert!(args.is_empty()); loop { @@ -2042,7 +2072,7 @@ fn run_pass( }); }; - assert_eq!(*guarded_flag, true); + assert!(!*guarded_flag); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); for (func, fork_join_map) in build_selection(pm, selection, false) @@ -2059,8 +2089,6 @@ fn run_pass( pm.clear_analyses(); } Pass::ForkFissionBufferize => { - pm.make_fork_join_maps(); - assert_eq!(args.len(), 2); let Some(Value::Label { labels: fork_labels, -- GitLab From 4bc008359a15444c63e0c2bd3e58c91e44e85d44 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 11 Feb 2025 14:09:12 -0600 Subject: [PATCH 2/7] setup to do reduce slf --- juno_samples/fork_join_tests/src/fork_join_tests.jn | 6 +++--- juno_samples/fork_join_tests/src/gpu.sch | 9 ++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 31752ec3..8f569cfb 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -12,9 +12,9 @@ fn test1(input : i32) -> i32[4, 4] { #[entry] fn test2(input : i32) -> i32[4, 5] { let arr : i32[4, 5]; - for i = 0 to 8 { - for k = 0 to 5 { - for j = 0 to 4 { + @loop1 for i = 0 to 8 { + @loop2 for k = 0 to 5 { + @loop3 for j = 0 to 4 { arr[j, k] += input; } } diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 8e1cd687..123313a0 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -31,7 +31,10 @@ dce(*); fixpoint panic after 20 { forkify(*); fork-guard-elim(*); - fork-coalesce(*); +} + +fixpoint panic after 20 { + fork-coalesce(auto.test1, auto.test3, auto.test4, auto.test5, auto.test7, auto.test8); } gvn(*); @@ -42,6 +45,8 @@ fixpoint panic after 20 { infer-schedules(*); } +fork-coalesce(auto.test2@loop2); + fork-tile[32, 0, false](test6@loop); let out = fork-split(test6@loop); let out = auto-outline(test6); @@ -69,7 +74,5 @@ phi-elim(*); dce(*); gcm(*); -xdot[true](*); - float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5); gcm(*); -- GitLab From 0e36e8b01951fa9ea1c13ca741fb6b10b1919d45 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 11 Feb 2025 14:47:42 -0600 Subject: [PATCH 3/7] Figure out how to do reduce SLF --- hercules_opt/src/slf.rs | 63 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs index 024a442e..f8c29e7b 100644 --- a/hercules_opt/src/slf.rs +++ b/hercules_opt/src/slf.rs @@ -229,3 +229,66 @@ pub fn array_slf( } } } + +/* + * Top level function to run reduce store-to-load forwarding on a function. + * There are two variants of reduce SLF. One optimizes parallel reductions and + * the other optimizes scalar reductions. This pass just runs one after the + * other - it's up to the user to potentially wrap this pass in a fixpoint. + * + * The parallel variant looks for reductions on collections with the following + * form: + * + * 1. A write `reduct` use. + * 2. A single read user in the reduction cycle. + * 3. The write use and read user have identical indices. + * 4. The indices set directly refers to at least every thread ID produced by + * the fork corresponding to the reduce. + * + * Condition #4 roughly corresponds to the same condition needed to infer the + * ParallelReduce schedule - however, in this scenario, it can't be inferred + * because the data being written is in the reduce cycle, since the data being + * written is derived from the read user of the reduce. However, since the + * indices changes each iteration, the read could equivalently read from the + * `init` input of the reduce, rather than the reduce itself. This optimization + * replaces the use of the reduce in the read with the `init` input of the + * reduce, nominally so that ParallelReduce can get inferred. + * + * The scalar variant looks for reductions on collections with the following + * form: + * + * 1. A write `reduct` use. + * 2. A single read user in the reduction cycle. + * 3. The write use and read user have identical indices. + * 4. The indices set doesn't reference any nodes in the fork-join of the reduce + * (doesn't depend on the thread IDs of the fork-join). + * + * Instead of repeatedly reading / writing the same collection item each + * iteration, the reduction can reduce over the scalar value directly, and do a + * single write into the collection after the fork-join (note technically, the + * "scalar" may itself be a collection in the case of nested collections, but + * this optimization most often handles scalars). + */ +pub fn reduce_slf( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) { + // Helper to get write use and read user of reduce with identical indices. + let read_write_helper = |reduce: NodeID, editor: &FunctionEditor| -> Option<(NodeID, NodeID)> { + let nodes = &editor.func().nodes; + let reduct = nodes[reduce.idx()].try_reduce().unwrap().2; + if !nodes[reduct.idx()].is_write() { + return None; + } + let mut users = editor + .get_users(reduce) + .filter(|id| reduce_cycles[&reduce].contains(id)); + let read = users.next()?; + if users.next().is_some() || !nodes[read.idx()].is_read() { + return None; + } + Some((read, reduct)) + }; +} -- GitLab From 6210141d55decb42df595ec87c45ce0f5d4f8f0e Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 11 Feb 2025 16:05:32 -0600 Subject: [PATCH 4/7] Parallel reduce slf --- hercules_opt/src/fork_transforms.rs | 34 +++++++++++++ hercules_opt/src/slf.rs | 34 ++++++++++++- hercules_opt/src/utils.rs | 62 ++++++++++++++++++++++-- juno_samples/fork_join_tests/src/gpu.sch | 5 ++ juno_scheduler/src/compile.rs | 3 +- juno_scheduler/src/ir.rs | 2 + juno_scheduler/src/pm.rs | 23 +++++++++ 7 files changed, 155 insertions(+), 8 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 33c4e68c..fd6747d7 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1121,10 +1121,44 @@ fn fork_interchange( else { panic!() }; + let fix_tids: Vec<(NodeID, Node)> = editor + .get_users(fork) + .filter_map(|id| { + nodes[id.idx()] + .try_thread_id() + .map(|(_, dim)| { + if dim == first_dim { + Some(( + id, + Node::ThreadID { + control: fork, + dimension: second_dim, + }, + )) + } else if dim == second_dim { + Some(( + id, + Node::ThreadID { + control: fork, + dimension: first_dim, + }, + )) + } else { + None + } + }) + .flatten() + }) + .collect(); let mut factors = factors.clone(); factors.swap(first_dim, second_dim); let new_fork = Node::Fork { control, factors }; editor.edit(|mut edit| { + for (old_id, new_tid) in fix_tids { + let new_id = edit.add_node(new_tid); + edit = edit.replace_all_uses(old_id, new_id)?; + edit = edit.delete_node(old_id)?; + } let new_fork = edit.add_node(new_fork); edit = edit.replace_all_uses(fork, new_fork)?; edit.delete_node(fork) diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs index f8c29e7b..b10f58ce 100644 --- a/hercules_opt/src/slf.rs +++ b/hercules_opt/src/slf.rs @@ -1,4 +1,5 @@ use std::collections::{BTreeMap, HashMap, HashSet}; +use std::iter::once; use hercules_ir::*; @@ -284,11 +285,40 @@ pub fn reduce_slf( } let mut users = editor .get_users(reduce) - .filter(|id| reduce_cycles[&reduce].contains(id)); + .filter(|id| reduce_cycles[&reduce].contains(id) && *id != reduct); let read = users.next()?; - if users.next().is_some() || !nodes[read.idx()].is_read() { + if users.next().is_some() + || !nodes[read.idx()].is_read() + || nodes[read.idx()].try_read().unwrap().1 != nodes[reduct.idx()].try_write().unwrap().2 + { return None; } Some((read, reduct)) }; + + // First, optimize parallel reductions. + for (fork, join) in fork_join_map { + let reduces: Vec<_> = editor + .get_users(*join) + .filter(|id| editor.func().nodes[id.idx()].is_reduce()) + .collect(); + for reduce in reduces { + let Some((read, _)) = read_write_helper(reduce, editor) else { + continue; + }; + let indices = editor.func().nodes[read.idx()].try_read().unwrap().1; + if indices_parallel_over_forks(editor, indices, once(*fork)) { + let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1; + let new_read = Node::Read { + collect: init, + indices: indices.to_vec().into_boxed_slice(), + }; + editor.edit(|mut edit| { + let new_read = edit.add_node(new_read); + edit = edit.replace_all_uses(read, new_read)?; + edit.delete_node(read) + }); + } + } + } } diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 022da629..1482a7f0 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -10,7 +10,7 @@ use crate::*; * Substitute all uses of dynamic constants in a type that are keys in the substs map with the * dynamic constant value for that key. Return the substituted version of the type, once memoized. */ -pub(crate) fn substitute_dynamic_constants_in_type( +pub fn substitute_dynamic_constants_in_type( substs: &HashMap<DynamicConstantID, DynamicConstantID>, ty: TypeID, edit: &mut FunctionEdit, @@ -62,7 +62,7 @@ pub(crate) fn substitute_dynamic_constants_in_type( * version of dc, once memoized. Takes a mutable edit instead of an editor since this may create * new dynamic constants, which can only be done inside an edit. */ -pub(crate) fn substitute_dynamic_constants( +pub fn substitute_dynamic_constants( substs: &HashMap<DynamicConstantID, DynamicConstantID>, dc: DynamicConstantID, edit: &mut FunctionEdit, @@ -154,7 +154,7 @@ pub(crate) fn substitute_dynamic_constants( * Substitute all uses of the dynamic constants specified by the subst map in a constant. Return * the substituted version of the constant, once memozied. */ -pub(crate) fn substitute_dynamic_constants_in_constant( +pub fn substitute_dynamic_constants_in_constant( substs: &HashMap<DynamicConstantID, DynamicConstantID>, cons: ConstantID, edit: &mut FunctionEdit, @@ -198,7 +198,7 @@ pub(crate) fn substitute_dynamic_constants_in_constant( /* * Substitute all uses of the dynamic constants specified by the subst map in a node. */ -pub(crate) fn substitute_dynamic_constants_in_node( +pub fn substitute_dynamic_constants_in_node( substs: &HashMap<DynamicConstantID, DynamicConstantID>, node: &mut Node, edit: &mut FunctionEdit, @@ -275,7 +275,7 @@ pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { new_return } -pub(crate) fn contains_between_control_flow(func: &Function) -> bool { +pub fn contains_between_control_flow(func: &Function) -> bool { let num_control = func.nodes.iter().filter(|node| node.is_control()).count(); assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node."); num_control > 2 @@ -477,3 +477,55 @@ pub fn materialize_simple_einsum_expr( _ => panic!(), } } + +/* + * Checks if a set of indices is fully parallel over a set of forks - that is, + * every thread ID from every fork appears at least once in positions in the + * indices set. + */ +pub fn indices_parallel_over_forks<I>( + editor: &FunctionEditor, + indices: &[Index], + mut forks: I, +) -> bool +where + I: Iterator<Item = NodeID>, +{ + // Get the node ids in the indices set. + let positions = indices + .iter() + .filter_map(|index| { + if let Index::Position(indices) = index { + Some(indices) + } else { + None + } + }) + .flat_map(|pos| pos.iter()); + + // Get the forks corresponding to uses of bare thread ids. + let nodes = &editor.func().nodes; + let fork_thread_id_pairs = positions.filter_map(|id| { + if let Node::ThreadID { control, dimension } = nodes[id.idx()] { + Some((control, dimension)) + } else { + None + } + }); + let mut rep_forks = HashMap::<NodeID, Vec<usize>>::new(); + for (fork, dim) in fork_thread_id_pairs { + rep_forks.entry(fork).or_default().push(dim); + } + + // If each fork the query is over is represented and each of its dimensions + // is represented, then the indices are parallel over the forks. + forks.all(|fork| { + let Some(mut rep_dims) = rep_forks.remove(&fork) else { + return false; + }; + + rep_dims.sort(); + rep_dims.dedup(); + nodes[fork.idx()].try_fork().unwrap().1.len() == rep_dims.len() + }) +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 123313a0..8f4ec9ad 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -46,6 +46,11 @@ fixpoint panic after 20 { } fork-coalesce(auto.test2@loop2); +dce(auto.test2); +reduce-slf(auto.test2); +slf(auto.test2); +infer-schedules(auto.test2); +fork-interchange[0, 1](auto.test2); fork-tile[32, 0, false](test6@loop); let out = fork-split(test6@loop); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 0ca36380..4ea8dfb5 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -8,7 +8,6 @@ use hercules_ir::ir::{Device, Schedule}; use lrlex::DefaultLexerTypes; use lrpar::NonStreamingLexer; -use lrpar::Span; use std::fmt; use std::str::FromStr; @@ -112,11 +111,13 @@ impl FromStr for Appliable { } "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)), "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), + "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), "predication" => Ok(Appliable::Pass(ir::Pass::Predication)), + "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)), "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)), "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 881c63a3..0ecac39a 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -26,6 +26,7 @@ pub enum Pass { Outline, PhiElim, Predication, + ReduceSLF, SLF, SROA, Serialize, @@ -42,6 +43,7 @@ impl Pass { Pass::Xdot => 1, Pass::ForkChunk => 3, Pass::ForkFissionBufferize => 2, + Pass::ForkInterchange => 2, _ => 0, } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 8581ee23..e9c681cd 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1966,6 +1966,29 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ReduceSLF => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_reduce_cycles(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + for (((func, fork_join_map), reduce_cycles), nodes_in_fork_joins) in + build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + reduce_slf(&mut func, fork_join_map, reduce_cycles, nodes_in_fork_joins); + changed |= func.modified(); + } + } Pass::SLF => { assert!(args.is_empty()); pm.make_reverse_postorders(); -- GitLab From 33801197254c852ba8a0df65ca152f4977a97878 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 11 Feb 2025 16:28:59 -0600 Subject: [PATCH 5/7] Scalar reduce slf --- hercules_opt/src/slf.rs | 55 ++++++++++++++++++++++++++++++++++++++- hercules_opt/src/utils.rs | 33 +++++++++++++---------- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs index b10f58ce..3133972f 100644 --- a/hercules_opt/src/slf.rs +++ b/hercules_opt/src/slf.rs @@ -277,10 +277,12 @@ pub fn reduce_slf( nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) { // Helper to get write use and read user of reduce with identical indices. + // This checks conditions 1, 2, and 3 of both parallel and scalar reduction + // optimization. let read_write_helper = |reduce: NodeID, editor: &FunctionEditor| -> Option<(NodeID, NodeID)> { let nodes = &editor.func().nodes; let reduct = nodes[reduce.idx()].try_reduce().unwrap().2; - if !nodes[reduct.idx()].is_write() { + if !nodes[reduct.idx()].is_write() || nodes[reduct.idx()].try_write().unwrap().0 != reduce { return None; } let mut users = editor @@ -306,6 +308,8 @@ pub fn reduce_slf( let Some((read, _)) = read_write_helper(reduce, editor) else { continue; }; + + // Check condition 4 of parallel reduction optimization. let indices = editor.func().nodes[read.idx()].try_read().unwrap().1; if indices_parallel_over_forks(editor, indices, once(*fork)) { let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1; @@ -321,4 +325,53 @@ pub fn reduce_slf( } } } + + // Second, optimize scalar reductions. + for (fork, join) in fork_join_map { + let reduces: Vec<_> = editor + .get_users(*join) + .filter(|id| editor.func().nodes[id.idx()].is_reduce()) + .collect(); + for reduce in reduces { + let Some((read, write)) = read_write_helper(reduce, editor) else { + continue; + }; + + // Check condition 4 of scalar reduction optimization. + let indices = editor.func().nodes[read.idx()].try_read().unwrap().1; + if node_indices(indices).all(|id| !nodes_in_fork_joins[fork].contains(&id)) { + let indices = indices.to_vec().into_boxed_slice(); + let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1; + let data = editor.func().nodes[write.idx()].try_write().unwrap().1; + let init_read = Node::Read { + collect: init, + indices: indices.clone(), + }; + editor.edit(|mut edit| { + let init_read = edit.add_node(init_read); + let new_reduce = Node::Reduce { + control: *join, + init: init_read, + reduct: data, + }; + let new_reduce = edit.add_node(new_reduce); + let post_write = Node::Write { + collect: init, + data: new_reduce, + indices, + }; + let post_write = edit.add_node(post_write); + + edit = edit.replace_all_uses(read, new_reduce)?; + edit = edit.replace_all_uses(reduce, post_write)?; + + edit = edit.delete_node(read)?; + edit = edit.delete_node(reduce)?; + edit = edit.delete_node(write)?; + + Ok(edit) + }); + } + } + } } diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 1482a7f0..c5e0d934 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -478,6 +478,23 @@ pub fn materialize_simple_einsum_expr( } } +/* + * Get the node IDs referred to in position indices in a indices set. + */ +pub fn node_indices(indices: &[Index]) -> impl Iterator<Item = NodeID> + '_ { + indices + .iter() + .filter_map(|index| { + if let Index::Position(indices) = index { + Some(indices) + } else { + None + } + }) + .flat_map(|pos| pos.iter()) + .map(|id| *id) +} + /* * Checks if a set of indices is fully parallel over a set of forks - that is, * every thread ID from every fork appears at least once in positions in the @@ -491,21 +508,9 @@ pub fn indices_parallel_over_forks<I>( where I: Iterator<Item = NodeID>, { - // Get the node ids in the indices set. - let positions = indices - .iter() - .filter_map(|index| { - if let Index::Position(indices) = index { - Some(indices) - } else { - None - } - }) - .flat_map(|pos| pos.iter()); - - // Get the forks corresponding to uses of bare thread ids. + // Get the forks corresponding to position uses of bare thread ids. let nodes = &editor.func().nodes; - let fork_thread_id_pairs = positions.filter_map(|id| { + let fork_thread_id_pairs = node_indices(indices).filter_map(|id| { if let Node::ThreadID { control, dimension } = nodes[id.idx()] { Some((control, dimension)) } else { -- GitLab From cc5f607f047ffb049e3c1eef36e0f75bd03e0e4d Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 11 Feb 2025 16:43:11 -0600 Subject: [PATCH 6/7] add extra case to parallel reduce slf --- hercules_opt/src/slf.rs | 44 +++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs index 3133972f..aeacec76 100644 --- a/hercules_opt/src/slf.rs +++ b/hercules_opt/src/slf.rs @@ -255,6 +255,10 @@ pub fn array_slf( * replaces the use of the reduce in the read with the `init` input of the * reduce, nominally so that ParallelReduce can get inferred. * + * If a reduction has the ParallelReduce schedule on it, then any read of it + * directly can be much more easily optimized to refer to the `init` input + * rather than the reduce itself. + * * The scalar variant looks for reductions on collections with the following * form: * @@ -305,22 +309,32 @@ pub fn reduce_slf( .filter(|id| editor.func().nodes[id.idx()].is_reduce()) .collect(); for reduce in reduces { - let Some((read, _)) = read_write_helper(reduce, editor) else { - continue; - }; - - // Check condition 4 of parallel reduction optimization. - let indices = editor.func().nodes[read.idx()].try_read().unwrap().1; - if indices_parallel_over_forks(editor, indices, once(*fork)) { + if let Some((read, _)) = read_write_helper(reduce, editor) { + // Check condition 4 of parallel reduction optimization. + let indices = editor.func().nodes[read.idx()].try_read().unwrap().1; + if indices_parallel_over_forks(editor, indices, once(*fork)) { + let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1; + let new_read = Node::Read { + collect: init, + indices: indices.to_vec().into_boxed_slice(), + }; + editor.edit(|mut edit| { + let new_read = edit.add_node(new_read); + edit = edit.replace_all_uses(read, new_read)?; + edit.delete_node(read) + }); + } + } else if editor.func().schedules[reduce.idx()].contains(&Schedule::ParallelReduce) { let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1; - let new_read = Node::Read { - collect: init, - indices: indices.to_vec().into_boxed_slice(), - }; - editor.edit(|mut edit| { - let new_read = edit.add_node(new_read); - edit = edit.replace_all_uses(read, new_read)?; - edit.delete_node(read) + let read_users: Vec<_> = editor + .get_users(reduce) + .filter(|id| { + editor.func().nodes[id.idx()].is_read() + && nodes_in_fork_joins[fork].contains(&id) + }) + .collect(); + editor.edit(|edit| { + edit.replace_all_uses_where(reduce, init, |id| read_users.contains(id)) }); } } -- GitLab From ea6e12b43761fde64d31aff66c8ca22aa0a45dfc Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 11 Feb 2025 16:46:55 -0600 Subject: [PATCH 7/7] transform juno matmul sample to hercules matmul sample --- hercules_opt/src/slf.rs | 1 + juno_samples/matmul/src/gpu.sch | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs index aeacec76..1d70676a 100644 --- a/hercules_opt/src/slf.rs +++ b/hercules_opt/src/slf.rs @@ -112,6 +112,7 @@ pub fn slf(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, typing: value } + Node::Constant { id: _ } | Node::Undef { ty: _ } => SLFLattice::top(), _ => SLFLattice::bottom(), } }); diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index 3d3f919c..e6eb3641 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -12,7 +12,21 @@ gvn(*); phi-elim(*); dce(*); -infer-schedules(*); +fixpoint panic after 20 { + forkify(out.matmul); + fork-guard-elim(out.matmul); +} +gvn(out.matmul); +phi-elim(out.matmul); +dce(out.matmul); + +fixpoint panic after 20 { + reduce-slf(out.matmul); + slf(out.matmul); + infer-schedules(out.matmul); +} +fork-coalesce(out.matmul); +dce(out.matmul); gcm(*); float-collections(*); -- GitLab