diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 8bd3f735fa60cf4e0404edde28eb858e0ed7e581..1c220b9930ba8c251e5ae87723d983a16a6365c9 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -306,40 +306,43 @@ where pub fn fork_fission<'a>( editor: &'a mut FunctionEditor, - _control_subgraph: &Subgraph, - _types: &Vec<TypeID>, - _loop_tree: &LoopTree, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, + loop_tree: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, -) -> () { - let forks: Vec<_> = editor - .func() - .nodes - .iter() - .enumerate() - .filter_map(|(idx, node)| { - if node.is_fork() { - Some(NodeID::new(idx)) - } else { - None - } - }) + fork_label: LabelID, +) -> Vec<NodeID> { + let forks: Vec<_> = loop_tree + .bottom_up_loops() + .into_iter() + .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) .collect(); - let control_pred = NodeID::new(0); + let mut created_forks = Vec::new(); + + // This does the reduction fission + for fork in forks { + let join = fork_join_map[&fork.0]; - // This does the reduction fission: - for fork in forks.clone() { - // FIXME: If there is control in between fork and join, don't just give up. - let join = fork_join_map[&fork]; - let join_pred = editor.func().nodes[join.idx()].try_join().unwrap(); - if join_pred != fork { - todo!("Can't do fork fission on nodes with internal control") - // Inner control LOOPs are hard - // inner control in general *should* work right now without modifications. + // FIXME: Don't make multiple forks for reduces that are in cycles with each other. + let reduce_partition = default_reduce_partition(editor, fork.0, join); + + if !editor.func().labels[fork.0.idx()].contains(&fork_label) { + continue; } - let reduce_partition = default_reduce_partition(editor, fork, join); - fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork); + + if editor.is_mutable(fork.0) { + created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0); + if created_forks.is_empty() { + continue; + } else { + return created_forks; + } + } + } + + created_forks } /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ @@ -488,48 +491,38 @@ where } } -/** Split a 1D fork into a separate fork for each reduction. */ +/** Split a fork into a separate fork for each reduction. */ pub fn fork_reduce_fission_helper<'a>( editor: &'a mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split, - original_control_pred: NodeID, // What the new fork connects to. - + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, fork: NodeID, -) -> (NodeID, NodeID) { +) -> Vec<NodeID> { let join = fork_join_map[&fork]; - let mut new_control_pred: NodeID = original_control_pred; - // Important edges are: Reduces, + let mut new_forks = Vec::new(); - // NOTE: - // Say two reduce are in a fork, s.t reduce A depends on reduce B - // If user wants A and B in separate forks: - // - we can simply refuse - // - or we can duplicate B + let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap(); let mut new_fork = NodeID::new(0); let mut new_join = NodeID::new(0); + let subgraph = &nodes_in_fork_joins[&fork]; + // Gets everything between fork & join that this reduce needs. (ALL CONTROL) - for reduce in reduce_partition { - let reduce = reduce.0; - - let function = editor.func(); - let subgraph = find_reduce_dependencies(function, reduce, fork); - - let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect(); - - subgraph.insert(join); - subgraph.insert(fork); - subgraph.insert(reduce); - - let (_, mapping, _) = copy_subgraph(editor, subgraph); + editor.edit(|mut edit| { + for reduce in reduce_partition { + let reduce = reduce.0; - new_fork = mapping[&fork]; - new_join = mapping[&join]; + let a = copy_subgraph_in_edit(edit, subgraph.clone())?; + edit = a.0; + let mapping = a.1; - editor.edit(|mut edit| { + new_fork = mapping[&fork]; + new_forks.push(new_fork); + new_join = mapping[&join]; + // Atttach new_fork after control_pred let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone(); edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| { @@ -538,13 +531,9 @@ pub fn fork_reduce_fission_helper<'a>( // Replace uses of reduce edit = edit.replace_all_uses(reduce, mapping[&reduce])?; - Ok(edit) - }); - - new_control_pred = new_join; - } + new_control_pred = new_join; + }; - editor.edit(|mut edit| { // Replace original join w/ new final join edit = edit.replace_all_uses_where(join, new_join, |_| true)?; @@ -553,10 +542,12 @@ pub fn fork_reduce_fission_helper<'a>( // Replace all users of original fork, and then delete it, leftover users will be DCE'd. edit = edit.replace_all_uses(fork, new_fork)?; - edit.delete_node(fork) + edit = edit.delete_node(fork)?; + + Ok(edit) }); - (new_fork, new_join) + new_forks } pub fn fork_coalesce( diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 5f3ff94e9a995b1ff2f47cb4fcb4d499896189f0..ff30b277499d5d672e45a59a21317807fba64a8a 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -3,7 +3,7 @@ gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10, test11); cpu(auto.test1); cpu(auto.test2); cpu(auto.test3); @@ -13,6 +13,7 @@ cpu(auto.test7); cpu(auto.test8); cpu(auto.test9); cpu(auto.test10); +cpu(auto.test11); let test1_cpu = auto.test1; rename["test1_cpu"](test1_cpu); @@ -95,11 +96,20 @@ dce(auto.test8); simplify-cfg(auto.test8); dce(auto.test8); -array-slf(auto.test10); -ccp(auto.test10); +fork-split(auto.test10); +let fission_out = fork-fission-reduces[test10@loop3](auto.test10); +fork-fusion(fission_out.test10_8.fj0, fission_out.test10_8.fj1); dce(auto.test10); simplify-cfg(auto.test10); dce(auto.test10); +fork-split(auto.test10); unforkify(auto.test10); +array-slf(auto.test11); +ccp(auto.test11); +dce(auto.test11); +simplify-cfg(auto.test11); +dce(auto.test11); +unforkify(auto.test11); + gcm(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 2eab56b9400a2695548547b56d52a87eb0751771..dfc81cd2f3e9b169424e3619118de7426e0782ac 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -149,7 +149,23 @@ fn test9<r, c : usize>(input : i32[r, c]) -> i32[r, c] { } #[entry] -fn test10(k1 : i32[8], k2 : i32[8], v : i32[8]) -> i32 { +fn test10(i0 : u64, i1 : u64, input : i32) -> i32 { + let arr0 : i32[4, 5]; + let arr1 : i32[4, 5]; + let arr2 : i32[4, 5]; + @loop2 for k = 0 to 5 { + @loop3 for j = 0 to 4 { + arr0[j, k] += input; + arr1[j, k] += 2; + arr2[j, k] += input * 2; + } + } + + return arr0[i0, i1] + arr1[i0, i1] + arr2[i0, i1]; +} + +#[entry] +fn test11(k1 : i32[8], k2 : i32[8], v : i32[8]) -> i32 { @const let s : i32[8]; for i = 0 to 8 { s[i] = v[k1[i] as u64]; @@ -159,4 +175,4 @@ fn test10(k1 : i32[8], k2 : i32[8], v : i32[8]) -> i32 { sum += s[k2[i] as u64]; } return sum; -} \ No newline at end of file +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 43b28e34e74f51da8d0df12901be065bbc8a9a5a..aa8861d5eda51dc95783d6affbc757c398204e5c 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -9,12 +9,13 @@ no-memset(test8@const1); no-memset(test8@const2); no-memset(test9@const); no-memset(test10@const); +no-memset(test11@const); gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10, test11); gpu(auto.test1); gpu(auto.test2); gpu(auto.test3); @@ -24,6 +25,7 @@ gpu(auto.test7); gpu(auto.test8); gpu(auto.test9); gpu(auto.test10); +gpu(auto.test11); ip-sroa(*); sroa(*); @@ -77,6 +79,13 @@ dce(auto.test8); no-memset(test9@const); +fork-split(auto.test10); +fork-fission-reduces[test10@loop3](auto.test10); + +dce(auto.test10); +simplify-cfg(auto.test10); +dce(auto.test10); + ip-sroa(*); sroa(*); dce(*); @@ -86,5 +95,5 @@ phi-elim(*); dce(*); gcm(*); -float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5); +float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5, test10, auto.test10); gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 0b37a99d6f14f17238391ca25dfccbfc9a753441..29b62beaf24645d68b95271438685073e01d63a6 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -76,6 +76,11 @@ fn main() { assert(&correct, output); let mut r = runner!(test10); + let output = r.run(1, 2, 5).await; + let correct = 17i32; + assert_eq!(correct, output); + + let mut r = runner!(test11); let k1 = vec![0, 4, 3, 7, 3, 4, 2, 1]; let k2 = vec![6, 4, 3, 2, 4, 1, 0, 5]; let v = vec![3, -499, 4, 32, -2, 55, -74, 10]; diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 8b68ed7111de9009458dbdc26b511e66ee13db58..81fc82cd8f72d35ebe5a8f3276fe5690b2e0af88 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -144,6 +144,7 @@ impl FromStr for Appliable { "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), + "fork-fission-reduces" => Ok(Appliable::Pass(ir::Pass::ForkFission)), "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index bacb4142c62df5b0bf974bcb1703400dd3caa02c..287cd21a1ff64838942f8f2117937fd52020e75c 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -18,6 +18,7 @@ pub enum Pass { ForkDimMerge, ForkExtend, ForkFissionBufferize, + ForkFission, ForkFusion, ForkGuardElim, ForkInterchange, @@ -59,6 +60,7 @@ impl Pass { Pass::ForkChunk => num == 4, Pass::ForkExtend => num == 1, Pass::ForkFissionBufferize => num == 2 || num == 1, + Pass::ForkFission => num == 1, Pass::ForkInterchange => num == 2, Pass::ForkReshape => true, Pass::InterproceduralSROA => num == 0 || num == 1, @@ -77,6 +79,7 @@ impl Pass { Pass::ForkChunk => "4", Pass::ForkExtend => "1", Pass::ForkFissionBufferize => "1 or 2", + Pass::ForkFission => "1", Pass::ForkInterchange => "2", Pass::ForkReshape => "any", Pass::InterproceduralSROA => "0 or 1", diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 767386bb765b1badfbab302ada9452bff855ee64..97bec5445eeeb9e9f849ef129b8ce9cf8c242798 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2780,7 +2780,127 @@ fn run_pass( } pm.delete_gravestones(); pm.clear_analyses(); - } + }, + Pass::ForkFission => { + assert!(args.len() == 1); + + let Some(Value::Label { + labels: fork_labels, + }) = args.get(0) + else { + return Err(SchedulerError::PassError { + pass: "forkFission".to_string(), + error: "expected label argument".to_string(), + }); + }; + + + pm.make_fork_join_maps(); + pm.make_loops(); + pm.make_reduce_cycles(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + let mut created_fork_joins = vec![vec![]; pm.functions.len()]; + + // assert only one function is in the selection. + let num_functions = build_selection(pm, selection.clone(), false) + .iter() + .filter(|func| func.is_some()) + .count(); + + assert!(num_functions <= 1); + assert_eq!(fork_labels.len(), 1); + + let fork_label = fork_labels[0].label; + + for ((((func, fork_join_map), loop_tree), reduce_cycles), nodes_in_fork_joins, + ) in build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(reduce_cycles.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + + let new_forks = fork_fission( + &mut func, + nodes_in_fork_joins, + reduce_cycles, + loop_tree, + fork_join_map, + fork_label, + ); + + let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; + + for f in new_forks { + created_fork_joins.push(f); + }; + + changed |= func.modified(); + } + + pm.clear_analyses(); + pm.make_nodes_in_fork_joins(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + let mut new_fork_joins = HashMap::new(); + + for (mut func, created_fork_joins) in + build_editors(pm).into_iter().zip(created_fork_joins) + { + // For every function, create a label for every level of fork- + // joins resulting from the split. + let name = func.func().name.clone(); + let func_id = func.func_id(); + let labels = create_labels_for_node_sets( + &mut func, + created_fork_joins.into_iter().map(|fork| { + nodes_in_fork_joins[func_id.idx()][&fork] + .iter() + .map(|id| *id) + }), + ); + + // Assemble those labels into a record for this function. The + // format of the records is <function>.<f>, where N is the + // level of the split fork-joins being referred to. + let mut func_record = HashMap::new(); + for (idx, label) in labels { + let fmt = format!("fj{idx}"); + func_record.insert( + fmt.to_string(), + Value::Label { + labels: vec![LabelInfo { + func: func_id, + label: label, + }], + }, + ); + } + + // Try to avoid creating unnecessary record entries. + if !func_record.is_empty() { + new_fork_joins.entry(name).insert_entry(Value::Record { + fields: func_record, + }); + } + } + + pm.delete_gravestones(); + pm.clear_analyses(); + + result = Value::Record { + fields: new_fork_joins, + }; + + }, Pass::ForkFissionBufferize => { assert!(args.len() == 1 || args.len() == 2); let Some(Value::Label {