diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index e832e559e1bc5a5c796f924fed6ea5407c4d26a1..7f6dd1bcb7ae6c90148f109efefda9a8b1962c83 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -250,11 +250,12 @@ pub fn ff_bufferize_any_fork<'a, 'b>( where 'a: 'b, { - let forks: Vec<_> = loop_tree + let mut forks: Vec<_> = loop_tree .bottom_up_loops() .into_iter() .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) .collect(); + forks.reverse(); for l in forks { let fork_info = Loop { @@ -1506,6 +1507,7 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else { continue; }; + let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect(); match nodes[reduct.idx()] { Node::Binary { @@ -1519,12 +1521,15 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { let zero = edit.add_zero_constant(typing[init.idx()]); let zero = edit.add_node(Node::Constant { id: zero }); edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?; - let final_add = edit.add_node(Node::Binary { + let final_op = edit.add_node(Node::Binary { op, left: init, right: id, }); - edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add) + for u in out_uses { + edit.sub_edit(u, final_op); + } + edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) }); } Node::Binary { @@ -1538,12 +1543,15 @@ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { let one = edit.add_one_constant(typing[init.idx()]); let one = edit.add_node(Node::Constant { id: one }); edit = edit.replace_all_uses_where(init, one, |u| *u == id)?; - let final_add = edit.add_node(Node::Binary { + let final_op = edit.add_node(Node::Binary { op, left: init, right: id, }); - edit.replace_all_uses_where(id, final_add, |u| *u != reduct && *u != final_add) + for u in out_uses { + edit.sub_edit(u, final_op); + } + edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) }); } _ => {} diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch index 4e40e351bd86d15cb8819f97067201690a43720c..734054abf35115263103457e3a24f1a2d1598eda 100644 --- a/juno_samples/dot/src/cpu.sch +++ b/juno_samples/dot/src/cpu.sch @@ -1,24 +1,33 @@ -phi-elim(*); +phi-elim(dot); +ip-sroa(*); +sroa(dot); +dce(dot); -forkify(*); -fork-guard-elim(*); -dce(*); +forkify(dot); +fork-guard-elim(dot); +dce(dot); -fork-tile[8, 0, false, true](*); -fork-tile[32, 0, false, false](*); -fork-split(*); +fork-tile[8, 0, false, true](dot); +fork-tile[32, 0, false, false](dot); +let split_out = fork-split(dot); infer-schedules(*); clean-monoid-reduces(*); infer-schedules(*); clean-monoid-reduces(*); -let out = auto-outline(*); -cpu(out.dot); +let out = outline(split_out.dot.fj1); ip-sroa(*); -sroa(*); -dce(*); +sroa(dot); +gvn(dot); +dce(dot); -xdot[true](*); +let fission_out = fork-fission[out@loop](dot); +simplify-cfg(dot); +dce(dot); +unforkify(fission_out.dot.fj_loop_bottom); +ccp(dot); +gvn(dot); +dce(dot); -unforkify(*); +unforkify(out); gcm(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 2e930639bc2c488291fc8a61ffa8c868d851789a..88816562924c5bb90b991224a4b943f80c3721af 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -125,7 +125,9 @@ impl FromStr for Appliable { "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) } - "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)), + "fork-fission-bufferize" | "fork-fission" => { + Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)) + } "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),