-
Aaron Councilman authoredAaron Councilman authored
fork_guard_elim.rs 11.76 KiB
use std::collections::{HashMap, HashSet};
use std::ops::Deref;
use hercules_ir::*;
use crate::*;
/*
* This is a Hercules IR transformation that:
* - Eliminates guards (directly) surrounding fork-joins when the guard's
* condition is of the form 0 < n for a fork with replication factor n,
* and when the initial inputs to the reduce nodes and phi nodes they feed
* into (for the region that joins control back between the guard's two
* branches) are the same
*
* This optimization is useful with code generated by the Juno frontend as it
* generates guarded loops which are eventually converted into forks but the
* guard remains and in these cases the guard is no longer needed.
*/
// Simplify factors through max
enum Factor {
Max(usize, DynamicConstantID),
Normal(DynamicConstantID),
}
impl Factor {
fn get_id(&self) -> DynamicConstantID {
match self {
Factor::Max(_, dynamic_constant_id) => *dynamic_constant_id,
Factor::Normal(dynamic_constant_id) => *dynamic_constant_id,
}
}
}
struct GuardedFork {
fork: NodeID,
join: NodeID,
guard_if: NodeID,
fork_taken_proj: NodeID,
fork_skipped_proj: NodeID,
guard_join_region: NodeID,
phi_reduce_map: HashMap<NodeID, NodeID>,
factor: Factor, // The factor that matches the guard
}
/* Given a node index and the node itself, return None if the node is not
* a guarded fork where we can eliminate the guard.
* If the node is a fork with a guard we can eliminate returns a tuple of
* - This node's NodeID
* - The replication factor of the fork
* - The ID of the if of the guard
* - The ID of the projections of the if
* - The guard's predecessor
* - A map of NodeIDs for the phi nodes to the reduce they should be replaced
* with, and also the region that joins the guard's branches mapping to the
* fork's join NodeID
* - If the replication factor is a max that can be eliminated.
*/
fn guarded_fork(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
node: NodeID,
) -> Option<GuardedFork> {
let function = editor.func();
// Identify fork nodes
let Node::Fork { control, factors } = &function.nodes[node.idx()] else {
return None;
};
let mut factors = factors.iter().enumerate().map(|(idx, dc)| {
let factor = editor.get_dynamic_constant(*dc);
let DynamicConstant::Max(xs) = factor.deref() else {
return Factor::Normal(*dc);
};
// Filter out any terms which are just 1s
let non_ones = xs
.iter()
.filter(|i| {
if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() {
false
} else {
true
}
})
.collect::<Vec<_>>();
// If we're left with just one term x, we had max { 1, x }
if non_ones.len() == 1 {
Factor::Max(idx, *non_ones[0])
} else {
Factor::Normal(*dc)
}
});
// Whose predecessor is a read from an if
let Node::ControlProjection {
control: if_node,
ref selection,
} = function.nodes[control.idx()]
else {
return None;
};
let Node::If {
control: if_pred,
cond,
} = function.nodes[if_node.idx()]
else {
return None;
};
// Whose condition is appropriate
let Node::Binary { left, right, op } = function.nodes[cond.idx()] else {
return None;
};
let branch_idx = *selection;
let factor = {
// branchIdx == 1 means the true branch so we want the condition to be
// 0 < n or n > 0
if branch_idx == 1 {
[
(left, BinaryOperator::LT, right),
(right, BinaryOperator::GT, left),
]
.iter()
.find_map(|(pattern_zero, pattern_op, pattern_factor)| {
// Match Op
if op != *pattern_op {
return None;
}
// Match Zero
if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants())
|| editor
.node(pattern_zero)
.is_zero_dc(&editor.get_dynamic_constants()))
{
return None;
}
// Match Factor
let factor = factors.find(|factor| {
match (
&function.nodes[pattern_factor.idx()],
&*editor.get_dynamic_constant(factor.get_id()),
) {
(Node::Constant { id }, DynamicConstant::Constant(v)) => {
let Constant::UnsignedInteger64(pattern_v) = *editor.get_constant(*id)
else {
return false;
};
pattern_v == (*v as u64)
}
(Node::DynamicConstant { id }, _) => *id == factor.get_id(),
_ => false,
}
});
factor
})
}
// branchIdx == 0 means the false branch so we want the condition to be
// n < 0 or 0 > n
else if branch_idx == 0 {
[
(right, BinaryOperator::LT, left),
(left, BinaryOperator::GT, right),
]
.iter()
.find_map(|(pattern_zero, pattern_op, pattern_factor)| {
// Match Op
if op != *pattern_op {
return None;
}
// Match Zero
if !(function.nodes[pattern_zero.idx()].is_zero_constant(&editor.get_constants())
|| editor
.node(pattern_zero)
.is_zero_dc(&editor.get_dynamic_constants()))
{
return None;
}
// Match Factor
let factor = factors.find(|factor| {
function.nodes[pattern_factor.idx()].try_dynamic_constant()
== Some(factor.get_id())
});
factor
})
} else {
None
}
};
let Some(factor) = factor else { return None };
// Identify the join node and its users
let join_id = fork_join_map.get(&node)?;
// Find the unique control use of the join; if it's not a region we can't
// eliminate this guard
let join_control = editor
.get_users(*join_id)
.filter(|n| function.nodes[n.idx()].is_region())
.collect::<Vec<_>>();
if join_control.len() != 1 {
return None;
}
let join_control = join_control[0];
let Some(Node::Region { preds }) = function.nodes.get(join_control.idx()) else {
return None;
};
// The region after the join can only have two predecessors (for the guard
// and the fork-join)
if preds.len() != 2 {
return None;
}
let other_pred = if preds[1] == *join_id {
preds[0]
} else if preds[0] == *join_id {
preds[1]
} else {
return None;
};
// Other predecessor needs to be the other projection from the guard's if
let Node::ControlProjection {
control: if_node2,
ref selection,
} = function.nodes[other_pred.idx()]
else {
return None;
};
let else_branch = *selection;
if else_branch == branch_idx {
return None;
}
if if_node2 != if_node {
return None;
}
// Finally, identify the phi nodes associated with the region and match
// them with the reduce nodes of the fork-join
let reduce_nodes = editor
.get_users(*join_id)
.filter(|n| function.nodes[n.idx()].is_reduce())
.collect::<HashSet<_>>();
// Construct a map from phi nodes indices to the reduce node index
let phi_nodes = editor
.get_users(join_control)
.filter_map(|n| {
let Node::Phi {
control: _,
ref data,
} = function.nodes[n.idx()]
else {
return None;
};
if data.len() != 2 {
return Some((n, None));
}
let (init_idx, reduce_node) = if reduce_nodes.contains(&data[0]) {
(1, data[0])
} else if reduce_nodes.contains(&data[1]) {
(0, data[1])
} else {
return Some((n, None));
};
let Node::Reduce {
control: _, init, ..
} = function.nodes[reduce_node.idx()]
else {
return Some((n, None));
};
if data[init_idx] != init {
return Some((n, None));
}
Some((n, Some(reduce_node)))
})
.collect::<HashMap<_, _>>();
// If any of the phi nodes do not have an associated reduce node, we cannot
// remove the loop guard
if phi_nodes.iter().any(|(_, red)| red.is_none()) {
return None;
}
let phi_nodes = phi_nodes
.into_iter()
.map(|(phi, red)| (phi, red.unwrap()))
.collect::<HashMap<_, _>>();
// Finally, we return this node's index along with
// - The replication factor of the fork
// - The if node
// - The true and false reads of the if
// - The guard's predecessor
// - The map from phi nodes to reduce nodes and the region to the join
Some(GuardedFork {
fork: node,
join: *join_id,
guard_if: if_node,
fork_taken_proj: *control,
fork_skipped_proj: other_pred,
guard_join_region: join_control,
phi_reduce_map: phi_nodes,
factor,
})
}
/*
* Top level function to run fork guard elimination, as described above.
*/
pub fn fork_guard_elim(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
let guard_info = editor
.node_ids()
.filter_map(|node| guarded_fork(editor, fork_join_map, node))
.collect::<Vec<_>>();
for GuardedFork {
fork,
join,
fork_taken_proj,
fork_skipped_proj,
phi_reduce_map,
factor,
guard_if,
guard_join_region,
} in guard_info
{
let Some(guard_pred) = editor.get_uses(guard_if).next() else {
unreachable!()
};
let new_fork_info = if let Factor::Max(idx, dc) = factor {
let Node::Fork {
control: _,
mut factors,
} = editor.func().nodes[fork.idx()].clone()
else {
unreachable!()
};
factors[idx] = dc;
let new_fork = Node::Fork {
control: guard_pred,
factors,
};
Some(new_fork)
} else {
None
};
editor.edit(|mut edit| {
edit =
edit.replace_all_uses_where(fork_taken_proj, guard_pred, |usee| *usee == fork)?;
edit = edit.delete_node(guard_if)?;
edit = edit.delete_node(fork_taken_proj)?;
edit = edit.delete_node(fork_skipped_proj)?;
edit = edit.replace_all_uses(guard_join_region, join)?;
edit = edit.delete_node(guard_join_region)?;
// Delete region node
for (phi, reduce) in phi_reduce_map.iter() {
edit = edit.replace_all_uses(*phi, *reduce)?;
edit = edit.delete_node(*phi)?;
}
if let Some(new_fork_info) = new_fork_info {
let new_fork = edit.add_node(new_fork_info);
edit = edit.replace_all_uses(fork, new_fork)?;
edit = edit.delete_node(fork)?;
}
Ok(edit)
});
}
}