pred.rs 13.29 KiB
use std::cmp::{max, min};
use std::collections::{BTreeMap, BTreeSet};
use itertools::Itertools;
use hercules_ir::*;
use crate::*;
/*
* Top level function to run predication on a function. Repeatedly looks for
* acyclic control flow that can be converted into dataflow.
*/
pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) {
// Remove branches iteratively, since predicating an inside branch may cause
// an outside branch to be available for predication.
let mut bad_branches = BTreeSet::new();
loop {
// First, look for a branch whose projections all point to the same
// region. These are branches with no internal control flow.
let nodes = &editor.func().nodes;
let Some((region, branch, false_proj, true_proj)) = editor
.node_ids()
.filter_map(|id| {
if let Node::Region { ref preds } = nodes[id.idx()] {
// Look for two projections with the same branch.
let preds = preds.into_iter().filter_map(|id| {
nodes[id.idx()]
.try_control_proj()
.map(|(branch, selection)| (*id, branch, selection))
});
// Index projections by if branch.
let mut pred_map: BTreeMap<NodeID, Vec<(NodeID, usize)>> = BTreeMap::new();
for (proj, branch, selection) in preds {
if nodes[branch.idx()].is_if() && !bad_branches.contains(&branch) {
pred_map.entry(branch).or_default().push((proj, selection));
}
}
// Look for an if branch with two projections going into the
// same region.
for (branch, projs) in pred_map {
if projs.len() == 2 {
let way = projs[0].1;
assert_ne!(way, projs[1].1);
return Some((id, branch, projs[way].0, projs[1 - way].0));
}
}
}
None
})
.next()
else {
break;
};
let Node::Region { preds } = nodes[region.idx()].clone() else {
panic!()
};
let Node::If {
control: if_pred,
cond,
} = nodes[branch.idx()]
else {
panic!()
};
let phis: Vec<_> = editor
.get_users(region)
.filter(|id| nodes[id.idx()].is_phi())
.collect();
// Don't predicate branches where one of the phis is a collection.
// Predicating this branch would result in a clone and probably woudln't
// result in good vector code.
if phis
.iter()
.any(|id| !editor.get_type(typing[id.idx()]).is_primitive())
{
bad_branches.insert(branch);
continue;
}
let false_pos = preds.iter().position(|id| *id == false_proj).unwrap();
let true_pos = preds.iter().position(|id| *id == true_proj).unwrap();
// Second, make all the modifications:
// - Add the select nodes.
// - Replace uses in phis with the select node.
// - Remove the branch from the control flow.
// This leaves the old region in place - if the region has one
// predecessor, it may be removed by CCP.
let success = editor.edit(|mut edit| {
// Replace the branch projection predecessors of the region with the
// predecessor of the branch.
let Node::Region { preds } = edit.get_node(region).clone() else {
panic!()
};
let mut preds = Vec::from(preds);
preds.remove(max(true_pos, false_pos));
preds.remove(min(true_pos, false_pos));
preds.push(if_pred);
let new_region = edit.add_node(Node::Region {
preds: preds.into_boxed_slice(),
});
edit = edit.replace_all_uses(region, new_region)?;
// Replace the corresponding inputs in the phi nodes with select
// nodes selecting over the old inputs to the phis.
for phi in phis {
let Node::Phi { control: _, data } = edit.get_node(phi).clone() else {
panic!()
};
let mut data = Vec::from(data);
let select = edit.add_node(Node::Ternary {
op: TernaryOperator::Select,
first: cond,
second: data[true_pos],
third: data[false_pos],
});
data.remove(max(true_pos, false_pos));
data.remove(min(true_pos, false_pos));
data.push(select);
let new_phi = edit.add_node(Node::Phi {
control: new_region,
data: data.into_boxed_slice(),
});
edit = edit.replace_all_uses(phi, new_phi)?;
edit = edit.delete_node(phi)?;
}
// Delete the old control nodes.
edit = edit.delete_node(region)?;
edit = edit.delete_node(branch)?;
edit = edit.delete_node(false_proj)?;
edit = edit.delete_node(true_proj)?;
Ok(edit)
});
if !success {
bad_branches.insert(branch);
}
}
// Do a quick and dirty rewrite to convert select(a, b, false) to a && b and
// select(a, b, true) to a || b.
for id in editor.node_ids() {
let nodes = &editor.func().nodes;
if let Node::Ternary {
op: TernaryOperator::Select,
first,
second,
third,
} = nodes[id.idx()]
{
if let Some(cons) = nodes[second.idx()].try_constant()
&& editor.get_constant(cons).is_false()
{
editor.edit(|mut edit| {
let inv = edit.add_node(Node::Unary {
op: UnaryOperator::Not,
input: first,
});
let node = edit.add_node(Node::Binary {
op: BinaryOperator::And,
left: inv,
right: third,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
} else if let Some(cons) = nodes[third.idx()].try_constant()
&& editor.get_constant(cons).is_false()
{
editor.edit(|mut edit| {
let node = edit.add_node(Node::Binary {
op: BinaryOperator::And,
left: first,
right: second,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
} else if let Some(cons) = nodes[second.idx()].try_constant()
&& editor.get_constant(cons).is_true()
{
editor.edit(|mut edit| {
let node = edit.add_node(Node::Binary {
op: BinaryOperator::Or,
left: first,
right: third,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
} else if let Some(cons) = nodes[third.idx()].try_constant()
&& editor.get_constant(cons).is_true()
{
editor.edit(|mut edit| {
let inv = edit.add_node(Node::Unary {
op: UnaryOperator::Not,
input: first,
});
let node = edit.add_node(Node::Binary {
op: BinaryOperator::Or,
left: inv,
right: second,
});
edit = edit.replace_all_uses(id, node)?;
edit.delete_node(id)
});
}
}
}
}
/*
* Top level function to run write predication on a function. Repeatedly looks
* for phi nodes where every data input is a write node and each write node has
* matching types of indices. These writes are coalesced into a single write
* using the phi as the `collect` input, and the phi selects over the old
* `collect` inputs to the writes. New phis are added to select over the `data`
* inputs and indices.
*/
pub fn write_predication(editor: &mut FunctionEditor) {
let mut bad_phis = BTreeSet::new();
loop {
// First, look for phis where every input is a write with the same
// indexing structure.
let nodes = &editor.func().nodes;
let Some((phi, control, writes)) = editor
.node_ids()
.filter_map(|id| {
if let Node::Phi {
control,
ref data,
} = nodes[id.idx()]
&& !bad_phis.contains(&id)
// Check that every input is a write - if this weren't true,
// we'd have to insert dummy writes that write something that
// was just read from the array. We could handle this case, but
// it's probably not needed for now. Also check that the phi is
// the only user of the write.
&& data.into_iter().all(|id| nodes[id.idx()].is_write() && editor.get_users(*id).count() == 1)
// Check that every write input has equivalent indexing
// structure.
&& data
.into_iter()
.filter_map(|id| nodes[id.idx()].try_write())
.tuple_windows()
.all(|(w1, w2)| indices_structurally_equivalent(w1.2, w2.2))
{
Some((id, control, data.clone()))
} else {
None
}
})
.next()
else {
break;
};
let (collects, datas, indices): (Vec<_>, Vec<_>, Vec<_>) = writes
.iter()
.filter_map(|id| nodes[id.idx()].try_write())
.map(|(collect, data, indices)| (collect, data, indices.to_owned()))
.multiunzip();
// Second, make all the modifications:
// - Replace the old phi with a phi selecting over the `collect` inputs
// of the old write inputs.
// - Add phis for the `data` and `indices` inputs to the old writes.
// - Add a write that uses the phi-ed data and indices.
// Be a little careful over how the old phi and writes get replaced,
// since the old phi itself may be used by the old writes.
let success = editor.edit(|mut edit| {
// Create all the phis.
let collect_phi = edit.add_node(Node::Phi {
control,
data: collects.into_boxed_slice(),
});
let data_phi = edit.add_node(Node::Phi {
control,
data: datas.into_boxed_slice(),
});
let mut phied_indices = vec![];
for index in 0..indices[0].len() {
match indices[0][index] {
// For field and variant indices, the index is the same
// across all writes, so just take the one from the first
// set of indices.
Index::Position(ref old_pos) => {
let mut pos = vec![];
for pos_idx in 0..old_pos.len() {
// This code is kind of messy due to three layers of
// arrays. Basically, we are collecting every
// indexing node on indices across each write.
pos.push(
edit.add_node(Node::Phi {
control,
data: indices
.iter()
.map(|indices| {
indices[index].try_position().unwrap()[pos_idx]
})
.collect(),
}),
);
}
phied_indices.push(Index::Position(pos.into_boxed_slice()));
}
_ => {
phied_indices.push(indices[0][index].clone());
}
}
}
// Create the write.
let new_write = edit.add_node(Node::Write {
collect: collect_phi,
data: data_phi,
indices: phied_indices.into_boxed_slice(),
});
// Replace the old phi with the new write.
edit = edit.replace_all_uses(phi, new_write)?;
// Delete the old phi and writes.
edit = edit.delete_node(phi)?;
for write in writes {
edit = edit.delete_node(write)?;
}
Ok(edit)
});
if !success {
bad_phis.insert(phi);
}
}
}