diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index e6db0345def31324243cdee2bdcb6b5cca5d9a7b..1df5338ddb3a4fa1a3ff65db27a6c9571dd21bb0 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -306,40 +306,33 @@ where pub fn fork_fission<'a>( editor: &'a mut FunctionEditor, - _control_subgraph: &Subgraph, - _types: &Vec<TypeID>, - _loop_tree: &LoopTree, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, + loop_tree: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, -) -> () { - let forks: Vec<_> = editor - .func() - .nodes - .iter() - .enumerate() - .filter_map(|(idx, node)| { - if node.is_fork() { - Some(NodeID::new(idx)) - } else { - None - } - }) +) -> Vec<NodeID> { + let mut forks: Vec<_> = loop_tree + .bottom_up_loops() + .into_iter() + .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) .collect(); - let control_pred = NodeID::new(0); - - // This does the reduction fission: + let mut created_forks = Vec::new(); + // This does the reduction fission for fork in forks.clone() { - // FIXME: If there is control in between fork and join, don't just give up. - let join = fork_join_map[&fork]; - let join_pred = editor.func().nodes[join.idx()].try_join().unwrap(); - if join_pred != fork { - todo!("Can't do fork fission on nodes with internal control") - // Inner control LOOPs are hard - // inner control in general *should* work right now without modifications. + let join = fork_join_map[&fork.0]; + let reduce_partition = default_reduce_partition(editor, fork.0, join); + + if editor.is_mutable(fork.0) { + created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0); + if !created_forks.is_empty() { + break; + } } - let reduce_partition = default_reduce_partition(editor, fork, join); - fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork); + } + + created_forks } /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ @@ -488,48 +481,38 @@ where } } -/** Split a 1D fork into a separate fork for each reduction. */ +/** Split a fork into a separate fork for each reduction. */ pub fn fork_reduce_fission_helper<'a>( editor: &'a mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split, - original_control_pred: NodeID, // What the new fork connects to. - + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, fork: NodeID, -) -> (NodeID, NodeID) { +) -> Vec<NodeID> { let join = fork_join_map[&fork]; - let mut new_control_pred: NodeID = original_control_pred; - // Important edges are: Reduces, + let mut new_forks = Vec::new(); - // NOTE: - // Say two reduce are in a fork, s.t reduce A depends on reduce B - // If user wants A and B in separate forks: - // - we can simply refuse - // - or we can duplicate B + let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap(); let mut new_fork = NodeID::new(0); let mut new_join = NodeID::new(0); + let subgraph = &nodes_in_fork_joins[&fork]; + // Gets everything between fork & join that this reduce needs. (ALL CONTROL) - for reduce in reduce_partition { - let reduce = reduce.0; - - let function = editor.func(); - let subgraph = find_reduce_dependencies(function, reduce, fork); - - let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect(); - - subgraph.insert(join); - subgraph.insert(fork); - subgraph.insert(reduce); - - let (_, mapping, _) = copy_subgraph(editor, subgraph); + editor.edit(|mut edit| { + for reduce in reduce_partition { + let reduce = reduce.0; - new_fork = mapping[&fork]; - new_join = mapping[&join]; + let a = copy_subgraph_in_edit(edit, subgraph.clone())?; + edit = a.0; + let mapping = a.1; - editor.edit(|mut edit| { + new_fork = mapping[&fork]; + new_forks.push(new_fork); + new_join = mapping[&join]; + // Atttach new_fork after control_pred let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone(); edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| { @@ -538,13 +521,9 @@ pub fn fork_reduce_fission_helper<'a>( // Replace uses of reduce edit = edit.replace_all_uses(reduce, mapping[&reduce])?; - Ok(edit) - }); - - new_control_pred = new_join; - } + new_control_pred = new_join; + }; - editor.edit(|mut edit| { // Replace original join w/ new final join edit = edit.replace_all_uses_where(join, new_join, |_| true)?; @@ -553,10 +532,12 @@ pub fn fork_reduce_fission_helper<'a>( // Replace all users of original fork, and then delete it, leftover users will be DCE'd. edit = edit.replace_all_uses(fork, new_fork)?; - edit.delete_node(fork) + edit = edit.delete_node(fork)?; + + Ok(edit) }); - (new_fork, new_join) + new_forks } pub fn fork_coalesce( diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index f46c91d6a84a08b2258332af1dc5d6a662d86639..185ea441e801a01a084185fc09cf652c8fa6fb1e 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -3,7 +3,7 @@ gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9, test10); cpu(auto.test1); cpu(auto.test2); cpu(auto.test3); @@ -12,6 +12,7 @@ cpu(auto.test5); cpu(auto.test7); cpu(auto.test8); cpu(auto.test9); +cpu(auto.test10); let test1_cpu = auto.test1; rename["test1_cpu"](test1_cpu); @@ -96,4 +97,19 @@ dce(auto.test8); no-memset(test9@const); +fork-split(auto.test10); +xdot[true](auto.test10); +fork-fission(auto.test10); +dce(auto.test10); +xdot[false](auto.test10); +dce(auto.test10); +simplify-cfg(auto.test10); +dce(auto.test10); +xdot[true](auto.test10); + +fork-split(auto.test10); + +unforkify(auto.test10); + + gcm(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 334fc2bfe4f745cec9004fdf9ebdf80d11818c0f..51284c29b159559a4c27dbe77521bf7bb725eaae 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -147,3 +147,19 @@ fn test9<r, c : usize>(input : i32[r, c]) -> i32[r, c] { return out; } + +#[entry] +fn test10(i0 : u64, i1 : u64, input : i32) -> i32 { + let arr0 : i32[4, 5]; + let arr1 : i32[4, 5]; + let arr2 : i32[4, 5]; + @loop2 for k = 0 to 5 { + @loop3 for j = 0 to 4 { + arr0[j, k] += input; + arr1[j, k] += 2; + arr2[j, k] += input * 2; + } + } + + return arr0[i0, i1] + arr1[i0, i1] + arr2[i0, i1]; +} diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index e66309b22b0650feaab315829dd412f6275e9a99..277484f6bc094dbee05155fb7fee63fb3dcf3319 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -74,6 +74,11 @@ fn main() { 5 + 6 + 8 + 9, ]; assert(&correct, output); + + let mut r = runner!(test10); + let output = r.run(1, 2, 5).await; + let correct = 17i32; + assert_eq!(correct, output); }); } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 9d020c64ccef3b9c0a79694876c5b0ace606f938..39fb3469021939d98921444bff9f42c02a4d4fc2 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -126,7 +126,7 @@ impl FromStr for Appliable { "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) } - "fork-fission-bufferize" | "fork-fission" => { + "fork-fission-bufferize" => { Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)) } "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), @@ -134,6 +134,7 @@ impl FromStr for Appliable { "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), + "fork-fission" | "fission" => Ok(Appliable::Pass(ir::Pass::ForkFission)), "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index ab1495b816c99452560d03c0addf77a5aec18974..d381e861c323f5d1326e26cccb2cf916a0bf735f 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -18,6 +18,7 @@ pub enum Pass { ForkDimMerge, ForkExtend, ForkFissionBufferize, + ForkFission, ForkFusion, ForkGuardElim, ForkInterchange, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 456df2eda49b93a6c80327a090b6f6606ae711bb..9281925a96dda6fbad44c746bf824206f13c34d7 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2740,7 +2740,104 @@ fn run_pass( } pm.delete_gravestones(); pm.clear_analyses(); - } + }, + Pass::ForkFission => { + assert!(args.len() == 0); + + pm.make_fork_join_maps(); + pm.make_loops(); + pm.make_reduce_cycles(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + let mut created_fork_joins = vec![vec![]; pm.functions.len()]; + + for ((((func, fork_join_map), loop_tree), reduce_cycles), nodes_in_fork_joins, + ) in build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + .zip(reduce_cycles.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + + let new_forks = fork_fission( + &mut func, + nodes_in_fork_joins, + reduce_cycles, + loop_tree, + fork_join_map, + ); + + let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; + + for f in new_forks { + created_fork_joins.push(f); + }; + + changed |= func.modified(); + } + + pm.clear_analyses(); + pm.make_nodes_in_fork_joins(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + let mut new_fork_joins = HashMap::new(); + + for (mut func, created_fork_joins) in + build_editors(pm).into_iter().zip(created_fork_joins) + { + // For every function, create a label for every level of fork- + // joins resulting from the split. + let name = func.func().name.clone(); + let func_id = func.func_id(); + let labels = create_labels_for_node_sets( + &mut func, + created_fork_joins.into_iter().map(|fork| { + nodes_in_fork_joins[func_id.idx()][&fork] + .iter() + .map(|id| *id) + }), + ); + + // Assemble those labels into a record for this function. The + // format of the records is <function>.<f>, where N is the + // level of the split fork-joins being referred to. + let mut func_record = HashMap::new(); + for (idx, label) in labels { + let fmt = if idx % 2 == 0 { "fj_top" } else { "fj_bottom" }; + func_record.insert( + fmt.to_string(), + Value::Label { + labels: vec![LabelInfo { + func: func_id, + label: label, + }], + }, + ); + } + + // Try to avoid creating unnecessary record entries. + if !func_record.is_empty() { + new_fork_joins.entry(name).insert_entry(Value::Record { + fields: func_record, + }); + } + } + + pm.delete_gravestones(); + pm.clear_analyses(); + + result = Value::Record { + fields: new_fork_joins, + }; + + }, Pass::ForkFissionBufferize => { assert!(args.len() == 1 || args.len() == 2); let Some(Value::Label { @@ -2777,6 +2874,14 @@ fn run_pass( let fork_label = fork_labels[0].label; + // assert only one function is in the selection. + let num_functions = build_selection(pm, selection.clone(), false) + .iter() + .filter(|func| func.is_some()) + .count(); + + assert!(num_functions <= 1); + for ( ((((func, fork_join_map), loop_tree), typing), reduce_cycles), nodes_in_fork_joins,