Skip to content
Snippets Groups Projects

Fork reshape

Merged Aaron Councilman requested to merge fork-reshape into main
1 file
+ 19
28
Compare changes
  • Side-by-side
  • Inline
@@ -686,40 +686,33 @@ pub fn fork_coalesce_helper(
// CHECKME / FIXME: Might need to be added the other way.
new_factors.append(&mut inner_dims.to_vec());
for tid in inner_tids {
let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap();
let new_tid = Node::ThreadID {
control: fork,
dimension: dim + num_outer_dims,
};
editor.edit(|mut edit| {
for tid in inner_tids {
let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap();
let new_tid = Node::ThreadID {
control: fork,
dimension: dim + num_outer_dims,
};
editor.edit(|mut edit| {
let new_tid = edit.add_node(new_tid);
let edit = edit.replace_all_uses(tid, new_tid)?;
Ok(edit)
});
}
edit = edit.replace_all_uses(tid, new_tid)?;
}
// Fuse Reductions
for (outer_reduce, inner_reduce) in pairs {
let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()]
.try_reduce()
.unwrap();
let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
.try_reduce()
.unwrap();
editor.edit(|mut edit| {
// Fuse Reductions
for (outer_reduce, inner_reduce) in pairs {
let (_, outer_init, _) = edit.get_node(outer_reduce)
.try_reduce()
.unwrap();
let (_, inner_init, _) = edit.get_node(inner_reduce)
.try_reduce()
.unwrap();
// Set inner init to outer init.
edit =
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
edit = edit.delete_node(outer_reduce)?;
}
Ok(edit)
});
}
editor.edit(|mut edit| {
let new_fork = Node::Fork {
control: outer_pred,
factors: new_factors.into(),
@@ -734,9 +727,7 @@ pub fn fork_coalesce_helper(
edit = edit.delete_node(outer_fork)?;
Ok(edit)
});
true
})
}
pub fn split_any_fork(
Loading