diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs new file mode 100644 index 0000000000000000000000000000000000000000..0ff3f785ed82da60316da94d2a63a3ffcf06dfd3 --- /dev/null +++ b/hercules_opt/src/fork_guard_elim.rs @@ -0,0 +1,185 @@ +extern crate hercules_ir; + +use std::collections::{HashMap, HashSet}; + +use self::hercules_ir::ir::*; +use self::hercules_ir::ImmutableDefUseMap; +use self::hercules_ir::get_uses_mut; + +/* + * This is a Hercules IR transformation that: + * - Eliminates guards (directly) surrounding fork-joins when the guard's + * condition is of the form 0 < n for a fork with replication factor n, + * and when the initial inputs to the reduce nodes and phi nodes they feed + * into (for the region that joins control back between the guard's two + * branches) are the same + * + * This optimization is useful with code generated by the Juno frontend as it + * generates guarded loops which are eventually converted into forks but the + * guard remains and in these cases the guard is no longer needed. + */ + +/* Given a node index and the node itself, return None if the node is not + * a guarded fork where we can eliminate the guard. + * If the node is a fork with a guard we can eliminate returns a tuple of + * - This node's NodeID + * - The replication factor of the fork + * - The ID of the if of the guard + * - The ID of the projections of the if + * - The guard's predecessor + * - A map of NodeIDs for the phi nodes to the reduce they should be replaced + * with, and also the region that joins the guard's branches mapping to the + * fork's join NodeID + */ +fn guarded_fork(function: &Function, + constants: &Vec<Constant>, + fork_join_map: &HashMap<NodeID, NodeID>, + def_use: &ImmutableDefUseMap, + index: usize, + node: &Node) + -> Option<(NodeID, DynamicConstantID, NodeID, NodeID, NodeID, + NodeID, HashMap<NodeID, NodeID>)> { + // Identify fork nodes + let Node::Fork { control, factor } = node else { return None; }; + // Whose predecessor is a read from an if + let Node::Read { collect : if_node, ref indices } + = function.nodes[control.idx()] else { return None; }; + if indices.len() != 1 { return None; } + let Index::Control(branchIdx) = indices[0] else { return None; }; + let Node::If { control : if_pred, cond } = function.nodes[if_node.idx()] + else { return None; }; + + // Whose condition is appropriate + let Node::Binary { left, right, op } = function.nodes[cond.idx()] + else { return None; }; + // branchIdx == 1 means the true branch so we want the condition to be + // 0 < n or n > 0 + if branchIdx == 1 + && !((op == BinaryOperator::LT && function.nodes[left.idx()].is_zero_constant(constants) + && function.nodes[right.idx()].try_dynamic_constant() == Some(*factor)) + || (op == BinaryOperator::GT && function.nodes[right.idx()].is_zero_constant(constants) + && function.nodes[left.idx()].try_dynamic_constant() == Some(*factor))) { + return None; + } + // branchIdx == 0 means the false branch so we want the condition to be + // n < 0 or 0 > n + if branchIdx == 0 + && !((op == BinaryOperator::LT && function.nodes[left.idx()].try_dynamic_constant() == Some(*factor) + && function.nodes[right.idx()].is_zero_constant(constants)) + || (op == BinaryOperator::GT && function.nodes[right.idx()].try_dynamic_constant() == Some(*factor) + && function.nodes[left.idx()].is_zero_constant(constants))) { + return None; + } + + // Identify the join node and its users + let join_id = fork_join_map.get(&NodeID::new(index))?; + let join_users = def_use.get_users(*join_id); + + // Find the unique control use of the join; if it's not a region we can't + // eliminate this guard + let join_control = join_users.iter().filter(|n| function.nodes[n.idx()].is_region()) + .collect::<Vec<_>>(); + if join_control.len() != 1 { return None; } + + let join_control = join_control[0]; + let Some(Node::Region { preds }) = function.nodes.get(join_control.idx()) + else { return None; }; + + // The region after the join can only have two predecessors (for the guard + // and the fork-join) + if preds.len() != 2 { return None; } + let other_pred = + if preds[1] == *join_id { + preds[0] + } else if preds[0] == *join_id { + preds[1] + } else { + return None; + }; + // Other predecessor needs to be the other read from the guard's if + let Node::Read { collect : read_control, ref indices } + = function.nodes[other_pred.idx()] + else { return None; }; + if indices.len() != 1 { return None; } + let Index::Control(elseBranch) = indices[0] else { return None; }; + if elseBranch == branchIdx { return None; } + if read_control != if_node { return None; } + + // Finally, identify the phi nodes associated with the region and match + // them with the reduce nodes of the fork-join + let reduce_nodes = join_users.iter().filter(|n| function.nodes[n.idx()].is_reduce()) + .collect::<HashSet<_>>(); + // Construct a map from phi nodes indices to the reduce node index + let phi_nodes = def_use.get_users(*join_control).iter() + .filter_map(|n| { + let Node::Phi { control : _, ref data } = function.nodes[n.idx()] + else { return None; }; + if data.len() != 2 { return Some((*n, None)); } + let (init_idx, reduce_node) + = if reduce_nodes.contains(&data[0]) { + (1, data[0]) + } else if reduce_nodes.contains(&data[1]) { + (0, data[1]) + } else { return Some((*n, None)); }; + let Node::Reduce { control : _, init, .. } + = function.nodes[reduce_node.idx()] + else { return Some((*n, None)); }; + if data[init_idx] != init { return Some((*n, None)); } + Some((*n, Some(reduce_node))) + }) + .collect::<HashMap<_, _>>(); + + // If any of the phi nodes do not have an associated reduce node, we cannot + // remove the loop guard + if phi_nodes.iter().any(|(_, red)| red.is_none()) { return None; } + + let mut phi_nodes = phi_nodes.into_iter() + .map(|(phi, red)| (phi, red.unwrap())) + .collect::<HashMap<_, _>>(); + + // We also add a map from the region to the join to this map so we only + // need one map to handle all node replacements in the elimination process + phi_nodes.insert(*join_control, *join_id); + + // Finally, we return this node's index along with + // - The replication factor of the fork + // - The if node + // - The true and false reads of the if + // - The guard's predecessor + // - The map from phi nodes to reduce nodes and the region to the join + Some ((NodeID::new(index), *factor, if_node, *control, other_pred, if_pred, phi_nodes)) +} + +/* + * Top level function to run fork guard elimination, as described above. + * Deletes nodes by setting nodes to gravestones. Works with a function already + * containing gravestones. + */ +pub fn fork_guard_elim(function: &mut Function, + constants: &Vec<Constant>, + fork_join_map: &HashMap<NodeID, NodeID>, + def_use: &ImmutableDefUseMap) { + let guard_info = function.nodes.iter().enumerate() + .filter_map(|(i, n)| guarded_fork(function, constants, + fork_join_map, + def_use, i, n)) + .collect::<Vec<_>>(); + + for (fork_node, factor, guard_node, guard_proj1, guard_proj2, guard_pred, map) in guard_info { + function.nodes[guard_node.idx()] = Node::Start; + function.nodes[guard_proj1.idx()] = Node::Start; + function.nodes[guard_proj2.idx()] = Node::Start; + function.nodes[fork_node.idx()] + = Node::Fork { control : guard_pred, factor : factor}; + + for (idx, node) in function.nodes.iter_mut().enumerate() { + let node_idx = NodeID::new(idx); + if map.contains_key(&node_idx) { *node = Node::Start; } + for u in get_uses_mut(node).as_mut() { + if let Some(replacement) = map.get(u) { + **u = *replacement; + } + } + } + } +} diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 27a9cfed113151a177312269e8824b3d3a483696..f336e8d9e9c6313159e2a9a18906bae968447b4e 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -2,6 +2,7 @@ pub mod ccp; pub mod dce; +pub mod fork_guard_elim; pub mod forkify; pub mod gvn; pub mod phi_elim; @@ -13,5 +14,6 @@ pub use crate::dce::*; pub use crate::forkify::*; pub use crate::gvn::*; pub use crate::phi_elim::*; +pub use crate::fork_guard_elim::*; pub use crate::pass::*; pub use crate::pred::*; diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index ef79b36e6c667436e48bee1edc3da8950a55b7f1..bb6026069ee807fc4ac31cfede8d308042a4f1c3 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -27,6 +27,7 @@ pub enum Pass { GVN, Forkify, PhiElim, + ForkGuardElim, Predication, Verify, // Parameterized over whether analyses that aid visualization are necessary. @@ -327,6 +328,20 @@ impl PassManager { phi_elim(function); } } + Pass::ForkGuardElim => { + self.make_def_uses(); + self.make_fork_join_maps(); + let def_uses = self.def_uses.as_ref().unwrap(); + let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); + for idx in 0..self.module.functions.len() { + fork_guard_elim( + &mut self.module.functions[idx], + &self.module.constants, + &fork_join_maps[idx], + &def_uses[idx], + ) + } + } Pass::Predication => { self.make_def_uses(); self.make_reverse_postorders(); diff --git a/juno_frontend/src/main.rs b/juno_frontend/src/main.rs index f8e267eb3d2f2eb9bc8e7555cd313af0cc4c857d..33dccb0d26bc8d54aa1e438d0b4989d97b46eb11 100644 --- a/juno_frontend/src/main.rs +++ b/juno_frontend/src/main.rs @@ -35,7 +35,9 @@ fn main() { pm.add_pass(hercules_opt::pass::Pass::DCE); pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); pm.add_pass(hercules_opt::pass::Pass::Forkify); + pm.add_pass(hercules_opt::pass::Pass::ForkGuardElim); pm.add_pass(hercules_opt::pass::Pass::DCE); + pm.add_pass(hercules_opt::pass::Pass::Verify); pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); let _ = pm.run_passes(); },