diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 7f6dd1bcb7ae6c90148f109efefda9a8b1962c83..283734a009ab3f619910b02465bf9d4d05856bef 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -1520,6 +1520,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { editor.edit(|mut edit| { let zero = edit.add_zero_constant(typing[init.idx()]); let zero = edit.add_node(Node::Constant { id: zero }); + edit.sub_edit(id, zero); edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?; let final_op = edit.add_node(Node::Binary { op, @@ -1542,6 +1543,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { editor.edit(|mut edit| { let one = edit.add_one_constant(typing[init.idx()]); let one = edit.add_node(Node::Constant { id: one }); + edit.sub_edit(id, one); edit = edit.replace_all_uses_where(init, one, |u| *u == id)?; let final_op = edit.add_node(Node::Binary { op, diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index 7451b0adf7c4a1ac197760c4ff09f06984520fb7..b44ed8df82b494b7da0aff006587246c501f8e5d 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -117,7 +117,31 @@ pub fn unforkify_all( } } -pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_tree: &LoopTree) { +pub fn unforkify_one( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + loop_tree: &LoopTree, +) { + for l in loop_tree.bottom_up_loops().into_iter().rev() { + if !editor.node(l.0).is_fork() { + continue; + } + + let fork = l.0; + let join = fork_join_map[&fork]; + + if unforkify(editor, fork, join, loop_tree) { + break; + } + } +} + +pub fn unforkify( + editor: &mut FunctionEditor, + fork: NodeID, + join: NodeID, + loop_tree: &LoopTree, +) -> bool { let mut zero_cons_id = ConstantID::new(0); let mut one_cons_id = ConstantID::new(0); assert!(editor.edit(|mut edit| { @@ -138,7 +162,7 @@ pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_t if factors.len() > 1 { // For now, don't convert multi-dimensional fork-joins. Rely on pass // that splits fork-joins. - return; + return false; } let join_control = nodes[join.idx()].try_join().unwrap(); let tids: Vec<_> = editor @@ -296,5 +320,5 @@ pub fn unforkify(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, loop_t } Ok(edit) - }); + }) } diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch index 734054abf35115263103457e3a24f1a2d1598eda..aa87972e36d43a58e9a831248fbdcd92acb9b707 100644 --- a/juno_samples/dot/src/cpu.sch +++ b/juno_samples/dot/src/cpu.sch @@ -17,17 +17,22 @@ clean-monoid-reduces(*); let out = outline(split_out.dot.fj1); ip-sroa(*); -sroa(dot); -gvn(dot); -dce(dot); +sroa(*); +gvn(*); +dce(*); let fission_out = fork-fission[out@loop](dot); simplify-cfg(dot); dce(dot); unforkify(fission_out.dot.fj_loop_bottom); ccp(dot); +simplify-cfg(dot); gvn(dot); dce(dot); -unforkify(out); +unforkify-one(out); +ccp(out); +simplify-cfg(out); +gvn(out); +dce(out); gcm(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 88816562924c5bb90b991224a4b943f80c3721af..fc2a729ec40db410f2beb7c148b50534b4d25312 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -144,6 +144,7 @@ impl FromStr for Appliable { "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), "unforkify" => Ok(Appliable::Pass(ir::Pass::Unforkify)), + "unforkify-one" => Ok(Appliable::Pass(ir::Pass::UnforkifyOne)), "fork-coalesce" => Ok(Appliable::Pass(ir::Pass::ForkCoalesce)), "verify" => Ok(Appliable::Pass(ir::Pass::Verify)), "xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 11cf6b1359abd34b9689d98b7bfc61c81e51baae..bf3fe03739c9159202d60e8a5908c8bb0da4cb28 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -38,6 +38,7 @@ pub enum Pass { Serialize, SimplifyCFG, Unforkify, + UnforkifyOne, Verify, WritePredication, Xdot, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 34c2474bfb35cc02f0870b78c684f4a762d0dee6..8db79b46199c4d9ab54e139590c1cc34a539f96e 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2364,6 +2364,28 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::UnforkifyOne => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_loops(); + + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + + for ((func, fork_join_map), loop_tree) in build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + { + let Some(mut func) = func else { + continue; + }; + unforkify_one(&mut func, fork_join_map, loop_tree); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkChunk => { assert_eq!(args.len(), 4); let Some(Value::Integer { val: tile_size }) = args.get(0) else {