diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 7c4238921a6d7c96dc4645f4fdb123aa1bf1a594..fd6747d7a15628034bc98f7ef3ebc00631f8abae 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1,5 +1,4 @@ use std::collections::{HashMap, HashSet}; -use std::hash::Hash; use std::iter::zip; use bimap::BiMap; @@ -1076,3 +1075,92 @@ pub fn fork_dim_merge( new_fork } + +/* + * Run fork interchange on all fork-joins that are mutable in an editor. + */ +pub fn fork_interchange_all_forks( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + first_dim: usize, + second_dim: usize, +) { + for (fork, join) in fork_join_map { + if editor.is_mutable(*fork) { + fork_interchange(editor, *fork, *join, first_dim, second_dim); + } + } +} + +fn fork_interchange( + editor: &mut FunctionEditor, + fork: NodeID, + join: NodeID, + first_dim: usize, + second_dim: usize, +) { + // Check that every reduce on the join is parallel or tight associative. + let nodes = &editor.func().nodes; + let schedules = &editor.func().schedules; + if !editor + .get_users(join) + .filter(|id| nodes[id.idx()].is_reduce()) + .all(|id| { + schedules[id.idx()].contains(&Schedule::ParallelReduce) + || schedules[id.idx()].contains(&Schedule::TightAssociative) + }) + { + // If not, we can't necessarily do interchange. + return; + } + + let Node::Fork { + control, + ref factors, + } = nodes[fork.idx()] + else { + panic!() + }; + let fix_tids: Vec<(NodeID, Node)> = editor + .get_users(fork) + .filter_map(|id| { + nodes[id.idx()] + .try_thread_id() + .map(|(_, dim)| { + if dim == first_dim { + Some(( + id, + Node::ThreadID { + control: fork, + dimension: second_dim, + }, + )) + } else if dim == second_dim { + Some(( + id, + Node::ThreadID { + control: fork, + dimension: first_dim, + }, + )) + } else { + None + } + }) + .flatten() + }) + .collect(); + let mut factors = factors.clone(); + factors.swap(first_dim, second_dim); + let new_fork = Node::Fork { control, factors }; + editor.edit(|mut edit| { + for (old_id, new_tid) in fix_tids { + let new_id = edit.add_node(new_tid); + edit = edit.replace_all_uses(old_id, new_id)?; + edit = edit.delete_node(old_id)?; + } + let new_fork = edit.add_node(new_fork); + edit = edit.replace_all_uses(fork, new_fork)?; + edit.delete_node(fork) + }); +} diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs index 024a442e006d23a9b9bc445a702734b828998177..1d70676ae5d42ec9cdce17163e3417b1daedecfc 100644 --- a/hercules_opt/src/slf.rs +++ b/hercules_opt/src/slf.rs @@ -1,4 +1,5 @@ use std::collections::{BTreeMap, HashMap, HashSet}; +use std::iter::once; use hercules_ir::*; @@ -111,6 +112,7 @@ pub fn slf(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, typing: value } + Node::Constant { id: _ } | Node::Undef { ty: _ } => SLFLattice::top(), _ => SLFLattice::bottom(), } }); @@ -229,3 +231,162 @@ pub fn array_slf( } } } + +/* + * Top level function to run reduce store-to-load forwarding on a function. + * There are two variants of reduce SLF. One optimizes parallel reductions and + * the other optimizes scalar reductions. This pass just runs one after the + * other - it's up to the user to potentially wrap this pass in a fixpoint. + * + * The parallel variant looks for reductions on collections with the following + * form: + * + * 1. A write `reduct` use. + * 2. A single read user in the reduction cycle. + * 3. The write use and read user have identical indices. + * 4. The indices set directly refers to at least every thread ID produced by + * the fork corresponding to the reduce. + * + * Condition #4 roughly corresponds to the same condition needed to infer the + * ParallelReduce schedule - however, in this scenario, it can't be inferred + * because the data being written is in the reduce cycle, since the data being + * written is derived from the read user of the reduce. However, since the + * indices changes each iteration, the read could equivalently read from the + * `init` input of the reduce, rather than the reduce itself. This optimization + * replaces the use of the reduce in the read with the `init` input of the + * reduce, nominally so that ParallelReduce can get inferred. + * + * If a reduction has the ParallelReduce schedule on it, then any read of it + * directly can be much more easily optimized to refer to the `init` input + * rather than the reduce itself. + * + * The scalar variant looks for reductions on collections with the following + * form: + * + * 1. A write `reduct` use. + * 2. A single read user in the reduction cycle. + * 3. The write use and read user have identical indices. + * 4. The indices set doesn't reference any nodes in the fork-join of the reduce + * (doesn't depend on the thread IDs of the fork-join). + * + * Instead of repeatedly reading / writing the same collection item each + * iteration, the reduction can reduce over the scalar value directly, and do a + * single write into the collection after the fork-join (note technically, the + * "scalar" may itself be a collection in the case of nested collections, but + * this optimization most often handles scalars). + */ +pub fn reduce_slf( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) { + // Helper to get write use and read user of reduce with identical indices. + // This checks conditions 1, 2, and 3 of both parallel and scalar reduction + // optimization. + let read_write_helper = |reduce: NodeID, editor: &FunctionEditor| -> Option<(NodeID, NodeID)> { + let nodes = &editor.func().nodes; + let reduct = nodes[reduce.idx()].try_reduce().unwrap().2; + if !nodes[reduct.idx()].is_write() || nodes[reduct.idx()].try_write().unwrap().0 != reduce { + return None; + } + let mut users = editor + .get_users(reduce) + .filter(|id| reduce_cycles[&reduce].contains(id) && *id != reduct); + let read = users.next()?; + if users.next().is_some() + || !nodes[read.idx()].is_read() + || nodes[read.idx()].try_read().unwrap().1 != nodes[reduct.idx()].try_write().unwrap().2 + { + return None; + } + Some((read, reduct)) + }; + + // First, optimize parallel reductions. + for (fork, join) in fork_join_map { + let reduces: Vec<_> = editor + .get_users(*join) + .filter(|id| editor.func().nodes[id.idx()].is_reduce()) + .collect(); + for reduce in reduces { + if let Some((read, _)) = read_write_helper(reduce, editor) { + // Check condition 4 of parallel reduction optimization. + let indices = editor.func().nodes[read.idx()].try_read().unwrap().1; + if indices_parallel_over_forks(editor, indices, once(*fork)) { + let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1; + let new_read = Node::Read { + collect: init, + indices: indices.to_vec().into_boxed_slice(), + }; + editor.edit(|mut edit| { + let new_read = edit.add_node(new_read); + edit = edit.replace_all_uses(read, new_read)?; + edit.delete_node(read) + }); + } + } else if editor.func().schedules[reduce.idx()].contains(&Schedule::ParallelReduce) { + let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1; + let read_users: Vec<_> = editor + .get_users(reduce) + .filter(|id| { + editor.func().nodes[id.idx()].is_read() + && nodes_in_fork_joins[fork].contains(&id) + }) + .collect(); + editor.edit(|edit| { + edit.replace_all_uses_where(reduce, init, |id| read_users.contains(id)) + }); + } + } + } + + // Second, optimize scalar reductions. + for (fork, join) in fork_join_map { + let reduces: Vec<_> = editor + .get_users(*join) + .filter(|id| editor.func().nodes[id.idx()].is_reduce()) + .collect(); + for reduce in reduces { + let Some((read, write)) = read_write_helper(reduce, editor) else { + continue; + }; + + // Check condition 4 of scalar reduction optimization. + let indices = editor.func().nodes[read.idx()].try_read().unwrap().1; + if node_indices(indices).all(|id| !nodes_in_fork_joins[fork].contains(&id)) { + let indices = indices.to_vec().into_boxed_slice(); + let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1; + let data = editor.func().nodes[write.idx()].try_write().unwrap().1; + let init_read = Node::Read { + collect: init, + indices: indices.clone(), + }; + editor.edit(|mut edit| { + let init_read = edit.add_node(init_read); + let new_reduce = Node::Reduce { + control: *join, + init: init_read, + reduct: data, + }; + let new_reduce = edit.add_node(new_reduce); + let post_write = Node::Write { + collect: init, + data: new_reduce, + indices, + }; + let post_write = edit.add_node(post_write); + + edit = edit.replace_all_uses(read, new_reduce)?; + edit = edit.replace_all_uses(reduce, post_write)?; + + edit = edit.delete_node(read)?; + edit = edit.delete_node(reduce)?; + edit = edit.delete_node(write)?; + + Ok(edit) + }); + } + } + } +} diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 022da629cf7267316af1c4a44007f8ed424c5e31..c5e0d934fb2b4673890e3e7f1d209db2e028d513 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -10,7 +10,7 @@ use crate::*; * Substitute all uses of dynamic constants in a type that are keys in the substs map with the * dynamic constant value for that key. Return the substituted version of the type, once memoized. */ -pub(crate) fn substitute_dynamic_constants_in_type( +pub fn substitute_dynamic_constants_in_type( substs: &HashMap<DynamicConstantID, DynamicConstantID>, ty: TypeID, edit: &mut FunctionEdit, @@ -62,7 +62,7 @@ pub(crate) fn substitute_dynamic_constants_in_type( * version of dc, once memoized. Takes a mutable edit instead of an editor since this may create * new dynamic constants, which can only be done inside an edit. */ -pub(crate) fn substitute_dynamic_constants( +pub fn substitute_dynamic_constants( substs: &HashMap<DynamicConstantID, DynamicConstantID>, dc: DynamicConstantID, edit: &mut FunctionEdit, @@ -154,7 +154,7 @@ pub(crate) fn substitute_dynamic_constants( * Substitute all uses of the dynamic constants specified by the subst map in a constant. Return * the substituted version of the constant, once memozied. */ -pub(crate) fn substitute_dynamic_constants_in_constant( +pub fn substitute_dynamic_constants_in_constant( substs: &HashMap<DynamicConstantID, DynamicConstantID>, cons: ConstantID, edit: &mut FunctionEdit, @@ -198,7 +198,7 @@ pub(crate) fn substitute_dynamic_constants_in_constant( /* * Substitute all uses of the dynamic constants specified by the subst map in a node. */ -pub(crate) fn substitute_dynamic_constants_in_node( +pub fn substitute_dynamic_constants_in_node( substs: &HashMap<DynamicConstantID, DynamicConstantID>, node: &mut Node, edit: &mut FunctionEdit, @@ -275,7 +275,7 @@ pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { new_return } -pub(crate) fn contains_between_control_flow(func: &Function) -> bool { +pub fn contains_between_control_flow(func: &Function) -> bool { let num_control = func.nodes.iter().filter(|node| node.is_control()).count(); assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node."); num_control > 2 @@ -477,3 +477,60 @@ pub fn materialize_simple_einsum_expr( _ => panic!(), } } + +/* + * Get the node IDs referred to in position indices in a indices set. + */ +pub fn node_indices(indices: &[Index]) -> impl Iterator<Item = NodeID> + '_ { + indices + .iter() + .filter_map(|index| { + if let Index::Position(indices) = index { + Some(indices) + } else { + None + } + }) + .flat_map(|pos| pos.iter()) + .map(|id| *id) +} + +/* + * Checks if a set of indices is fully parallel over a set of forks - that is, + * every thread ID from every fork appears at least once in positions in the + * indices set. + */ +pub fn indices_parallel_over_forks<I>( + editor: &FunctionEditor, + indices: &[Index], + mut forks: I, +) -> bool +where + I: Iterator<Item = NodeID>, +{ + // Get the forks corresponding to position uses of bare thread ids. + let nodes = &editor.func().nodes; + let fork_thread_id_pairs = node_indices(indices).filter_map(|id| { + if let Node::ThreadID { control, dimension } = nodes[id.idx()] { + Some((control, dimension)) + } else { + None + } + }); + let mut rep_forks = HashMap::<NodeID, Vec<usize>>::new(); + for (fork, dim) in fork_thread_id_pairs { + rep_forks.entry(fork).or_default().push(dim); + } + + // If each fork the query is over is represented and each of its dimensions + // is represented, then the indices are parallel over the forks. + forks.all(|fork| { + let Some(mut rep_dims) = rep_forks.remove(&fork) else { + return false; + }; + + rep_dims.sort(); + rep_dims.dedup(); + nodes[fork.idx()].try_fork().unwrap().1.len() == rep_dims.len() + }) +} diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 38a38c2bacfe480710b128ff44a24abfd0059df6..fe0a8802e8b68f6e21cc8fe3586a03f0ce658fa5 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -50,7 +50,7 @@ gvn(*); phi-elim(*); dce(*); -fork-tile[32, 0, true](test6@loop); +fork-tile[32, 0, false](test6@loop); let out = fork-split(test6@loop); let out = outline(out.test6.fj1); cpu(out); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 886ab13b238402e10ad1480b853bf1a237dd41b4..8f569cfbe75105054e60c5d9bbb84b8ebbac6041 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -10,11 +10,11 @@ fn test1(input : i32) -> i32[4, 4] { } #[entry] -fn test2(input : i32) -> i32[4, 4] { - let arr : i32[4, 4]; - for i = 0 to 8 { - for j = 0 to 4 { - for k = 0 to 4 { +fn test2(input : i32) -> i32[4, 5] { + let arr : i32[4, 5]; + @loop1 for i = 0 to 8 { + @loop2 for k = 0 to 5 { + @loop3 for j = 0 to 4 { arr[j, k] += input; } } diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 117cf37e2e2ac0c5ebf6a908253dc85b0e31029f..8f4ec9ad06b2525e2e872f7c629880a7c8b7f213 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -31,7 +31,10 @@ dce(*); fixpoint panic after 20 { forkify(*); fork-guard-elim(*); - fork-coalesce(*); +} + +fixpoint panic after 20 { + fork-coalesce(auto.test1, auto.test3, auto.test4, auto.test5, auto.test7, auto.test8); } gvn(*); @@ -42,7 +45,14 @@ fixpoint panic after 20 { infer-schedules(*); } -fork-tile[32, 0, true](test6@loop); +fork-coalesce(auto.test2@loop2); +dce(auto.test2); +reduce-slf(auto.test2); +slf(auto.test2); +infer-schedules(auto.test2); +fork-interchange[0, 1](auto.test2); + +fork-tile[32, 0, false](test6@loop); let out = fork-split(test6@loop); let out = auto-outline(test6); gpu(out.test6); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 1013a1f028b9d204a442be21bc0a1966d025d25e..21ccd7c49b5627c4e07839cc7583c101e85c5013 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -25,7 +25,7 @@ fn main() { let mut r = runner!(test2); let output = r.run(3).await; - let correct = vec![24i32; 16]; + let correct = vec![24i32; 20]; assert(&correct, output); let mut r = runner!(test3); diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index 3d3f919cd26e4eba480540df06f839a8b86976b0..e6eb3641fef1496e0b82534f1bc449b1027374a6 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -12,7 +12,21 @@ gvn(*); phi-elim(*); dce(*); -infer-schedules(*); +fixpoint panic after 20 { + forkify(out.matmul); + fork-guard-elim(out.matmul); +} +gvn(out.matmul); +phi-elim(out.matmul); +dce(out.matmul); + +fixpoint panic after 20 { + reduce-slf(out.matmul); + slf(out.matmul); + infer-schedules(out.matmul); +} +fork-coalesce(out.matmul); +dce(out.matmul); gcm(*); float-collections(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 0ca3638010bc1a026b44d56d5d49696633f7d9fa..4ea8dfb50a4fc6e6fadcb0d7b40c2a83d7223ef3 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -8,7 +8,6 @@ use hercules_ir::ir::{Device, Schedule}; use lrlex::DefaultLexerTypes; use lrpar::NonStreamingLexer; -use lrpar::Span; use std::fmt; use std::str::FromStr; @@ -112,11 +111,13 @@ impl FromStr for Appliable { } "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)), "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), + "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), "predication" => Ok(Appliable::Pass(ir::Pass::Predication)), + "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)), "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)), "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 0ec6f52cd678c6e42324f77ec2f0799deaca1535..0ecac39a19a5e364c0ab1185b162c0470f5b6a5b 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -9,13 +9,14 @@ pub enum Pass { DCE, DeleteUncalled, FloatCollections, + ForkChunk, + ForkCoalesce, + ForkDimMerge, + ForkFissionBufferize, ForkGuardElim, + ForkInterchange, ForkSplit, - ForkCoalesce, Forkify, - ForkFissionBufferize, - ForkDimMerge, - ForkChunk, GCM, GVN, InferSchedules, @@ -25,14 +26,15 @@ pub enum Pass { Outline, PhiElim, Predication, - SimplifyCFG, + ReduceSLF, SLF, SROA, + Serialize, + SimplifyCFG, Unforkify, - WritePredication, Verify, + WritePredication, Xdot, - Serialize, } impl Pass { @@ -41,6 +43,7 @@ impl Pass { Pass::Xdot => 1, Pass::ForkChunk => 3, Pass::ForkFissionBufferize => 2, + Pass::ForkInterchange => 2, _ => 0, } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 20825c549825d7abac60abb82fa508dd9f3ae803..e9c681cd1239fca6cc1946c55032837eff7c2f00 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1635,6 +1635,36 @@ fn run_pass( fields: new_fork_joins, }; } + Pass::ForkInterchange => { + assert_eq!(args.len(), 2); + let Some(Value::Integer { val: first_dim }) = args.get(0) else { + return Err(SchedulerError::PassError { + pass: "forkInterchange".to_string(), + error: "expected integer argument".to_string(), + }); + }; + let Some(Value::Integer { val: second_dim }) = args.get(1) else { + return Err(SchedulerError::PassError { + pass: "forkInterchange".to_string(), + error: "expected integer argument".to_string(), + }); + }; + + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + fork_interchange_all_forks(&mut func, fork_join_map, *first_dim, *second_dim); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::Forkify => { assert!(args.is_empty()); loop { @@ -1936,6 +1966,29 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ReduceSLF => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_reduce_cycles(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + for (((func, fork_join_map), reduce_cycles), nodes_in_fork_joins) in + build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + reduce_slf(&mut func, fork_join_map, reduce_cycles, nodes_in_fork_joins); + changed |= func.modified(); + } + } Pass::SLF => { assert!(args.is_empty()); pm.make_reverse_postorders(); @@ -2042,7 +2095,7 @@ fn run_pass( }); }; - assert_eq!(*guarded_flag, true); + assert!(!*guarded_flag); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); for (func, fork_join_map) in build_selection(pm, selection, false) @@ -2059,8 +2112,6 @@ fn run_pass( pm.clear_analyses(); } Pass::ForkFissionBufferize => { - pm.make_fork_join_maps(); - assert_eq!(args.len(), 2); let Some(Value::Label { labels: fork_labels,