diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index ff0f0283767996914e8f9b2274ed9a6d538b1812..cb0e7de48cb0efe01a586f9ae633f66cb61703a7 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -578,7 +578,7 @@ pub fn fork_coalesce( // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early. // something like: `fork_joins.postorder_iter().windows(2)` is ideal here. for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) { - if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) { + if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() { return true; } } @@ -587,13 +587,15 @@ pub fn fork_coalesce( /** Opposite of fork split, takes two fork-joins with no control between them, and merges them into a single fork-join. + Returns None if the forks could not be merged and the NodeIDs of the + resulting fork and join if it succeeds in merging them. */ pub fn fork_coalesce_helper( editor: &mut FunctionEditor, outer_fork: NodeID, inner_fork: NodeID, fork_join_map: &HashMap<NodeID, NodeID>, -) -> bool { +) -> Option<(NodeID, NodeID)> { // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork. let outer_join = fork_join_map[&outer_fork]; @@ -621,19 +623,19 @@ pub fn fork_coalesce_helper( reduct: _, } = inner_reduce_node else { - return false; + return None; }; // FIXME: check this condition better (i.e reduce might not be attached to join) if *inner_control != inner_join { - return false; + return None; }; if *inner_init != outer_reduce { - return false; + return None; }; if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) { - return false; + return None; } else { pairs.insert(outer_reduce, inner_reduce); } @@ -645,11 +647,11 @@ pub fn fork_coalesce_helper( .filter(|node| editor.func().nodes[node.idx()].is_control()) .next() else { - return false; + return None; }; if user != inner_fork { - return false; + return None; } let Some(user) = editor @@ -657,11 +659,11 @@ pub fn fork_coalesce_helper( .filter(|node| editor.func().nodes[node.idx()].is_control()) .next() else { - return false; + return None; }; if user != outer_join { - return false; + return None; } // Checklist: @@ -709,10 +711,10 @@ pub fn fork_coalesce_helper( let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()] .try_reduce() .unwrap(); - editor.edit(|mut edit| { + let success = editor.edit(|mut edit| { // Set inner init to outer init. edit = - edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; + edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; @@ -720,12 +722,15 @@ pub fn fork_coalesce_helper( }); } + let mut new_fork = NodeID::new(0); + let new_join = inner_join; // We reuse the inner join as the join of the new fork + editor.edit(|mut edit| { - let new_fork = Node::Fork { + let new_fork_node = Node::Fork { control: outer_pred, factors: new_factors.into(), }; - let new_fork = edit.add_node(new_fork); + new_fork = edit.add_node(new_fork_node); edit = edit.replace_all_uses(inner_fork, new_fork)?; edit = edit.replace_all_uses(outer_fork, new_fork)?; @@ -737,7 +742,7 @@ pub fn fork_coalesce_helper( Ok(edit) }); - true + Some((new_fork, new_join)) } pub fn split_any_fork( @@ -760,7 +765,7 @@ pub fn split_any_fork( * Useful for code generation. A single iteration of `fork_split` only splits * at most one fork-join, it must be called repeatedly to split all fork-joins. */ -pub(crate) fn split_fork( +pub fn split_fork( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, @@ -1215,13 +1220,13 @@ pub fn fork_interchange_all_forks( } } -fn fork_interchange( +pub fn fork_interchange( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, first_dim: usize, second_dim: usize, -) { +) -> Option<NodeID> { // Check that every reduce on the join is parallel or associative. let nodes = &editor.func().nodes; let schedules = &editor.func().schedules; @@ -1234,7 +1239,7 @@ fn fork_interchange( }) { // If not, we can't necessarily do interchange. - return; + return None; } let Node::Fork { @@ -1276,6 +1281,7 @@ fn fork_interchange( let mut factors = factors.clone(); factors.swap(first_dim, second_dim); let new_fork = Node::Fork { control, factors }; + let mut new_fork_id = None; editor.edit(|mut edit| { for (old_id, new_tid) in fix_tids { let new_id = edit.add_node(new_tid); @@ -1283,9 +1289,12 @@ fn fork_interchange( edit = edit.delete_node(old_id)?; } let new_fork = edit.add_node(new_fork); + new_fork_id = Some(new_fork); edit = edit.replace_all_uses(fork, new_fork)?; edit.delete_node(fork) }); + + new_fork_id } /*