fork_transforms.rs 58.36 KiB
use std::collections::{HashMap, HashSet};
use std::iter::zip;
use bimap::BiMap;
use itertools::Itertools;
use hercules_ir::*;
use crate::*;
type ForkID = usize;
/** Places each reduce node into its own fork */
pub fn default_reduce_partition(
editor: &FunctionEditor,
_fork: NodeID,
join: NodeID,
) -> SparseNodeMap<ForkID> {
let mut map = SparseNodeMap::new();
editor
.get_users(join)
.filter(|id| editor.func().nodes[id.idx()].is_reduce())
.enumerate()
.for_each(|(fork, reduce)| {
map.insert(reduce, fork);
});
map
}
// TODO: Refine these conditions.
/** */
pub fn find_reduce_dependencies<'a>(
function: &'a Function,
reduce: NodeID,
fork: NodeID,
) -> impl IntoIterator<Item = NodeID> + 'a {
let len = function.nodes.len();
let mut visited: DenseNodeMap<bool> = vec![false; len];
let mut depdendent: DenseNodeMap<bool> = vec![false; len];
// Does `fork` need to be a parameter here? It never changes. If this was a closure could it just capture it?
fn recurse(
function: &Function,
node: NodeID,
fork: NodeID,
dependent_map: &mut DenseNodeMap<bool>,
visited: &mut DenseNodeMap<bool>,
) -> () {
// return through dependent_map {
if visited[node.idx()] {
return;
}
visited[node.idx()] = true;
if node == fork {
dependent_map[node.idx()] = true;
return;
}
let binding = get_uses(&function.nodes[node.idx()]);
let uses = binding.as_ref();
for used in uses {
recurse(function, *used, fork, dependent_map, visited);
}
dependent_map[node.idx()] = uses.iter().map(|id| dependent_map[id.idx()]).any(|a| a);
return;
}
// Note: HACKY, the condition wwe want is 'all nodes on any path from the fork to the reduce (in the forward graph), or the reduce to the fork (in the directed graph)
// cycles break this, but we assume for now that the only cycles are ones that involve the reduce node
// NOTE: (control may break this (i.e loop inside fork) is a cycle that isn't the reduce)
// the current solution is just to mark the reduce as dependent at the start of traversing the graph.
depdendent[reduce.idx()] = true;
recurse(function, reduce, fork, &mut depdendent, &mut visited);
// Return node IDs that are dependent
let ret_val: Vec<_> = depdendent
.iter()
.enumerate()
.filter_map(|(idx, dependent)| {
if *dependent {
Some(NodeID::new(idx))
} else {
None
}
})
.collect();
ret_val
}
pub fn copy_subgraph_in_edit<'a, 'b>(
mut edit: FunctionEdit<'a, 'b>,
subgraph: HashSet<NodeID>,
) -> Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>> {
let mut map: HashMap<NodeID, NodeID> = HashMap::new();
// Copy nodes in subgraph
for old_id in subgraph.iter().cloned() {
let new_id = edit.copy_node(old_id);
map.insert(old_id, new_id);
}
// Update edges to new nodes
for old_id in subgraph.iter() {
edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| {
map.values().contains(node_id)
})?;
}
Ok((edit, map))
}
pub fn copy_subgraph(
editor: &mut FunctionEditor,
subgraph: HashSet<NodeID>,
) -> (
HashSet<NodeID>,
HashMap<NodeID, NodeID>,
Vec<(NodeID, NodeID)>,
) // returns all new nodes, a map from old nodes to new nodes, and
// a vec of pairs of nodes (old node, outside node) s.t old node -> outside node,
// outside means not part of the original subgraph.
{
let mut map: HashMap<NodeID, NodeID> = HashMap::new();
let mut new_nodes: HashSet<NodeID> = HashSet::new();
// Copy nodes
for old_id in subgraph.iter() {
editor.edit(|mut edit| {
let new_id = edit.copy_node(*old_id);
map.insert(*old_id, new_id);
new_nodes.insert(new_id);
Ok(edit)
});
}
// Update edges to new nodes
for old_id in subgraph.iter() {
// Replace all uses of old_id w/ new_id, where the use is in new_node
editor.edit(|edit| {
edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id))
});
}
// Get all users that aren't in new_nodes.
let mut outside_users = Vec::new();
for node in new_nodes.iter() {
for user in editor.get_users(*node) {
if !new_nodes.contains(&user) {
outside_users.push((*node, user));
}
}
}
(new_nodes, map, outside_users)
}
pub fn find_bufferize_edges(
editor: &mut FunctionEditor,
fork: NodeID,
loop_tree: &LoopTree,
fork_join_map: &HashMap<NodeID, NodeID>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
data_label: &LabelID,
) -> HashSet<(NodeID, NodeID)> {
let mut edges: HashSet<_> = HashSet::new();
for node in &nodes_in_fork_joins[&fork] {
// Edge from *has data label** to doesn't have data label*
let node_labels = &editor.func().labels[node.idx()];
if !node_labels.contains(data_label) {
continue;
}
// Don't draw bufferize edges from fork tids
if editor.get_users(fork).contains(node) {
continue;
}
for user in editor.get_users(*node) {
let user_labels = &editor.func().labels[user.idx()];
if user_labels.contains(data_label) {
continue;
}
if editor.node(user).is_control() || editor.node(node).is_control() {
continue;
}
edges.insert((*node, user));
}
}
edges
}
pub fn ff_bufferize_create_not_reduce_cycle_label_helper(
editor: &mut FunctionEditor,
fork: NodeID,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) -> LabelID {
let join = fork_join_map[&fork];
let mut nodes_not_in_a_reduce_cycle = nodes_in_fork_joins[&fork].clone();
for (cycle, reduce) in editor
.get_users(join)
.filter_map(|id| reduce_cycles.get(&id).map(|cycle| (cycle, id)))
{
nodes_not_in_a_reduce_cycle.remove(&reduce);
for id in cycle {
nodes_not_in_a_reduce_cycle.remove(id);
}
}
nodes_not_in_a_reduce_cycle.remove(&join);
let mut label = LabelID::new(0);
let success = editor.edit(|mut edit| {
label = edit.fresh_label();
for id in nodes_not_in_a_reduce_cycle {
edit = edit.add_label(id, label)?;
}
Ok(edit)
});
assert!(success);
label
}
pub fn ff_bufferize_any_fork<'a, 'b>(
editor: &'b mut FunctionEditor<'a>,
loop_tree: &'b LoopTree,
fork_join_map: &'b HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
typing: &'b Vec<TypeID>,
fork_label: LabelID,
data_label: Option<LabelID>,
) -> Option<(NodeID, NodeID)>
where
'a: 'b,
{
let mut forks: Vec<_> = loop_tree
.bottom_up_loops()
.into_iter()
.filter(|(k, _)| editor.func().nodes[k.idx()].is_fork())
.collect();
forks.reverse();
for l in forks {
let fork_info = Loop {
header: l.0,
control: l.1.clone(),
};
let fork = fork_info.header;
let join = fork_join_map[&fork];
if !editor.func().labels[fork.idx()].contains(&fork_label) {
continue;
}
let data_label = data_label.unwrap_or_else(|| {
ff_bufferize_create_not_reduce_cycle_label_helper(
editor,
fork,
fork_join_map,
reduce_cycles,
nodes_in_fork_joins,
)
});
let edges = find_bufferize_edges(
editor,
fork,
&loop_tree,
&fork_join_map,
&nodes_in_fork_joins,
&data_label,
);
let result = fork_bufferize_fission_helper(
editor,
&fork_info,
&edges,
nodes_in_fork_joins,
typing,
fork,
join,
);
if result.is_none() {
continue;
} else {
return result;
}
}
return None;
}
pub fn fork_fission<'a>(
editor: &'a mut FunctionEditor,
_control_subgraph: &Subgraph,
_types: &Vec<TypeID>,
_loop_tree: &LoopTree,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> () {
let forks: Vec<_> = editor
.func()
.nodes
.iter()
.enumerate()
.filter_map(|(idx, node)| {
if node.is_fork() {
Some(NodeID::new(idx))
} else {
None
}
})
.collect();
let control_pred = NodeID::new(0);
// This does the reduction fission:
for fork in forks.clone() {
// FIXME: If there is control in between fork and join, don't just give up.
let join = fork_join_map[&fork];
let join_pred = editor.func().nodes[join.idx()].try_join().unwrap();
if join_pred != fork {
todo!("Can't do fork fission on nodes with internal control")
// Inner control LOOPs are hard
// inner control in general *should* work right now without modifications.
}
let reduce_partition = default_reduce_partition(editor, fork, join);
fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork);
}
}
/** Split a 1D fork into two forks, placing select intermediate data into buffers. */
pub fn fork_bufferize_fission_helper<'a, 'b>(
editor: &'b mut FunctionEditor<'a>,
l: &Loop,
bufferized_edges: &HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized.
data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>,
types: &'b Vec<TypeID>,
fork: NodeID,
join: NodeID,
) -> Option<(NodeID, NodeID)>
where
'a: 'b,
{
if bufferized_edges.is_empty() {
return None;
}
let all_loop_nodes = l.get_all_nodes();
// FIXME: Cloning hell.
let data_nodes = data_node_in_fork_joins[&fork].clone();
let loop_nodes = editor
.node_ids()
.filter(|node_id| all_loop_nodes[node_id.idx()]);
// Clone the subgraph that consists of this fork-join and all data and control nodes in it.
let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes));
let mut outside_users = Vec::new(); // old_node, outside_user
for node in subgraph.iter() {
for user in editor.get_users(*node) {
if !subgraph.iter().contains(&user) {
outside_users.push((*node, user));
}
}
}
let factors: Vec<_> = editor.func().nodes[fork.idx()]
.try_fork()
.unwrap()
.1
.iter()
.cloned()
.collect();
let thread_stuff_it = factors.into_iter().enumerate();
// Control succesors
let fork_pred = editor
.get_uses(fork)
.filter(|a| editor.node(a).is_control())
.next()
.unwrap();
let join_successor = editor
.get_users(join)
.filter(|a| editor.node(a).is_control())
.next()
.unwrap();
let mut new_fork_id = NodeID::new(0);
let edit_result = editor.edit(|edit| {
let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?;
edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?;
edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?;
// Replace outside uses of reduces in old subgraph with reduces in new subgraph.
for (old_node, outside_user) in outside_users {
edit = edit
.replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
}
let new_fork = map[&fork];
// FIXME: Do this as part of copy subgraph?
// Add tids to original subgraph for indexing.
let mut old_tids = Vec::new();
let mut new_tids = Vec::new();
for (dim, _) in thread_stuff_it.clone() {
let old_id = edit.add_node(Node::ThreadID {
control: fork,
dimension: dim,
});
let new_id = edit.add_node(Node::ThreadID {
control: new_fork,
dimension: dim,
});
old_tids.push(old_id);
new_tids.push(new_id);
}
for (src, dst) in bufferized_edges {
let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
let position_idx = Index::Position(old_tids.clone().into_boxed_slice());
let write = edit.add_node(Node::Write {
collect: NodeID::new(0),
data: *src,
indices: vec![position_idx.clone()].into(),
});
let ele_type = types[src.idx()];
let empty_buffer = edit.add_type(hercules_ir::Type::Array(
ele_type,
array_dims.collect::<Vec<_>>().into_boxed_slice(),
));
let empty_buffer = edit.add_zero_constant(empty_buffer);
let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
edit = edit.add_schedule(empty_buffer, Schedule::NoResetConstant)?;
let reduce = Node::Reduce {
control: join,
init: empty_buffer,
reduct: write,
};
let reduce = edit.add_node(reduce);
edit = edit.add_schedule(reduce, Schedule::ParallelReduce)?;
// Fix write node
edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
// Create reads from buffer
let position_idx = Index::Position(new_tids.clone().into_boxed_slice());
let read = edit.add_node(Node::Read {
collect: reduce,
indices: vec![position_idx].into(),
});
// Replaces uses of bufferized edge src with corresponding reduce and read in old subraph
edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?;
}
new_fork_id = new_fork;
Ok(edit)
});
if edit_result {
Some((fork, new_fork_id))
} else {
None
}
}
/** Split a 1D fork into a separate fork for each reduction. */
pub fn fork_reduce_fission_helper<'a>(
editor: &'a mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split,
original_control_pred: NodeID, // What the new fork connects to.
fork: NodeID,
) -> (NodeID, NodeID) {
let join = fork_join_map[&fork];
let mut new_control_pred: NodeID = original_control_pred;
// Important edges are: Reduces,
// NOTE:
// Say two reduce are in a fork, s.t reduce A depends on reduce B
// If user wants A and B in separate forks:
// - we can simply refuse
// - or we can duplicate B
let mut new_fork = NodeID::new(0);
let mut new_join = NodeID::new(0);
// Gets everything between fork & join that this reduce needs. (ALL CONTROL)
for reduce in reduce_partition {
let reduce = reduce.0;
let function = editor.func();
let subgraph = find_reduce_dependencies(function, reduce, fork);
let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect();
subgraph.insert(join);
subgraph.insert(fork);
subgraph.insert(reduce);
let (_, mapping, _) = copy_subgraph(editor, subgraph);
new_fork = mapping[&fork];
new_join = mapping[&join];
editor.edit(|mut edit| {
// Atttach new_fork after control_pred
let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
*usee == new_fork
})?;
// Replace uses of reduce
edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
Ok(edit)
});
new_control_pred = new_join;
}
editor.edit(|mut edit| {
// Replace original join w/ new final join
edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
// Delete original join (all reduce users have been moved)
edit = edit.delete_node(join)?;
// Replace all users of original fork, and then delete it, leftover users will be DCE'd.
edit = edit.replace_all_uses(fork, new_fork)?;
edit.delete_node(fork)
});
(new_fork, new_join)
}
pub fn fork_coalesce(
editor: &mut FunctionEditor,
loops: &LoopTree,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> bool {
let fork_joins = loops.bottom_up_loops().into_iter().filter_map(|(k, _)| {
if editor.func().nodes[k.idx()].is_fork() {
Some(k)
} else {
None
}
});
let fork_joins: Vec<_> = fork_joins.collect();
// FIXME: Add a postorder traversal to optimize this.
// FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
// something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() {
return true;
}
}
return false;
}
/** Opposite of fork split, takes two fork-joins
with no control between them, and merges them into a single fork-join.
Returns None if the forks could not be merged and the NodeIDs of the
resulting fork and join if it succeeds in merging them.
*/
pub fn fork_coalesce_helper(
editor: &mut FunctionEditor,
outer_fork: NodeID,
inner_fork: NodeID,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> Option<(NodeID, NodeID)> {
// Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.
let outer_join = fork_join_map[&outer_fork];
let inner_join = fork_join_map[&inner_fork];
let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner
// FIXME: Iterate all control uses of joins to really collect all reduces
// (reduces can be attached to inner control)
for outer_reduce in editor
.get_users(outer_join)
.filter(|node| editor.func().nodes[node.idx()].is_reduce())
{
// check that inner reduce is of the inner join
let (_, _, outer_reduct) = editor.func().nodes[outer_reduce.idx()]
.try_reduce()
.unwrap();
let inner_reduce = outer_reduct;
let inner_reduce_node = &editor.func().nodes[outer_reduct.idx()];
let Node::Reduce {
control: inner_control,
init: inner_init,
reduct: _,
} = inner_reduce_node
else {
return None;
};
// FIXME: check this condition better (i.e reduce might not be attached to join)
if *inner_control != inner_join {
return None;
};
if *inner_init != outer_reduce {
return None;
};
if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
return None;
} else {
pairs.insert(outer_reduce, inner_reduce);
}
}
// Check for control between join-join and fork-fork
let (control, _) = editor.node(inner_fork).try_fork().unwrap();
if control != outer_fork {
return None;
}
let control = editor.node(outer_join).try_join().unwrap();
if control != inner_join {
return None;
}
// Checklist:
// Increment inner TIDs
// Add outer fork's dimension to front of inner fork.
// Fuse reductions
// - Initializer becomes outer initializer
// Replace uses of outer fork w/ inner fork.
// Replace uses of outer join w/ inner join.
// Delete outer fork-join
let inner_tids: Vec<NodeID> = editor
.get_users(inner_fork)
.filter(|node| editor.func().nodes[node.idx()].is_thread_id())
.collect();
let (outer_pred, outer_dims) = editor.func().nodes[outer_fork.idx()].try_fork().unwrap();
let (_, inner_dims) = editor.func().nodes[inner_fork.idx()].try_fork().unwrap();
let num_outer_dims = outer_dims.len();
let mut new_factors = outer_dims.to_vec();
// CHECKME / FIXME: Might need to be added the other way.
new_factors.append(&mut inner_dims.to_vec());
let mut new_fork = NodeID::new(0);
let new_join = inner_join; // We'll reuse the inner join as the join of the new fork
let success = editor.edit(|mut edit| {
for tid in inner_tids {
let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap();
let new_tid = Node::ThreadID {
control: fork,
dimension: dim + num_outer_dims,
};
let new_tid = edit.add_node(new_tid);
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
}
// Fuse Reductions
for (outer_reduce, inner_reduce) in pairs {
let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap();
let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap();
// Set inner init to outer init.
edit =
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
edit = edit.delete_node(outer_reduce)?;
}
let new_fork_node = Node::Fork {
control: outer_pred,
factors: new_factors.into(),
};
new_fork = edit.add_node(new_fork_node);
if edit
.get_schedule(outer_fork)
.contains(&Schedule::ParallelFork)
&& edit
.get_schedule(inner_fork)
.contains(&Schedule::ParallelFork)
{
edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
}
edit = edit.replace_all_uses(inner_fork, new_fork)?;
edit = edit.replace_all_uses(outer_fork, new_fork)?;
edit = edit.replace_all_uses(outer_join, inner_join)?;
edit = edit.delete_node(outer_join)?;
edit = edit.delete_node(inner_fork)?;
edit = edit.delete_node(outer_fork)?;
Ok(edit)
});
if success {
Some((new_fork, new_join))
} else {
None
}
}
pub fn split_any_fork(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> Option<(Vec<NodeID>, Vec<NodeID>)> {
for (fork, join) in fork_join_map {
if let Some((forks, joins)) = split_fork(editor, *fork, *join, reduce_cycles)
&& forks.len() > 1
{
return Some((forks, joins));
}
}
None
}
/*
* Split multi-dimensional fork-joins into separate one-dimensional fork-joins.
* Useful for code generation. A single iteration of `fork_split` only splits
* at most one fork-join, it must be called repeatedly to split all fork-joins.
*/
pub fn split_fork(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> Option<(Vec<NodeID>, Vec<NodeID>)> {
// A single multi-dimensional fork becomes multiple forks, a join becomes
// multiple joins, a thread ID becomes a thread ID on the correct
// fork, and a reduce becomes multiple reduces to shuffle the reduction
// value through the fork-join nest.
let nodes = &editor.func().nodes;
let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap();
if factors.len() < 2 {
return Some((vec![fork], vec![join]));
}
let factors: Box<[DynamicConstantID]> = factors.into();
let join_control = nodes[join.idx()].try_join().unwrap();
let tids: Vec<_> = editor
.get_users(fork)
.filter(|id| nodes[id.idx()].is_thread_id())
.collect();
let reduces: Vec<_> = editor
.get_users(join)
.filter(|id| nodes[id.idx()].is_reduce())
.collect();
let data_in_reduce_cycle: HashSet<(NodeID, NodeID)> = reduces
.iter()
.map(|reduce| editor.get_users(*reduce).map(move |user| (user, *reduce)))
.flatten()
.filter(|(user, reduce)| reduce_cycles[&reduce].contains(&user))
.collect();
let mut new_forks = vec![];
let mut new_joins = vec![];
let success = editor.edit(|mut edit| {
// Create the forks and a thread ID per fork.
let mut acc_fork = fork_control;
let mut new_tids = vec![];
for factor in factors {
acc_fork = edit.add_node(Node::Fork {
control: acc_fork,
factors: Box::new([factor]),
});
new_forks.push(acc_fork);
edit.sub_edit(fork, acc_fork);
new_tids.push(edit.add_node(Node::ThreadID {
control: acc_fork,
dimension: 0,
}));
}
// Create the joins.
let mut acc_join = if join_control == fork {
acc_fork
} else {
join_control
};
for _ in new_tids.iter() {
acc_join = edit.add_node(Node::Join { control: acc_join });
edit.sub_edit(join, acc_join);
new_joins.push(acc_join);
}
// Create the reduces.
let mut new_reduces = vec![];
for reduce in reduces.iter() {
let (_, init, reduct) = edit.get_node(*reduce).try_reduce().unwrap();
let num_nodes = edit.num_node_ids();
let mut inner_reduce = NodeID::new(0);
let mut outer_reduce = NodeID::new(0);
for (join_idx, join) in new_joins.iter().enumerate() {
let init = if join_idx == new_joins.len() - 1 {
init
} else {
NodeID::new(num_nodes + join_idx + 1)
};
let reduct = if join_idx == 0 {
reduct
} else {
NodeID::new(num_nodes + join_idx - 1)
};
let new_reduce = edit.add_node(Node::Reduce {
control: *join,
init,
reduct,
});
assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx));
edit.sub_edit(*reduce, new_reduce);
if join_idx == 0 {
inner_reduce = new_reduce;
}
if join_idx == new_joins.len() - 1 {
outer_reduce = new_reduce;
}
}
new_reduces.push((inner_reduce, outer_reduce));
}
// Replace everything.
edit = edit.replace_all_uses(fork, acc_fork)?;
edit = edit.replace_all_uses(join, acc_join)?;
for tid in tids.iter() {
let dim = edit.get_node(*tid).try_thread_id().unwrap().1;
edit.sub_edit(*tid, new_tids[dim]);
edit = edit.replace_all_uses(*tid, new_tids[dim])?;
}
for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) {
edit = edit.replace_all_uses_where(*reduce, inner_reduce, |id| {
data_in_reduce_cycle.contains(&(*id, *reduce))
})?;
edit = edit.replace_all_uses_where(*reduce, outer_reduce, |id| {
!data_in_reduce_cycle.contains(&(*id, *reduce))
})?;
}
// Delete all the old stuff.
edit = edit.delete_node(fork)?;
edit = edit.delete_node(join)?;
for tid in tids {
edit = edit.delete_node(tid)?;
}
for reduce in reduces {
edit = edit.delete_node(reduce)?;
}
Ok(edit)
});
if success {
new_joins.reverse();
Some((new_forks, new_joins))
} else {
None
}
}
pub fn chunk_all_forks_unguarded(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
dim_idx: usize,
tile_size: usize,
order: bool,
) -> () {
// Add dc
let mut dc_id = DynamicConstantID::new(0);
editor.edit(|mut edit| {
dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size));
Ok(edit)
});
let order = match order {
true => &TileOrder::TileInner,
false => &TileOrder::TileOuter,
};
for (fork, _) in fork_join_map {
if editor.is_mutable(*fork) {
chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order);
}
}
}
// Splits a dimension of a single fork join into multiple.
// Iterates an outer loop original_dim / tile_size times
// adds a tile_size loop as the inner loop
// Assumes that tile size divides original dim evenly.
enum TileOrder {
TileInner,
TileOuter,
}
pub fn chunk_fork_unguarded(
editor: &mut FunctionEditor,
fork: NodeID,
dim_idx: usize,
tile_size: DynamicConstantID,
order: &TileOrder,
) -> () {
// tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1)
let Node::Fork {
control: old_control,
factors: ref old_factors,
} = *editor.node(fork)
else {
return;
};
assert!(dim_idx < old_factors.len());
let mut new_factors: Vec<_> = old_factors.to_vec();
let fork_users: Vec<_> = editor
.get_users(fork)
.map(|f| (f, editor.node(f).clone()))
.collect();
match order {
TileOrder::TileInner => {
editor.edit(|mut edit| {
let outer = DynamicConstant::div(new_factors[dim_idx], tile_size);
new_factors.insert(dim_idx + 1, tile_size);
new_factors[dim_idx] = edit.add_dynamic_constant(outer);
let new_fork = Node::Fork {
control: old_control,
factors: new_factors.into(),
};
let new_fork = edit.add_node(new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
edit.sub_edit(fork, new_fork);
for (tid, node) in fork_users {
let Node::ThreadID {
control: _,
dimension: tid_dim,
} = node
else {
continue;
};
if tid_dim > dim_idx {
let new_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim + 1,
};
let new_tid = edit.add_node(new_tid);
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
edit = edit.delete_node(tid)?;
} else if tid_dim == dim_idx {
let tile_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim + 1,
};
let tile_tid = edit.add_node(tile_tid);
let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size });
let mul = edit.add_node(Node::Binary {
left: tid,
right: tile_size,
op: BinaryOperator::Mul,
});
let add = edit.add_node(Node::Binary {
left: mul,
right: tile_tid,
op: BinaryOperator::Add,
});
edit.sub_edit(tid, add);
edit.sub_edit(tid, tile_tid);
edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
}
}
edit = edit.delete_node(fork)?;
Ok(edit)
});
}
TileOrder::TileOuter => {
editor.edit(|mut edit| {
let inner = DynamicConstant::div(new_factors[dim_idx], tile_size);
new_factors.insert(dim_idx, tile_size);
let inner_dc_id = edit.add_dynamic_constant(inner);
new_factors[dim_idx + 1] = inner_dc_id;
let new_fork = Node::Fork {
control: old_control,
factors: new_factors.into(),
};
let new_fork = edit.add_node(new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
edit.sub_edit(fork, new_fork);
for (tid, node) in fork_users {
let Node::ThreadID {
control: _,
dimension: tid_dim,
} = node
else {
continue;
};
if tid_dim > dim_idx {
let new_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim + 1,
};
let new_tid = edit.add_node(new_tid);
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
edit = edit.delete_node(tid)?;
} else if tid_dim == dim_idx {
let tile_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim + 1,
};
let tile_tid = edit.add_node(tile_tid);
let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id });
let mul = edit.add_node(Node::Binary {
left: tid,
right: inner_dc,
op: BinaryOperator::Mul,
});
let add = edit.add_node(Node::Binary {
left: mul,
right: tile_tid,
op: BinaryOperator::Add,
});
edit.sub_edit(tid, add);
edit.sub_edit(tid, tile_tid);
edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?;
}
}
edit = edit.delete_node(fork)?;
Ok(edit)
});
}
}
}
pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
for (fork, _) in fork_join_map {
let Node::Fork {
control: _,
factors: dims,
} = editor.node(fork)
else {
unreachable!();
};
let mut fork = *fork;
for _ in 0..dims.len() - 1 {
let outer = 0;
let inner = 1;
fork = fork_dim_merge(editor, fork, outer, inner);
}
}
}
pub fn fork_dim_merge(
editor: &mut FunctionEditor,
fork: NodeID,
dim_idx1: usize,
dim_idx2: usize,
) -> NodeID {
// tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2)
// tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2)
assert_ne!(dim_idx1, dim_idx2);
// Outer is smaller, and also closer to the left of the factors array.
let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 {
(dim_idx2, dim_idx1)
} else {
(dim_idx1, dim_idx2)
};
let Node::Fork {
control: old_control,
factors: ref old_factors,
} = *editor.node(fork)
else {
return fork;
};
let mut new_factors: Vec<_> = old_factors.to_vec();
let fork_users: Vec<_> = editor
.get_users(fork)
.map(|f| (f, editor.node(f).clone()))
.collect();
let mut new_nodes = vec![];
let outer_dc_id = new_factors[outer_idx];
let inner_dc_id = new_factors[inner_idx];
let mut new_fork = NodeID::new(0);
editor.edit(|mut edit| {
new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul(
new_factors[outer_idx],
new_factors[inner_idx],
));
new_factors.remove(inner_idx);
new_fork = edit.add_node(Node::Fork {
control: old_control,
factors: new_factors.into(),
});
edit.sub_edit(fork, new_fork);
edit = edit.replace_all_uses(fork, new_fork)?;
edit = edit.delete_node(fork)?;
for (tid, node) in fork_users {
let Node::ThreadID {
control: _,
dimension: tid_dim,
} = node
else {
continue;
};
if tid_dim > inner_idx {
let new_tid = Node::ThreadID {
control: new_fork,
dimension: tid_dim - 1,
};
let new_tid = edit.add_node(new_tid);
edit = edit.replace_all_uses(tid, new_tid)?;
edit.sub_edit(tid, new_tid);
} else if tid_dim == outer_idx {
let outer_tid = Node::ThreadID {
control: new_fork,
dimension: outer_idx,
};
let outer_tid = edit.add_node(outer_tid);
let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
new_nodes.push(outer_tid);
// inner_idx % dim(outer_idx)
let rem = edit.add_node(Node::Binary {
left: outer_tid,
right: outer_dc,
op: BinaryOperator::Rem,
});
edit.sub_edit(tid, rem);
edit.sub_edit(tid, outer_tid);
edit = edit.replace_all_uses(tid, rem)?;
} else if tid_dim == inner_idx {
let outer_tid = Node::ThreadID {
control: new_fork,
dimension: outer_idx,
};
let outer_tid = edit.add_node(outer_tid);
let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id });
// inner_idx / dim(outer_idx)
let div = edit.add_node(Node::Binary {
left: outer_tid,
right: outer_dc,
op: BinaryOperator::Div,
});
edit.sub_edit(tid, div);
edit.sub_edit(tid, outer_tid);
edit = edit.replace_all_uses(tid, div)?;
}
}
Ok(edit)
});
new_fork
}
/*
* Run fork interchange on all fork-joins that are mutable in an editor.
*/
pub fn fork_interchange_all_forks(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
first_dim: usize,
second_dim: usize,
) {
for (fork, join) in fork_join_map {
if editor.is_mutable(*fork) {
fork_interchange(editor, *fork, *join, first_dim, second_dim);
}
}
}
pub fn fork_interchange(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
first_dim: usize,
second_dim: usize,
) -> Option<NodeID> {
// Check that every reduce on the join is parallel or associative.
let nodes = &editor.func().nodes;
let schedules = &editor.func().schedules;
if !editor
.get_users(join)
.filter(|id| nodes[id.idx()].is_reduce())
.all(|id| {
schedules[id.idx()].contains(&Schedule::ParallelReduce)
|| schedules[id.idx()].contains(&Schedule::MonoidReduce)
})
{
// If not, we can't necessarily do interchange.
return None;
}
let Node::Fork {
control,
ref factors,
} = nodes[fork.idx()]
else {
panic!()
};
let fix_tids: Vec<(NodeID, Node)> = editor
.get_users(fork)
.filter_map(|id| {
nodes[id.idx()]
.try_thread_id()
.map(|(_, dim)| {
if dim == first_dim {
Some((
id,
Node::ThreadID {
control: fork,
dimension: second_dim,
},
))
} else if dim == second_dim {
Some((
id,
Node::ThreadID {
control: fork,
dimension: first_dim,
},
))
} else {
None
}
})
.flatten()
})
.collect();
let mut factors = factors.clone();
factors.swap(first_dim, second_dim);
let new_fork = Node::Fork { control, factors };
let mut new_fork_id = None;
editor.edit(|mut edit| {
for (old_id, new_tid) in fix_tids {
let new_id = edit.add_node(new_tid);
edit = edit.replace_all_uses(old_id, new_id)?;
edit = edit.delete_node(old_id)?;
}
let new_fork = edit.add_node(new_fork);
if edit.get_schedule(fork).contains(&Schedule::ParallelFork) {
edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?;
}
edit = edit.replace_all_uses(fork, new_fork)?;
edit = edit.delete_node(fork)?;
new_fork_id = Some(new_fork);
Ok(edit)
});
new_fork_id
}
/*
* Run fork unrolling on all fork-joins that are mutable in an editor.
*/
pub fn fork_unroll_all_forks(
editor: &mut FunctionEditor,
fork_joins: &HashMap<NodeID, NodeID>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) {
for (fork, join) in fork_joins {
if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) {
break;
}
}
}
pub fn fork_unroll(
editor: &mut FunctionEditor,
fork: NodeID,
join: NodeID,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) -> bool {
// We can only unroll fork-joins with a compile time known factor list. For
// simplicity, just unroll fork-joins that have a single dimension.
let nodes = &editor.func().nodes;
let Node::Fork {
control,
ref factors,
} = nodes[fork.idx()]
else {
panic!()
};
if factors.len() != 1 || editor.get_users(fork).count() != 2 {
return false;
}
let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else {
return false;
};
let tid = editor
.get_users(fork)
.filter(|id| nodes[id.idx()].is_thread_id())
.next()
.unwrap();
let inits: HashMap<NodeID, NodeID> = editor
.get_users(join)
.filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init)))
.collect();
editor.edit(|mut edit| {
// Create a copy of the nodes in the fork join per unrolled iteration,
// excluding the fork itself, the join itself, the thread IDs of the fork,
// and the reduces on the join. Keep a running tally of the top control
// node and the current reduction value.
let mut top_control = control;
let mut current_reduces = inits;
for iter in 0..cons {
let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64));
let iter_tid = edit.add_node(Node::Constant { id: iter_cons });
// First, add a copy of each node in the fork join unmodified.
// Record the mapping from old ID to new ID.
let mut added_ids = HashSet::new();
let mut old_to_new_ids = HashMap::new();
let mut new_control = None;
let mut new_reduces = HashMap::new();
for node in nodes_in_fork_joins[&fork].iter() {
if *node == fork {
old_to_new_ids.insert(*node, top_control);
} else if *node == join {
new_control = Some(edit.get_node(*node).try_join().unwrap());
} else if *node == tid {
old_to_new_ids.insert(*node, iter_tid);
} else if let Some(current) = current_reduces.get(node) {
old_to_new_ids.insert(*node, *current);
new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2);
} else {
let new_node = edit.add_node(edit.get_node(*node).clone());
old_to_new_ids.insert(*node, new_node);
added_ids.insert(new_node);
}
}
// Second, replace all the uses in the just added nodes.
if let Some(new_control) = new_control {
top_control = old_to_new_ids[&new_control];
}
for (reduce, reduct) in new_reduces {
current_reduces.insert(reduce, old_to_new_ids[&reduct]);
}
for (old, new) in old_to_new_ids {
edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?;
}
}
// Hook up the control and reduce outputs to the rest of the function.
edit = edit.replace_all_uses(join, top_control)?;
for (reduce, reduct) in current_reduces {
edit = edit.replace_all_uses(reduce, reduct)?;
}
// Delete the old fork-join.
for node in nodes_in_fork_joins[&fork].iter() {
edit = edit.delete_node(*node)?;
}
Ok(edit)
})
}
/*
* Looks for fork-joins that are next to each other, not inter-dependent, and
* have the same bounds. These fork-joins can be fused, pooling together all
* their reductions.
*/
pub fn fork_fusion_all_forks(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) {
for (fork, join) in fork_join_map {
if editor.is_mutable(*fork)
&& fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins)
{
break;
}
}
}
/*
* Tries to fuse a given fork join with the immediately following fork-join, if
* it exists.
*/
fn fork_fusion(
editor: &mut FunctionEditor,
top_fork: NodeID,
top_join: NodeID,
fork_join_map: &HashMap<NodeID, NodeID>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) -> bool {
let nodes = &editor.func().nodes;
// Rust operator precedence is not such that these can be put in one big
// let-else statement. Sad!
let Some(bottom_fork) = editor
.get_users(top_join)
.filter(|id| nodes[id.idx()].is_control())
.next()
else {
return false;
};
let Some(bottom_join) = fork_join_map.get(&bottom_fork) else {
return false;
};
let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap();
let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap();
assert_eq!(bottom_fork_pred, top_join);
let top_join_pred = nodes[top_join.idx()].try_join().unwrap();
let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap();
// The fork factors must be identical.
if top_factors != bottom_factors {
return false;
}
// Check that no iterated users of the top's reduces are in the bottom fork-
// join (iteration stops at a phi or reduce outside the bottom fork-join).
for reduce in editor
.get_users(top_join)
.filter(|id| nodes[id.idx()].is_reduce())
{
let mut visited = HashSet::new();
visited.insert(reduce);
let mut workset = vec![reduce];
while let Some(pop) = workset.pop() {
for u in editor.get_users(pop) {
if nodes_in_fork_joins[&bottom_fork].contains(&u) {
return false;
} else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce())
&& !nodes_in_fork_joins[&top_fork].contains(&u)
{
} else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) {
visited.insert(u);
workset.push(u);
}
}
}
}
// Perform the fusion.
let bottom_tids: Vec<_> = editor
.get_users(bottom_fork)
.filter(|id| nodes[id.idx()].is_thread_id())
.collect();
editor.edit(|mut edit| {
edit = edit.replace_all_uses_where(bottom_fork, top_fork, |id| bottom_tids.contains(id))?;
if bottom_join_pred != bottom_fork {
// If there is control flow in the bottom fork-join, stitch it into
// the top fork-join.
edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| {
nodes_in_fork_joins[&bottom_fork].contains(id)
})?;
edit =
edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?;
}
// Replace the bottom fork and join with the top fork and join.
edit = edit.replace_all_uses(bottom_fork, top_fork)?;
edit = edit.replace_all_uses(*bottom_join, top_join)?;
edit = edit.delete_node(bottom_fork)?;
edit = edit.delete_node(*bottom_join)?;
Ok(edit)
})
}
/*
* Looks for monoid reductions where the initial input is not the identity
* element, and converts them into a form whose initial input is an identity
* element. This aides in parallelizing outer loops. Looks only at reduces with
* the monoid reduce schedule, since that indicates a particular structure which
* is annoying to check for again.
*/
pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
for id in editor.node_ids() {
if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) {
continue;
}
let nodes = &editor.func().nodes;
let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else {
continue;
};
let out_uses: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect();
match nodes[reduct.idx()] {
Node::Binary {
op,
left: _,
right: _,
} if (op == BinaryOperator::Add || op == BinaryOperator::Or)
&& !is_zero(editor, init) =>
{
editor.edit(|mut edit| {
let zero = edit.add_zero_constant(typing[init.idx()]);
let zero = edit.add_node(Node::Constant { id: zero });
edit.sub_edit(id, zero);
edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?;
let final_op = edit.add_node(Node::Binary {
op,
left: init,
right: id,
});
for u in out_uses {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
});
}
Node::Binary {
op,
left: _,
right: _,
} if (op == BinaryOperator::Mul || op == BinaryOperator::And)
&& !is_one(editor, init) =>
{
editor.edit(|mut edit| {
let one = edit.add_one_constant(typing[init.idx()]);
let one = edit.add_node(Node::Constant { id: one });
edit.sub_edit(id, one);
edit = edit.replace_all_uses_where(init, one, |u| *u == id)?;
let final_op = edit.add_node(Node::Binary {
op,
left: init,
right: id,
});
for u in out_uses {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
});
}
Node::IntrinsicCall {
intrinsic: Intrinsic::Max,
args: _,
} if !is_smallest(editor, init) => {
editor.edit(|mut edit| {
let smallest = edit.add_smallest_constant(typing[init.idx()]);
let smallest = edit.add_node(Node::Constant { id: smallest });
edit.sub_edit(id, smallest);
edit = edit.replace_all_uses_where(init, smallest, |u| *u == id)?;
let final_op = edit.add_node(Node::IntrinsicCall {
intrinsic: Intrinsic::Max,
args: Box::new([init, id]),
});
for u in out_uses {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
});
}
Node::IntrinsicCall {
intrinsic: Intrinsic::Min,
args: _,
} if !is_largest(editor, init) => {
editor.edit(|mut edit| {
let largest = edit.add_largest_constant(typing[init.idx()]);
let largest = edit.add_node(Node::Constant { id: largest });
edit.sub_edit(id, largest);
edit = edit.replace_all_uses_where(init, largest, |u| *u == id)?;
let final_op = edit.add_node(Node::IntrinsicCall {
intrinsic: Intrinsic::Min,
args: Box::new([init, id]),
});
for u in out_uses {
edit.sub_edit(u, final_op);
}
edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op)
});
}
_ => {}
}
}
}
/*
* Extends the dimensions of a fork-join to be a multiple of a number and gates
* the execution of the body.
*/
pub fn extend_all_forks(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
multiple: usize,
) {
for (fork, join) in fork_join_map {
if editor.is_mutable(*fork) {
extend_fork(editor, *fork, *join, multiple);
}
}
}
fn extend_fork(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, multiple: usize) {
let nodes = &editor.func().nodes;
let (fork_pred, factors) = nodes[fork.idx()].try_fork().unwrap();
let factors = factors.to_vec();
let fork_succ = editor
.get_users(fork)
.filter(|id| nodes[id.idx()].is_control())
.next()
.unwrap();
let join_pred = nodes[join.idx()].try_join().unwrap();
let ctrl_between = fork != join_pred;
let reduces: Vec<_> = editor
.get_users(join)
.filter_map(|id| nodes[id.idx()].try_reduce().map(|x| (id, x)))
.collect();
editor.edit(|mut edit| {
// We can round up a dynamic constant A to a multiple of another dynamic
// constant B via the following math:
// ((A + B - 1) / B) * B
let new_factors: Vec<_> = factors
.iter()
.map(|factor| {
let b = edit.add_dynamic_constant(DynamicConstant::Constant(multiple));
let apb = edit.add_dynamic_constant(DynamicConstant::add(*factor, b));
let o = edit.add_dynamic_constant(DynamicConstant::Constant(1));
let apbmo = edit.add_dynamic_constant(DynamicConstant::sub(apb, o));
let apbmodb = edit.add_dynamic_constant(DynamicConstant::div(apbmo, b));
edit.add_dynamic_constant(DynamicConstant::mul(apbmodb, b))
})
.collect();
// Create the new control structure.
let new_fork = edit.add_node(Node::Fork {
control: fork_pred,
factors: new_factors.into_boxed_slice(),
});
edit = edit.replace_all_uses_where(fork, new_fork, |id| *id != fork_succ)?;
edit.sub_edit(fork, new_fork);
let conds: Vec<_> = factors
.iter()
.enumerate()
.map(|(idx, old_factor)| {
let tid = edit.add_node(Node::ThreadID {
control: new_fork,
dimension: idx,
});
let old_bound = edit.add_node(Node::DynamicConstant { id: *old_factor });
edit.add_node(Node::Binary {
op: BinaryOperator::LT,
left: tid,
right: old_bound,
})
})
.collect();
let cond = conds
.into_iter()
.reduce(|left, right| {
edit.add_node(Node::Binary {
op: BinaryOperator::And,
left,
right,
})
})
.unwrap();
let branch = edit.add_node(Node::If {
control: new_fork,
cond,
});
let false_proj = edit.add_node(Node::ControlProjection {
control: branch,
selection: 0,
});
let true_proj = edit.add_node(Node::ControlProjection {
control: branch,
selection: 1,
});
if ctrl_between {
edit = edit.replace_all_uses_where(fork, true_proj, |id| *id == fork_succ)?;
}
let bottom_region = edit.add_node(Node::Region {
preds: Box::new([false_proj, if ctrl_between { join_pred } else { true_proj }]),
});
let new_join = edit.add_node(Node::Join {
control: bottom_region,
});
edit = edit.replace_all_uses(join, new_join)?;
edit.sub_edit(join, new_join);
edit = edit.delete_node(fork)?;
edit = edit.delete_node(join)?;
// Update the reduces to use phis on the region node to gate their execution.
for (reduce, (_, init, reduct)) in reduces {
let phi = edit.add_node(Node::Phi {
control: bottom_region,
data: Box::new([reduce, reduct]),
});
let new_reduce = edit.add_node(Node::Reduce {
control: new_join,
init,
reduct: phi,
});
edit = edit.replace_all_uses(reduce, new_reduce)?;
edit.sub_edit(reduce, new_reduce);
edit = edit.delete_node(reduce)?;
}
Ok(edit)
});
}