Skip to content
Snippets Groups Projects

Fork fission

Merged Xavier Routh requested to merge fork-fission into main
Files
8
@@ -306,40 +306,43 @@ where
@@ -306,40 +306,43 @@ where
pub fn fork_fission<'a>(
pub fn fork_fission<'a>(
editor: &'a mut FunctionEditor,
editor: &'a mut FunctionEditor,
_control_subgraph: &Subgraph,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
_types: &Vec<TypeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
_loop_tree: &LoopTree,
loop_tree: &LoopTree,
fork_join_map: &HashMap<NodeID, NodeID>,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> () {
fork_label: LabelID,
let forks: Vec<_> = editor
) -> Vec<NodeID> {
.func()
let forks: Vec<_> = loop_tree
.nodes
.bottom_up_loops()
.iter()
.into_iter()
.enumerate()
.filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
.filter_map(|(idx, node)| {
if node.is_fork() {
Some(NodeID::new(idx))
} else {
None
}
})
.collect();
.collect();
let control_pred = NodeID::new(0);
let mut created_forks = Vec::new();
 
 
// This does the reduction fission
 
for fork in forks {
 
let join = fork_join_map[&fork.0];
// This does the reduction fission:
// FIXME: Don't make multiple forks for reduces that are in cycles with each other.
for fork in forks.clone() {
let reduce_partition = default_reduce_partition(editor, fork.0, join);
// FIXME: If there is control in between fork and join, don't just give up.
let join = fork_join_map[&fork];
if !editor.func().labels[fork.0.idx()].contains(&fork_label) {
let join_pred = editor.func().nodes[join.idx()].try_join().unwrap();
continue;
if join_pred != fork {
todo!("Can't do fork fission on nodes with internal control")
// Inner control LOOPs are hard
// inner control in general *should* work right now without modifications.
}
}
let reduce_partition = default_reduce_partition(editor, fork, join);
fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork);
if editor.is_mutable(fork.0) {
 
created_forks = fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0);
 
if created_forks.is_empty() {
 
continue;
 
} else {
 
return created_forks;
 
}
 
}
 
}
}
 
 
created_forks
}
}
/** Split a 1D fork into two forks, placing select intermediate data into buffers. */
/** Split a 1D fork into two forks, placing select intermediate data into buffers. */
@@ -488,48 +491,38 @@ where
@@ -488,48 +491,38 @@ where
}
}
}
}
/** Split a 1D fork into a separate fork for each reduction. */
/** Split a fork into a separate fork for each reduction. */
pub fn fork_reduce_fission_helper<'a>(
pub fn fork_reduce_fission_helper<'a>(
editor: &'a mut FunctionEditor,
editor: &'a mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split,
reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split,
original_control_pred: NodeID, // What the new fork connects to.
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
fork: NodeID,
fork: NodeID,
) -> (NodeID, NodeID) {
) -> Vec<NodeID> {
let join = fork_join_map[&fork];
let join = fork_join_map[&fork];
let mut new_control_pred: NodeID = original_control_pred;
let mut new_forks = Vec::new();
// Important edges are: Reduces,
// NOTE:
let mut new_control_pred: NodeID = editor.get_uses(fork).filter(|n| editor.node(n).is_control()).next().unwrap();
// Say two reduce are in a fork, s.t reduce A depends on reduce B
// If user wants A and B in separate forks:
// - we can simply refuse
// - or we can duplicate B
let mut new_fork = NodeID::new(0);
let mut new_fork = NodeID::new(0);
let mut new_join = NodeID::new(0);
let mut new_join = NodeID::new(0);
 
let subgraph = &nodes_in_fork_joins[&fork];
 
// Gets everything between fork & join that this reduce needs. (ALL CONTROL)
// Gets everything between fork & join that this reduce needs. (ALL CONTROL)
for reduce in reduce_partition {
editor.edit(|mut edit| {
let reduce = reduce.0;
for reduce in reduce_partition {
let reduce = reduce.0;
let function = editor.func();
let subgraph = find_reduce_dependencies(function, reduce, fork);
let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect();
subgraph.insert(join);
subgraph.insert(fork);
subgraph.insert(reduce);
let (_, mapping, _) = copy_subgraph(editor, subgraph);
new_fork = mapping[&fork];
let a = copy_subgraph_in_edit(edit, subgraph.clone())?;
new_join = mapping[&join];
edit = a.0;
 
let mapping = a.1;
editor.edit(|mut edit| {
new_fork = mapping[&fork];
 
new_forks.push(new_fork);
 
new_join = mapping[&join];
 
// Atttach new_fork after control_pred
// Atttach new_fork after control_pred
let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
@@ -538,13 +531,9 @@ pub fn fork_reduce_fission_helper<'a>(
@@ -538,13 +531,9 @@ pub fn fork_reduce_fission_helper<'a>(
// Replace uses of reduce
// Replace uses of reduce
edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
Ok(edit)
new_control_pred = new_join;
});
};
new_control_pred = new_join;
}
editor.edit(|mut edit| {
// Replace original join w/ new final join
// Replace original join w/ new final join
edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
@@ -553,10 +542,12 @@ pub fn fork_reduce_fission_helper<'a>(
@@ -553,10 +542,12 @@ pub fn fork_reduce_fission_helper<'a>(
// Replace all users of original fork, and then delete it, leftover users will be DCE'd.
// Replace all users of original fork, and then delete it, leftover users will be DCE'd.
edit = edit.replace_all_uses(fork, new_fork)?;
edit = edit.replace_all_uses(fork, new_fork)?;
edit.delete_node(fork)
edit = edit.delete_node(fork)?;
 
 
Ok(edit)
});
});
(new_fork, new_join)
new_forks
}
}
pub fn fork_coalesce(
pub fn fork_coalesce(
Loading