use std::collections::{HashMap, HashSet}; use std::iter::zip; use bimap::BiMap; use itertools::Itertools; use hercules_ir::*; use crate::*; type ForkID = usize; /** Places each reduce node into its own fork */ pub fn default_reduce_partition( editor: &FunctionEditor, _fork: NodeID, join: NodeID, ) -> SparseNodeMap<ForkID> { let mut map = SparseNodeMap::new(); editor .get_users(join) .filter(|id| editor.func().nodes[id.idx()].is_reduce()) .enumerate() .for_each(|(fork, reduce)| { map.insert(reduce, fork); }); map } // TODO: Refine these conditions. /** */ pub fn find_reduce_dependencies<'a>( function: &'a Function, reduce: NodeID, fork: NodeID, ) -> impl IntoIterator<Item = NodeID> + 'a { let len = function.nodes.len(); let mut visited: DenseNodeMap<bool> = vec![false; len]; let mut depdendent: DenseNodeMap<bool> = vec![false; len]; // Does `fork` need to be a parameter here? It never changes. If this was a closure could it just capture it? fn recurse( function: &Function, node: NodeID, fork: NodeID, dependent_map: &mut DenseNodeMap<bool>, visited: &mut DenseNodeMap<bool>, ) -> () { // return through dependent_map { if visited[node.idx()] { return; } visited[node.idx()] = true; if node == fork { dependent_map[node.idx()] = true; return; } let binding = get_uses(&function.nodes[node.idx()]); let uses = binding.as_ref(); for used in uses { recurse(function, *used, fork, dependent_map, visited); } dependent_map[node.idx()] = uses.iter().map(|id| dependent_map[id.idx()]).any(|a| a); return; } // Note: HACKY, the condition wwe want is 'all nodes on any path from the fork to the reduce (in the forward graph), or the reduce to the fork (in the directed graph) // cycles break this, but we assume for now that the only cycles are ones that involve the reduce node // NOTE: (control may break this (i.e loop inside fork) is a cycle that isn't the reduce) // the current solution is just to mark the reduce as dependent at the start of traversing the graph. depdendent[reduce.idx()] = true; recurse(function, reduce, fork, &mut depdendent, &mut visited); // Return node IDs that are dependent let ret_val: Vec<_> = depdendent .iter() .enumerate() .filter_map(|(idx, dependent)| { if *dependent { Some(NodeID::new(idx)) } else { None } }) .collect(); ret_val } pub fn copy_subgraph_in_edit<'a, 'b>( mut edit: FunctionEdit<'a, 'b>, subgraph: HashSet<NodeID>, ) -> Result<(FunctionEdit<'a, 'b>, HashMap<NodeID, NodeID>), FunctionEdit<'a, 'b>> { let mut map: HashMap<NodeID, NodeID> = HashMap::new(); // Copy nodes in subgraph for old_id in subgraph.iter().cloned() { let new_id = edit.copy_node(old_id); map.insert(old_id, new_id); } // Update edges to new nodes for old_id in subgraph.iter() { edit = edit.replace_all_uses_where(*old_id, map[old_id], |node_id| { map.values().contains(node_id) })?; } Ok((edit, map)) } pub fn copy_subgraph( editor: &mut FunctionEditor, subgraph: HashSet<NodeID>, ) -> ( HashSet<NodeID>, HashMap<NodeID, NodeID>, Vec<(NodeID, NodeID)>, ) // returns all new nodes, a map from old nodes to new nodes, and // a vec of pairs of nodes (old node, outside node) s.t old node -> outside node, // outside means not part of the original subgraph. { let mut map: HashMap<NodeID, NodeID> = HashMap::new(); let mut new_nodes: HashSet<NodeID> = HashSet::new(); // Copy nodes for old_id in subgraph.iter() { editor.edit(|mut edit| { let new_id = edit.copy_node(*old_id); map.insert(*old_id, new_id); new_nodes.insert(new_id); Ok(edit) }); } // Update edges to new nodes for old_id in subgraph.iter() { // Replace all uses of old_id w/ new_id, where the use is in new_node editor.edit(|edit| { edit.replace_all_uses_where(*old_id, map[old_id], |node_id| new_nodes.contains(node_id)) }); } // Get all users that aren't in new_nodes. let mut outside_users = Vec::new(); for node in new_nodes.iter() { for user in editor.get_users(*node) { if !new_nodes.contains(&user) { outside_users.push((*node, user)); } } } (new_nodes, map, outside_users) } pub fn find_bufferize_edges( editor: &mut FunctionEditor, fork: NodeID, loop_tree: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, data_label: &LabelID, ) -> HashSet<(NodeID, NodeID)> { let mut edges: HashSet<_> = HashSet::new(); for node in &nodes_in_fork_joins[&fork] { // Edge from *has data label** to doesn't have data label* let node_labels = &editor.func().labels[node.idx()]; if !node_labels.contains(data_label) { continue; } // Don't draw bufferize edges from fork tids if editor.get_users(fork).contains(node) { continue; } for user in editor.get_users(*node) { let user_labels = &editor.func().labels[user.idx()]; if user_labels.contains(data_label) { continue; } if editor.node(user).is_control() || editor.node(node).is_control() { continue; } edges.insert((*node, user)); } } edges } pub fn ff_bufferize_create_not_reduce_cycle_label_helper( editor: &mut FunctionEditor, fork: NodeID, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> LabelID { let join = fork_join_map[&fork]; let mut nodes_not_in_a_reduce_cycle = nodes_in_fork_joins[&fork].clone(); for (cycle, reduce) in editor .get_users(join) .filter_map(|id| reduce_cycles.get(&id).map(|cycle| (cycle, id))) { nodes_not_in_a_reduce_cycle.remove(&reduce); for id in cycle { nodes_not_in_a_reduce_cycle.remove(id); } } nodes_not_in_a_reduce_cycle.remove(&join); let mut label = LabelID::new(0); let success = editor.edit(|mut edit| { label = edit.fresh_label(); for id in nodes_not_in_a_reduce_cycle { edit = edit.add_label(id, label)?; } Ok(edit) }); assert!(success); label } pub fn ff_bufferize_any_fork<'a, 'b>( editor: &'b mut FunctionEditor<'a>, loop_tree: &'b LoopTree, fork_join_map: &'b HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, typing: &'b Vec<TypeID>, fork_label: LabelID, data_label: Option<LabelID>, ) -> Option<(NodeID, NodeID)> where 'a: 'b, { let mut forks: Vec<_> = loop_tree .bottom_up_loops() .into_iter() .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) .collect(); forks.reverse(); for l in forks { let fork_info = Loop { header: l.0, control: l.1.clone(), }; let fork = fork_info.header; let join = fork_join_map[&fork]; if !editor.func().labels[fork.idx()].contains(&fork_label) { continue; } let data_label = data_label.unwrap_or_else(|| { ff_bufferize_create_not_reduce_cycle_label_helper( editor, fork, fork_join_map, reduce_cycles, nodes_in_fork_joins, ) }); let edges = find_bufferize_edges( editor, fork, &loop_tree, &fork_join_map, &nodes_in_fork_joins, &data_label, ); let result = fork_bufferize_fission_helper( editor, &fork_info, &edges, nodes_in_fork_joins, typing, fork, join, ); if result.is_none() { continue; } else { return result; } } return None; } pub fn fork_fission<'a>( editor: &'a mut FunctionEditor, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, loop_tree: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, fork_label: LabelID, ) -> Vec<NodeID> { let forks: Vec<_> = loop_tree .bottom_up_loops() .into_iter() .filter(|(k, _)| editor.func().nodes[k.idx()].is_fork()) .collect(); let mut created_forks = Vec::new(); // This does the reduction fission for fork in forks { let join = fork_join_map[&fork.0]; // FIXME: Don't make multiple forks for reduces that are in cycles with each other. let reduce_partition = default_reduce_partition(editor, fork.0, join); if !editor.func().labels[fork.0.idx()].contains(&fork_label) { continue; } if editor.is_mutable(fork.0) { created_forks = fork_reduce_fission_helper( editor, fork_join_map, reduce_partition, nodes_in_fork_joins, fork.0, ); if created_forks.is_empty() { continue; } else { return created_forks; } } } created_forks } /** Split a 1D fork into two forks, placing select intermediate data into buffers. */ pub fn fork_bufferize_fission_helper<'a, 'b>( editor: &'b mut FunctionEditor<'a>, l: &Loop, bufferized_edges: &HashSet<(NodeID, NodeID)>, // Describes what intermediate data should be bufferized. data_node_in_fork_joins: &'b HashMap<NodeID, HashSet<NodeID>>, types: &'b Vec<TypeID>, fork: NodeID, join: NodeID, ) -> Option<(NodeID, NodeID)> where 'a: 'b, { if bufferized_edges.is_empty() { return None; } let all_loop_nodes = l.get_all_nodes(); // FIXME: Cloning hell. let data_nodes = data_node_in_fork_joins[&fork].clone(); let loop_nodes = editor .node_ids() .filter(|node_id| all_loop_nodes[node_id.idx()]); // Clone the subgraph that consists of this fork-join and all data and control nodes in it. let subgraph = HashSet::from_iter(data_nodes.into_iter().chain(loop_nodes)); let mut outside_users = Vec::new(); // old_node, outside_user for node in subgraph.iter() { for user in editor.get_users(*node) { if !subgraph.iter().contains(&user) { outside_users.push((*node, user)); } } } let factors: Vec<_> = editor.func().nodes[fork.idx()] .try_fork() .unwrap() .1 .iter() .cloned() .collect(); let thread_stuff_it = factors.into_iter().enumerate(); // Control succesors let fork_pred = editor .get_uses(fork) .filter(|a| editor.node(a).is_control()) .next() .unwrap(); let join_successor = editor .get_users(join) .filter(|a| editor.node(a).is_control()) .next() .unwrap(); let mut new_fork_id = NodeID::new(0); let edit_result = editor.edit(|edit| { let (mut edit, map) = copy_subgraph_in_edit(edit, subgraph)?; edit = edit.replace_all_uses_where(fork_pred, join, |a| *a == map[&fork])?; edit = edit.replace_all_uses_where(join, map[&join], |a| *a == join_successor)?; // Replace outside uses of reduces in old subgraph with reduces in new subgraph. for (old_node, outside_user) in outside_users { edit = edit .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?; } let new_fork = map[&fork]; // FIXME: Do this as part of copy subgraph? // Add tids to original subgraph for indexing. let mut old_tids = Vec::new(); let mut new_tids = Vec::new(); for (dim, _) in thread_stuff_it.clone() { let old_id = edit.add_node(Node::ThreadID { control: fork, dimension: dim, }); let new_id = edit.add_node(Node::ThreadID { control: new_fork, dimension: dim, }); old_tids.push(old_id); new_tids.push(new_id); } for (src, dst) in bufferized_edges { let array_dims = thread_stuff_it.clone().map(|(_, factor)| (factor)); let position_idx = Index::Position(old_tids.clone().into_boxed_slice()); let write = edit.add_node(Node::Write { collect: NodeID::new(0), data: *src, indices: vec![position_idx.clone()].into(), }); let ele_type = types[src.idx()]; let empty_buffer = edit.add_type(hercules_ir::Type::Array( ele_type, array_dims.collect::<Vec<_>>().into_boxed_slice(), )); let empty_buffer = edit.add_zero_constant(empty_buffer); let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer }); edit = edit.add_schedule(empty_buffer, Schedule::NoResetConstant)?; let reduce = Node::Reduce { control: join, init: empty_buffer, reduct: write, }; let reduce = edit.add_node(reduce); edit = edit.add_schedule(reduce, Schedule::ParallelReduce)?; // Fix write node edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?; // Create reads from buffer let position_idx = Index::Position(new_tids.clone().into_boxed_slice()); let read = edit.add_node(Node::Read { collect: reduce, indices: vec![position_idx].into(), }); // Replaces uses of bufferized edge src with corresponding reduce and read in old subraph edit = edit.replace_all_uses_where(map[src], read, |usee| *usee == map[dst])?; } new_fork_id = new_fork; Ok(edit) }); if edit_result { Some((fork, new_fork_id)) } else { None } } /** Split a fork into a separate fork for each reduction. */ pub fn fork_reduce_fission_helper<'a>( editor: &'a mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_partition: SparseNodeMap<ForkID>, // Describes how the reduces of the fork should be split, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, fork: NodeID, ) -> Vec<NodeID> { let join = fork_join_map[&fork]; let mut new_forks = Vec::new(); let mut new_control_pred: NodeID = editor .get_uses(fork) .filter(|n| editor.node(n).is_control()) .next() .unwrap(); let mut new_fork = NodeID::new(0); let mut new_join = NodeID::new(0); let subgraph = &nodes_in_fork_joins[&fork]; // Gets everything between fork & join that this reduce needs. (ALL CONTROL) editor.edit(|mut edit| { for reduce in reduce_partition { let reduce = reduce.0; let a = copy_subgraph_in_edit(edit, subgraph.clone())?; edit = a.0; let mapping = a.1; new_fork = mapping[&fork]; new_forks.push(new_fork); new_join = mapping[&join]; // Atttach new_fork after control_pred let (old_control_pred, _) = edit.get_node(new_fork).try_fork().unwrap().clone(); edit = edit.replace_all_uses_where(old_control_pred, new_control_pred, |usee| { *usee == new_fork })?; // Replace uses of reduce edit = edit.replace_all_uses(reduce, mapping[&reduce])?; new_control_pred = new_join; } // Replace original join w/ new final join edit = edit.replace_all_uses_where(join, new_join, |_| true)?; // Delete original join (all reduce users have been moved) edit = edit.delete_node(join)?; // Replace all users of original fork, and then delete it, leftover users will be DCE'd. edit = edit.replace_all_uses(fork, new_fork)?; edit = edit.delete_node(fork)?; Ok(edit) }); new_forks } pub fn fork_coalesce( editor: &mut FunctionEditor, loops: &LoopTree, fork_join_map: &HashMap<NodeID, NodeID>, ) -> bool { let fork_joins = loops.bottom_up_loops().into_iter().filter_map(|(k, _)| { if editor.func().nodes[k.idx()].is_fork() { Some(k) } else { None } }); let fork_joins: Vec<_> = fork_joins.collect(); // FIXME: Add a postorder traversal to optimize this. // FIXME: This could give us two forks that aren't actually ancestors / related, but then the helper will just return false early. // something like: `fork_joins.postorder_iter().windows(2)` is ideal here. for (inner, outer) in fork_joins.iter().cartesian_product(fork_joins.iter()) { if fork_coalesce_helper(editor, *outer, *inner, fork_join_map).is_some() { return true; } } return false; } /** Opposite of fork split, takes two fork-joins with no control between them, and merges them into a single fork-join. Returns None if the forks could not be merged and the NodeIDs of the resulting fork and join if it succeeds in merging them. */ pub fn fork_coalesce_helper( editor: &mut FunctionEditor, outer_fork: NodeID, inner_fork: NodeID, fork_join_map: &HashMap<NodeID, NodeID>, ) -> Option<(NodeID, NodeID)> { // Check that all reduces in the outer fork are in *simple* cycles with a unique reduce of the inner fork. let outer_join = fork_join_map[&outer_fork]; let inner_join = fork_join_map[&inner_fork]; let mut pairs: BiMap<NodeID, NodeID> = BiMap::new(); // Outer <-> Inner // FIXME: Iterate all control uses of joins to really collect all reduces // (reduces can be attached to inner control) for outer_reduce in editor .get_users(outer_join) .filter(|node| editor.func().nodes[node.idx()].is_reduce()) { // check that inner reduce is of the inner join let (_, _, outer_reduct) = editor.func().nodes[outer_reduce.idx()] .try_reduce() .unwrap(); let inner_reduce = outer_reduct; let inner_reduce_node = &editor.func().nodes[outer_reduct.idx()]; let Node::Reduce { control: inner_control, init: inner_init, reduct: _, } = inner_reduce_node else { return None; }; // FIXME: check this condition better (i.e reduce might not be attached to join) if *inner_control != inner_join { return None; }; if *inner_init != outer_reduce { return None; }; if pairs.contains_left(&outer_reduce) || pairs.contains_right(&inner_reduce) { return None; } else { pairs.insert(outer_reduce, inner_reduce); } } // Check for control between join-join and fork-fork let (control, _) = editor.node(inner_fork).try_fork().unwrap(); if control != outer_fork { return None; } let control = editor.node(outer_join).try_join().unwrap(); if control != inner_join { return None; } // Checklist: // Increment inner TIDs // Add outer fork's dimension to front of inner fork. // Fuse reductions // - Initializer becomes outer initializer // Replace uses of outer fork w/ inner fork. // Replace uses of outer join w/ inner join. // Delete outer fork-join let inner_tids: Vec<NodeID> = editor .get_users(inner_fork) .filter(|node| editor.func().nodes[node.idx()].is_thread_id()) .collect(); let (outer_pred, outer_dims) = editor.func().nodes[outer_fork.idx()].try_fork().unwrap(); let (_, inner_dims) = editor.func().nodes[inner_fork.idx()].try_fork().unwrap(); let num_outer_dims = outer_dims.len(); let mut new_factors = outer_dims.to_vec(); // CHECKME / FIXME: Might need to be added the other way. new_factors.append(&mut inner_dims.to_vec()); let mut new_fork = NodeID::new(0); let new_join = inner_join; // We'll reuse the inner join as the join of the new fork let success = editor.edit(|mut edit| { for tid in inner_tids { let (fork, dim) = edit.get_node(tid).try_thread_id().unwrap(); let new_tid = Node::ThreadID { control: fork, dimension: dim + num_outer_dims, }; let new_tid = edit.add_node(new_tid); edit = edit.replace_all_uses(tid, new_tid)?; edit.sub_edit(tid, new_tid); } // Fuse Reductions for (outer_reduce, inner_reduce) in pairs { let (_, outer_init, _) = edit.get_node(outer_reduce).try_reduce().unwrap(); let (_, inner_init, _) = edit.get_node(inner_reduce).try_reduce().unwrap(); // Set inner init to outer init. edit = edit.replace_all_uses_where(inner_init, outer_init, |usee| *usee == inner_reduce)?; edit = edit.replace_all_uses(outer_reduce, inner_reduce)?; edit = edit.delete_node(outer_reduce)?; } let new_fork_node = Node::Fork { control: outer_pred, factors: new_factors.into(), }; new_fork = edit.add_node(new_fork_node); if edit .get_schedule(outer_fork) .contains(&Schedule::ParallelFork) && edit .get_schedule(inner_fork) .contains(&Schedule::ParallelFork) { edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; } edit = edit.replace_all_uses(inner_fork, new_fork)?; edit = edit.replace_all_uses(outer_fork, new_fork)?; edit = edit.replace_all_uses(outer_join, inner_join)?; edit = edit.delete_node(outer_join)?; edit = edit.delete_node(inner_fork)?; edit = edit.delete_node(outer_fork)?; Ok(edit) }); if success { Some((new_fork, new_join)) } else { None } } pub fn split_any_fork( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) -> Option<(Vec<NodeID>, Vec<NodeID>)> { for (fork, join) in fork_join_map { if let Some((forks, joins)) = split_fork(editor, *fork, *join, reduce_cycles) && forks.len() > 1 { return Some((forks, joins)); } } None } /* * Split multi-dimensional fork-joins into separate one-dimensional fork-joins. * Useful for code generation. A single iteration of `fork_split` only splits * at most one fork-join, it must be called repeatedly to split all fork-joins. */ pub fn split_fork( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) -> Option<(Vec<NodeID>, Vec<NodeID>)> { // A single multi-dimensional fork becomes multiple forks, a join becomes // multiple joins, a thread ID becomes a thread ID on the correct // fork, and a reduce becomes multiple reduces to shuffle the reduction // value through the fork-join nest. let nodes = &editor.func().nodes; let (fork_control, factors) = nodes[fork.idx()].try_fork().unwrap(); if factors.len() < 2 { return Some((vec![fork], vec![join])); } let factors: Box<[DynamicConstantID]> = factors.into(); let join_control = nodes[join.idx()].try_join().unwrap(); let tids: Vec<_> = editor .get_users(fork) .filter(|id| nodes[id.idx()].is_thread_id()) .collect(); let reduces: Vec<_> = editor .get_users(join) .filter(|id| nodes[id.idx()].is_reduce()) .collect(); let data_in_reduce_cycle: HashSet<(NodeID, NodeID)> = reduces .iter() .map(|reduce| editor.get_users(*reduce).map(move |user| (user, *reduce))) .flatten() .filter(|(user, reduce)| reduce_cycles[&reduce].contains(&user)) .collect(); let mut new_forks = vec![]; let mut new_joins = vec![]; let success = editor.edit(|mut edit| { // Create the forks and a thread ID per fork. let mut acc_fork = fork_control; let mut new_tids = vec![]; for factor in factors { acc_fork = edit.add_node(Node::Fork { control: acc_fork, factors: Box::new([factor]), }); new_forks.push(acc_fork); edit.sub_edit(fork, acc_fork); new_tids.push(edit.add_node(Node::ThreadID { control: acc_fork, dimension: 0, })); } // Create the joins. let mut acc_join = if join_control == fork { acc_fork } else { join_control }; for _ in new_tids.iter() { acc_join = edit.add_node(Node::Join { control: acc_join }); edit.sub_edit(join, acc_join); new_joins.push(acc_join); } // Create the reduces. let mut new_reduces = vec![]; for reduce in reduces.iter() { let (_, init, reduct) = edit.get_node(*reduce).try_reduce().unwrap(); let num_nodes = edit.num_node_ids(); let mut inner_reduce = NodeID::new(0); let mut outer_reduce = NodeID::new(0); for (join_idx, join) in new_joins.iter().enumerate() { let init = if join_idx == new_joins.len() - 1 { init } else { NodeID::new(num_nodes + join_idx + 1) }; let reduct = if join_idx == 0 { reduct } else { NodeID::new(num_nodes + join_idx - 1) }; let new_reduce = edit.add_node(Node::Reduce { control: *join, init, reduct, }); assert_eq!(new_reduce, NodeID::new(num_nodes + join_idx)); edit.sub_edit(*reduce, new_reduce); if join_idx == 0 { inner_reduce = new_reduce; } if join_idx == new_joins.len() - 1 { outer_reduce = new_reduce; } } new_reduces.push((inner_reduce, outer_reduce)); } // Replace everything. edit = edit.replace_all_uses(fork, acc_fork)?; edit = edit.replace_all_uses(join, acc_join)?; for tid in tids.iter() { let dim = edit.get_node(*tid).try_thread_id().unwrap().1; edit.sub_edit(*tid, new_tids[dim]); edit = edit.replace_all_uses(*tid, new_tids[dim])?; } for (reduce, (inner_reduce, outer_reduce)) in zip(reduces.iter(), new_reduces) { edit = edit.replace_all_uses_where(*reduce, inner_reduce, |id| { data_in_reduce_cycle.contains(&(*id, *reduce)) })?; edit = edit.replace_all_uses_where(*reduce, outer_reduce, |id| { !data_in_reduce_cycle.contains(&(*id, *reduce)) })?; } // Delete all the old stuff. edit = edit.delete_node(fork)?; edit = edit.delete_node(join)?; for tid in tids { edit = edit.delete_node(tid)?; } for reduce in reduces { edit = edit.delete_node(reduce)?; } Ok(edit) }); if success { new_joins.reverse(); Some((new_forks, new_joins)) } else { None } } pub fn chunk_all_forks_unguarded( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, dim_idx: usize, tile_size: usize, order: bool, ) -> () { // Add dc let mut dc_id = DynamicConstantID::new(0); editor.edit(|mut edit| { dc_id = edit.add_dynamic_constant(DynamicConstant::Constant(tile_size)); Ok(edit) }); let order = match order { true => &TileOrder::TileInner, false => &TileOrder::TileOuter, }; for (fork, _) in fork_join_map { if editor.is_mutable(*fork) { chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order); } } } // Splits a dimension of a single fork join into multiple. // Iterates an outer loop original_dim / tile_size times // adds a tile_size loop as the inner loop // Assumes that tile size divides original dim evenly. enum TileOrder { TileInner, TileOuter, } pub fn chunk_fork_unguarded( editor: &mut FunctionEditor, fork: NodeID, dim_idx: usize, tile_size: DynamicConstantID, order: &TileOrder, ) -> () { // tid_dim_idx = tid_dim_idx * tile_size + tid_(dim_idx + 1) let Node::Fork { control: old_control, factors: ref old_factors, } = *editor.node(fork) else { return; }; assert!(dim_idx < old_factors.len()); let mut new_factors: Vec<_> = old_factors.to_vec(); let fork_users: Vec<_> = editor .get_users(fork) .map(|f| (f, editor.node(f).clone())) .collect(); match order { TileOrder::TileInner => { editor.edit(|mut edit| { let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx + 1, tile_size); new_factors[dim_idx] = edit.add_dynamic_constant(outer); let new_fork = Node::Fork { control: old_control, factors: new_factors.into(), }; let new_fork = edit.add_node(new_fork); edit = edit.replace_all_uses(fork, new_fork)?; edit.sub_edit(fork, new_fork); for (tid, node) in fork_users { let Node::ThreadID { control: _, dimension: tid_dim, } = node else { continue; }; if tid_dim > dim_idx { let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1, }; let new_tid = edit.add_node(new_tid); edit = edit.replace_all_uses(tid, new_tid)?; edit.sub_edit(tid, new_tid); edit = edit.delete_node(tid)?; } else if tid_dim == dim_idx { let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1, }; let tile_tid = edit.add_node(tile_tid); let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); let mul = edit.add_node(Node::Binary { left: tid, right: tile_size, op: BinaryOperator::Mul, }); let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add, }); edit.sub_edit(tid, add); edit.sub_edit(tid, tile_tid); edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?; } } edit = edit.delete_node(fork)?; Ok(edit) }); } TileOrder::TileOuter => { editor.edit(|mut edit| { let inner = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx, tile_size); let inner_dc_id = edit.add_dynamic_constant(inner); new_factors[dim_idx + 1] = inner_dc_id; let new_fork = Node::Fork { control: old_control, factors: new_factors.into(), }; let new_fork = edit.add_node(new_fork); edit = edit.replace_all_uses(fork, new_fork)?; edit.sub_edit(fork, new_fork); for (tid, node) in fork_users { let Node::ThreadID { control: _, dimension: tid_dim, } = node else { continue; }; if tid_dim > dim_idx { let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1, }; let new_tid = edit.add_node(new_tid); edit = edit.replace_all_uses(tid, new_tid)?; edit.sub_edit(tid, new_tid); edit = edit.delete_node(tid)?; } else if tid_dim == dim_idx { let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim + 1, }; let tile_tid = edit.add_node(tile_tid); let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id }); let mul = edit.add_node(Node::Binary { left: tid, right: inner_dc, op: BinaryOperator::Mul, }); let add = edit.add_node(Node::Binary { left: mul, right: tile_tid, op: BinaryOperator::Add, }); edit.sub_edit(tid, add); edit.sub_edit(tid, tile_tid); edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?; } } edit = edit.delete_node(fork)?; Ok(edit) }); } } } pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { for (fork, _) in fork_join_map { let Node::Fork { control: _, factors: dims, } = editor.node(fork) else { unreachable!(); }; let mut fork = *fork; for _ in 0..dims.len() - 1 { let outer = 0; let inner = 1; fork = fork_dim_merge(editor, fork, outer, inner); } } } pub fn fork_dim_merge( editor: &mut FunctionEditor, fork: NodeID, dim_idx1: usize, dim_idx2: usize, ) -> NodeID { // tid_dim_idx1 (replaced w/) <- dim_idx1 / dim(dim_idx2) // tid_dim_idx2 (replaced w/) <- dim_idx1 % dim(dim_idx2) assert_ne!(dim_idx1, dim_idx2); // Outer is smaller, and also closer to the left of the factors array. let (outer_idx, inner_idx) = if dim_idx2 < dim_idx1 { (dim_idx2, dim_idx1) } else { (dim_idx1, dim_idx2) }; let Node::Fork { control: old_control, factors: ref old_factors, } = *editor.node(fork) else { return fork; }; let mut new_factors: Vec<_> = old_factors.to_vec(); let fork_users: Vec<_> = editor .get_users(fork) .map(|f| (f, editor.node(f).clone())) .collect(); let mut new_nodes = vec![]; let outer_dc_id = new_factors[outer_idx]; let inner_dc_id = new_factors[inner_idx]; let mut new_fork = NodeID::new(0); editor.edit(|mut edit| { new_factors[outer_idx] = edit.add_dynamic_constant(DynamicConstant::mul( new_factors[outer_idx], new_factors[inner_idx], )); new_factors.remove(inner_idx); new_fork = edit.add_node(Node::Fork { control: old_control, factors: new_factors.into(), }); edit.sub_edit(fork, new_fork); edit = edit.replace_all_uses(fork, new_fork)?; edit = edit.delete_node(fork)?; for (tid, node) in fork_users { let Node::ThreadID { control: _, dimension: tid_dim, } = node else { continue; }; if tid_dim > inner_idx { let new_tid = Node::ThreadID { control: new_fork, dimension: tid_dim - 1, }; let new_tid = edit.add_node(new_tid); edit = edit.replace_all_uses(tid, new_tid)?; edit.sub_edit(tid, new_tid); } else if tid_dim == outer_idx { let outer_tid = Node::ThreadID { control: new_fork, dimension: outer_idx, }; let outer_tid = edit.add_node(outer_tid); let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); new_nodes.push(outer_tid); // inner_idx % dim(outer_idx) let rem = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Rem, }); edit.sub_edit(tid, rem); edit.sub_edit(tid, outer_tid); edit = edit.replace_all_uses(tid, rem)?; } else if tid_dim == inner_idx { let outer_tid = Node::ThreadID { control: new_fork, dimension: outer_idx, }; let outer_tid = edit.add_node(outer_tid); let outer_dc = edit.add_node(Node::DynamicConstant { id: outer_dc_id }); // inner_idx / dim(outer_idx) let div = edit.add_node(Node::Binary { left: outer_tid, right: outer_dc, op: BinaryOperator::Div, }); edit.sub_edit(tid, div); edit.sub_edit(tid, outer_tid); edit = edit.replace_all_uses(tid, div)?; } } Ok(edit) }); new_fork } /* * Run fork interchange on all fork-joins that are mutable in an editor. */ pub fn fork_interchange_all_forks( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, first_dim: usize, second_dim: usize, ) { for (fork, join) in fork_join_map { if editor.is_mutable(*fork) { fork_interchange(editor, *fork, *join, first_dim, second_dim); } } } pub fn fork_interchange( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, first_dim: usize, second_dim: usize, ) -> Option<NodeID> { // Check that every reduce on the join is parallel or associative. let nodes = &editor.func().nodes; let schedules = &editor.func().schedules; if !editor .get_users(join) .filter(|id| nodes[id.idx()].is_reduce()) .all(|id| { schedules[id.idx()].contains(&Schedule::ParallelReduce) || schedules[id.idx()].contains(&Schedule::MonoidReduce) }) { // If not, we can't necessarily do interchange. return None; } let Node::Fork { control, ref factors, } = nodes[fork.idx()] else { panic!() }; let fix_tids: Vec<(NodeID, Node)> = editor .get_users(fork) .filter_map(|id| { nodes[id.idx()] .try_thread_id() .map(|(_, dim)| { if dim == first_dim { Some(( id, Node::ThreadID { control: fork, dimension: second_dim, }, )) } else if dim == second_dim { Some(( id, Node::ThreadID { control: fork, dimension: first_dim, }, )) } else { None } }) .flatten() }) .collect(); let mut factors = factors.clone(); factors.swap(first_dim, second_dim); let new_fork = Node::Fork { control, factors }; let mut new_fork_id = None; editor.edit(|mut edit| { for (old_id, new_tid) in fix_tids { let new_id = edit.add_node(new_tid); edit = edit.replace_all_uses(old_id, new_id)?; edit = edit.delete_node(old_id)?; } let new_fork = edit.add_node(new_fork); if edit.get_schedule(fork).contains(&Schedule::ParallelFork) { edit = edit.add_schedule(new_fork, Schedule::ParallelFork)?; } edit = edit.replace_all_uses(fork, new_fork)?; edit = edit.delete_node(fork)?; new_fork_id = Some(new_fork); Ok(edit) }); new_fork_id } /* * Run fork unrolling on all fork-joins that are mutable in an editor. */ pub fn fork_unroll_all_forks( editor: &mut FunctionEditor, fork_joins: &HashMap<NodeID, NodeID>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) { for (fork, join) in fork_joins { if editor.is_mutable(*fork) && fork_unroll(editor, *fork, *join, nodes_in_fork_joins) { break; } } } pub fn fork_unroll( editor: &mut FunctionEditor, fork: NodeID, join: NodeID, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> bool { // We can only unroll fork-joins with a compile time known factor list. For // simplicity, just unroll fork-joins that have a single dimension. let nodes = &editor.func().nodes; let Node::Fork { control, ref factors, } = nodes[fork.idx()] else { panic!() }; if factors.len() != 1 || editor.get_users(fork).count() != 2 { return false; } let DynamicConstant::Constant(cons) = *editor.get_dynamic_constant(factors[0]) else { return false; }; let tid = editor .get_users(fork) .filter(|id| nodes[id.idx()].is_thread_id()) .next() .unwrap(); let inits: HashMap<NodeID, NodeID> = editor .get_users(join) .filter_map(|id| nodes[id.idx()].try_reduce().map(|(_, init, _)| (id, init))) .collect(); editor.edit(|mut edit| { // Create a copy of the nodes in the fork join per unrolled iteration, // excluding the fork itself, the join itself, the thread IDs of the fork, // and the reduces on the join. Keep a running tally of the top control // node and the current reduction value. let mut top_control = control; let mut current_reduces = inits; for iter in 0..cons { let iter_cons = edit.add_constant(Constant::UnsignedInteger64(iter as u64)); let iter_tid = edit.add_node(Node::Constant { id: iter_cons }); // First, add a copy of each node in the fork join unmodified. // Record the mapping from old ID to new ID. let mut added_ids = HashSet::new(); let mut old_to_new_ids = HashMap::new(); let mut new_control = None; let mut new_reduces = HashMap::new(); for node in nodes_in_fork_joins[&fork].iter() { if *node == fork { old_to_new_ids.insert(*node, top_control); } else if *node == join { new_control = Some(edit.get_node(*node).try_join().unwrap()); } else if *node == tid { old_to_new_ids.insert(*node, iter_tid); } else if let Some(current) = current_reduces.get(node) { old_to_new_ids.insert(*node, *current); new_reduces.insert(*node, edit.get_node(*node).try_reduce().unwrap().2); } else { let new_node = edit.add_node(edit.get_node(*node).clone()); old_to_new_ids.insert(*node, new_node); added_ids.insert(new_node); } } // Second, replace all the uses in the just added nodes. if let Some(new_control) = new_control { top_control = old_to_new_ids[&new_control]; } for (reduce, reduct) in new_reduces { current_reduces.insert(reduce, old_to_new_ids[&reduct]); } for (old, new) in old_to_new_ids { edit = edit.replace_all_uses_where(old, new, |id| added_ids.contains(id))?; } } // Hook up the control and reduce outputs to the rest of the function. edit = edit.replace_all_uses(join, top_control)?; for (reduce, reduct) in current_reduces { edit = edit.replace_all_uses(reduce, reduct)?; } // Delete the old fork-join. for node in nodes_in_fork_joins[&fork].iter() { edit = edit.delete_node(*node)?; } Ok(edit) }) } /* * Looks for fork-joins that are next to each other, not inter-dependent, and * have the same bounds. These fork-joins can be fused, pooling together all * their reductions. */ pub fn fork_fusion_all_forks( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) { for (fork, join) in fork_join_map { if editor.is_mutable(*fork) && fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins) { break; } } } /* * Tries to fuse a given fork join with the immediately following fork-join, if * it exists. */ fn fork_fusion( editor: &mut FunctionEditor, top_fork: NodeID, top_join: NodeID, fork_join_map: &HashMap<NodeID, NodeID>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> bool { let nodes = &editor.func().nodes; // Rust operator precedence is not such that these can be put in one big // let-else statement. Sad! let Some(bottom_fork) = editor .get_users(top_join) .filter(|id| nodes[id.idx()].is_control()) .next() else { return false; }; let Some(bottom_join) = fork_join_map.get(&bottom_fork) else { return false; }; let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap(); let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap(); assert_eq!(bottom_fork_pred, top_join); let top_join_pred = nodes[top_join.idx()].try_join().unwrap(); let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap(); // The fork factors must be identical. if top_factors != bottom_factors { return false; } // Check that no iterated users of the top's reduces are in the bottom fork- // join (iteration stops at a phi or reduce outside the bottom fork-join). for reduce in editor .get_users(top_join) .filter(|id| nodes[id.idx()].is_reduce()) { let mut visited = HashSet::new(); visited.insert(reduce); let mut workset = vec![reduce]; while let Some(pop) = workset.pop() { for u in editor.get_users(pop) { if nodes_in_fork_joins[&bottom_fork].contains(&u) { return false; } else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce()) && !nodes_in_fork_joins[&top_fork].contains(&u) { } else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) { visited.insert(u); workset.push(u); } } } } // Perform the fusion. let bottom_tids: Vec<_> = editor .get_users(bottom_fork) .filter(|id| nodes[id.idx()].is_thread_id()) .collect(); editor.edit(|mut edit| { edit = edit.replace_all_uses_where(bottom_fork, top_fork, |id| bottom_tids.contains(id))?; if bottom_join_pred != bottom_fork { // If there is control flow in the bottom fork-join, stitch it into // the top fork-join. edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| { nodes_in_fork_joins[&bottom_fork].contains(id) })?; edit = edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?; } // Replace the bottom fork and join with the top fork and join. edit = edit.replace_all_uses(bottom_fork, top_fork)?; edit = edit.replace_all_uses(*bottom_join, top_join)?; edit = edit.delete_node(bottom_fork)?; edit = edit.delete_node(*bottom_join)?; Ok(edit) }) } /* * Looks for monoid reductions where the initial input is not the identity * element, and converts them into a form whose initial input is an identity * element. This aides in parallelizing outer loops. Looks only at reduces with * the monoid reduce schedule, since that indicates a particular structure which * is annoying to check for again. * * Looks for would-be monoid reduces, if not for a gate on the reduction. * Partially predicate the gated reduction to allow for a proper monoid * reduction. */ pub fn clean_monoid_reduces(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { for id in editor.node_ids() { if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) { continue; } let nodes = &editor.func().nodes; let Some((_, init, reduct)) = nodes[id.idx()].try_reduce() else { continue; }; let out_users: Vec<_> = editor.get_users(id).filter(|id| *id != reduct).collect(); match nodes[reduct.idx()] { Node::Binary { op, left: _, right: _, } if (op == BinaryOperator::Add || op == BinaryOperator::Or) && !is_zero(editor, init) && !is_false(editor, init) => { editor.edit(|mut edit| { let zero = edit.add_zero_constant(typing[init.idx()]); let zero = edit.add_node(Node::Constant { id: zero }); edit.sub_edit(id, zero); edit = edit.replace_all_uses_where(init, zero, |u| *u == id)?; let final_op = edit.add_node(Node::Binary { op, left: init, right: id, }); for u in out_users { edit.sub_edit(u, final_op); } edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) }); } Node::Binary { op, left: _, right: _, } if (op == BinaryOperator::Mul || op == BinaryOperator::And) && !is_one(editor, init) && !is_true(editor, init) => { editor.edit(|mut edit| { let one = edit.add_one_constant(typing[init.idx()]); let one = edit.add_node(Node::Constant { id: one }); edit.sub_edit(id, one); edit = edit.replace_all_uses_where(init, one, |u| *u == id)?; let final_op = edit.add_node(Node::Binary { op, left: init, right: id, }); for u in out_users { edit.sub_edit(u, final_op); } edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) }); } Node::IntrinsicCall { intrinsic: Intrinsic::Max, args: _, } if !is_smallest(editor, init) => { editor.edit(|mut edit| { let smallest = edit.add_smallest_constant(typing[init.idx()]); let smallest = edit.add_node(Node::Constant { id: smallest }); edit.sub_edit(id, smallest); edit = edit.replace_all_uses_where(init, smallest, |u| *u == id)?; let final_op = edit.add_node(Node::IntrinsicCall { intrinsic: Intrinsic::Max, args: Box::new([init, id]), }); for u in out_users { edit.sub_edit(u, final_op); } edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) }); } Node::IntrinsicCall { intrinsic: Intrinsic::Min, args: _, } if !is_largest(editor, init) => { editor.edit(|mut edit| { let largest = edit.add_largest_constant(typing[init.idx()]); let largest = edit.add_node(Node::Constant { id: largest }); edit.sub_edit(id, largest); edit = edit.replace_all_uses_where(init, largest, |u| *u == id)?; let final_op = edit.add_node(Node::IntrinsicCall { intrinsic: Intrinsic::Min, args: Box::new([init, id]), }); for u in out_users { edit.sub_edit(u, final_op); } edit.replace_all_uses_where(id, final_op, |u| *u != reduct && *u != final_op) }); } _ => {} } } for id in editor.node_ids() { if !editor.func().schedules[id.idx()].contains(&Schedule::MonoidReduce) { continue; } let nodes = &editor.func().nodes; let Some((control, init, reduct)) = nodes[id.idx()].try_reduce() else { continue; }; if let Node::Phi { control: phi_control, ref data, } = nodes[reduct.idx()] && data.len() == 2 && data.contains(&id) && let other = *data .into_iter() .filter(|other| **other != id) .next() .unwrap() && let Node::Binary { op: BinaryOperator::Add, left, right, } = nodes[other.idx()] && ((left == id) ^ (right == id)) { let gated_input = if left == id { right } else { left }; let data = data.clone(); editor.edit(|mut edit| { let zero = edit.add_zero_constant(typing[id.idx()]); let zero = edit.add_node(Node::Constant { id: zero }); let phi = edit.add_node(Node::Phi { control: phi_control, data: data .iter() .map(|phi_use| if *phi_use == id { zero } else { gated_input }) .collect(), }); let new_reduce_id = NodeID::new(edit.num_node_ids()); let new_reduct_id = NodeID::new(edit.num_node_ids() + 1); let new_reduce = Node::Reduce { control, init, reduct: new_reduct_id, }; let new_add = Node::Binary { op: BinaryOperator::Add, left: new_reduce_id, right: phi, }; let new_reduce = edit.add_node(new_reduce); edit.add_node(new_add); edit = edit.replace_all_uses(id, new_reduce)?; edit = edit.delete_node(id)?; Ok(edit) }); } } } /* * Extends the dimensions of a fork-join to be a multiple of a number and gates * the execution of the body. */ pub fn extend_all_forks( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, multiple: usize, ) { for (fork, join) in fork_join_map { if editor.is_mutable(*fork) { extend_fork(editor, *fork, *join, multiple); } } } fn extend_fork(editor: &mut FunctionEditor, fork: NodeID, join: NodeID, multiple: usize) { let nodes = &editor.func().nodes; let (fork_pred, factors) = nodes[fork.idx()].try_fork().unwrap(); let factors = factors.to_vec(); let fork_succ = editor .get_users(fork) .filter(|id| nodes[id.idx()].is_control()) .next() .unwrap(); let join_pred = nodes[join.idx()].try_join().unwrap(); let ctrl_between = fork != join_pred; let reduces: Vec<_> = editor .get_users(join) .filter_map(|id| nodes[id.idx()].try_reduce().map(|x| (id, x))) .collect(); editor.edit(|mut edit| { // We can round up a dynamic constant A to a multiple of another dynamic // constant B via the following math: // ((A + B - 1) / B) * B let new_factors: Vec<_> = factors .iter() .map(|factor| { let b = edit.add_dynamic_constant(DynamicConstant::Constant(multiple)); let apb = edit.add_dynamic_constant(DynamicConstant::add(*factor, b)); let o = edit.add_dynamic_constant(DynamicConstant::Constant(1)); let apbmo = edit.add_dynamic_constant(DynamicConstant::sub(apb, o)); let apbmodb = edit.add_dynamic_constant(DynamicConstant::div(apbmo, b)); edit.add_dynamic_constant(DynamicConstant::mul(apbmodb, b)) }) .collect(); // Create the new control structure. let new_fork = edit.add_node(Node::Fork { control: fork_pred, factors: new_factors.into_boxed_slice(), }); edit = edit.replace_all_uses_where(fork, new_fork, |id| *id != fork_succ)?; edit.sub_edit(fork, new_fork); let conds: Vec<_> = factors .iter() .enumerate() .map(|(idx, old_factor)| { let tid = edit.add_node(Node::ThreadID { control: new_fork, dimension: idx, }); edit.sub_edit(fork, tid); let old_bound = edit.add_node(Node::DynamicConstant { id: *old_factor }); edit.add_node(Node::Binary { op: BinaryOperator::LT, left: tid, right: old_bound, }) }) .collect(); let cond = conds .into_iter() .reduce(|left, right| { edit.add_node(Node::Binary { op: BinaryOperator::And, left, right, }) }) .unwrap(); let branch = edit.add_node(Node::If { control: new_fork, cond, }); let false_proj = edit.add_node(Node::ControlProjection { control: branch, selection: 0, }); let true_proj = edit.add_node(Node::ControlProjection { control: branch, selection: 1, }); if ctrl_between { edit = edit.replace_all_uses_where(fork, true_proj, |id| *id == fork_succ)?; } let bottom_region = edit.add_node(Node::Region { preds: Box::new([false_proj, if ctrl_between { join_pred } else { true_proj }]), }); let new_join = edit.add_node(Node::Join { control: bottom_region, }); edit = edit.replace_all_uses(join, new_join)?; edit.sub_edit(join, new_join); edit = edit.delete_node(fork)?; edit = edit.delete_node(join)?; // Update the reduces to use phis on the region node to gate their execution. for (reduce, (_, init, reduct)) in reduces { let phi = edit.add_node(Node::Phi { control: bottom_region, data: Box::new([reduce, reduct]), }); let new_reduce = edit.add_node(Node::Reduce { control: new_join, init, reduct: phi, }); edit = edit.replace_all_uses(reduce, new_reduce)?; edit.sub_edit(reduce, new_reduce); edit = edit.delete_node(reduce)?; } Ok(edit) }); }