Skip to content
Snippets Groups Projects

SROA read chains

Merged Aaron Councilman requested to merge sroa-read-chains into main
3 files
+ 262
12
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 244
11
@@ -33,6 +33,8 @@ use crate::*;
*
* - Read: the read node reads primitive fields from product values - these get
* replaced by a direct use of the field value
* A read can also extract a product from an array or sum; the value read out
* will be broken into individual fields (by individual reads from the array)
*
* - Write: the write node writes primitive fields in product values - these get
* replaced by a direct def of the field value
@@ -54,15 +56,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
// for the call's arguments or the return's value
let mut call_return_nodes: Vec<NodeID> = vec![];
let func = editor.func();
for node in reverse_postorder {
match func.nodes[node.idx()] {
match &editor.func().nodes[node.idx()] {
Node::Phi { .. }
| Node::Reduce { .. }
| Node::Parameter { .. }
| Node::Constant { .. }
| Node::Write { .. }
| Node::Ternary {
first: _,
second: _,
@@ -70,8 +69,211 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
op: TernaryOperator::Select,
} if editor.get_type(types[&node]).is_product() => product_nodes.push(*node),
Node::Read { collect, .. } if editor.get_type(types[&collect]).is_product() => {
product_nodes.push(*node)
Node::Write {
collect,
data,
indices,
} => {
let data = *data;
let collect = *collect;
// For a write, we may need to split it into two pieces if the it contains a mix of
// field and non-field indices
let (fields_write, write_prod_into_non) = {
let mut fields = vec![];
let mut remainder = vec![];
let mut indices = indices.iter();
while let Some(idx) = indices.next() {
if idx.is_field() {
fields.push(idx.clone());
} else {
remainder.push(idx.clone());
remainder.extend(indices.cloned());
break;
}
}
if fields.is_empty() {
if editor.get_type(types[&data]).is_product() {
(None, Some((*node, collect, remainder)))
} else {
(None, None)
}
} else if remainder.is_empty() {
(Some(*node), None)
} else {
// Here we perform the split into two writes
// We need to find the type of the collection that will be extracted from
// the collection being modified when we read it at the fields index
let after_fields_type = type_at_index(editor, types[&collect], &fields);
let mut inner_collection = None;
let mut fields_write = None;
let mut remainder_write = None;
editor.edit(|mut edit| {
let read_inner = edit.add_node(Node::Read {
collect,
indices: fields.clone().into(),
});
types.insert(read_inner, after_fields_type);
product_nodes.push(read_inner);
inner_collection = Some(read_inner);
let rem_write = edit.add_node(Node::Write {
collect: read_inner,
data,
indices: remainder.clone().into(),
});
types.insert(rem_write, after_fields_type);
remainder_write = Some(rem_write);
let complete_write = edit.add_node(Node::Write {
collect,
data: rem_write,
indices: fields.into(),
});
types.insert(complete_write, types[&collect]);
fields_write = Some(complete_write);
edit = edit.replace_all_uses(*node, complete_write)?;
edit.delete_node(*node)
});
let inner_collection = inner_collection.unwrap();
let fields_write = fields_write.unwrap();
let remainder_write = remainder_write.unwrap();
if editor.get_type(types[&data]).is_product() {
(
Some(fields_write),
Some((remainder_write, inner_collection, remainder)),
)
} else {
(Some(fields_write), None)
}
}
};
if let Some(node) = fields_write {
product_nodes.push(node);
}
if let Some((write_node, collection, index)) = write_prod_into_non {
let node = write_node;
// If we're writing a product into a non-product we need to replace the write
// by a sequence of writes that read each field of the product and write them
// into the collection, then those write nodes can be ignored for SROA but the
// reads will be handled by SROA
// The value being written must be the data and so must be a product
assert!(editor.get_type(types[&data]).is_product());
let fields = generate_reads(editor, types[&data], data);
let mut collection = collection;
let collection_type = types[&collection];
fields.for_each(|field: &Vec<Index>, val: &NodeID| {
product_nodes.push(*val);
editor.edit(|mut edit| {
collection = edit.add_node(Node::Write {
collect: collection,
data: *val,
indices: index
.iter()
.chain(field)
.cloned()
.collect::<Vec<_>>()
.into(),
});
types.insert(collection, collection_type);
Ok(edit)
});
});
editor.edit(|mut edit| {
edit = edit.replace_all_uses(node, collection)?;
edit.delete_node(node)
});
}
}
Node::Read { collect, indices } => {
// For a read, we split the read into a series of reads where each piece has either
// only field reads or no field reads. Those with fields are the only ones
// considered during SROA but any read whose collection is not a product but
// produces a product (i.e. if there's an array of products) then following the
// read we replace the read that produces a product by reads of each field and add
// that information to the node map for the rest of SROA (this produces some reads
// that mix types of indices, since we only read leaves but that's okay since those
// reads are not handled by SROA)
let indices = indices
.chunk_by(|i, j| i.is_field() && j.is_field())
.collect::<Vec<_>>();
let (field_reads, non_fields_produce_prod) = {
if indices.len() == 0 {
// If there are no indices then there were no indices originally, this is
// only used with clones of arrays
(vec![], vec![])
} else if indices.len() == 1 {
// If once we perform chunking there's only one set of indices, we can just
// use the original node
if indices[0][0].is_field() {
(vec![*node], vec![])
} else if editor.get_type(types[node]).is_product() {
(vec![], vec![*node])
} else {
(vec![], vec![])
}
} else {
let mut field_reads = vec![];
let mut non_field = vec![];
// To construct the multiple reads we need to track the current collection
// and the type of that collection
let mut collect = *collect;
let mut typ = types[&collect];
let indices = indices
.into_iter()
.map(|i| i.into_iter().cloned().collect::<Vec<_>>())
.collect::<Vec<_>>();
for index in indices {
let is_field_read = index[0].is_field();
let field_type = type_at_index(editor, typ, &index);
editor.edit(|mut edit| {
collect = edit.add_node(Node::Read {
collect,
indices: index.into(),
});
types.insert(collect, field_type);
typ = field_type;
Ok(edit)
});
if is_field_read {
field_reads.push(collect);
} else if editor.get_type(typ).is_product() {
non_field.push(collect);
}
}
// Replace all uses of the original read (with mixed indices) with the
// newly constructed reads
editor.edit(|mut edit| {
edit = edit.replace_all_uses(*node, collect)?;
edit.delete_node(*node)
});
(field_reads, non_field)
}
};
product_nodes.extend(field_reads);
for node in non_fields_produce_prod {
field_map.insert(node, generate_reads(editor, types[&node], node));
}
}
// We add all calls to the call/return list and check their arguments later
@@ -516,8 +718,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1),
}
} else {
// TODO: This could be hit because of an array inside of a product
panic!("Error handling lookup of field");
panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
}
} else {
self
@@ -548,7 +749,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
}
}
} else {
panic!("Error handling set of field");
panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
}
} else {
IndexTree::Leaf(val)
@@ -579,7 +780,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
}
}
} else {
panic!("Error handling set of field");
panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
}
} else {
val
@@ -658,6 +859,38 @@ impl<T: std::fmt::Debug> IndexTree<T> {
}
}
// Given the editor, type of some collection, and a list of indices to access that type at, returns
// the TypeID of accessing the collection at the given indices
fn type_at_index(editor: &FunctionEditor, typ: TypeID, idx: &[Index]) -> TypeID {
let mut typ = typ;
for index in idx {
match index {
Index::Field(i) => {
let Type::Product(ref ts) = *editor.get_type(typ) else {
panic!("Accessing a field of a non-product type; did typechecking succeed?");
};
typ = ts[*i];
}
Index::Variant(i) => {
let Type::Summation(ref ts) = *editor.get_type(typ) else {
panic!(
"Accessing a variant of a non-summation type; did typechecking succeed?"
);
};
typ = ts[*i];
}
Index::Position(pos) => {
let Type::Array(elem, ref dims) = *editor.get_type(typ) else {
panic!("Accessing an array position of a non-array type; did typechecking succeed?");
};
assert!(pos.len() == dims.len(), "Read mismatch array dimensions");
typ = elem;
}
}
}
return typ;
}
// Given a product value val of type typ, constructs a copy of that value by extracting all fields
// from that value and then writing them into a new constant
// This process also adds all the read nodes that are generated into the read_list so that the
@@ -696,7 +929,7 @@ fn reconstruct_product(
}
// Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and
// returns a list of pairs of the indices and the node that reads that index
// returns an IndexTree that tracks the nodes reading each leaf field
fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
let res = generate_reads_at_index(editor, typ, val, vec![]);
res
Loading