Skip to content
Snippets Groups Projects

Fork reshape

Merged Aaron Councilman requested to merge fork-reshape into main
1 file
+ 19
28
Compare changes
  • Side-by-side
  • Inline
@@ -578,7 +578,7 @@ pub fn fork_coalesce(
// FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
// something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) {
if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() {
return true;
}
}
@@ -587,13 +587,15 @@ pub fn fork_coalesce(
/** Opposite of fork split, takes two fork-joins
with no control between them, and merges them into a single fork-join.
Returns None if the forks could not be merged and the NodeIDs of the
resulting fork and join if it succeeds in merging them.
*/
pub fn fork_coalesce_helper(
editor: &mut FunctionEditor,
outer_fork: NodeID,
inner_fork: NodeID,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> bool {
) -> Option<(NodeID, NodeID)> {
// Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.
let outer_join = fork_join_map[&outer_fork];
@@ -621,47 +623,35 @@ pub fn fork_coalesce_helper(
reduct: _,
} = inner_reduce_node
else {
return false;
return None;
};
// FIXME: check this condition better (i.e reduce might not be attached to join)
if *inner_control != inner_join {
return false;
return None;
};
if *inner_init != outer_reduce {
return false;
return None;
};
if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
return false;
return None;
} else {
pairs.insert(outer_reduce, inner_reduce);
}
}
// Check for control between join-join and fork-fork
let Some(user) = editor
.get_users(outer_fork)
.filter(|node| editor.func().nodes[node.idx()].is_control())
.next()
else {
return false;
};
let (control, _) = editor.node(inner_fork).try_fork().unwrap();
if user != inner_fork {
return false;
if control != outer_fork {
return None;
}
let Some(user) = editor
.get_users(inner_join)
.filter(|node| editor.func().nodes[node.idx()].is_control())
.next()
else {
return false;
};
let control = editor.node(outer_join).try_join().unwrap();
if user != outer_join {
return false;
if control != inner_join {
return None;
}
// Checklist:
@@ -686,46 +676,47 @@ pub fn fork_coalesce_helper(
// CHECKME / FIXME: Might need to be added the other way.
new_factors.append(&mut inner_dims.to_vec());
for tid in inner_tids {
let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap();
let new_tid = Node::ThreadID {
control: fork,
dimension: dim + num_outer_dims,
};
let mut new_fork = NodeID::new(0);
let new_join = inner_join; // We'll reuse the inner join as the join of the new fork
let success = editor.edit(|mut edit| {
for tid in inner_tids {
let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap();
let new_tid = Node::ThreadID {
control: fork,
dimension: dim + num_outer_dims,
};
editor.edit(|mut edit| {
let new_tid = edit.add_node(new_tid);
let mut edit = edit.replace_all_uses(tid, new_tid)?;
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
Ok(edit)
});
}
// Fuse Reductions
for (outer_reduce, inner_reduce) in pairs {
let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()]
.try_reduce()
.unwrap();
let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
.try_reduce()
.unwrap();
editor.edit(|mut edit| {
}
// Fuse Reductions
for (outer_reduce, inner_reduce) in pairs {
let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap();
let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap();
// Set inner init to outer init.
edit =
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
edit = edit.delete_node(outer_reduce)?;
}
Ok(edit)
});
}
editor.edit(|mut edit| {
let new_fork = Node::Fork {
let new_fork_node = Node::Fork {
control: outer_pred,
factors: new_factors.into(),
};
let new_fork = edit.add_node(new_fork);
new_fork = edit.add_node(new_fork_node);
if edit
.get_schedule(outer_fork)
.contains(&Schedule::ParallelFork)
&& edit
.get_schedule(inner_fork)
.contains(&Schedule::ParallelFork)
{
edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
}
edit = edit.replace_all_uses(inner_fork, new_fork)?;
edit = edit.replace_all_uses(outer_fork, new_fork)?;
@@ -737,7 +728,11 @@ pub fn fork_coalesce_helper(
Ok(edit)
});
true
if success {
Some((new_fork, new_join))
} else {
None
}
}
pub fn split_any_fork(
@@ -760,7 +755,7 @@ pub fn split_any_fork(
* Useful for code generation. A single iteration of `fork_split` only splits
* at most one fork-join, it must be called repeatedly to split all fork-joins.
*/
pub(crate) fn split_fork(
pub fn split_fork(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
@@ -1215,13 +1210,13 @@ pub fn fork_interchange_all_forks(
}
}
fn fork_interchange(
pub fn fork_interchange(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
first_dim: usize,
second_dim: usize,
) {
) -> Option<NodeID> {
// Check that every reduce on the join is parallel or associative.
let nodes = &editor.func().nodes;
let schedules = &editor.func().schedules;
@@ -1234,7 +1229,7 @@ fn fork_interchange(
})
{
// If not, we can't necessarily do interchange.
return;
return None;
}
let Node::Fork {
@@ -1276,6 +1271,7 @@ fn fork_interchange(
let mut factors = factors.clone();
factors.swap(first_dim, second_dim);
let new_fork = Node::Fork { control, factors };
let mut new_fork_id = None;
editor.edit(|mut edit| {
for (old_id, new_tid) in fix_tids {
let new_id = edit.add_node(new_tid);
@@ -1283,9 +1279,17 @@ fn fork_interchange(
edit = edit.delete_node(old_id)?;
}
let new_fork = edit.add_node(new_fork);
if edit.get_schedule(fork).contains(&Schedule::ParallelFork) {
edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
}
edit = edit.replace_all_uses(fork, new_fork)?;
edit.delete_node(fork)
edit = edit.delete_node(fork)?;
new_fork_id = Some(new_fork);
Ok(edit)
});
new_fork_id
}
/*
Loading