Skip to content
Snippets Groups Projects

Fork unrolling

Merged rarbore2 requested to merge fork_unroll into main
Files
5
@@ -1164,3 +1164,111 @@ fn fork_interchange(
edit.delete_node(fork)
});
}
/*
* Run fork unrolling on all fork-joins that are mutable in an editor.
*/
pub fn fork_unroll_all_forks(
editor: &mut FunctionEditor,
fork_joins: &HashMap<NodeID, NodeID>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) {
for (fork, join) in fork_joins {
if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) {
break;
}
}
}
pub fn fork_unroll(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) -> bool {
// We can only unroll fork-joins with a compile time known factor list. For
// simplicity, just unroll fork-joins that have a single dimension.
let nodes = &editor.func().nodes;
let Node::Fork {
control,
ref factors,
} = nodes[fork.idx()]
else {
panic!()
};
if factors.len() != 1 || editor.get_users(fork).count() != 2 {
return false;
}
let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else {
return false;
};
let tid = editor
.get_users(fork)
.filter(|id| nodes[id.idx()].is_thread_id())
.next()
.unwrap();
let inits: HashMap<NodeID, NodeID> = editor
.get_users(join)
.filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init)))
.collect();
editor.edit(|mut edit| {
// Create a copy of the nodes in the fork join per unrolled iteration,
// excluding the fork itself, the join itself, the thread IDs of the fork,
// and the reduces on the join. Keep a running tally of the top control
// node and the current reduction value.
let mut top_control = control;
let mut current_reduces = inits;
for iter in 0..cons {
let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64));
let iter_tid = edit.add_node(Node::Constant { id: iter_cons });
// First, add a copy of each node in the fork join unmodified.
// Record the mapping from old ID to new ID.
let mut added_ids = HashSet::new();
let mut old_to_new_ids = HashMap::new();
let mut new_control = None;
let mut new_reduces = HashMap::new();
for node in nodes_in_fork_joins[&fork].iter() {
if *node == fork {
old_to_new_ids.insert(*node, top_control);
} else if *node == join {
new_control = Some(edit.get_node(*node).try_join().unwrap());
} else if *node == tid {
old_to_new_ids.insert(*node, iter_tid);
} else if let Some(current) = current_reduces.get(node) {
old_to_new_ids.insert(*node, *current);
new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2);
} else {
let new_node = edit.add_node(edit.get_node(*node).clone());
old_to_new_ids.insert(*node, new_node);
added_ids.insert(new_node);
}
}
// Second, replace all the uses in the just added nodes.
if let Some(new_control) = new_control {
top_control = old_to_new_ids[&new_control];
}
for (reduce, reduct) in new_reduces {
current_reduces.insert(reduce, old_to_new_ids[&reduct]);
}
for (old, new) in old_to_new_ids {
edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?;
}
}
// Hook up the control and reduce outputs to the rest of the function.
edit = edit.replace_all_uses(join, top_control)?;
for (reduce, reduct) in current_reduces {
edit = edit.replace_all_uses(reduce, reduct)?;
}
// Delete the old fork-join.
for node in nodes_in_fork_joins[&fork].iter() {
edit = edit.delete_node(*node)?;
}
Ok(edit)
})
}
Loading