fork_transforms.rs 17.15 KiB
use std::collections::{HashMap, HashSet};
use bimap::BiMap;
use itertools::Itertools;
use hercules_ir::*;
use crate::*;
type ForkID = usize;
/** Places each reduce node into its own fork */
pub fn default_reduce_partition(
editor: &FunctionEditor,
_fork: NodeID,
join: NodeID,
) -> SparseNodeMap<ForkID> {
let mut map = SparseNodeMap::new();
editor
.get_users(join)
.filter(|id| editor.func().nodes[id.idx()].is_reduce())
.enumerate()
.for_each(|(fork, reduce)| {
map.insert(reduce, fork);
});
map
}
// TODO: Refine these conditions.
/** */
pub fn find_reduce_dependencies<'a>(
function: &'a Function,
reduce: NodeID,
fork: NodeID,
) -> impl IntoIterator<Item = NodeID> + 'a {
let len = function.nodes.len();
let mut visited: DenseNodeMap<bool> = vec![false; len];
let mut depdendent: DenseNodeMap<bool> = vec![false; len];
// Does `fork` need to be a parameter here? It never changes. If this was a closure could it just capture it?
fn recurse(
function: &Function,
node: NodeID,
fork: NodeID,
dependent_map: &mut DenseNodeMap<bool>,
visited: &mut DenseNodeMap<bool>,
) -> () {
// return through dependent_map {
if visited[node.idx()] {
return;
}
visited[node.idx()] = true;
if node == fork {
dependent_map[node.idx()] = true;
return;
}
let binding = get_uses(&function.nodes[node.idx()]);
let uses = binding.as_ref();
for used in uses {
recurse(function, *used, fork, dependent_map, visited);
}
dependent_map[node.idx()] = uses.iter().map(|id| dependent_map[id.idx()]).any(|a| a);
return;
}
// Note: HACKY, the condition wwe want is 'all nodes on any path from the fork to the reduce (in the forward graph), or the reduce to the fork (in the directed graph)
// cycles break this, but we assume for now that the only cycles are ones that involve the reduce node
// NOTE: (control may break this (i.e loop inside fork) is a cycle that isn't the reduce)
// the current solution is just to mark the reduce as dependent at the start of traversing the graph.
depdendent[reduce.idx()] = true;
recurse(function, reduce, fork, &mut depdendent, &mut visited);
// Return node IDs that are dependent
let ret_val: Vec<_> = depdendent
.iter()
.enumerate()
.filter_map(|(idx, dependent)| {
if *dependent {
Some(NodeID::new(idx))
} else {
None
}
})
.collect();
ret_val
}
pub fn copy_subgraph(
editor: &mut FunctionEditor,
subgraph: HashSet<NodeID>,
) -> (
HashSet<NodeID>,
HashMap<NodeID, NodeID>,
Vec<(NodeID, NodeID)>,
) // returns all new nodes, a map from old nodes to new nodes, and
// a vec of pairs of nodes (old node, outside node) s.t old node -> outside node,
// outside means not part of the original subgraph.
{
let mut map: HashMap<NodeID, NodeID> = HashMap::new();
let mut new_nodes: HashSet<NodeID> = HashSet::new();
// Copy nodes
for old_id in subgraph.iter() {
editor.edit(|mut edit| {
let new_id = edit.copy_node(*old_id);
map.insert(*old_id, new_id);
new_nodes.insert(new_id);
Ok(edit)
});
}
// Update edges to new nodes
for old_id in subgraph.iter() {
// Replace all uses of old_id w/ new_id, where the use is in new_node
editor.edit(|edit| {
edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id))
});
}
// Get all users that aren't in new_nodes.
let mut outside_users = Vec::new();
for node in new_nodes.iter() {
for user in editor.get_users(*node) {
if !new_nodes.contains(&user) {
outside_users.push((*node, user));
}
}
}
(new_nodes, map, outside_users)
}
pub fn fork_fission<'a>(
editor: &'a mut FunctionEditor,
_control_subgraph: &Subgraph,
_types: &Vec<TypeID>,
_loop_tree: &LoopTree,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> () {
let forks: Vec<_> = editor
.func()
.nodes
.iter()
.enumerate()
.filter_map(|(idx, node)| {
if node.is_fork() {
Some(NodeID::new(idx))
} else {
None
}
})
.collect();
let control_pred = NodeID::new(0);
// This does the reduction fission:
for fork in forks.clone() {
// FIXME: If there is control in between fork and join, don't just give up.
let join = fork_join_map[&fork];
let join_pred = editor.func().nodes[join.idx()].try_join().unwrap();
if join_pred != fork {
todo!("Can't do fork fission on nodes with internal control")
// Inner control LOOPs are hard
// inner control in general *should* work right now without modifications.
}
let reduce_partition = default_reduce_partition(editor, fork, join);
fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork);
}
}
/** Split a 1D fork into two forks, placing select intermediate data into buffers. */
pub fn fork_bufferize_fission_helper<'a>(
editor: &'a mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized.
_original_control_pred: NodeID, // What the new fork connects to.
types: &Vec<TypeID>,
fork: NodeID,
) -> (NodeID, NodeID) {
// Returns the two forks that it generates.
// TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork.
// Copy fork + control intermediates + join to new fork + join,
// How does control get partitioned?
// (depending on how it affects the data nodes on each side of the bufferized_edges)
// may end up in each loop, fix me later.
// place new fork + join after join of first.
// Only handle fork+joins with no inner control for now.
// Create fork + join + Thread control
let join = fork_join_map[&fork];
let mut new_fork_id = NodeID::new(0);
let mut new_join_id = NodeID::new(0);
editor.edit(|mut edit| {
new_join_id = edit.add_node(Node::Join { control: fork });
let factors = edit.get_node(fork).try_fork().unwrap().1;
new_fork_id = edit.add_node(Node::Fork {
control: new_join_id,
factors: factors.into(),
});
edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join)
});
for (src, dst) in bufferized_edges {
// FIXME: Disgusting cloning and allocationing and iterators.
let factors: Vec<_> = editor.func().nodes[fork.idx()]
.try_fork()
.unwrap()
.1
.iter()
.cloned()
.collect();
editor.edit(|mut edit| {
// Create write to buffer
let thread_stuff_it = factors.into_iter().enumerate();
// FIxme: try to use unzip here? Idk why it wasn't working.
let tids = thread_stuff_it.clone().map(|(dim, _)| {
edit.add_node(Node::ThreadID {
control: fork,
dimension: dim,
})
});
let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor));
// Assume 1-d fork only for now.
// let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 });
let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
let write = edit.add_node(Node::Write {
collect: NodeID::new(0),
data: src,
indices: vec![position_idx].into(),
});
let ele_type = types[src.idx()];
let empty_buffer = edit.add_type(hercules_ir::Type::Array(
ele_type,
array_dims.collect::<Vec<_>>().into_boxed_slice(),
));
let empty_buffer = edit.add_zero_constant(empty_buffer);
let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
let reduce = Node::Reduce {
control: new_join_id,
init: empty_buffer,
reduct: write,
};
let reduce = edit.add_node(reduce);
// Fix write node
edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
// Create read from buffer
let tids = thread_stuff_it.clone().map(|(dim, _)| {
edit.add_node(Node::ThreadID {
control: new_fork_id,
dimension: dim,
})
});
let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice());
let read = edit.add_node(Node::Read {
collect: reduce,
indices: vec![position_idx].into(),
});
edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?;
Ok(edit)
});
}
(fork, new_fork_id)
}
/** Split a 1D fork into a separate fork for each reduction. */
pub fn fork_reduce_fission_helper<'a>(
editor: &'a mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split,
original_control_pred: NodeID, // What the new fork connects to.
fork: NodeID,
) -> (NodeID, NodeID) {
let join = fork_join_map[&fork];
let mut new_control_pred: NodeID = original_control_pred;
// Important edges are: Reduces,
// NOTE:
// Say two reduce are in a fork, s.t reduce A depends on reduce B
// If user wants A and B in separate forks:
// - we can simply refuse
// - or we can duplicate B
let mut new_fork = NodeID::new(0);
let mut new_join = NodeID::new(0);
// Gets everything between fork & join that this reduce needs. (ALL CONTROL)
for reduce in reduce_partition {
let reduce = reduce.0;
let function = editor.func();
let subgraph = find_reduce_dependencies(function, reduce, fork);
let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect();
subgraph.insert(join);
subgraph.insert(fork);
subgraph.insert(reduce);
let (_, mapping, _) = copy_subgraph(editor, subgraph);
new_fork = mapping[&fork];
new_join = mapping[&join];
editor.edit(|mut edit| {
// Atttach new_fork after control_pred
let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone();
edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| {
*usee == new_fork
})?;
// Replace uses of reduce
edit = edit.replace_all_uses(reduce, mapping[&reduce])?;
Ok(edit)
});
new_control_pred = new_join;
}
editor.edit(|mut edit| {
// Replace original join w/ new final join
edit = edit.replace_all_uses_where(join, new_join, |_| true)?;
// Delete original join (all reduce users have been moved)
edit = edit.delete_node(join)?;
// Replace all users of original fork, and then delete it, leftover users will be DCE'd.
edit = edit.replace_all_uses(fork, new_fork)?;
edit.delete_node(fork)
});
(new_fork, new_join)
}
pub fn fork_coalesce(
editor: &mut FunctionEditor,
loops: &LoopTree,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> bool {
let fork_joins = loops.bottom_up_loops().into_iter().filter_map(|(k, _)| {
if editor.func().nodes[k.idx()].is_fork() {
Some(k)
} else {
None
}
});
let fork_joins: Vec<_> = fork_joins.collect();
// FIXME: Add a postorder traversal to optimize this.
// FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early.
// something like: `fork_joins.postorder_iter().windows(2)` is ideal here.
for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) {
if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) {
return true;
}
}
return false;
}
/** Opposite of fork split, takes two fork-joins
with no control between them, and merges them into a single fork-join.
*/
pub fn fork_coalesce_helper(
editor: &mut FunctionEditor,
outer_fork: NodeID,
inner_fork: NodeID,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> bool {
// Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork.
let outer_join = fork_join_map[&outer_fork];
let inner_join = fork_join_map[&inner_fork];
let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner
// FIXME: Iterate all control uses of joins to really collect all reduces
// (reduces can be attached to inner control)
for outer_reduce in editor
.get_users(outer_join)
.filter(|node| editor.func().nodes[node.idx()].is_reduce())
{
// check that inner reduce is of the inner join
let (_, _, outer_reduct) = editor.func().nodes[outer_reduce.idx()]
.try_reduce()
.unwrap();
let inner_reduce = outer_reduct;
let inner_reduce_node = &editor.func().nodes[outer_reduct.idx()];
let Node::Reduce {
control: inner_control,
init: inner_init,
reduct: _,
} = inner_reduce_node
else {
return false;
};
// FIXME: check this condition better (i.e reduce might not be attached to join)
if *inner_control != inner_join {
return false;
};
if *inner_init != outer_reduce {
return false;
};
if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) {
return false;
} else {
pairs.insert(outer_reduce, inner_reduce);
}
}
// Check for control between join-join and fork-fork
let Some(user) = editor
.get_users(outer_fork)
.filter(|node| editor.func().nodes[node.idx()].is_control())
.next()
else {
return false;
};
if user != inner_fork {
return false;
}
let Some(user) = editor
.get_users(inner_join)
.filter(|node| editor.func().nodes[node.idx()].is_control())
.next()
else {
return false;
};
if user != outer_join {
return false;
}
// Checklist:
// Increment inner TIDs
// Add outer fork's dimension to front of inner fork.
// Fuse reductions
// - Initializer becomes outer initializer
// Replace uses of outer fork w/ inner fork.
// Replace uses of outer join w/ inner join.
// Delete outer fork-join
let inner_tids: Vec<NodeID> = editor
.get_users(inner_fork)
.filter(|node| editor.func().nodes[node.idx()].is_thread_id())
.collect();
let (outer_pred, outer_dims) = editor.func().nodes[outer_fork.idx()].try_fork().unwrap();
let (_, inner_dims) = editor.func().nodes[inner_fork.idx()].try_fork().unwrap();
let num_outer_dims = outer_dims.len();
let mut new_factors = outer_dims.to_vec();
// CHECKME / FIXME: Might need to be added the other way.
new_factors.append(&mut inner_dims.to_vec());
for tid in inner_tids {
let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap();
let new_tid = Node::ThreadID {
control: fork,
dimension: dim + num_outer_dims,
};
editor.edit(|mut edit| {
let new_tid = edit.add_node(new_tid);
let edit = edit.replace_all_uses(tid, new_tid)?;
Ok(edit)
});
}
// Fuse Reductions
for (outer_reduce, inner_reduce) in pairs {
let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()]
.try_reduce()
.unwrap();
let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()]
.try_reduce()
.unwrap();
editor.edit(|mut edit| {
// Set inner init to outer init.
edit =
edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?;
edit = edit.replace_all_uses(outer_reduce, inner_reduce)?;
edit = edit.delete_node(outer_reduce)?;
Ok(edit)
});
}
editor.edit(|mut edit| {
let new_fork = Node::Fork {
control: outer_pred,
factors: new_factors.into(),
};
let new_fork = edit.add_node(new_fork);
edit = edit.replace_all_uses(inner_fork, new_fork)?;
edit = edit.replace_all_uses(outer_fork, new_fork)?;
edit = edit.replace_all_uses(outer_join, inner_join)?;
edit = edit.delete_node(outer_join)?;
edit = edit.delete_node(inner_fork)?;
edit = edit.delete_node(outer_fork)?;
Ok(edit)
});
true
}