Skip to content
Snippets Groups Projects

Fork reshape

Merged Aaron Councilman requested to merge fork-reshape into main
1 file
+ 38
32
Compare changes
  • Side-by-side
  • Inline
@@ -676,50 +676,47 @@ pub fn fork_coalesce_helper(
@@ -676,50 +676,47 @@ pub fn fork_coalesce_helper(
// CHECKME / FIXME: Might need to be added the other way.
// CHECKME / FIXME: Might need to be added the other way.
new_factors.append(&mut inner_dims.to_vec());
new_factors.append(&mut inner_dims.to_vec());
for tid in inner_tids {
let mut new_fork = NodeID::new(0);
let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap();
let new_join = inner_join; // We'll reuse the inner join as the join of the new fork
let new_tid = Node::ThreadID {
control: fork,
let success = editor.edit(|mut edit| {
dimension: dim + num_outer_dims,
for tid in inner_tids {
};
let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap();
 
let new_tid = Node::ThreadID {
 
control: fork,
 
dimension: dim + num_outer_dims,
 
};
editor.edit(|mut edit| {
let new_tid = edit.add_node(new_tid);
let new_tid = edit.add_node(new_tid);
let mut edit = edit.replace_all_uses(tid, new_tid)?;
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
edit.sub_edit(tid, new_tid);
Ok(edit)
}
});
// Fuse Reductions
}
for (outer_reduce, inner_reduce) in pairs {
let (_, outer_init, _) = edit.get_node(outer_reduce)
// Fuse Reductions
.try_reduce()
for (outer_reduce, inner_reduce) in pairs {
.unwrap();
let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()]
let (_, inner_init, _) = edit.get_node(inner_reduce)
.try_reduce()
.try_reduce()
.unwrap();
.unwrap();
let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
.try_reduce()
.unwrap();
let success = editor.edit(|mut edit| {
// Set inner init to outer init.
// Set inner init to outer init.
edit =
edit =
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?;
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce )?;
edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
edit = edit.delete_node(outer_reduce)?;
edit = edit.delete_node(outer_reduce)?;
 
}
Ok(edit)
});
}
let mut new_fork = NodeID::new(0);
let new_join = inner_join; // We reuse the inner join as the join of the new fork
editor.edit(|mut edit| {
let new_fork_node = Node::Fork {
let new_fork_node = Node::Fork {
control: outer_pred,
control: outer_pred,
factors: new_factors.into(),
factors: new_factors.into(),
};
};
new_fork = edit.add_node(new_fork_node);
new_fork = edit.add_node(new_fork_node);
 
if edit.get_schedule(outer_fork).contains(&Schedule::ParallelFork)
 
&& edit.get_schedule(inner_fork).contains(&Schedule::ParallelFork) {
 
edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
 
}
 
edit = edit.replace_all_uses(inner_fork, new_fork)?;
edit = edit.replace_all_uses(inner_fork, new_fork)?;
edit = edit.replace_all_uses(outer_fork, new_fork)?;
edit = edit.replace_all_uses(outer_fork, new_fork)?;
edit = edit.replace_all_uses(outer_join, inner_join)?;
edit = edit.replace_all_uses(outer_join, inner_join)?;
@@ -730,7 +727,11 @@ pub fn fork_coalesce_helper(
@@ -730,7 +727,11 @@ pub fn fork_coalesce_helper(
Ok(edit)
Ok(edit)
});
});
Some((new_fork, new_join))
if success {
 
Some((new_fork, new_join))
 
} else {
 
None
 
}
}
}
pub fn split_any_fork(
pub fn split_any_fork(
@@ -1277,9 +1278,14 @@ pub fn fork_interchange(
@@ -1277,9 +1278,14 @@ pub fn fork_interchange(
edit = edit.delete_node(old_id)?;
edit = edit.delete_node(old_id)?;
}
}
let new_fork = edit.add_node(new_fork);
let new_fork = edit.add_node(new_fork);
new_fork_id = Some(new_fork);
if edit.get_schedule(fork).contains(&Schedule::ParallelFork) {
 
edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
 
}
edit = edit.replace_all_uses(fork, new_fork)?;
edit = edit.replace_all_uses(fork, new_fork)?;
edit.delete_node(fork)
edit = edit.delete_node(fork)?;
 
 
new_fork_id = Some(new_fork);
 
Ok(edit)
});
});
new_fork_id
new_fork_id
Loading