Skip to content
Snippets Groups Projects
Commit fde9f419 authored by Aaron Councilman's avatar Aaron Councilman
Browse files

Fork Guard Elimination

parent ccea5a50
No related branches found
No related tags found
1 merge request!21Fork Guard Elimination
extern crate hercules_ir;
use std::collections::{HashMap, HashSet};
use self::hercules_ir::ir::*;
use self::hercules_ir::ImmutableDefUseMap;
use self::hercules_ir::get_uses_mut;
/*
* This is a Hercules IR transformation that:
* - Eliminates guards (directly) surrounding fork-joins when the guard's
* condition is of the form 0 < n for a fork with replication factor n,
* and when the initial inputs to the reduce nodes and phi nodes they feed
* into (for the region that joins control back between the guard's two
* branches) are the same
*
* This optimization is useful with code generated by the Juno frontend as it
* generates guarded loops which are eventually converted into forks but the
* guard remains and in these cases the guard is no longer needed.
*/
/* Given a node index and the node itself, return None if the node is not
* a guarded fork where we can eliminate the guard.
* If the node is a fork with a guard we can eliminate returns a tuple of
* - This node's NodeID
* - The replication factor of the fork
* - The ID of the if of the guard
* - The ID of the projections of the if
* - The guard's predecessor
* - A map of NodeIDs for the phi nodes to the reduce they should be replaced
* with, and also the region that joins the guard's branches mapping to the
* fork's join NodeID
*/
fn guarded_fork(function: &Function,
constants: &Vec<Constant>,
fork_join_map: &HashMap<NodeID, NodeID>,
def_use: &ImmutableDefUseMap,
index: usize,
node: &Node)
-> Option<(NodeID, DynamicConstantID, NodeID, NodeID, NodeID,
NodeID, HashMap<NodeID, NodeID>)> {
// Identify fork nodes
let Node::Fork { control, factor } = node else { return None; };
// Whose predecessor is a read from an if
let Node::Read { collect : if_node, ref indices }
= function.nodes[control.idx()] else { return None; };
if indices.len() != 1 { return None; }
let Index::Control(branchIdx) = indices[0] else { return None; };
let Node::If { control : if_pred, cond } = function.nodes[if_node.idx()]
else { return None; };
// Whose condition is appropriate
let Node::Binary { left, right, op } = function.nodes[cond.idx()]
else { return None; };
// branchIdx == 1 means the true branch so we want the condition to be
// 0 < n or n > 0
if branchIdx == 1
&& !((op == BinaryOperator::LT && function.nodes[left.idx()].is_zero_constant(constants)
&& function.nodes[right.idx()].try_dynamic_constant() == Some(*factor))
|| (op == BinaryOperator::GT && function.nodes[right.idx()].is_zero_constant(constants)
&& function.nodes[left.idx()].try_dynamic_constant() == Some(*factor))) {
return None;
}
// branchIdx == 0 means the false branch so we want the condition to be
// n < 0 or 0 > n
if branchIdx == 0
&& !((op == BinaryOperator::LT && function.nodes[left.idx()].try_dynamic_constant() == Some(*factor)
&& function.nodes[right.idx()].is_zero_constant(constants))
|| (op == BinaryOperator::GT && function.nodes[right.idx()].try_dynamic_constant() == Some(*factor)
&& function.nodes[left.idx()].is_zero_constant(constants))) {
return None;
}
// Identify the join node and its users
let join_id = fork_join_map.get(&NodeID::new(index))?;
let join_users = def_use.get_users(*join_id);
// Find the unique control use of the join; if it's not a region we can't
// eliminate this guard
let join_control = join_users.iter().filter(|n| function.nodes[n.idx()].is_region())
.collect::<Vec<_>>();
if join_control.len() != 1 { return None; }
let join_control = join_control[0];
let Some(Node::Region { preds }) = function.nodes.get(join_control.idx())
else { return None; };
// The region after the join can only have two predecessors (for the guard
// and the fork-join)
if preds.len() != 2 { return None; }
let other_pred =
if preds[1] == *join_id {
preds[0]
} else if preds[0] == *join_id {
preds[1]
} else {
return None;
};
// Other predecessor needs to be the other read from the guard's if
let Node::Read { collect : read_control, ref indices }
= function.nodes[other_pred.idx()]
else { return None; };
if indices.len() != 1 { return None; }
let Index::Control(elseBranch) = indices[0] else { return None; };
if elseBranch == branchIdx { return None; }
if read_control != if_node { return None; }
// Finally, identify the phi nodes associated with the region and match
// them with the reduce nodes of the fork-join
let reduce_nodes = join_users.iter().filter(|n| function.nodes[n.idx()].is_reduce())
.collect::<HashSet<_>>();
// Construct a map from phi nodes indices to the reduce node index
let phi_nodes = def_use.get_users(*join_control).iter()
.filter_map(|n| {
let Node::Phi { control : _, ref data } = function.nodes[n.idx()]
else { return None; };
if data.len() != 2 { return Some((*n, None)); }
let (init_idx, reduce_node)
= if reduce_nodes.contains(&data[0]) {
(1, data[0])
} else if reduce_nodes.contains(&data[1]) {
(0, data[1])
} else { return Some((*n, None)); };
let Node::Reduce { control : _, init, .. }
= function.nodes[reduce_node.idx()]
else { return Some((*n, None)); };
if data[init_idx] != init { return Some((*n, None)); }
Some((*n, Some(reduce_node)))
})
.collect::<HashMap<_, _>>();
// If any of the phi nodes do not have an associated reduce node, we cannot
// remove the loop guard
if phi_nodes.iter().any(|(_, red)| red.is_none()) { return None; }
let mut phi_nodes = phi_nodes.into_iter()
.map(|(phi, red)| (phi, red.unwrap()))
.collect::<HashMap<_, _>>();
// We also add a map from the region to the join to this map so we only
// need one map to handle all node replacements in the elimination process
phi_nodes.insert(*join_control, *join_id);
// Finally, we return this node's index along with
// - The replication factor of the fork
// - The if node
// - The true and false reads of the if
// - The guard's predecessor
// - The map from phi nodes to reduce nodes and the region to the join
Some ((NodeID::new(index), *factor, if_node, *control, other_pred, if_pred, phi_nodes))
}
/*
* Top level function to run fork guard elimination, as described above.
* Deletes nodes by setting nodes to gravestones. Works with a function already
* containing gravestones.
*/
pub fn fork_guard_elim(function: &mut Function,
constants: &Vec<Constant>,
fork_join_map: &HashMap<NodeID, NodeID>,
def_use: &ImmutableDefUseMap) {
let guard_info = function.nodes.iter().enumerate()
.filter_map(|(i, n)| guarded_fork(function, constants,
fork_join_map,
def_use, i, n))
.collect::<Vec<_>>();
for (fork_node, factor, guard_node, guard_proj1, guard_proj2, guard_pred, map) in guard_info {
function.nodes[guard_node.idx()] = Node::Start;
function.nodes[guard_proj1.idx()] = Node::Start;
function.nodes[guard_proj2.idx()] = Node::Start;
function.nodes[fork_node.idx()]
= Node::Fork { control : guard_pred, factor : factor};
for (idx, node) in function.nodes.iter_mut().enumerate() {
let node_idx = NodeID::new(idx);
if map.contains_key(&node_idx) { *node = Node::Start; }
for u in get_uses_mut(node).as_mut() {
if let Some(replacement) = map.get(u) {
**u = *replacement;
}
}
}
}
}
......@@ -2,6 +2,7 @@
pub mod ccp;
pub mod dce;
pub mod fork_guard_elim;
pub mod forkify;
pub mod gvn;
pub mod phi_elim;
......@@ -13,5 +14,6 @@ pub use crate::dce::*;
pub use crate::forkify::*;
pub use crate::gvn::*;
pub use crate::phi_elim::*;
pub use crate::fork_guard_elim::*;
pub use crate::pass::*;
pub use crate::pred::*;
......@@ -27,6 +27,7 @@ pub enum Pass {
GVN,
Forkify,
PhiElim,
ForkGuardElim,
Predication,
Verify,
// Parameterized over whether analyses that aid visualization are necessary.
......@@ -327,6 +328,20 @@ impl PassManager {
phi_elim(function);
}
}
Pass::ForkGuardElim => {
self.make_def_uses();
self.make_fork_join_maps();
let def_uses = self.def_uses.as_ref().unwrap();
let fork_join_maps = self.fork_join_maps.as_ref().unwrap();
for idx in 0..self.module.functions.len() {
fork_guard_elim(
&mut self.module.functions[idx],
&self.module.constants,
&fork_join_maps[idx],
&def_uses[idx],
)
}
}
Pass::Predication => {
self.make_def_uses();
self.make_reverse_postorders();
......
......@@ -35,7 +35,9 @@ fn main() {
pm.add_pass(hercules_opt::pass::Pass::DCE);
pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
pm.add_pass(hercules_opt::pass::Pass::Forkify);
pm.add_pass(hercules_opt::pass::Pass::ForkGuardElim);
pm.add_pass(hercules_opt::pass::Pass::DCE);
pm.add_pass(hercules_opt::pass::Pass::Verify);
pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
let _ = pm.run_passes();
},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment