use std::collections::{HashMap, HashSet}; use std::iter::zip; use bimap::BiMap; use itertools::Itertools; use hercules_ir::*; use crate::*; type ForkID = usize; /** Places each reduce node into its own fork */ pub fn default_reduce_partition( editor: &FunctionEditor, _fork: NodeID, join: NodeID, ) -> SparseNodeMap<ForkID> { let mut map = SparseNodeMap::new(); editor .get_users(join) .filter(|id| editor.func().nodes[id.idx()].is_reduce()) .enumerate() .for_each(|(fork, reduce)| { map.insert(reduce, fork); }); map } // TODO: Refine these conditions. /** */ pub fn find_reduce_dependencies<'a>( function: &'a Function, reduce: NodeID, fork: NodeID, ) -> impl IntoIterator<Item = NodeID> + 'a { let len = function.nodes.len(); let mut visited: DenseNodeMap<bool> = vec![false; len]; let mut depdendent: DenseNodeMap<bool> = vec![false; len]; // Does `fork` need to be a parameter here? It never changes. If this was a closure could it just capture it? fn recurse( function: &Function, node: NodeID, fork: NodeID, dependent_map: &mut DenseNodeMap<bool>, visited: &mut DenseNodeMap<bool>, ) -> () { // return through dependent_map { if visited[node.idx()] { return; } visited[node.idx()] = true; if node == fork { dependent_map[node.idx()] = true; return; } let binding = get_uses(&function.nodes[node.idx()]); let uses = binding.as_ref(); for used in uses { recurse(function, *used, fork, dependent_map, visited); } dependent_map[node.idx()] = uses.iter().map(|id| dependent_map[id.idx()]).any(|a| a); return; } // Note: HACKY, the condition wwe want is 'all nodes on any path from the fork to the reduce (in the forward graph), or the reduce to the fork (in the directed graph) // cycles break this, but we assume for now that the only cycles are ones that involve the reduce node // NOTE: (control may break this (i.e loop inside fork) is a cycle that isn't the reduce) // the current solution is just to mark the reduce as dependent at the start of traversing the graph. depdendent[reduce.idx()] = true; recurse(function, reduce, fork, &mut depdendent, &mut visited); // Return node IDs that are dependent let ret_val: Vec<_> = depdendent .iter() .enumerate() .filter_map(|(idx, dependent)| { if *dependent { Some(NodeID::new(idx)) } else { None } }) .collect(); ret_val } pub fn copy_subgraph( editor: &mut FunctionEditor, subgraph: HashSet<NodeID>, ) -> ( HashSet<NodeID>, HashMap<NodeID, NodeID>, Vec<(NodeID, NodeID)>, ) // returns all new nodes, a map from old nodes to new nodes, and // a vec of pairs of nodes (old node, outside node) s.t old node -> outside node, // outside means not part of the original subgraph. { let mut map: HashMap<NodeID, NodeID> = HashMap::new(); let mut new_nodes: HashSet<NodeID> = HashSet::new(); // Copy nodes for old_id in subgraph.iter() { editor.edit(|mut edit| { let new_id = edit.copy_node(*old_id); map.insert(*old_id, new_id); new_nodes.insert(new_id); Ok(edit) }); } // Update edges to new nodes for old_id in subgraph.iter() { // Replace all uses of old_id w/ new_id, where the use is in new_node editor.edit(|edit| { edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id)) }); } // Get all users that aren't in new_nodes. let mut outside_users = Vec::new(); for node in new_nodes.iter() { for user in editor.get_users(*node) { if !new_nodes.contains(&user) { outside_users.push((*node, user)); } } } (new_nodes, map, outside_users) } pub fn fork_fission<'a>( editor: &'a mut FunctionEditor, _control_subgraph: &Subgraph, _types: &Vec<TypeID>, _loop_tree: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, ) -> () { let forks: Vec<_> = editor .func() .nodes .iter() .enumerate() .filter_map(|(idx, node)| { if node.is_fork() { Some(NodeID::new(idx)) } else { None } }) .collect(); let control_pred = NodeID::new(0); // This does the reduction fission: for fork in forks.clone() { // FIXME: If there is control in between fork and join, don't just give up. let join = fork_join_map[&fork]; let join_pred = editor.func().nodes[join.idx()].try_join().unwrap(); if join_pred != fork { todo!("Can't do fork fission on nodes with internal control") // Inner control LOOPs are hard // inner control in general *should* work right now without modifications. } let reduce_partition = default_reduce_partition(editor, fork, join); fork_reduce_fission_helper(editor, fork_join_map, reduce_partition, control_pred, fork); } } /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ pub fn fork_bufferize_fission_helper<'a>( editor: &'a mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, bufferized_edges: HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. _original_control_pred: NodeID, // What the new fork connects to. types: &Vec<TypeID>, fork: NodeID, ) -> (NodeID, NodeID) { // Returns the two forks that it generates. // TODO: Check that bufferized edges src doesn't depend on anything that comes after the fork. // Copy fork + control intermediates + join to new fork + join, // How does control get partitioned? // (depending on how it affects the data nodes on each side of the bufferized_edges) // may end up in each loop, fix me later. // place new fork + join after join of first. // Only handle fork+joins with no inner control for now. // Create fork + join + Thread control let join = fork_join_map[&fork]; let mut new_fork_id = NodeID::new(0); let mut new_join_id = NodeID::new(0); editor.edit(|mut edit| { new_join_id = edit.add_node(Node::Join { control: fork }); let factors = edit.get_node(fork).try_fork().unwrap().1; new_fork_id = edit.add_node(Node::Fork { control: new_join_id, factors: factors.into(), }); edit.replace_all_uses_where(fork, new_fork_id, |usee| *usee == join) }); for (src, dst) in bufferized_edges { // FIXME: Disgusting cloning and allocationing and iterators. let factors: Vec<_> = editor.func().nodes[fork.idx()] .try_fork() .unwrap() .1 .iter() .cloned() .collect(); editor.edit(|mut edit| { // Create write to buffer let thread_stuff_it = factors.into_iter().enumerate(); // FIxme: try to use unzip here? Idk why it wasn't working. let tids = thread_stuff_it.clone().map(|(dim, _)| { edit.add_node(Node::ThreadID { control: fork, dimension: dim, }) }); let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); // Assume 1-d fork only for now. // let tid = edit.add_node(Node::ThreadID { control: fork, dimension: 0 }); let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); let write = edit.add_node(Node::Write { collect: NodeID::new(0), data: src, indices: vec![position_idx].into(), }); let ele_type = types[src.idx()]; let empty_buffer = edit.add_type(hercules_ir::Type::Array( ele_type, array_dims.collect::<Vec<_>>().into_boxed_slice(), )); let empty_buffer = edit.add_zero_constant(empty_buffer); let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); let reduce = Node::Reduce { control: new_join_id, init: empty_buffer, reduct: write, }; let reduce = edit.add_node(reduce); // Fix write node edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; // Create read from buffer let tids = thread_stuff_it.clone().map(|(dim, _)| { edit.add_node(Node::ThreadID { control: new_fork_id, dimension: dim, }) }); let position_idx = Index::Position(tids.collect::<Vec<_>>().into_boxed_slice()); let read = edit.add_node(Node::Read { collect: reduce, indices: vec![position_idx].into(), }); edit = edit.replace_all_uses_where(src, read, |usee| *usee == dst)?; Ok(edit) }); } (fork, new_fork_id) } /** Split a 1D fork into a separate fork for each reduction. */ pub fn fork_reduce_fission_helper<'a>( editor: &'a mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split, original_control_pred: NodeID, // What the new fork connects to. fork: NodeID, ) -> (NodeID, NodeID) { let join = fork_join_map[&fork]; let mut new_control_pred: NodeID = original_control_pred; // Important edges are: Reduces, // NOTE: // Say two reduce are in a fork, s.t reduce A depends on reduce B // If user wants A and B in separate forks: // - we can simply refuse // - or we can duplicate B let mut new_fork = NodeID::new(0); let mut new_join = NodeID::new(0); // Gets everything between fork & join that this reduce needs. (ALL CONTROL) for reduce in reduce_partition { let reduce = reduce.0; let function = editor.func(); let subgraph = find_reduce_dependencies(function, reduce, fork); let mut subgraph: HashSet<NodeID> = subgraph.into_iter().collect(); subgraph.insert(join); subgraph.insert(fork); subgraph.insert(reduce); let (_, mapping, _) = copy_subgraph(editor, subgraph); new_fork = mapping[&fork]; new_join = mapping[&join]; editor.edit(|mut edit| { // Atttach new_fork after control_pred let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone(); edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| { *usee == new_fork })?; // Replace uses of reduce edit = edit.replace_all_uses(reduce, mapping[&reduce])?; Ok(edit) }); new_control_pred = new_join; } editor.edit(|mut edit| { // Replace original join w/ new final join edit = edit.replace_all_uses_where(join, new_join, |_| true)?; // Delete original join (all reduce users have been moved) edit = edit.delete_node(join)?; // Replace all users of original fork, and then delete it, leftover users will be DCE'd. edit = edit.replace_all_uses(fork, new_fork)?; edit.delete_node(fork) }); (new_fork, new_join) } pub fn fork_coalesce( editor: &mut FunctionEditor, loops: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, ) -> bool { let fork_joins = loops.bottom_up_loops().into_iter().filter_map(|(k, _)| { if editor.func().nodes[k.idx()].is_fork() { Some(k) } else { None } }); let fork_joins: Vec<_> = fork_joins.collect(); // FIXME: Add a postorder traversal to optimize this. // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early. // something like: `fork_joins.postorder_iter().windows(2)` is ideal here. for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) { if fork_coalesce_helper(editor, *outer, *inner, fork_join_map) { return true; } } return false; } /** Opposite of fork split, takes two fork-joins with no control between them, and merges them into a single fork-join. */ pub fn fork_coalesce_helper( editor: &mut FunctionEditor, outer_fork: NodeID, inner_fork: NodeID, fork_join_map: &HashMap<NodeID, NodeID>, ) -> bool { // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork. let outer_join = fork_join_map[&outer_fork]; let inner_join = fork_join_map[&inner_fork]; let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner // FIXME: Iterate all control uses of joins to really collect all reduces // (reduces can be attached to inner control) for outer_reduce in editor .get_users(outer_join) .filter(|node| editor.func().nodes[node.idx()].is_reduce()) { // check that inner reduce is of the inner join let (_, _, outer_reduct) = editor.func().nodes[outer_reduce.idx()] .try_reduce() .unwrap(); let inner_reduce = outer_reduct; let inner_reduce_node = &editor.func().nodes[outer_reduct.idx()]; let Node::Reduce { control: inner_control, init: inner_init, reduct: _, } = inner_reduce_node else { return false; }; // FIXME: check this condition better (i.e reduce might not be attached to join) if *inner_control != inner_join { return false; }; if *inner_init != outer_reduce { return false; }; if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) { return false; } else { pairs.insert(outer_reduce, inner_reduce); } } // Check for control between join-join and fork-fork let Some(user) = editor .get_users(outer_fork) .filter(|node| editor.func().nodes[node.idx()].is_control()) .next() else { return false; }; if user != inner_fork { return false; } let Some(user) = editor .get_users(inner_join) .filter(|node| editor.func().nodes[node.idx()].is_control()) .next() else { return false; }; if user != outer_join { return false; } // Checklist: // Increment inner TIDs // Add outer fork's dimension to front of inner fork. // Fuse reductions // - Initializer becomes outer initializer // Replace uses of outer fork w/ inner fork. // Replace uses of outer join w/ inner join. // Delete outer fork-join let inner_tids: Vec<NodeID> = editor .get_users(inner_fork) .filter(|node| editor.func().nodes[node.idx()].is_thread_id()) .collect(); let (outer_pred, outer_dims) = editor.func().nodes[outer_fork.idx()].try_fork().unwrap(); let (_, inner_dims) = editor.func().nodes[inner_fork.idx()].try_fork().unwrap(); let num_outer_dims = outer_dims.len(); let mut new_factors = outer_dims.to_vec(); // CHECKME / FIXME: Might need to be added the other way. new_factors.append(&mut inner_dims.to_vec()); for tid in inner_tids { let (fork, dim) = editor.func().nodes[tid.idx()].try_thread_id().unwrap(); let new_tid = Node::ThreadID { control: fork, dimension: dim + num_outer_dims, }; editor.edit(|mut edit| { let new_tid = edit.add_node(new_tid); let edit = edit.replace_all_uses(tid, new_tid)?; Ok(edit) }); } // Fuse Reductions for (outer_reduce, inner_reduce) in pairs { let (_, outer_init, _) = editor.func().nodes[outer_reduce.idx()] .try_reduce() .unwrap(); let (_, inner_init, _) = editor.func().nodes[inner_reduce.idx()] .try_reduce() .unwrap(); editor.edit(|mut edit| { // Set inner init to outer init. edit = edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; Ok(edit) }); } editor.edit(|mut edit| { let new_fork = Node::Fork { control: outer_pred, factors: new_factors.into(), }; let new_fork = edit.add_node(new_fork); edit = edit.replace_all_uses(inner_fork, new_fork)?; edit = edit.replace_all_uses(outer_fork, new_fork)?; edit = edit.replace_all_uses(outer_join, inner_join)?; edit = edit.delete_node(outer_join)?; edit = edit.delete_node(inner_fork)?; edit = edit.delete_node(outer_fork)?; Ok(edit) }); true } pub fn split_all_forks( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { for (fork, join) in fork_join_map { if let Some((forks, _)) = split_fork(editor, *fork, *join, reduce_cycles) && forks.len() > 1 { break; } } } /* * Split multi-dimensional fork-joins into separate one-dimensional fork-joins. * Useful for code generation. A single iteration of `fork_split` only splits * at most one fork-join, it must be called repeatedly to split all fork-joins. */ pub(crate) fn split_fork( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) -> Option<(Vec<NodeID>, Vec<NodeID>)> { // A single multi-dimensional fork becomes multiple forks, a join becomes // multiple joins, a thread ID becomes a thread ID on the correct // fork, and a reduce becomes multiple reduces to shuffle the reduction // value through the fork-join nest. let nodes = &editor.func().nodes; let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap(); if factors.len() < 2 { return Some((vec![fork], vec![join])); } let factors: Box<[DynamicConstantID]> = factors.into(); let join_control = nodes[join.idx()].try_join().unwrap(); let tids: Vec<_> = editor .get_users(fork) .filter(|id| nodes[id.idx()].is_thread_id()) .collect(); let reduces: Vec<_> = editor .get_users(join) .filter(|id| nodes[id.idx()].is_reduce()) .collect(); let data_in_reduce_cycle: HashSet<(NodeID, NodeID)> = reduces .iter() .map(|reduce| editor.get_users(*reduce).map(move |user| (user, *reduce))) .flatten() .filter(|(user, reduce)| reduce_cycles[&reduce].contains(&user)) .collect(); let mut new_forks = vec![]; let mut new_joins = vec![]; let success = editor.edit(|mut edit| { // Create the forks and a thread ID per fork. let mut acc_fork = fork_control; let mut new_tids = vec![]; for factor in factors { acc_fork = edit.add_node(Node::Fork { control: acc_fork, factors: Box::new([factor]), }); new_forks.push(acc_fork); edit.sub_edit(fork, acc_fork); new_tids.push(edit.add_node(Node::ThreadID { control: acc_fork, dimension: 0, })); } // Create the joins. let mut acc_join = if join_control == fork { acc_fork } else { join_control }; for _ in new_tids.iter() { acc_join = edit.add_node(Node::Join { control: acc_join }); edit.sub_edit(join, acc_join); new_joins.push(acc_join); } // Create the reduces. let mut new_reduces = vec![]; for reduce in reduces.iter() { let (_, init, reduct) = edit.get_node(*reduce).try_reduce().unwrap(); let num_nodes = edit.num_node_ids(); let mut inner_reduce = NodeID::new(0); let mut outer_reduce = NodeID::new(0); for (join_idx, join) in new_joins.iter().enumerate() { let init = if join_idx == new_joins.len() - 1 { init } else { NodeID::new(num_nodes + join_idx + 1) }; let reduct = if join_idx == 0 { reduct } else { NodeID::new(num_nodes + join_idx - 1) }; let new_reduce = edit.add_node(Node::Reduce { control: *join, init, reduct, }); assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx)); edit.sub_edit(*reduce, new_reduce); if join_idx == 0 { inner_reduce = new_reduce; } if join_idx == new_joins.len() - 1 { outer_reduce = new_reduce; } } new_reduces.push((inner_reduce, outer_reduce)); } // Replace everything. edit = edit.replace_all_uses(fork, acc_fork)?; edit = edit.replace_all_uses(join, acc_join)?; for tid in tids.iter() { let dim = edit.get_node(*tid).try_thread_id().unwrap().1; edit.sub_edit(*tid, new_tids[dim]); edit = edit.replace_all_uses(*tid, new_tids[dim])?; } for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) { edit = edit.replace_all_uses_where(*reduce, inner_reduce, |id| { data_in_reduce_cycle.contains(&(*id, *reduce)) })?; edit = edit.replace_all_uses_where(*reduce, outer_reduce, |id| { !data_in_reduce_cycle.contains(&(*id, *reduce)) })?; } // Delete all the old stuff. edit = edit.delete_node(fork)?; edit = edit.delete_node(join)?; for tid in tids { edit = edit.delete_node(tid)?; } for reduce in reduces { edit = edit.delete_node(reduce)?; } Ok(edit) }); if success { Some((new_forks, new_joins)) } else { None } }