sroa.rs 13.63 KiB
extern crate bitvec;
extern crate hercules_ir;
use std::collections::HashMap;
use std::iter::zip;
use self::bitvec::prelude::*;
use self::hercules_ir::dataflow::*;
use self::hercules_ir::def_use::*;
use self::hercules_ir::ir::*;
/*
* Top level function to run SROA, intraprocedurally. Product values can be used
* and created by a relatively small number of nodes. Here are *all* of them:
*
* - Phi: can merge SSA values of products - these get broken up into phis on
* the individual fields
*
* - Reduce: similarly to phis, reduce nodes can cycle product values through
* reduction loops - these get broken up into reduces on the fields
*
* + Return: can return a product - these are untouched, and are the sinks for
* unbroken product values
*
* + Parameter: can introduce a product - these are untouched, and are the
* sources for unbroken product values
*
* - Constant: can introduce a product - these are broken up into constants for
* the individual fields
*
* - Ternary: the select ternary operator can select between products - these
* are broken up into ternary nodes for the individual fields
*
* + Call: the call node can use a product value as an argument to another
* function, and can produce a product value as a result - these are
* untouched, and are the sink and source for unbroken product values
*
* - Read: the read node reads primitive fields from product values - these get
* replaced by a direct use of the field value from the broken product value,
* but are retained when the product value is unbroken
*
* - Write: the write node writes primitive fields in product values - these get
* replaced by a direct def of the field value from the broken product value,
* but are retained when the product value is unbroken
*
* The nodes above with the list marker "+" are retained for maintaining API/ABI
* compatability with other Hercules functions and the host code. These are
* called "sink" or "source" nodes in comments below.
*/
pub fn sroa(
function: &mut Function,
def_use: &ImmutableDefUseMap,
reverse_postorder: &Vec<NodeID>,
typing: &Vec<TypeID>,
types: &Vec<Type>,
constants: &mut Vec<Constant>,
) {
// Determine which sources of product values we want to try breaking up. We
// can determine easily on the soure side if a node produces a product that
// shouldn't be broken up by just examining the node type. However, the way
// that products are used is also important for determining if the product
// can be broken up. We backward dataflow this info to the sources of
// product values.
#[derive(PartialEq, Eq, Clone, Debug)]
enum ProductUseLattice {
// The product value used by this node is eventually used by a sink.
UsedBySink,
// This node uses multiple product values - the stored node ID indicates
// which is eventually used by a sink. This lattice value is produced by
// read and write nodes implementing partial indexing.
SpecificUsedBySink(NodeID),
// This node doesn't use a product node, or the product node it does use
// is not in turn used by a sink.
UnusedBySink,
}
impl Semilattice for ProductUseLattice {
fn meet(a: &Self, b: &Self) -> Self {
match (a, b) {
(Self::UsedBySink, _) | (_, Self::UsedBySink) => Self::UsedBySink,
(Self::SpecificUsedBySink(id1), Self::SpecificUsedBySink(id2)) => {
if id1 == id2 {
Self::SpecificUsedBySink(*id1)
} else {
Self::UsedBySink
}
}
(Self::SpecificUsedBySink(id), _) | (_, Self::SpecificUsedBySink(id)) => {
Self::SpecificUsedBySink(*id)
}
_ => Self::UnusedBySink,
}
}
fn bottom() -> Self {
Self::UsedBySink
}
fn top() -> Self {
Self::UnusedBySink
}
}
// Run dataflow analysis to find which product values are used by a sink.
let product_uses = backward_dataflow(function, def_use, reverse_postorder, |succ_outs, id| {
match function.nodes[id.idx()] {
Node::Return {
control: _,
data: _,
} => {
if types[typing[id.idx()].idx()].is_product() {
ProductUseLattice::UsedBySink
} else {
ProductUseLattice::UnusedBySink
}
}
Node::Call {
function: _,
dynamic_constants: _,
args: _,
} => todo!(),
// For reads and writes, we only want to propagate the use of the
// product to the collect input of the node.
Node::Read {
collect,
indices: _,
}
| Node::Write {
collect,
data: _,
indices: _,
} => {
let meet = succ_outs
.iter()
.fold(ProductUseLattice::top(), |acc, latt| {
ProductUseLattice::meet(&acc, latt)
});
if meet == ProductUseLattice::UnusedBySink {
ProductUseLattice::UnusedBySink
} else {
ProductUseLattice::SpecificUsedBySink(collect)
}
}
// For non-sink nodes.
_ => {
if function.nodes[id.idx()].is_control() {
return ProductUseLattice::UnusedBySink;
}
let meet = succ_outs
.iter()
.fold(ProductUseLattice::top(), |acc, latt| {
ProductUseLattice::meet(&acc, latt)
});
if let ProductUseLattice::SpecificUsedBySink(meet_id) = meet {
if meet_id == id {
ProductUseLattice::UsedBySink
} else {
ProductUseLattice::UnusedBySink
}
} else {
meet
}
}
}
});
// Only product values introduced as constants can be replaced by scalars.
let to_sroa: Vec<(NodeID, ConstantID)> = product_uses
.into_iter()
.enumerate()
.filter_map(|(node_idx, product_use)| {
if ProductUseLattice::UnusedBySink == product_use
&& types[typing[node_idx].idx()].is_product()
{
function.nodes[node_idx]
.try_constant()
.map(|cons_id| (NodeID::new(node_idx), cons_id))
} else {
None
}
})
.collect();
println!("{:?}", to_sroa);
// Perform SROA. TODO: repair def-use when there are multiple product
// constants to SROA away.
assert!(to_sroa.len() < 2);
for (constant_node_id, constant_id) in to_sroa {
// Get the field constants to replace the product constant with.
let product_constant = constants[constant_id.idx()].clone();
let constant_fields = product_constant
.try_product_fields(types, constants)
.unwrap();
println!("{:?}", constant_fields);
// DFS to find all data nodes that use the product constant.
let to_replace = sroa_dfs(constant_node_id, function, def_use);
println!("{:?}", to_replace);
// Assemble a mapping from old nodes IDs acting on the product constant
// to new nodes IDs operating on the field constants.
let old_to_new_id_map: HashMap<NodeID, Vec<NodeID>> = to_replace
.iter()
.map(|old_id| match function.nodes[old_id.idx()] {
Node::Phi {
control: _,
data: _,
}
| Node::Reduce {
control: _,
init: _,
reduct: _,
}
| Node::Constant { id: _ }
| Node::Ternary {
op: _,
first: _,
second: _,
third: _,
}
| Node::Write {
collect: _,
data: _,
indices: _,
} => {
let new_ids = (0..constant_fields.len())
.map(|_| {
let id = NodeID::new(function.nodes.len());
function.nodes.push(Node::Start);
id
})
.collect();
(*old_id, new_ids)
}
Node::Read {
collect: _,
indices: _,
} => (*old_id, vec![]),
_ => panic!("PANIC: Invalid node using a constant product found during SROA."),
})
.collect();
// Replace the old nodes with the new nodes. Since we've already
// allocated the node IDs, at this point we can iterate through the to-
// replace nodes in an arbitrary order.
for (old_id, new_ids) in &old_to_new_id_map {
// First, add the new nodes to the node list.
let node = function.nodes[old_id.idx()].clone();
match node {
// Replace the original constant with constants for each field.
Node::Constant { id: _ } => {
for (new_id, field_id) in zip(new_ids.iter(), constant_fields.iter()) {
function.nodes[new_id.idx()] = Node::Constant { id: *field_id };
}
}
// Replace writes using the constant as the data use with a
// series of writes writing the invidiual constant fields. TODO:
// handle the case where the constant is the collect use of the
// write node.
Node::Write {
collect,
data,
ref indices,
} => {
// Create the write chain.
assert!(old_to_new_id_map.contains_key(&data), "PANIC: Can't handle case where write node depends on constant to SROA in the collect use yet.");
let mut collect_def = collect;
for (idx, (new_id, new_data_def)) in
zip(new_ids.iter(), old_to_new_id_map[&data].iter()).enumerate()
{
let mut new_indices = indices.clone().into_vec();
new_indices.push(Index::Field(idx));
function.nodes[new_id.idx()] = Node::Write {
collect: collect_def,
data: *new_data_def,
indices: new_indices.into_boxed_slice(),
};
collect_def = *new_id;
}
// Replace uses of the old write with the new write.
for user in def_use.get_users(*old_id) {
get_uses_mut(&mut function.nodes[user.idx()]).map(*old_id, collect_def);
}
}
_ => todo!(),
}
// Delete the old node.
function.nodes[old_id.idx()] = Node::Start;
}
}
}
fn sroa_dfs(src: NodeID, function: &Function, def_uses: &ImmutableDefUseMap) -> Vec<NodeID> {
// Initialize order vector and bitset for tracking which nodes have been
// visited.
let order = Vec::with_capacity(def_uses.num_nodes());
let visited = bitvec![u8, Lsb0; 0; def_uses.num_nodes()];
// Order and visited are threaded through arguments / return pair of
// sroa_dfs_helper for ownership reasons.
let (order, _) = sroa_dfs_helper(src, src, function, def_uses, order, visited);
order
}
fn sroa_dfs_helper(
node: NodeID,
def: NodeID,
function: &Function,
def_uses: &ImmutableDefUseMap,
mut order: Vec<NodeID>,
mut visited: BitVec<u8, Lsb0>,
) -> (Vec<NodeID>, BitVec<u8, Lsb0>) {
if visited[node.idx()] {
// If already visited, return early.
(order, visited)
} else {
// Set visited to true.
visited.set(node.idx(), true);
// Before iterating users, push this node.
order.push(node);
match function.nodes[node.idx()] {
Node::Phi {
control: _,
data: _,
}
| Node::Reduce {
control: _,
init: _,
reduct: _,
}
| Node::Constant { id: _ }
| Node::Ternary {
op: _,
first: _,
second: _,
third: _,
} => {}
Node::Read {
collect,
indices: _,
} => {
assert_eq!(def, collect);
return (order, visited);
}
Node::Write {
collect,
data,
indices: _,
} => {
if def == data {
return (order, visited);
}
assert_eq!(def, collect);
}
_ => panic!("PANIC: Invalid node using a constant product found during SROA."),
}
// Iterate over users, if we shouldn't stop here.
for user in def_uses.get_users(node) {
(order, visited) = sroa_dfs_helper(*user, node, function, def_uses, order, visited);
}
(order, visited)
}
}