diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index fd6747d7a15628034bc98f7ef3ebc00631f8abae..2f7a91faa990924358dc11778e7431a81577e702 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1164,3 +1164,111 @@ fn fork_interchange( edit.delete_node(fork) }); } + +/* + * Run fork unrolling on all fork-joins that are mutable in an editor. + */ +pub fn fork_unroll_all_forks( + editor: &mut FunctionEditor, + fork_joins: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) { + for (fork, join) in fork_joins { + if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) { + break; + } + } +} + +pub fn fork_unroll( + editor: &mut FunctionEditor, + fork: NodeID, + join: NodeID, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) -> bool { + // We can only unroll fork-joins with a compile time known factor list. For + // simplicity, just unroll fork-joins that have a single dimension. + let nodes = &editor.func().nodes; + let Node::Fork { + control, + ref factors, + } = nodes[fork.idx()] + else { + panic!() + }; + if factors.len() != 1 || editor.get_users(fork).count() != 2 { + return false; + } + let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else { + return false; + }; + let tid = editor + .get_users(fork) + .filter(|id| nodes[id.idx()].is_thread_id()) + .next() + .unwrap(); + let inits: HashMap<NodeID, NodeID> = editor + .get_users(join) + .filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init))) + .collect(); + + editor.edit(|mut edit| { + // Create a copy of the nodes in the fork join per unrolled iteration, + // excluding the fork itself, the join itself, the thread IDs of the fork, + // and the reduces on the join. Keep a running tally of the top control + // node and the current reduction value. + let mut top_control = control; + let mut current_reduces = inits; + for iter in 0..cons { + let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64)); + let iter_tid = edit.add_node(Node::Constant { id: iter_cons }); + + // First, add a copy of each node in the fork join unmodified. + // Record the mapping from old ID to new ID. + let mut added_ids = HashSet::new(); + let mut old_to_new_ids = HashMap::new(); + let mut new_control = None; + let mut new_reduces = HashMap::new(); + for node in nodes_in_fork_joins[&fork].iter() { + if *node == fork { + old_to_new_ids.insert(*node, top_control); + } else if *node == join { + new_control = Some(edit.get_node(*node).try_join().unwrap()); + } else if *node == tid { + old_to_new_ids.insert(*node, iter_tid); + } else if let Some(current) = current_reduces.get(node) { + old_to_new_ids.insert(*node, *current); + new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2); + } else { + let new_node = edit.add_node(edit.get_node(*node).clone()); + old_to_new_ids.insert(*node, new_node); + added_ids.insert(new_node); + } + } + + // Second, replace all the uses in the just added nodes. + if let Some(new_control) = new_control { + top_control = old_to_new_ids[&new_control]; + } + for (reduce, reduct) in new_reduces { + current_reduces.insert(reduce, old_to_new_ids[&reduct]); + } + for (old, new) in old_to_new_ids { + edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?; + } + } + + // Hook up the control and reduce outputs to the rest of the function. + edit = edit.replace_all_uses(join, top_control)?; + for (reduce, reduct) in current_reduces { + edit = edit.replace_all_uses(reduce, reduct)?; + } + + // Delete the old fork-join. + for node in nodes_in_fork_joins[&fork].iter() { + edit = edit.delete_node(*node)?; + } + + Ok(edit) + }) +} diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index fe0a8802e8b68f6e21cc8fe3586a03f0ce658fa5..9e87d26a899772c48a5cc534b6da15d153add675 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -39,12 +39,16 @@ dce(*); fixpoint panic after 20 { infer-schedules(*); } +fork-split(auto.test1); +fixpoint panic after 20 { + unroll(auto.test1); +} -fork-split(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5); +fork-split(auto.test2, auto.test3, auto.test4, auto.test5); gvn(*); phi-elim(*); dce(*); -unforkify(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5); +unforkify(auto.test2, auto.test3, auto.test4, auto.test5); ccp(*); gvn(*); phi-elim(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 1aaa10cddc350b703c22f734dbdf1f75fb3ef46c..6b40001c2b3324176913b1e843934d422ed2e711 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -113,6 +113,7 @@ impl FromStr for Appliable { "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), + "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 0ecac39a19a5e364c0ab1185b162c0470f5b6a5b..840f25a6e9dc986ab064adecbeba822ca47016d8 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -16,6 +16,7 @@ pub enum Pass { ForkGuardElim, ForkInterchange, ForkSplit, + ForkUnroll, Forkify, GCM, GVN, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 9c7391acd5686d95461ed562f61b3dbc9774651a..f59834eddad670cceeb4954812fff5926f39f862 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1665,6 +1665,28 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::ForkUnroll => { + assert_eq!(args.len(), 0); + + pm.make_fork_join_maps(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + for ((func, fork_join_map), nodes_in_fork_joins) in + build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + fork_unroll_all_forks(&mut func, fork_join_map, nodes_in_fork_joins); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::Forkify => { assert!(args.is_empty()); loop {