From 9bc5101eeac8a2ac2393cf0aedd7ff5aa9bcc74f Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 12 Feb 2025 18:27:36 -0600 Subject: [PATCH] Get unrollable fork-joins --- hercules_opt/src/fork_transforms.rs | 38 +++++++++++++++++++++--- juno_samples/fork_join_tests/src/cpu.sch | 6 ++-- juno_scheduler/src/pm.rs | 12 +++++--- 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 539b7fd1..94898b0d 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1168,14 +1168,44 @@ fn fork_interchange( /* * Run fork unrolling on all fork-joins that are mutable in an editor. */ -pub fn fork_unroll_all_forks(editor: &mut FunctionEditor, fork_joins: &HashMap<NodeID, NodeID>) { +pub fn fork_unroll_all_forks( + editor: &mut FunctionEditor, + fork_joins: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) { for (fork, join) in fork_joins { - if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join) { + if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) { break; } } } -pub fn fork_unroll(editor: &mut FunctionEditor, fork: NodeID, join: NodeID) -> bool { - false +pub fn fork_unroll( + editor: &mut FunctionEditor, + fork: NodeID, + join: NodeID, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) -> bool { + // We can only unroll forks with a compile time known factor list. + let nodes = &editor.func().nodes; + let Node::Fork { + control, + ref factors, + } = nodes[fork.idx()] + else { + panic!() + }; + let mut cons_factors = vec![]; + for factor in factors { + let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(*factor) else { + return false; + }; + cons_factors.push(cons); + } + println!("{}: {:?}", editor.func().name, cons_factors); + + editor.edit(|mut edit| { + (); + Ok(edit) + }) } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index fe0a8802..2c832d66 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -39,12 +39,14 @@ dce(*); fixpoint panic after 20 { infer-schedules(*); } +unroll(auto.test1); +xdot[true](*); -fork-split(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5); +fork-split(auto.test2, auto.test3, auto.test4, auto.test5); gvn(*); phi-elim(*); dce(*); -unforkify(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5); +unforkify(auto.test2, auto.test3, auto.test4, auto.test5); ccp(*); gvn(*); phi-elim(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 951ba51d..f59834ed 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1669,15 +1669,19 @@ fn run_pass( assert_eq!(args.len(), 0); pm.make_fork_join_maps(); + pm.make_nodes_in_fork_joins(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection, false) - .into_iter() - .zip(fork_join_maps.iter()) + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + for ((func, fork_join_map), nodes_in_fork_joins) in + build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(nodes_in_fork_joins.iter()) { let Some(mut func) = func else { continue; }; - fork_unroll_all_forks(&mut func, fork_join_map); + fork_unroll_all_forks(&mut func, fork_join_map, nodes_in_fork_joins); changed |= func.modified(); } pm.delete_gravestones(); -- GitLab