Skip to content
Snippets Groups Projects
Commit cc1726f2 authored by Aaron Councilman's avatar Aaron Councilman
Browse files

Intraprocedural SROA

parent 73012260
No related branches found
No related tags found
1 merge request!57Intraprocedural SROA
......@@ -728,6 +728,14 @@ impl Type {
None
}
}
pub fn try_product(&self) -> Option<&[TypeID]> {
if let Type::Product(ts) = self {
Some(ts)
} else {
None
}
}
}
impl Constant {
......
......@@ -555,7 +555,11 @@ pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) {
// Step 2: drop schedules for deleted nodes and create empty schedule lists
// for added nodes.
for deleted in total_edit.0.iter() {
plan.schedules[deleted.idx()] = vec![];
// Nodes that were created and deleted using the same editor don't have
// an existing schedule, so ignore them
if deleted.idx() < plan.schedules.len() {
plan.schedules[deleted.idx()] = vec![];
}
}
if !total_edit.1.is_empty() {
assert_eq!(
......
......@@ -543,16 +543,33 @@ impl PassManager {
let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
let typing = self.typing.as_ref().unwrap();
for idx in 0..self.module.functions.len() {
sroa(
let constants_ref =
RefCell::new(std::mem::take(&mut self.module.constants));
let dynamic_constants_ref =
RefCell::new(std::mem::take(&mut self.module.dynamic_constants));
let types_ref = RefCell::new(std::mem::take(&mut self.module.types));
let mut editor = FunctionEditor::new(
&mut self.module.functions[idx],
&constants_ref,
&dynamic_constants_ref,
&types_ref,
&def_uses[idx],
&reverse_postorders[idx],
&typing[idx],
&self.module.types,
&mut self.module.constants,
);
sroa(&mut editor, &reverse_postorders[idx], &typing[idx]);
self.module.constants = constants_ref.take();
self.module.dynamic_constants = dynamic_constants_ref.take();
self.module.types = types_ref.take();
let edits = &editor.edits();
if let Some(plans) = self.plans.as_mut() {
repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
}
let grave_mapping = self.module.functions[idx].delete_gravestones();
if let Some(plans) = self.plans.as_mut() {
plans[idx].fix_gravestones(&grave_mapping);
}
}
self.legacy_repair_plan();
self.clear_analyses();
}
Pass::Inline => {
......
extern crate bitvec;
extern crate hercules_ir;
use std::collections::HashMap;
use std::iter::zip;
use self::bitvec::prelude::*;
use std::collections::{BTreeMap, HashMap, LinkedList, VecDeque};
use self::hercules_ir::dataflow::*;
use self::hercules_ir::def_use::*;
use self::hercules_ir::ir::*;
use crate::*;
/*
* Top level function to run SROA, intraprocedurally. Product values can be used
* and created by a relatively small number of nodes. Here are *all* of them:
......@@ -20,11 +17,11 @@ use self::hercules_ir::ir::*;
* - Reduce: similarly to phis, reduce nodes can cycle product values through
* reduction loops - these get broken up into reduces on the fields
*
* + Return: can return a product - these are untouched, and are the sinks for
* unbroken product values
* - Return: can return a product - the product values will be constructed
* at the return site
*
* + Parameter: can introduce a product - these are untouched, and are the
* sources for unbroken product values
* - Parameter: can introduce a product - reads will be introduced for each
* field
*
* - Constant: can introduce a product - these are broken up into constants for
* the individual fields
......@@ -32,334 +29,796 @@ use self::hercules_ir::ir::*;
* - Ternary: the select ternary operator can select between products - these
* are broken up into ternary nodes for the individual fields
*
* + Call: the call node can use a product value as an argument to another
* function, and can produce a product value as a result - these are
* untouched, and are the sink and source for unbroken product values
* - Call: the call node can use a product value as an argument to another
* function, and can produce a product value as a result. Argument values
* will be constructed at the call site and the return value will be broken
* into individual fields
*
* - Read: the read node reads primitive fields from product values - these get
* replaced by a direct use of the field value from the broken product value,
* but are retained when the product value is unbroken
* replaced by a direct use of the field value
*
* - Write: the write node writes primitive fields in product values - these get
* replaced by a direct def of the field value from the broken product value,
* but are retained when the product value is unbroken
*
* The nodes above with the list marker "+" are retained for maintaining API/ABI
* compatability with other Hercules functions and the host code. These are
* called "sink" or "source" nodes in comments below.
* replaced by a direct def of the field value
*/
pub fn sroa(
function: &mut Function,
def_use: &ImmutableDefUseMap,
reverse_postorder: &Vec<NodeID>,
typing: &Vec<TypeID>,
types: &Vec<Type>,
constants: &mut Vec<Constant>,
) {
// Determine which sources of product values we want to try breaking up. We
// can determine easily on the soure side if a node produces a product that
// shouldn't be broken up by just examining the node type. However, the way
// that products are used is also important for determining if the product
// can be broken up. We backward dataflow this info to the sources of
// product values.
#[derive(PartialEq, Eq, Clone, Debug)]
enum ProductUseLattice {
// The product value used by this node is eventually used by a sink.
UsedBySink,
// This node uses multiple product values - the stored node ID indicates
// which is eventually used by a sink. This lattice value is produced by
// read and write nodes implementing partial indexing.
SpecificUsedBySink(NodeID),
// This node doesn't use a product node, or the product node it does use
// is not in turn used by a sink.
UnusedBySink,
pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) {
// This map stores a map from NodeID to an index tree which can be used to lookup the NodeID
// that contains the corresponding fields of the original value
let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();
// First: determine all nodes which interact with products (as described above)
let mut product_nodes: Vec<NodeID> = vec![];
// We track call and return nodes separately since they (may) require constructing new products
// for the call's arguments or the return's value
let mut call_return_nodes: Vec<NodeID> = vec![];
let func = editor.func();
for node in reverse_postorder {
match func.nodes[node.idx()] {
Node::Phi { .. }
| Node::Reduce { .. }
| Node::Parameter { .. }
| Node::Constant { .. }
| Node::Write { .. }
| Node::Ternary {
first: _,
second: _,
third: _,
op: TernaryOperator::Select,
} if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node),
Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => {
product_nodes.push(*node)
}
// We add all calls to the call/return list and check their arguments later
Node::Call { .. } => call_return_nodes.push(*node),
Node::Return { control: _, data }
if editor.get_type(types[data.idx()]).is_product() =>
{
call_return_nodes.push(*node)
}
_ => (),
}
}
impl Semilattice for ProductUseLattice {
fn meet(a: &Self, b: &Self) -> Self {
match (a, b) {
(Self::UsedBySink, _) | (_, Self::UsedBySink) => Self::UsedBySink,
(Self::SpecificUsedBySink(id1), Self::SpecificUsedBySink(id2)) => {
if id1 == id2 {
Self::SpecificUsedBySink(*id1)
// Next, we handle calls and returns. For returns, we will insert nodes that read each field of
// the returned product and then write them into a new product. These writes are not put into
// the list of product nodes since they must remain but the reads are so that they will be
// replaced later on.
// For calls, we do a similar process for each (product) argument. Additionally, if the call
// returns a product, we create reads for each field in that product and store it into our
// field map
for node in call_return_nodes {
match &editor.func().nodes[node.idx()] {
Node::Return { control, data } => {
assert!(editor.get_type(types[data.idx()]).is_product());
let control = *control;
let new_data =
reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes);
editor.edit(|mut edit| {
edit.add_node(Node::Return {
control,
data: new_data,
});
edit.delete_node(node)
});
}
Node::Call {
control,
function,
dynamic_constants,
args,
} => {
let control = *control;
let function = *function;
let dynamic_constants = dynamic_constants.clone();
let args = args.clone();
// If the call returns a product, we generate reads for each field
let fields = if editor.get_type(types[node.idx()]).is_product() {
Some(generate_reads(editor, types[node.idx()], node))
} else {
None
};
let mut new_args = vec![];
for arg in args {
if editor.get_type(types[arg.idx()]).is_product() {
new_args.push(reconstruct_product(
editor,
types[arg.idx()],
arg,
&mut product_nodes,
));
} else {
Self::UsedBySink
new_args.push(arg);
}
}
(Self::SpecificUsedBySink(id), _) | (_, Self::SpecificUsedBySink(id)) => {
Self::SpecificUsedBySink(*id)
}
_ => Self::UnusedBySink,
editor.edit(|mut edit| {
let new_call = edit.add_node(Node::Call {
control,
function,
dynamic_constants,
args: new_args.into(),
});
let edit = edit.replace_all_uses(node, new_call)?;
let edit = edit.delete_node(node)?;
match fields {
None => {}
Some(fields) => {
field_map.insert(new_call, fields);
}
}
Ok(edit)
});
}
_ => panic!("Processing non-call or return node"),
}
}
fn bottom() -> Self {
Self::UsedBySink
}
#[derive(Debug)]
enum WorkItem {
Unhandled(NodeID),
AllocatedPhi {
control: NodeID,
data: Vec<NodeID>,
node: NodeID,
fields: IndexTree<NodeID>,
},
AllocatedReduce {
control: NodeID,
init: NodeID,
reduct: NodeID,
node: NodeID,
fields: IndexTree<NodeID>,
},
AllocatedTernary {
cond: NodeID,
thn: NodeID,
els: NodeID,
node: NodeID,
fields: IndexTree<NodeID>,
},
}
fn top() -> Self {
Self::UnusedBySink
// Now, we process the other nodes that deal with products.
// The first step is to assign new NodeIDs to the nodes that will be split into multiple: phi,
// reduce, parameter, constant, and ternary.
// We do this in several steps: first we break apart parameters and constants
let mut to_delete = vec![];
let mut worklist: VecDeque<WorkItem> = VecDeque::new();
for node in product_nodes {
match editor.func().nodes[node.idx()] {
Node::Parameter { .. } => {
field_map.insert(node, generate_reads(editor, types[node.idx()], node));
}
Node::Constant { id } => {
field_map.insert(node, generate_constant_fields(editor, id));
to_delete.push(node);
}
_ => {
worklist.push_back(WorkItem::Unhandled(node));
}
}
}
// Run dataflow analysis to find which product values are used by a sink.
let product_uses = backward_dataflow(function, def_use, reverse_postorder, |succ_outs, id| {
match function.nodes[id.idx()] {
Node::Return {
control: _,
data: _,
} => {
if types[typing[id.idx()].idx()].is_product() {
ProductUseLattice::UsedBySink
} else {
ProductUseLattice::UnusedBySink
// Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map.
// We track the current NodeID and add nodes to a set we maintain of nodes to add (since we
// need to add nodes in a particular order we wait to do that until the end). If we don't have
// enough information to process a particular node, we add it back to the worklist
let mut next_id: usize = editor.func().nodes.len();
let mut to_insert = BTreeMap::new();
let mut to_replace: Vec<(NodeID, NodeID)> = vec![];
while let Some(mut item) = worklist.pop_front() {
if let WorkItem::Unhandled(node) = item {
match &editor.func().nodes[node.idx()] {
// For phi, reduce, and ternary, we break them apart into separate nodes for each field
Node::Phi { control, data } => {
let control = *control;
let data = data.clone();
let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
field_map.insert(node, fields.clone());
item = WorkItem::AllocatedPhi {
control,
data: data.into(),
node,
fields,
};
}
Node::Reduce {
control,
init,
reduct,
} => {
let control = *control;
let init = *init;
let reduct = *reduct;
let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
field_map.insert(node, fields.clone());
item = WorkItem::AllocatedReduce {
control,
init,
reduct,
node,
fields,
};
}
Node::Ternary {
first,
second,
third,
..
} => {
let first = *first;
let second = *second;
let third = *third;
let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
field_map.insert(node, fields.clone());
item = WorkItem::AllocatedTernary {
cond: first,
thn: second,
els: third,
node,
fields,
};
}
Node::Write {
collect,
data,
indices,
} => {
if let Some(index_map) = field_map.get(collect) {
if editor.get_type(types[data.idx()]).is_product() {
if let Some(data_idx) = field_map.get(data) {
field_map.insert(
node,
index_map.clone().replace(indices, data_idx.clone()),
);
to_delete.push(node);
} else {
worklist.push_back(WorkItem::Unhandled(node));
}
} else {
field_map.insert(node, index_map.clone().set(indices, *data));
to_delete.push(node);
}
} else {
worklist.push_back(WorkItem::Unhandled(node));
}
}
Node::Read { collect, indices } => {
if let Some(index_map) = field_map.get(collect) {
let read_info = index_map.lookup(indices);
match read_info {
IndexTree::Leaf(field) => {
to_replace.push((node, *field));
}
_ => {}
}
field_map.insert(node, read_info.clone());
to_delete.push(node);
} else {
worklist.push_back(WorkItem::Unhandled(node));
}
}
_ => panic!("Unexpected node type"),
}
Node::Call {
control: _,
function: _,
dynamic_constants: _,
args: _,
} => todo!(),
// For reads and writes, we only want to propagate the use of the
// product to the collect input of the node.
Node::Read {
collect,
indices: _,
}
| Node::Write {
collect,
data: _,
indices: _,
}
match item {
WorkItem::Unhandled(_) => {}
WorkItem::AllocatedPhi {
control,
data,
node,
fields,
} => {
let meet = succ_outs
.iter()
.fold(ProductUseLattice::top(), |acc, latt| {
ProductUseLattice::meet(&acc, latt)
let mut data_fields = vec![];
let mut ready = true;
for val in data.iter() {
if let Some(val_fields) = field_map.get(val) {
data_fields.push(val_fields);
} else {
ready = false;
break;
}
}
if ready {
fields.zip_list(data_fields).for_each(|idx, (res, data)| {
to_insert.insert(
res.idx(),
Node::Phi {
control,
data: data.into_iter().map(|n| **n).collect::<Vec<_>>().into(),
},
);
});
if meet == ProductUseLattice::UnusedBySink {
ProductUseLattice::UnusedBySink
to_delete.push(node);
} else {
ProductUseLattice::SpecificUsedBySink(collect)
worklist.push_back(WorkItem::AllocatedPhi {
control,
data,
node,
fields,
});
}
}
// For non-sink nodes.
_ => {
if function.nodes[id.idx()].is_control() {
return ProductUseLattice::UnusedBySink;
}
let meet = succ_outs
.iter()
.fold(ProductUseLattice::top(), |acc, latt| {
ProductUseLattice::meet(&acc, latt)
WorkItem::AllocatedReduce {
control,
init,
reduct,
node,
fields,
} => {
if let (Some(init_fields), Some(reduct_fields)) =
(field_map.get(&init), field_map.get(&reduct))
{
fields.zip(init_fields).zip(reduct_fields).for_each(
|idx, ((res, init), reduct)| {
to_insert.insert(
res.idx(),
Node::Reduce {
control,
init: **init,
reduct: **reduct,
},
);
},
);
to_delete.push(node);
} else {
worklist.push_back(WorkItem::AllocatedReduce {
control,
init,
reduct,
node,
fields,
});
if let ProductUseLattice::SpecificUsedBySink(meet_id) = meet {
if meet_id == id {
ProductUseLattice::UsedBySink
} else {
ProductUseLattice::UnusedBySink
}
}
}
WorkItem::AllocatedTernary {
cond,
thn,
els,
node,
fields,
} => {
if let (Some(thn_fields), Some(els_fields)) =
(field_map.get(&thn), field_map.get(&els))
{
fields
.zip(thn_fields)
.zip(els_fields)
.for_each(|idx, ((res, thn), els)| {
to_insert.insert(
res.idx(),
Node::Ternary {
first: cond,
second: **thn,
third: **els,
op: TernaryOperator::Select,
},
);
});
to_delete.push(node);
} else {
meet
worklist.push_back(WorkItem::AllocatedTernary {
cond,
thn,
els,
node,
fields,
});
}
}
}
}
// Create new nodes nodes
editor.edit(|mut edit| {
for (node_id, node) in to_insert {
let id = edit.add_node(node);
assert_eq!(node_id, id.idx());
}
Ok(edit)
});
// Only product values introduced as constants can be replaced by scalars.
let to_sroa: Vec<(NodeID, ConstantID)> = product_uses
.into_iter()
.enumerate()
.filter_map(|(node_idx, product_use)| {
if ProductUseLattice::UnusedBySink == product_use
&& types[typing[node_idx].idx()].is_product()
{
function.nodes[node_idx]
.try_constant()
.map(|cons_id| (NodeID::new(node_idx), cons_id))
} else {
None
// Replace uses of old reads
// Because a read that is being replaced could also be the node some other read is being
// replaced by (if the first read is then written into a product that is then read from again)
// we need to track what nodes have already been replaced (and by what) so we can properly
// replace uses without leaving users of nodes that should be deleted.
// replaced_by tracks what a node has been replaced by while replaced_of tracks everything that
// maps to a particular node (which is needed to maintain the data structure efficiently)
let mut replaced_by: HashMap<NodeID, NodeID> = HashMap::new();
let mut replaced_of: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
for (old, new) in to_replace {
let new = match replaced_by.get(&new) {
Some(res) => *res,
None => new,
};
editor.edit(|edit| edit.replace_all_uses(old, new));
replaced_by.insert(old, new);
let mut replaced = vec![];
match replaced_of.get_mut(&old) {
Some(res) => {
std::mem::swap(res, &mut replaced);
}
})
.collect();
// Perform SROA. TODO: repair def-use when there are multiple product
// constants to SROA away.
assert!(to_sroa.len() < 2);
for (constant_node_id, constant_id) in to_sroa {
// Get the field constants to replace the product constant with.
let product_constant = constants[constant_id.idx()].clone();
let constant_fields = product_constant.try_product_fields().unwrap();
// DFS to find all data nodes that use the product constant.
let to_replace = sroa_dfs(constant_node_id, function, def_use);
// Assemble a mapping from old nodes IDs acting on the product constant
// to new nodes IDs operating on the field constants.
let old_to_new_id_map: HashMap<NodeID, Vec<NodeID>> = to_replace
.iter()
.map(|old_id| match function.nodes[old_id.idx()] {
Node::Phi {
control: _,
data: _,
None => {}
}
let new_of = match replaced_of.get_mut(&new) {
Some(res) => res,
None => {
replaced_of.insert(new, vec![]);
replaced_of.get_mut(&new).unwrap()
}
};
new_of.push(old);
for n in replaced {
replaced_by.insert(n, new);
new_of.push(n);
}
}
// Remove nodes
editor.edit(|mut edit| {
for node in to_delete {
edit = edit.delete_node(node)?
}
Ok(edit)
});
}
// An index tree is used to store results at many index lists
#[derive(Clone, Debug)]
enum IndexTree<T> {
Leaf(T),
Node(Vec<IndexTree<T>>),
}
impl<T: std::fmt::Debug> IndexTree<T> {
fn lookup(&self, idx: &[Index]) -> &IndexTree<T> {
self.lookup_idx(idx, 0)
}
fn lookup_idx(&self, idx: &[Index], n: usize) -> &IndexTree<T> {
if n < idx.len() {
if let Index::Field(i) = idx[n] {
match self {
IndexTree::Leaf(_) => panic!("Invalid field"),
IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1),
}
| Node::Reduce {
control: _,
init: _,
reduct: _,
} else {
// TODO: This could be hit because of an array inside of a product
panic!("Error handling lookup of field");
}
} else {
self
}
}
fn set(self, idx: &[Index], val: T) -> IndexTree<T> {
self.set_idx(idx, val, 0)
}
fn set_idx(self, idx: &[Index], val: T, n: usize) -> IndexTree<T> {
if n < idx.len() {
if let Index::Field(i) = idx[n] {
match self {
IndexTree::Leaf(_) => panic!("Invalid field"),
IndexTree::Node(mut ts) => {
if i + 1 == ts.len() {
let t = ts.pop().unwrap();
ts.push(t.set_idx(idx, val, n + 1));
} else {
let mut t = ts.pop().unwrap();
std::mem::swap(&mut ts[i], &mut t);
t = t.set_idx(idx, val, n + 1);
std::mem::swap(&mut ts[i], &mut t);
ts.push(t);
}
IndexTree::Node(ts)
}
}
| Node::Constant { id: _ }
| Node::Ternary {
op: _,
first: _,
second: _,
third: _,
} else {
panic!("Error handling set of field");
}
} else {
IndexTree::Leaf(val)
}
}
fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> {
self.replace_idx(idx, val, 0)
}
fn replace_idx(self, idx: &[Index], val: IndexTree<T>, n: usize) -> IndexTree<T> {
if n < idx.len() {
if let Index::Field(i) = idx[n] {
match self {
IndexTree::Leaf(_) => panic!("Invalid field"),
IndexTree::Node(mut ts) => {
if i + 1 == ts.len() {
let t = ts.pop().unwrap();
ts.push(t.replace_idx(idx, val, n + 1));
} else {
let mut t = ts.pop().unwrap();
std::mem::swap(&mut ts[i], &mut t);
t = t.replace_idx(idx, val, n + 1);
std::mem::swap(&mut ts[i], &mut t);
ts.push(t);
}
IndexTree::Node(ts)
}
}
| Node::Write {
collect: _,
data: _,
indices: _,
} => {
let new_ids = (0..constant_fields.len())
.map(|_| {
let id = NodeID::new(function.nodes.len());
function.nodes.push(Node::Start);
id
})
.collect();
(*old_id, new_ids)
} else {
panic!("Error handling set of field");
}
} else {
val
}
}
fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> {
match (self, other) {
(IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)),
(IndexTree::Node(t), IndexTree::Node(a)) => {
let mut fields = vec![];
for (t, a) in t.into_iter().zip(a.iter()) {
fields.push(t.zip(a));
}
Node::Read {
collect: _,
indices: _,
} => (*old_id, vec![]),
_ => panic!("PANIC: Invalid node using a constant product found during SROA."),
})
.collect();
// Replace the old nodes with the new nodes. Since we've already
// allocated the node IDs, at this point we can iterate through the to-
// replace nodes in an arbitrary order.
for (old_id, new_ids) in &old_to_new_id_map {
// First, add the new nodes to the node list.
let node = function.nodes[old_id.idx()].clone();
match node {
// Replace the original constant with constants for each field.
Node::Constant { id: _ } => {
for (new_id, field_id) in zip(new_ids.iter(), constant_fields.iter()) {
function.nodes[new_id.idx()] = Node::Constant { id: *field_id };
IndexTree::Node(fields)
}
_ => panic!("IndexTrees do not have the same fields, cannot zip"),
}
}
fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> {
match self {
IndexTree::Leaf(t) => {
let mut res = vec![];
for other in others {
match other {
IndexTree::Leaf(a) => res.push(a),
_ => panic!("IndexTrees do not have the same fields, cannot zip"),
}
}
// Replace writes using the constant as the data use with a
// series of writes writing the invidiual constant fields. TODO:
// handle the case where the constant is the collect use of the
// write node.
Node::Write {
collect,
data,
ref indices,
} => {
// Create the write chain.
assert!(old_to_new_id_map.contains_key(&data), "PANIC: Can't handle case where write node depends on constant to SROA in the collect use yet.");
let mut collect_def = collect;
for (idx, (new_id, new_data_def)) in
zip(new_ids.iter(), old_to_new_id_map[&data].iter()).enumerate()
{
let mut new_indices = indices.clone().into_vec();
new_indices.push(Index::Field(idx));
function.nodes[new_id.idx()] = Node::Write {
collect: collect_def,
data: *new_data_def,
indices: new_indices.into_boxed_slice(),
};
collect_def = *new_id;
}
// Replace uses of the old write with the new write.
for user in def_use.get_users(*old_id) {
get_uses_mut(&mut function.nodes[user.idx()]).map(*old_id, collect_def);
IndexTree::Leaf((t, res))
}
IndexTree::Node(t) => {
let mut fields: Vec<Vec<&'a IndexTree<A>>> = vec![vec![]; t.len()];
for other in others {
match other {
IndexTree::Node(a) => {
for (i, a) in a.iter().enumerate() {
fields[i].push(a);
}
}
_ => panic!("IndexTrees do not have the same fields, cannot zip"),
}
}
_ => todo!(),
IndexTree::Node(
t.into_iter()
.zip(fields.into_iter())
.map(|(t, f)| t.zip_list(f))
.collect(),
)
}
}
}
fn for_each<F>(&self, mut f: F)
where
F: FnMut(&Vec<Index>, &T),
{
self.for_each_idx(&mut vec![], &mut f);
}
// Delete the old node.
function.nodes[old_id.idx()] = Node::Start;
fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F)
where
F: FnMut(&Vec<Index>, &T),
{
match self {
IndexTree::Leaf(t) => f(idx, t),
IndexTree::Node(ts) => {
for (i, t) in ts.iter().enumerate() {
idx.push(Index::Field(i));
t.for_each_idx(idx, f);
idx.pop();
}
}
}
}
}
fn sroa_dfs(src: NodeID, function: &Function, def_uses: &ImmutableDefUseMap) -> Vec<NodeID> {
// Initialize order vector and bitset for tracking which nodes have been
// visited.
let order = Vec::with_capacity(def_uses.num_nodes());
let visited = bitvec![u8, Lsb0; 0; def_uses.num_nodes()];
// Given a product value val of type typ, constructs a copy of that value by extracting all fields
// from that value and then writing them into a new constant
// This process also adds all the read nodes that are generated into the read_list so that the
// reads can be eliminated by later parts of SROA
fn reconstruct_product(
editor: &mut FunctionEditor,
typ: TypeID,
val: NodeID,
read_list: &mut Vec<NodeID>,
) -> NodeID {
let fields = generate_reads(editor, typ, val);
let new_const = generate_constant(editor, typ);
// Create a constant node
let mut const_node = None;
editor.edit(|mut edit| {
const_node = Some(edit.add_node(Node::Constant { id: new_const }));
Ok(edit)
});
// Generate writes for each field
let mut value = const_node.expect("Add node cannot fail");
fields.for_each(|idx: &Vec<Index>, val: &NodeID| {
read_list.push(*val);
editor.edit(|mut edit| {
value = edit.add_node(Node::Write {
collect: value,
data: *val,
indices: idx.clone().into(),
});
Ok(edit)
});
});
value
}
// Order and visited are threaded through arguments / return pair of
// sroa_dfs_helper for ownership reasons.
let (order, _) = sroa_dfs_helper(src, src, function, def_uses, order, visited);
order
// Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and
// returns a list of pairs of the indices and the node that reads that index
fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
let res = generate_reads_at_index(editor, typ, val, vec![]);
res
}
fn sroa_dfs_helper(
node: NodeID,
def: NodeID,
function: &Function,
def_uses: &ImmutableDefUseMap,
mut order: Vec<NodeID>,
mut visited: BitVec<u8, Lsb0>,
) -> (Vec<NodeID>, BitVec<u8, Lsb0>) {
if visited[node.idx()] {
// If already visited, return early.
(order, visited)
// Given a node val of type which at the indices idx has type typ, construct reads of all (leaf)
// fields within this sub-value of val and return the correspondence list
fn generate_reads_at_index(
editor: &mut FunctionEditor,
typ: TypeID,
val: NodeID,
idx: Vec<Index>,
) -> IndexTree<NodeID> {
let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
Some(ts.into())
} else {
// Set visited to true.
visited.set(node.idx(), true);
// Before iterating users, push this node.
order.push(node);
match function.nodes[node.idx()] {
Node::Phi {
control: _,
data: _,
}
| Node::Reduce {
control: _,
init: _,
reduct: _,
}
| Node::Constant { id: _ }
| Node::Ternary {
op: _,
first: _,
second: _,
third: _,
} => {}
Node::Read {
collect,
indices: _,
} => {
assert_eq!(def, collect);
return (order, visited);
}
Node::Write {
collect,
data,
indices: _,
} => {
if def == data {
return (order, visited);
}
assert_eq!(def, collect);
None
};
if let Some(ts) = ts {
// For product values, we will recurse down each of its fields with an extended index
// and the appropriate type of that field
let mut fields = vec![];
for (i, t) in ts.into_iter().enumerate() {
let mut new_idx = idx.clone();
new_idx.push(Index::Field(i));
fields.push(generate_reads_at_index(editor, t, val, new_idx));
}
IndexTree::Node(fields)
} else {
// For non-product types, we've reached a leaf so we generate the read and return it's
// information
let mut read_id = None;
editor.edit(|mut edit| {
read_id = Some(edit.add_node(Node::Read {
collect: val,
indices: idx.clone().into(),
}));
Ok(edit)
});
IndexTree::Leaf(read_id.expect("Add node canont fail"))
}
}
macro_rules! add_const {
($editor:ident, $const:expr) => {{
let mut res = None;
$editor.edit(|mut edit| {
res = Some(edit.add_constant($const));
Ok(edit)
});
res.expect("Add constant cannot fail")
}};
}
// Given a type, builds a default constant of that type
fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
let t = editor.get_type(typ).clone();
match t {
Type::Product(ts) => {
let mut cs = vec![];
for t in ts {
cs.push(generate_constant(editor, t));
}
_ => panic!("PANIC: Invalid node using a constant product found during SROA."),
add_const!(editor, Constant::Product(typ, cs.into()))
}
Type::Boolean => add_const!(editor, Constant::Boolean(false)),
Type::Integer8 => add_const!(editor, Constant::Integer8(0)),
Type::Integer16 => add_const!(editor, Constant::Integer16(0)),
Type::Integer32 => add_const!(editor, Constant::Integer32(0)),
Type::Integer64 => add_const!(editor, Constant::Integer64(0)),
Type::UnsignedInteger8 => add_const!(editor, Constant::UnsignedInteger8(0)),
Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)),
Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)),
Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)),
Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))),
Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))),
Type::Summation(ts) => {
let const_id = generate_constant(editor, ts[0]);
add_const!(editor, Constant::Summation(typ, 0, const_id))
}
Type::Array(elem, _) => {
add_const!(editor, Constant::Array(typ))
}
Type::Control => panic!("Cannot create constant of control type"),
}
}
// Iterate over users, if we shouldn't stop here.
for user in def_uses.get_users(node) {
(order, visited) = sroa_dfs_helper(*user, node, function, def_uses, order, visited);
// Given a constant cnst adds node to the function which are the constant values of each field and
// returns a list of pairs of indices and the node that holds that index
fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> {
let cs: Option<Vec<ConstantID>> =
if let Some(cs) = editor.get_constant(cnst).try_product_fields() {
Some(cs.into())
} else {
None
};
if let Some(cs) = cs {
let mut fields = vec![];
for c in cs {
fields.push(generate_constant_fields(editor, c));
}
IndexTree::Node(fields)
} else {
let mut node = None;
editor.edit(|mut edit| {
node = Some(edit.add_node(Node::Constant { id: cnst }));
Ok(edit)
});
IndexTree::Leaf(node.expect("Add node cannot fail"))
}
}
// Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the
// id provided
fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> {
let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
Some(ts.into())
} else {
None
};
(order, visited)
if let Some(ts) = ts {
let mut fields = vec![];
for t in ts {
fields.push(allocate_fields(editor, t, id));
}
IndexTree::Node(fields)
} else {
let node = *id;
*id += 1;
IndexTree::Leaf(NodeID::new(node))
}
}
......@@ -69,11 +69,11 @@ impl fmt::Display for ErrorMessage {
match self {
ErrorMessage::SemanticError(errs) => {
for err in errs {
write!(f, "{}", err)?;
write!(f, "{}\n", err)?;
}
}
ErrorMessage::SchedulingError(msg) => {
write!(f, "{}", msg)?;
write!(f, "{}\n", msg)?;
}
}
Ok(())
......@@ -152,11 +152,20 @@ pub fn compile_ir(
pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
}
add_pass!(pm, verify, Inline);
// Run SROA pretty early (though after inlining which can make SROA more effective) so that
// CCP, GVN, etc. can work on the result of SROA
add_pass!(pm, verify, InterproceduralSROA);
add_pass!(pm, verify, SROA);
// We run phi-elim again because SROA can introduce new phis that might be able to be
// simplified
add_verified_pass!(pm, verify, PhiElim);
if x_dot {
pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
}
add_pass!(pm, verify, CCP);
add_pass!(pm, verify, DCE);
add_pass!(pm, verify, GVN);
add_pass!(pm, verify, DCE);
add_pass!(pm, verify, InterproceduralSROA);
if x_dot {
pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
}
......
fn test_call(x : i32, y : f32) -> (i32, f32) {
let res = (x, y);
for i = 0 to 10 {
res.0 += 1;
}
return res;
}
fn test(x : i32, y : f32) -> (f32, i32) {
let res = test_call(x, y);
return (res.1, res.0);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment