use std::collections::{HashMap, HashSet}; use std::ops::Deref; use hercules_ir::*; use crate::*; /* * This is a Hercules IR transformation that: * - Eliminates guards (directly) surrounding fork-joins when the guard's * condition is of the form 0 < n for a fork with replication factor n, * and when the initial inputs to the reduce nodes and phi nodes they feed * into (for the region that joins control back between the guard's two * branches) are the same * * This optimization is useful with code generated by the Juno frontend as it * generates guarded loops which are eventually converted into forks but the * guard remains and in these cases the guard is no longer needed. */ // Simplify factors through max enum Factor { Max(usize, DynamicConstantID), Normal(DynamicConstantID), } impl Factor { fn get_id(&self) -> DynamicConstantID { match self { Factor::Max(_, dynamic_constant_id) => *dynamic_constant_id, Factor::Normal(dynamic_constant_id) => *dynamic_constant_id, } } } struct GuardedFork { fork: NodeID, join: NodeID, guard_if: NodeID, fork_taken_proj: NodeID, fork_skipped_proj: NodeID, guard_join_region: NodeID, phi_reduce_map: HashMap<NodeID, NodeID>, factor: Factor, // The factor that matches the guard } /* Given a node index and the node itself, return None if the node is not * a guarded fork where we can eliminate the guard. * If the node is a fork with a guard we can eliminate returns a tuple of * - This node's NodeID * - The replication factor of the fork * - The ID of the if of the guard * - The ID of the projections of the if * - The guard's predecessor * - A map of NodeIDs for the phi nodes to the reduce they should be replaced * with, and also the region that joins the guard's branches mapping to the * fork's join NodeID * - If the replication factor is a max that can be eliminated. */ fn guarded_fork( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, node: NodeID, ) -> Option<GuardedFork> { let function = editor.func(); // Identify fork nodes let Node::Fork { control, factors } = &function.nodes[node.idx()] else { return None; }; let mut factors = factors.iter().enumerate().map(|(idx, dc)| { let factor = editor.get_dynamic_constant(*dc); let DynamicConstant::Max(xs) = factor.deref() else { return Factor::Normal(*dc); }; // Filter out any terms which are just 1s let non_ones = xs .iter() .filter(|i| { if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() { false } else { true } }) .collect::<Vec<_>>(); // If we're left with just one term x, we had max { 1, x } if non_ones.len() == 1 { Factor::Max(idx, *non_ones[0]) } else { Factor::Normal(*dc) } }); // Whose predecessor is a read from an if let Node::ControlProjection { control: if_node, ref selection, } = function.nodes[control.idx()] else { return None; }; let Node::If { control: if_pred, cond, } = function.nodes[if_node.idx()] else { return None; }; // Whose condition is appropriate let Node::Binary { left, right, op } = function.nodes[cond.idx()] else { return None; }; let branch_idx = *selection; let factor = { // branchIdx == 1 means the true branch so we want the condition to be // 0 < n or n > 0 if branch_idx == 1 { [ (left, BinaryOperator::LT, right), (right, BinaryOperator::GT, left), ] .iter() .find_map(|(pattern_zero, pattern_op, pattern_factor)| { // Match Op if op != *pattern_op { return None; } // Match Zero if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants()) || editor .node(pattern_zero) .is_zero_dc(&editor.get_dynamic_constants())) { return None; } // Match Factor let factor = factors.find(|factor| { match ( &function.nodes[pattern_factor.idx()], &*editor.get_dynamic_constant(factor.get_id()), ) { (Node::Constant { id }, DynamicConstant::Constant(v)) => { let Constant::UnsignedInteger64(pattern_v) = *editor.get_constant(*id) else { return false; }; pattern_v == (*v as u64) } (Node::DynamicConstant { id }, _) => *id == factor.get_id(), _ => false, } }); factor }) } // branchIdx == 0 means the false branch so we want the condition to be // n < 0 or 0 > n else if branch_idx == 0 { [ (right, BinaryOperator::LT, left), (left, BinaryOperator::GT, right), ] .iter() .find_map(|(pattern_zero, pattern_op, pattern_factor)| { // Match Op if op != *pattern_op { return None; } // Match Zero if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants()) || editor .node(pattern_zero) .is_zero_dc(&editor.get_dynamic_constants())) { return None; } // Match Factor let factor = factors.find(|factor| { function.nodes[pattern_factor.idx()].try_dynamic_constant() == Some(factor.get_id()) }); factor }) } else { None } }; let Some(factor) = factor else { return None }; // Identify the join node and its users let join_id = fork_join_map.get(&node)?; // Find the unique control use of the join; if it's not a region we can't // eliminate this guard let join_control = editor .get_users(*join_id) .filter(|n| function.nodes[n.idx()].is_region()) .collect::<Vec<_>>(); if join_control.len() != 1 { return None; } let join_control = join_control[0]; let Some(Node::Region { preds }) = function.nodes.get(join_control.idx()) else { return None; }; // The region after the join can only have two predecessors (for the guard // and the fork-join) if preds.len() != 2 { return None; } let other_pred = if preds[1] == *join_id { preds[0] } else if preds[0] == *join_id { preds[1] } else { return None; }; // Other predecessor needs to be the other projection from the guard's if let Node::ControlProjection { control: if_node2, ref selection, } = function.nodes[other_pred.idx()] else { return None; }; let else_branch = *selection; if else_branch == branch_idx { return None; } if if_node2 != if_node { return None; } // Finally, identify the phi nodes associated with the region and match // them with the reduce nodes of the fork-join let reduce_nodes = editor .get_users(*join_id) .filter(|n| function.nodes[n.idx()].is_reduce()) .collect::<HashSet<_>>(); // Construct a map from phi nodes indices to the reduce node index let phi_nodes = editor .get_users(join_control) .filter_map(|n| { let Node::Phi { control: _, ref data, } = function.nodes[n.idx()] else { return None; }; if data.len() != 2 { return Some((n, None)); } let (init_idx, reduce_node) = if reduce_nodes.contains(&data[0]) { (1, data[0]) } else if reduce_nodes.contains(&data[1]) { (0, data[1]) } else { return Some((n, None)); }; let Node::Reduce { control: _, init, .. } = function.nodes[reduce_node.idx()] else { return Some((n, None)); }; if data[init_idx] != init { return Some((n, None)); } Some((n, Some(reduce_node))) }) .collect::<HashMap<_, _>>(); // If any of the phi nodes do not have an associated reduce node, we cannot // remove the loop guard if phi_nodes.iter().any(|(_, red)| red.is_none()) { return None; } let phi_nodes = phi_nodes .into_iter() .map(|(phi, red)| (phi, red.unwrap())) .collect::<HashMap<_, _>>(); // Finally, we return this node's index along with // - The replication factor of the fork // - The if node // - The true and false reads of the if // - The guard's predecessor // - The map from phi nodes to reduce nodes and the region to the join Some(GuardedFork { fork: node, join: *join_id, guard_if: if_node, fork_taken_proj: *control, fork_skipped_proj: other_pred, guard_join_region: join_control, phi_reduce_map: phi_nodes, factor, }) } /* * Top level function to run fork guard elimination, as described above. */ pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { let guard_info = editor .node_ids() .filter_map(|node| guarded_fork(editor, fork_join_map, node)) .collect::<Vec<_>>(); for GuardedFork { fork, join, fork_taken_proj, fork_skipped_proj, phi_reduce_map, factor, guard_if, guard_join_region, } in guard_info { let Some(guard_pred) = editor.get_uses(guard_if).next() else { unreachable!() }; let new_fork_info = if let Factor::Max(idx, dc) = factor { let Node::Fork { control: _, mut factors, } = editor.func().nodes[fork.idx()].clone() else { unreachable!() }; factors[idx] = dc; let new_fork = Node::Fork { control: guard_pred, factors, }; Some(new_fork) } else { None }; editor.edit(|mut edit| { edit = edit.replace_all_uses_where(fork_taken_proj, guard_pred, |usee| *usee == fork)?; edit = edit.delete_node(guard_if)?; edit = edit.delete_node(fork_taken_proj)?; edit = edit.delete_node(fork_skipped_proj)?; edit = edit.replace_all_uses(guard_join_region, join)?; edit = edit.delete_node(guard_join_region)?; // Delete region node for (phi, reduce) in phi_reduce_map.iter() { edit = edit.replace_all_uses(*phi, *reduce)?; edit = edit.delete_node(*phi)?; } if let Some(new_fork_info) = new_fork_info { let new_fork = edit.add_node(new_fork_info); edit = edit.replace_all_uses(fork, new_fork)?; edit = edit.delete_node(fork)?; } Ok(edit) }); } }