Skip to content
Snippets Groups Projects
Commit 4754a909 authored by Aaron Councilman's avatar Aaron Councilman
Browse files

Return new fork ID from fork coalesce and interchange

parent 3c8eaae2
No related branches found
No related tags found
1 merge request!205Fork reshape
......@@ -578,7 +578,7 @@ pub fn fork_coalesce(
// FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
// something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) {
if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() {
return true;
}
}
......@@ -587,13 +587,15 @@ pub fn fork_coalesce(
/** Opposite of fork split, takes two fork-joins
with no control between them, and merges them into a single fork-join.
Returns None if the forks could not be merged and the NodeIDs of the
resulting fork and join if it succeeds in merging them.
*/
pub fn fork_coalesce_helper(
editor: &mut FunctionEditor,
outer_fork: NodeID,
inner_fork: NodeID,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> bool {
) -> Option<(NodeID, NodeID)> {
// Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.
let outer_join = fork_join_map[&outer_fork];
......@@ -621,19 +623,19 @@ pub fn fork_coalesce_helper(
reduct: _,
} = inner_reduce_node
else {
return false;
return None;
};
// FIXME: check this condition better (i.e reduce might not be attached to join)
if *inner_control != inner_join {
return false;
return None;
};
if *inner_init != outer_reduce {
return false;
return None;
};
if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
return false;
return None;
} else {
pairs.insert(outer_reduce, inner_reduce);
}
......@@ -645,11 +647,11 @@ pub fn fork_coalesce_helper(
.filter(|node| editor.func().nodes[node.idx()].is_control())
.next()
else {
return false;
return None;
};
if user != inner_fork {
return false;
return None;
}
let Some(user) = editor
......@@ -657,11 +659,11 @@ pub fn fork_coalesce_helper(
.filter(|node| editor.func().nodes[node.idx()].is_control())
.next()
else {
return false;
return None;
};
if user != outer_join {
return false;
return None;
}
// Checklist:
......@@ -709,10 +711,10 @@ pub fn fork_coalesce_helper(
let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
.try_reduce()
.unwrap();
editor.edit(|mut edit| {
let success = editor.edit(|mut edit| {
// Set inner init to outer init.
edit =
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?;
edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
edit = edit.delete_node(outer_reduce)?;
......@@ -720,12 +722,15 @@ pub fn fork_coalesce_helper(
});
}
let mut new_fork = NodeID::new(0);
let new_join = inner_join; // We reuse the inner join as the join of the new fork
editor.edit(|mut edit| {
let new_fork = Node::Fork {
let new_fork_node = Node::Fork {
control: outer_pred,
factors: new_factors.into(),
};
let new_fork = edit.add_node(new_fork);
new_fork = edit.add_node(new_fork_node);
edit = edit.replace_all_uses(inner_fork, new_fork)?;
edit = edit.replace_all_uses(outer_fork, new_fork)?;
......@@ -737,7 +742,7 @@ pub fn fork_coalesce_helper(
Ok(edit)
});
true
Some((new_fork, new_join))
}
pub fn split_any_fork(
......@@ -760,7 +765,7 @@ pub fn split_any_fork(
* Useful for code generation. A single iteration of `fork_split` only splits
* at most one fork-join, it must be called repeatedly to split all fork-joins.
*/
pub(crate) fn split_fork(
pub fn split_fork(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
......@@ -1215,13 +1220,13 @@ pub fn fork_interchange_all_forks(
}
}
fn fork_interchange(
pub fn fork_interchange(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
first_dim: usize,
second_dim: usize,
) {
) -> Option<NodeID> {
// Check that every reduce on the join is parallel or associative.
let nodes = &editor.func().nodes;
let schedules = &editor.func().schedules;
......@@ -1234,7 +1239,7 @@ fn fork_interchange(
})
{
// If not, we can't necessarily do interchange.
return;
return None;
}
let Node::Fork {
......@@ -1276,6 +1281,7 @@ fn fork_interchange(
let mut factors = factors.clone();
factors.swap(first_dim, second_dim);
let new_fork = Node::Fork { control, factors };
let mut new_fork_id = None;
editor.edit(|mut edit| {
for (old_id, new_tid) in fix_tids {
let new_id = edit.add_node(new_tid);
......@@ -1283,9 +1289,12 @@ fn fork_interchange(
edit = edit.delete_node(old_id)?;
}
let new_fork = edit.add_node(new_fork);
new_fork_id = Some(new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
edit.delete_node(fork)
});
new_fork_id
}
/*
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment