Skip to content
Snippets Groups Projects

Fork reshape

Merged Aaron Councilman requested to merge fork-reshape into main
1 file
+ 28
19
Compare changes
  • Side-by-side
  • Inline
@@ -578,7 +578,7 @@ pub fn fork_coalesce(
// FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
// something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) {
if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() {
return true;
}
}
@@ -587,13 +587,15 @@ pub fn fork_coalesce(
/** Opposite of fork split, takes two fork-joins
with no control between them, and merges them into a single fork-join.
Returns None if the forks could not be merged and the NodeIDs of the
resulting fork and join if it succeeds in merging them.
*/
pub fn fork_coalesce_helper(
editor: &mut FunctionEditor,
outer_fork: NodeID,
inner_fork: NodeID,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> bool {
) -> Option<(NodeID, NodeID)> {
// Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.
let outer_join = fork_join_map[&outer_fork];
@@ -621,19 +623,19 @@ pub fn fork_coalesce_helper(
reduct: _,
} = inner_reduce_node
else {
return false;
return None;
};
// FIXME: check this condition better (i.e reduce might not be attached to join)
if *inner_control != inner_join {
return false;
return None;
};
if *inner_init != outer_reduce {
return false;
return None;
};
if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
return false;
return None;
} else {
pairs.insert(outer_reduce, inner_reduce);
}
@@ -645,11 +647,11 @@ pub fn fork_coalesce_helper(
.filter(|node| editor.func().nodes[node.idx()].is_control())
.next()
else {
return false;
return None;
};
if user != inner_fork {
return false;
return None;
}
let Some(user) = editor
@@ -657,11 +659,11 @@ pub fn fork_coalesce_helper(
.filter(|node| editor.func().nodes[node.idx()].is_control())
.next()
else {
return false;
return None;
};
if user != outer_join {
return false;
return None;
}
// Checklist:
@@ -709,10 +711,10 @@ pub fn fork_coalesce_helper(
let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
.try_reduce()
.unwrap();
editor.edit(|mut edit| {
let success = editor.edit(|mut edit| {
// Set inner init to outer init.
edit =
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?;
edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
edit = edit.delete_node(outer_reduce)?;
@@ -720,12 +722,15 @@ pub fn fork_coalesce_helper(
});
}
let mut new_fork = NodeID::new(0);
let new_join = inner_join; // We reuse the inner join as the join of the new fork
editor.edit(|mut edit| {
let new_fork = Node::Fork {
let new_fork_node = Node::Fork {
control: outer_pred,
factors: new_factors.into(),
};
let new_fork = edit.add_node(new_fork);
new_fork = edit.add_node(new_fork_node);
edit = edit.replace_all_uses(inner_fork, new_fork)?;
edit = edit.replace_all_uses(outer_fork, new_fork)?;
@@ -737,7 +742,7 @@ pub fn fork_coalesce_helper(
Ok(edit)
});
true
Some((new_fork, new_join))
}
pub fn split_any_fork(
@@ -760,7 +765,7 @@ pub fn split_any_fork(
* Useful for code generation. A single iteration of `fork_split` only splits
* at most one fork-join, it must be called repeatedly to split all fork-joins.
*/
pub(crate) fn split_fork(
pub fn split_fork(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
@@ -1215,13 +1220,13 @@ pub fn fork_interchange_all_forks(
}
}
fn fork_interchange(
pub fn fork_interchange(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
first_dim: usize,
second_dim: usize,
) {
) -> Option<NodeID> {
// Check that every reduce on the join is parallel or associative.
let nodes = &editor.func().nodes;
let schedules = &editor.func().schedules;
@@ -1234,7 +1239,7 @@ fn fork_interchange(
})
{
// If not, we can't necessarily do interchange.
return;
return None;
}
let Node::Fork {
@@ -1276,6 +1281,7 @@ fn fork_interchange(
let mut factors = factors.clone();
factors.swap(first_dim, second_dim);
let new_fork = Node::Fork { control, factors };
let mut new_fork_id = None;
editor.edit(|mut edit| {
for (old_id, new_tid) in fix_tids {
let new_id = edit.add_node(new_tid);
@@ -1283,9 +1289,12 @@ fn fork_interchange(
edit = edit.delete_node(old_id)?;
}
let new_fork = edit.add_node(new_fork);
new_fork_id = Some(new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
edit.delete_node(fork)
});
new_fork_id
}
/*
Loading