forkify.rs 18.72 KiB
use std::collections::HashMap;
use std::collections::HashSet;
use std::iter::zip;
use std::iter::FromIterator;
use itertools::Itertools;
use nestify::nest;
use hercules_ir::*;
use crate::*;
/*
* TODO: Forkify currently makes a bunch of small edits - this needs to be
* changed so that every loop that gets forkified corresponds to a single edit
* + sub-edits. This would allow us to run forkify on a subset of a function.
*/
pub fn forkify(
editor: &mut FunctionEditor,
control_subgraph: &Subgraph,
fork_join_map: &HashMap<NodeID, NodeID>,
loops: &LoopTree,
) -> bool {
let natural_loops = loops
.bottom_up_loops()
.into_iter()
.filter(|(k, _)| editor.func().nodes[k.idx()].is_region());
let natural_loops: Vec<_> = natural_loops.collect();
for l in natural_loops {
// FIXME: Run on all-bottom level loops, as they can be independently optimized without recomputing analyses.
if editor.is_mutable(l.0)
&& forkify_loop(
editor,
control_subgraph,
fork_join_map,
&Loop {
header: l.0,
control: l.1.clone(),
},
)
{
return true;
}
}
return false;
}
/** Given a node used as a loop bound, return a dynamic constant ID. */
pub fn get_node_as_dc(
editor: &mut FunctionEditor,
node: NodeID,
) -> Result<DynamicConstantID, String> {
// Check for a constant used as loop bound.
match editor.node(node) {
Node::DynamicConstant {
id: dynamic_constant_id,
} => Ok(*dynamic_constant_id),
Node::Constant { id: constant_id } => {
let dc = match *editor.get_constant(*constant_id) {
Constant::Integer8(x) => DynamicConstant::Constant(x as _),
Constant::Integer16(x) => DynamicConstant::Constant(x as _),
Constant::Integer32(x) => DynamicConstant::Constant(x as _),
Constant::Integer64(x) => DynamicConstant::Constant(x as _),
Constant::UnsignedInteger8(x) => DynamicConstant::Constant(x as _),
Constant::UnsignedInteger16(x) => DynamicConstant::Constant(x as _),
Constant::UnsignedInteger32(x) => DynamicConstant::Constant(x as _),
Constant::UnsignedInteger64(x) => DynamicConstant::Constant(x as _),
_ => return Err("Invalid constant as loop bound".to_string()),
};
let mut b = DynamicConstantID::new(0);
editor.edit(|mut edit| {
b = edit.add_dynamic_constant(dc);
Ok(edit)
});
// Return the ID of the dynamic constant that is generated from the constant
// or dynamic constant that is the existing loop bound
Ok(b)
}
_ => Err("Blah".to_owned()),
}
}
/**
Top level function to convert natural loops with simple induction variables
into fork-joins.
*/
pub fn forkify_loop(
editor: &mut FunctionEditor,
control_subgraph: &Subgraph,
_fork_join_map: &HashMap<NodeID, NodeID>,
l: &Loop,
) -> bool {
let function = editor.func();
let Some(loop_condition) = get_loop_exit_conditions(function, l, control_subgraph) else {
return false;
};
let LoopExit::Conditional {
if_node: loop_if,
condition_node,
} = loop_condition.clone()
else {
return false;
};
// Compute loop variance
let loop_variance = compute_loop_variance(editor, l);
let ivs = compute_induction_vars(editor.func(), l, &loop_variance);
let ivs = compute_iv_ranges(editor, l, ivs, &loop_condition);
let Some(canonical_iv) = has_canonical_iv(editor, l, &ivs) else {
return false;
};
// Get bound
let bound = match canonical_iv {
InductionVariable::Basic {
node: _,
initializer: _,
final_value,
update_expression,
update_value,
} => final_value
.map(|final_value| get_node_as_dc(editor, final_value))
.and_then(|r| r.ok()),
InductionVariable::SCEV(_) => return false,
};
let Some(bound_dc_id) = bound else {
return false;
};
let function = editor.func();
// Check if it is do-while loop.
let loop_exit_projection = editor
.get_users(loop_if)
.filter(|id| !l.control[id.idx()])
.next()
.unwrap();
let loop_continue_projection = editor
.get_users(loop_if)
.filter(|id| l.control[id.idx()])
.next()
.unwrap();
let loop_preds: Vec<_> = editor
.get_uses(l.header)
.filter(|id| !l.control[id.idx()])
.collect();
// FIXME: @xrouth
if loop_preds.len() != 1 {
return false;
}
let loop_pred = loop_preds[0];
if !editor
.get_uses(l.header)
.contains(&loop_continue_projection)
{
return false;
}
// Get all phis used outside of the loop, they need to be reductionable.
// For now just assume all phis will be phis used outside of the loop, except for the canonical iv.
// FIXME: We need a different definiton of `loop_nodes` to check for phis used outside hte loop than the one
// we currently have.
let loop_nodes = calculate_loop_nodes(editor, l);
// Check phis to see if they are reductionable, only PHIs depending on the loop are considered,
let candidate_phis: Vec<_> = editor
.get_users(l.header)
.filter(|id| function.nodes[id.idx()].is_phi())
.filter(|id| *id != canonical_iv.phi())
.collect();
let reductionable_phis: Vec<_> = analyze_phis(&editor, &l, &candidate_phis, &loop_nodes)
.into_iter()
.collect();
// TODO: Handle multiple loop body lasts.
// If there are multiple candidates for loop body last, return false.
if editor
.get_uses(loop_if)
.filter(|id| l.control[id.idx()])
.count()
> 1
{
return false;
}
let loop_body_last = editor.get_uses(loop_if).next().unwrap();
if reductionable_phis
.iter()
.any(|phi| !matches!(phi, LoopPHI::Reductionable { .. }))
{
return false;
}
let phi_latches: Vec<_> = reductionable_phis
.iter()
.map(|phi| {
let LoopPHI::Reductionable {
phi: _,
data_cycle: _,
continue_latch,
is_associative: _,
} = phi
else {
unreachable!()
};
continue_latch
})
.collect();
let stop_on: HashSet<_> = editor
.node_ids()
.filter(|node| {
if editor.node(node).is_phi() {
return true;
}
if editor.node(node).is_reduce() {
return true;
}
if editor.node(node).is_control() {
return true;
}
if phi_latches.contains(&node) {
return true;
}
false
})
.collect();
// Outside loop users of IV, then exit;
// Unless the outside user is through the loop latch of a reducing phi,
// then we know how to replace this edge, so its fine!
let iv_users: Vec<_> =
walk_all_users_stop_on(canonical_iv.phi(), editor, stop_on.clone()).collect();
if iv_users
.iter()
.any(|node| !loop_nodes.contains(&node) && *node != loop_if)
{
return false;
}
// Start Transformation:
// Graft everything between header and loop condition
// Attach join to right before header (after loop_body_last, unless loop body last *is* the header).
// Attach fork to right after loop_continue_projection.
// // Create fork and join nodes:
let mut join_id = NodeID::new(0);
let mut fork_id = NodeID::new(0);
// Turn dc bound into max (1, bound),
let bound_dc_id = {
let mut max_id = DynamicConstantID::new(0);
editor.edit(|mut edit| {
let one_id = edit.add_dynamic_constant(DynamicConstant::Constant(1));
max_id = edit.add_dynamic_constant(DynamicConstant::max(one_id, bound_dc_id));
Ok(edit)
});
max_id
};
// FIXME: (@xrouth) double check handling of control in loop body.
editor.edit(|mut edit| {
let fork = Node::Fork {
control: loop_pred,
factors: Box::new([bound_dc_id]),
};
fork_id = edit.add_node(fork);
let join = Node::Join {
control: if l.header == loop_body_last {
fork_id
} else {
loop_body_last
},
};
join_id = edit.add_node(join);
Ok(edit)
});
let function = editor.func();
let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap();
let dimension = factors.len() - 1;
let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis
.iter()
.map(|reduction_phi| {
let LoopPHI::Reductionable {
phi,
data_cycle: _,
continue_latch: _,
is_associative: _,
} = reduction_phi
else {
panic!();
};
let function = editor.func();
let init = *zip(
editor.get_uses(l.header),
function.nodes[phi.idx()].try_phi().unwrap().1.iter(),
)
.filter(|(c, _)| *c == loop_pred)
.next()
.unwrap()
.1;
(reduction_phi, init)
})
.collect();
// Start failable edit:
let result = editor.edit(|mut edit| {
let thread_id = Node::ThreadID {
control: fork_id,
dimension: dimension,
};
let thread_id_id = edit.add_node(thread_id);
// Replace uses that are inside with the thread id
edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| {
loop_nodes.contains(node)
})?;
edit.sub_edit(canonical_iv.phi(), thread_id_id);
edit = edit.delete_node(canonical_iv.phi())?;
for (reduction_phi, init) in redcutionable_phis_and_init {
let LoopPHI::Reductionable {
phi,
data_cycle: _,
continue_latch,
is_associative: _,
} = *reduction_phi
else {
panic!();
};
let reduce = Node::Reduce {
control: join_id,
init,
reduct: continue_latch,
};
let reduce_id = edit.add_node(reduce);
if (!edit.get_node(init).is_reduce()
&& edit.get_schedule(init).contains(&Schedule::ParallelReduce))
|| (!edit.get_node(continue_latch).is_reduce()
&& edit
.get_schedule(continue_latch)
.contains(&Schedule::ParallelReduce))
{
edit = edit.add_schedule(reduce_id, Schedule::ParallelReduce)?;
}
if (!edit.get_node(init).is_reduce()
&& edit.get_schedule(init).contains(&Schedule::MonoidReduce))
|| (!edit.get_node(continue_latch).is_reduce()
&& edit
.get_schedule(continue_latch)
.contains(&Schedule::MonoidReduce))
{
edit = edit.add_schedule(reduce_id, Schedule::MonoidReduce)?;
}
edit = edit.replace_all_uses_where(phi, reduce_id, |usee| *usee != reduce_id)?;
edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| {
!loop_nodes.contains(usee) && *usee != reduce_id
})?;
edit.sub_edit(phi, reduce_id);
edit = edit.delete_node(phi)?
}
edit = edit.replace_all_uses(l.header, fork_id)?;
edit = edit.replace_all_uses(loop_continue_projection, fork_id)?;
edit = edit.replace_all_uses(loop_exit_projection, join_id)?;
edit.sub_edit(l.header, fork_id);
edit.sub_edit(loop_continue_projection, fork_id);
edit.sub_edit(loop_exit_projection, join_id);
edit = edit.delete_node(loop_continue_projection)?;
edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this.
edit = edit.delete_node(loop_exit_projection)?;
edit = edit.delete_node(loop_if)?;
edit = edit.delete_node(l.header)?;
Ok(edit)
});
return result;
}
nest! {
#[derive(Debug)]
pub enum LoopPHI {
Reductionable {
phi: NodeID,
data_cycle: HashSet<NodeID>, // All nodes in a data cycle with this phi
continue_latch: NodeID,
is_associative: bool,
},
LoopDependant(NodeID),
ControlDependant(NodeID), // This phi is redcutionable, but its cycle might depend on internal control within the loop.
UsedByDependant(NodeID),
}
}
impl LoopPHI {
pub fn get_phi(&self) -> NodeID {
match self {
LoopPHI::Reductionable { phi, .. } => *phi,
LoopPHI::LoopDependant(node_id) => *node_id,
LoopPHI::UsedByDependant(node_id) => *node_id,
LoopPHI::ControlDependant(node_id) => *node_id,
}
}
}
/**
Checks some conditions on loop variables that will need to be converted into reductions to be forkified.
- The phi is in a cycle *in the loop* with itself.
- Every cycle *in the loop* containing the phi does not contain any other phi of the loop header.
- The phi does not immediatley (not blocked by another phi or another reduce) use any other phis of the loop header.
*/
pub fn analyze_phis<'a>(
editor: &'a FunctionEditor,
natural_loop: &'a Loop,
phis: &'a [NodeID],
loop_nodes: &'a HashSet<NodeID>,
) -> impl Iterator<Item = LoopPHI> + 'a {
// Find data cycles within the loop of this phi,
// Start from the phis loop_continue_latch, and walk its uses until we find the original phi.
phis.into_iter().map(move |phi| {
let stop_on: HashSet<NodeID> = editor
.node_ids()
.filter(|node| {
let data = &editor.func().nodes[node.idx()];
// External Phi
if let Node::Phi { control, data: _ } = data {
if !natural_loop.control[control.idx()] {
return true;
}
}
// This phi
if node == phi {
return true;
}
// External Reduce
if let Node::Reduce {
control,
init: _,
reduct: _,
} = data
{
if !natural_loop.control[control.idx()] {
return true;
} else {
return false;
}
}
// Data Cycles Only
if data.is_control() {
return true;
}
return false;
})
.collect();
let continue_idx = editor
.get_uses(natural_loop.header)
.position(|node| natural_loop.control[node.idx()])
.unwrap();
let loop_continue_latch = editor.node(phi).try_phi().unwrap().1[continue_idx];
let uses = walk_all_uses_stop_on(loop_continue_latch, editor, stop_on.clone());
let users = walk_all_users_stop_on(*phi, editor, stop_on.clone());
let other_stop_on: HashSet<NodeID> = editor
.node_ids()
.filter(|node| {
let data = &editor.func().nodes[node.idx()];
// Phi, Reduce
if data.is_phi() {
return true;
}
if data.is_reduce() {
return true;
}
// External Control
if data.is_control() {
return true;
}
return false;
})
.collect();
let mut uses_for_dependance =
walk_all_users_stop_on(loop_continue_latch, editor, other_stop_on);
let set1: HashSet<_> = HashSet::from_iter(uses);
let set2: HashSet<_> = HashSet::from_iter(users);
let intersection: HashSet<_> = set1.intersection(&set2).cloned().collect();
// If this phi uses any other phis the node is loop dependant,
// we use `phis` because this phi can actually contain the loop iv and its fine.
if uses_for_dependance.any(|node| phis.contains(&node) && node != *phi) {
LoopPHI::LoopDependant(*phi)
} else if intersection.clone().iter().next().is_some() {
// PHIs on the frontier of the uses by the candidate phi, i.e in uses_for_dependance need
// to have headers that postdominate the loop continue latch. The value of the PHI used needs to be defined
// by the time the reduce is triggered (at the end of the loop's internal control).
// No nodes in data cycles with this phi (in the loop) are used outside the loop, besides the loop_continue_latch.
// If some other node in the cycle is used, there is not a valid node to assign it after making the cycle a reduce.
if intersection
.iter()
.filter(|node| **node != loop_continue_latch)
.any(|data_node| {
editor
.get_users(*data_node)
.any(|user| !loop_nodes.contains(&user))
})
{
// This phi can be made into a reduce in different ways, if the cycle is associative (contains all the same kind of associative op)
// 3) Split the cycle into two phis, add them or multiply them together at the end.
// 4) Split the cycle into two reduces, add them or multiply them together at the end.
// Somewhere else should handle this.
return LoopPHI::LoopDependant(*phi);
}
// FIXME: Do we want to calculate associativity here, there might be a case where this information is used in forkify
// i.e as described above.
let is_associative = false;
// No nodes in the data cycle are used outside of the loop, besides the latched value of the phi
LoopPHI::Reductionable {
phi: *phi,
data_cycle: intersection,
continue_latch: loop_continue_latch,
is_associative,
}
} else {
// No cycles exist, this isn't a reduction.
LoopPHI::LoopDependant(*phi)
}
})
}