From a1baea161e0114374de969618cec803a56e42ea0 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 12 Feb 2025 17:08:36 -0600 Subject: [PATCH 1/4] Loop unroll skeleton --- hercules_opt/src/lib.rs | 2 ++ hercules_opt/src/unroll.rs | 18 ++++++++++++++++++ juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 18 ++++++++++++++++++ 5 files changed, 40 insertions(+) create mode 100644 hercules_opt/src/unroll.rs diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 7187508a..a810dfbf 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -23,6 +23,7 @@ pub mod simplify_cfg; pub mod slf; pub mod sroa; pub mod unforkify; +pub mod unroll; pub mod utils; pub use crate::ccp::*; @@ -48,4 +49,5 @@ pub use crate::simplify_cfg::*; pub use crate::slf::*; pub use crate::sroa::*; pub use crate::unforkify::*; +pub use crate::unroll::*; pub use crate::utils::*; diff --git a/hercules_opt/src/unroll.rs b/hercules_opt/src/unroll.rs new file mode 100644 index 00000000..f3c795ca --- /dev/null +++ b/hercules_opt/src/unroll.rs @@ -0,0 +1,18 @@ +use bitvec::prelude::*; + +use hercules_ir::*; + +use crate::*; + +/* + * Run loop unrolling on all loops that are mutable in an editor. + */ +pub fn loop_unroll_all_loops(editor: &mut FunctionEditor, loops: &LoopTree) { + for (header, contents) in loops.bottom_up_loops() { + if editor.is_mutable(header) { + loop_unroll(editor, header, contents); + } + } +} + +pub fn loop_unroll(editor: &mut FunctionEditor, header: NodeID, contents: &BitVec<u8, Lsb0>) {} diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 1aaa10cd..9d5a86cc 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -114,6 +114,7 @@ impl FromStr for Appliable { "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), + "loop-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::LoopUnroll)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), "predication" => Ok(Appliable::Pass(ir::Pass::Predication)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 0ecac39a..1bb6cf13 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -23,6 +23,7 @@ pub enum Pass { Inline, InterproceduralSROA, LiftDCMath, + LoopUnroll, Outline, PhiElim, Predication, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 9c7391ac..5c6aec5e 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1665,6 +1665,24 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::LoopUnroll => { + assert_eq!(args.len(), 0); + + pm.make_loops(); + let loops = pm.loops.take().unwrap(); + for (func, loops) in build_selection(pm, selection, false) + .into_iter() + .zip(loops.iter()) + { + let Some(mut func) = func else { + continue; + }; + loop_unroll_all_loops(&mut func, loops); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::Forkify => { assert!(args.is_empty()); loop { -- GitLab From a8a2fc3b61e010278a1ca622cd6aaab6f77ff140 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 12 Feb 2025 17:13:04 -0600 Subject: [PATCH 2/4] Just do fork unrolling --- hercules_opt/src/fork_transforms.rs | 15 +++++++++++++++ hercules_opt/src/lib.rs | 2 -- hercules_opt/src/unroll.rs | 18 ------------------ juno_scheduler/src/compile.rs | 2 +- juno_scheduler/src/ir.rs | 2 +- juno_scheduler/src/pm.rs | 12 ++++++------ 6 files changed, 23 insertions(+), 28 deletions(-) delete mode 100644 hercules_opt/src/unroll.rs diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index fd6747d7..539b7fd1 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1164,3 +1164,18 @@ fn fork_interchange( edit.delete_node(fork) }); } + +/* + * Run fork unrolling on all fork-joins that are mutable in an editor. + */ +pub fn fork_unroll_all_forks(editor: &mut FunctionEditor, fork_joins: &HashMap<NodeID, NodeID>) { + for (fork, join) in fork_joins { + if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join) { + break; + } + } +} + +pub fn fork_unroll(editor: &mut FunctionEditor, fork: NodeID, join: NodeID) -> bool { + false +} diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index a810dfbf..7187508a 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -23,7 +23,6 @@ pub mod simplify_cfg; pub mod slf; pub mod sroa; pub mod unforkify; -pub mod unroll; pub mod utils; pub use crate::ccp::*; @@ -49,5 +48,4 @@ pub use crate::simplify_cfg::*; pub use crate::slf::*; pub use crate::sroa::*; pub use crate::unforkify::*; -pub use crate::unroll::*; pub use crate::utils::*; diff --git a/hercules_opt/src/unroll.rs b/hercules_opt/src/unroll.rs deleted file mode 100644 index f3c795ca..00000000 --- a/hercules_opt/src/unroll.rs +++ /dev/null @@ -1,18 +0,0 @@ -use bitvec::prelude::*; - -use hercules_ir::*; - -use crate::*; - -/* - * Run loop unrolling on all loops that are mutable in an editor. - */ -pub fn loop_unroll_all_loops(editor: &mut FunctionEditor, loops: &LoopTree) { - for (header, contents) in loops.bottom_up_loops() { - if editor.is_mutable(header) { - loop_unroll(editor, header, contents); - } - } -} - -pub fn loop_unroll(editor: &mut FunctionEditor, header: NodeID, contents: &BitVec<u8, Lsb0>) {} diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 9d5a86cc..6b40001c 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -113,8 +113,8 @@ impl FromStr for Appliable { "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), + "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), - "loop-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::LoopUnroll)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), "predication" => Ok(Appliable::Pass(ir::Pass::Predication)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 1bb6cf13..840f25a6 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -16,6 +16,7 @@ pub enum Pass { ForkGuardElim, ForkInterchange, ForkSplit, + ForkUnroll, Forkify, GCM, GVN, @@ -23,7 +24,6 @@ pub enum Pass { Inline, InterproceduralSROA, LiftDCMath, - LoopUnroll, Outline, PhiElim, Predication, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 5c6aec5e..951ba51d 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1665,19 +1665,19 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } - Pass::LoopUnroll => { + Pass::ForkUnroll => { assert_eq!(args.len(), 0); - pm.make_loops(); - let loops = pm.loops.take().unwrap(); - for (func, loops) in build_selection(pm, selection, false) + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + for (func, fork_join_map) in build_selection(pm, selection, false) .into_iter() - .zip(loops.iter()) + .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; }; - loop_unroll_all_loops(&mut func, loops); + fork_unroll_all_forks(&mut func, fork_join_map); changed |= func.modified(); } pm.delete_gravestones(); -- GitLab From 9bc5101eeac8a2ac2393cf0aedd7ff5aa9bcc74f Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 12 Feb 2025 18:27:36 -0600 Subject: [PATCH 3/4] Get unrollable fork-joins --- hercules_opt/src/fork_transforms.rs | 38 +++++++++++++++++++++--- juno_samples/fork_join_tests/src/cpu.sch | 6 ++-- juno_scheduler/src/pm.rs | 12 +++++--- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 539b7fd1..94898b0d 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1168,14 +1168,44 @@ fn fork_interchange( /* * Run fork unrolling on all fork-joins that are mutable in an editor. */ -pub fn fork_unroll_all_forks(editor: &mut FunctionEditor, fork_joins: &HashMap<NodeID, NodeID>) { +pub fn fork_unroll_all_forks( + editor: &mut FunctionEditor, + fork_joins: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) { for (fork, join) in fork_joins { - if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join) { + if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) { break; } } } -pub fn fork_unroll(editor: &mut FunctionEditor, fork: NodeID, join: NodeID) -> bool { - false +pub fn fork_unroll( + editor: &mut FunctionEditor, + fork: NodeID, + join: NodeID, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) -> bool { + // We can only unroll forks with a compile time known factor list. + let nodes = &editor.func().nodes; + let Node::Fork { + control, + ref factors, + } = nodes[fork.idx()] + else { + panic!() + }; + let mut cons_factors = vec![]; + for factor in factors { + let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(*factor) else { + return false; + }; + cons_factors.push(cons); + } + println!("{}: {:?}", editor.func().name, cons_factors); + + editor.edit(|mut edit| { + (); + Ok(edit) + }) } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index fe0a8802..2c832d66 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -39,12 +39,14 @@ dce(*); fixpoint panic after 20 { infer-schedules(*); } +unroll(auto.test1); +xdot[true](*); -fork-split(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5); +fork-split(auto.test2, auto.test3, auto.test4, auto.test5); gvn(*); phi-elim(*); dce(*); -unforkify(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5); +unforkify(auto.test2, auto.test3, auto.test4, auto.test5); ccp(*); gvn(*); phi-elim(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 951ba51d..f59834ed 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1669,15 +1669,19 @@ fn run_pass( assert_eq!(args.len(), 0); pm.make_fork_join_maps(); + pm.make_nodes_in_fork_joins(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection, false) - .into_iter() - .zip(fork_join_maps.iter()) + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + for ((func, fork_join_map), nodes_in_fork_joins) in + build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(nodes_in_fork_joins.iter()) { let Some(mut func) = func else { continue; }; - fork_unroll_all_forks(&mut func, fork_join_map); + fork_unroll_all_forks(&mut func, fork_join_map, nodes_in_fork_joins); changed |= func.modified(); } pm.delete_gravestones(); -- GitLab From 23103bc2d1e31fae8880cd9063d4948a1be89a92 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 12 Feb 2025 22:02:51 -0600 Subject: [PATCH 4/4] holy shit that just worked --- hercules_opt/src/fork_transforms.rs | 81 +++++++++++++++++++++--- juno_samples/fork_join_tests/src/cpu.sch | 6 +- 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 94898b0d..2f7a91fa 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1186,7 +1186,8 @@ pub fn fork_unroll( join: NodeID, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> bool { - // We can only unroll forks with a compile time known factor list. + // We can only unroll fork-joins with a compile time known factor list. For + // simplicity, just unroll fork-joins that have a single dimension. let nodes = &editor.func().nodes; let Node::Fork { control, @@ -1195,17 +1196,79 @@ pub fn fork_unroll( else { panic!() }; - let mut cons_factors = vec![]; - for factor in factors { - let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(*factor) else { - return false; - }; - cons_factors.push(cons); + if factors.len() != 1 || editor.get_users(fork).count() != 2 { + return false; } - println!("{}: {:?}", editor.func().name, cons_factors); + let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else { + return false; + }; + let tid = editor + .get_users(fork) + .filter(|id| nodes[id.idx()].is_thread_id()) + .next() + .unwrap(); + let inits: HashMap<NodeID, NodeID> = editor + .get_users(join) + .filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init))) + .collect(); editor.edit(|mut edit| { - (); + // Create a copy of the nodes in the fork join per unrolled iteration, + // excluding the fork itself, the join itself, the thread IDs of the fork, + // and the reduces on the join. Keep a running tally of the top control + // node and the current reduction value. + let mut top_control = control; + let mut current_reduces = inits; + for iter in 0..cons { + let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64)); + let iter_tid = edit.add_node(Node::Constant { id: iter_cons }); + + // First, add a copy of each node in the fork join unmodified. + // Record the mapping from old ID to new ID. + let mut added_ids = HashSet::new(); + let mut old_to_new_ids = HashMap::new(); + let mut new_control = None; + let mut new_reduces = HashMap::new(); + for node in nodes_in_fork_joins[&fork].iter() { + if *node == fork { + old_to_new_ids.insert(*node, top_control); + } else if *node == join { + new_control = Some(edit.get_node(*node).try_join().unwrap()); + } else if *node == tid { + old_to_new_ids.insert(*node, iter_tid); + } else if let Some(current) = current_reduces.get(node) { + old_to_new_ids.insert(*node, *current); + new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2); + } else { + let new_node = edit.add_node(edit.get_node(*node).clone()); + old_to_new_ids.insert(*node, new_node); + added_ids.insert(new_node); + } + } + + // Second, replace all the uses in the just added nodes. + if let Some(new_control) = new_control { + top_control = old_to_new_ids[&new_control]; + } + for (reduce, reduct) in new_reduces { + current_reduces.insert(reduce, old_to_new_ids[&reduct]); + } + for (old, new) in old_to_new_ids { + edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?; + } + } + + // Hook up the control and reduce outputs to the rest of the function. + edit = edit.replace_all_uses(join, top_control)?; + for (reduce, reduct) in current_reduces { + edit = edit.replace_all_uses(reduce, reduct)?; + } + + // Delete the old fork-join. + for node in nodes_in_fork_joins[&fork].iter() { + edit = edit.delete_node(*node)?; + } + Ok(edit) }) } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 2c832d66..9e87d26a 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -39,8 +39,10 @@ dce(*); fixpoint panic after 20 { infer-schedules(*); } -unroll(auto.test1); -xdot[true](*); +fork-split(auto.test1); +fixpoint panic after 20 { + unroll(auto.test1); +} fork-split(auto.test2, auto.test3, auto.test4, auto.test5); gvn(*); -- GitLab