sroa.rs 43.45 KiB
use std::collections::{BTreeMap, HashMap, VecDeque};
use hercules_ir::ir::*;
use crate::*;
/*
* Top level function to run SROA, intraprocedurally. Product values can be used
* and created by a relatively small number of nodes. Here are *all* of them:
*
* - Phi: can merge SSA values of products - these get broken up into phis on
* the individual fields
*
* - Reduce: similarly to phis, reduce nodes can cycle product values through
* reduction loops - these get broken up into reduces on the fields
*
* - Return: can return a product - the product values will be constructed
* at the return site
*
* - Parameter: can introduce a product - reads will be introduced for each
* field
*
* - Constant: can introduce a product - these are broken up into constants for
* the individual fields
*
* - Ternary: the select ternary operator can select between products - these
* are broken up into ternary nodes for the individual fields
*
* - Call: the call node can use a product value as an argument to another
* function, argument values will be constructed at the call site
*
* - DataProjection: data projection nodes can produce a product value that was
* returned by a function, we will break the value into individual fields
*
* - Read: the read node reads primitive fields from product values - these get
* replaced by a direct use of the field value
* A read can also extract a product from an array or sum; the value read out
* will be broken into individual fields (by individual reads from the array)
*
* - Write: the write node writes primitive fields in product values - these get
* replaced by a direct def of the field value
*
* The allow_sroa_arrays variable controls whether products that contain arrays
* will be broken into pieces. This option is useful to have since breaking
* these products up can be expensive if it requires destructing and
* reconstructing the product at any point.
*
* TODO: Handle partial selections (i.e. immutable nodes). This will involve
* actually tracking each source and use of a product and verifying that all of
* the nodes involved are mutable.
*/
pub fn sroa(
editor: &mut FunctionEditor,
reverse_postorder: &Vec<NodeID>,
types: &Vec<TypeID>,
allow_sroa_arrays: bool,
) {
let mut types: HashMap<NodeID, TypeID> = types
.iter()
.enumerate()
.map(|(i, t)| (NodeID::new(i), *t))
.collect();
let can_sroa_type = |editor: &FunctionEditor, typ: TypeID| {
editor.get_type(typ).is_product()
&& (allow_sroa_arrays || !type_contains_array(editor, typ))
};
// This map stores a map from NodeID to an index tree which can be used to lookup the NodeID
// that contains the corresponding fields of the original value
let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();
// First: determine all nodes which interact with products (as described above)
let mut product_nodes: Vec<NodeID> = vec![];
// We track call, data projection, and return nodes separately since they (may) require
// constructing new products for the call's arguments, data projection's value, or a
// returned value
let mut call_return_nodes: Vec<NodeID> = vec![];
for node in reverse_postorder {
match &editor.func().nodes[node.idx()] {
Node::Phi { .. }
| Node::Reduce { .. }
| Node::Parameter { .. }
| Node::Constant { .. }
| Node::Ternary {
first: _,
second: _,
third: _,
op: TernaryOperator::Select,
} if can_sroa_type(editor, types[&node]) => product_nodes.push(*node),
Node::Write {
collect,
data,
indices,
} => {
let data = *data;
let collect = *collect;
// For a write, we may need to split it into two pieces if the it contains a mix of
// field and non-field indices
let (fields_write, write_prod_into_non) = {
let mut fields = vec![];
let mut remainder = vec![];
if can_sroa_type(editor, types[&node]) {
let mut indices = indices.iter();
while let Some(idx) = indices.next() {
if idx.is_field() {
fields.push(idx.clone());
} else {
remainder.push(idx.clone());
remainder.extend(indices.cloned());
break;
}
}
} else {
remainder.extend_from_slice(indices);
}
if fields.is_empty() {
if can_sroa_type(editor, types[&data]) {
(None, Some((*node, collect, remainder)))
} else {
(None, None)
}
} else if remainder.is_empty() {
(Some(*node), None)
} else {
// Here we perform the split into two writes
// We need to find the type of the collection that will be extracted from
// the collection being modified when we read it at the fields index
let after_fields_type = type_at_index(editor, types[&collect], &fields);
let mut inner_collection = None;
let mut fields_write = None;
let mut remainder_write = None;
editor.edit(|mut edit| {
let read_inner = edit.add_node(Node::Read {
collect,
indices: fields.clone().into(),
});
types.insert(read_inner, after_fields_type);
product_nodes.push(read_inner);
inner_collection = Some(read_inner);
let rem_write = edit.add_node(Node::Write {
collect: read_inner,
data,
indices: remainder.clone().into(),
});
types.insert(rem_write, after_fields_type);
remainder_write = Some(rem_write);
let complete_write = edit.add_node(Node::Write {
collect,
data: rem_write,
indices: fields.into(),
});
types.insert(complete_write, types[&collect]);
fields_write = Some(complete_write);
edit = edit.replace_all_uses(*node, complete_write)?;
edit.delete_node(*node)
});
let inner_collection = inner_collection.unwrap();
let fields_write = fields_write.unwrap();
let remainder_write = remainder_write.unwrap();
if editor.get_type(types[&data]).is_product() {
(
Some(fields_write),
Some((remainder_write, inner_collection, remainder)),
)
} else {
(Some(fields_write), None)
}
}
};
if let Some(node) = fields_write {
product_nodes.push(node);
}
if let Some((write_node, collection, index)) = write_prod_into_non {
let node = write_node;
// If we're writing a product into a non-product we need to replace the write
// by a sequence of writes that read each field of the product and write them
// into the collection, then those write nodes can be ignored for SROA but the
// reads will be handled by SROA
// The value being written must be the data and so must be a product
assert!(editor.get_type(types[&data]).is_product());
let fields = generate_reads(editor, types[&data], data);
let mut collection = collection;
let collection_type = types[&collection];
fields.for_each(|field: &Vec<Index>, val: &NodeID| {
product_nodes.push(*val);
editor.edit(|mut edit| {
collection = edit.add_node(Node::Write {
collect: collection,
data: *val,
indices: index
.iter()
.chain(field)
.cloned()
.collect::<Vec<_>>()
.into(),
});
types.insert(collection, collection_type);
Ok(edit)
});
});
editor.edit(|mut edit| {
edit = edit.replace_all_uses(node, collection)?;
edit.delete_node(node)
});
}
}
Node::Read { collect, indices } => {
// For a read, we split the read into a series of reads where each piece has either
// only field reads or no field reads. Those with fields are the only ones
// considered during SROA but any read whose collection is not a product but
// produces a product (i.e. if there's an array of products) then following the
// read we replace the read that produces a product by reads of each field and add
// that information to the node map for the rest of SROA (this produces some reads
// that mix types of indices, since we only read leaves but that's okay since those
// reads are not handled by SROA)
let indices = if can_sroa_type(editor, types[collect]) {
indices
.chunk_by(|i, j| i.is_field() == j.is_field())
.collect::<Vec<_>>()
} else {
vec![indices.as_ref()]
};
let (field_reads, non_fields_produce_prod) = {
if indices.len() == 0 {
// If there are no indices then there were no indices originally, this is
// only used with clones of arrays
(vec![], vec![])
} else if indices.len() == 1 {
// If once we perform chunking there's only one set of indices, we can just
// use the original node
if can_sroa_type(editor, types[collect]) {
(vec![*node], vec![])
} else if can_sroa_type(editor, types[node]) {
(vec![], vec![*node])
} else {
(vec![], vec![])
}
} else {
let mut field_reads = vec![];
let mut non_field = vec![];
// To construct the multiple reads we need to track the current collection
// and the type of that collection
let mut collect = *collect;
let mut typ = types[&collect];
let indices = indices
.into_iter()
.map(|i| i.into_iter().cloned().collect::<Vec<_>>())
.collect::<Vec<_>>();
for index in indices {
let is_field_read = index[0].is_field();
let field_type = type_at_index(editor, typ, &index);
editor.edit(|mut edit| {
collect = edit.add_node(Node::Read {
collect,
indices: index.into(),
});
types.insert(collect, field_type);
typ = field_type;
Ok(edit)
});
if is_field_read {
field_reads.push(collect);
} else if editor.get_type(typ).is_product() {
non_field.push(collect);
}
}
// Replace all uses of the original read (with mixed indices) with the
// newly constructed reads
editor.edit(|mut edit| {
edit = edit.replace_all_uses(*node, collect)?;
edit.delete_node(*node)
});
(field_reads, non_field)
}
};
product_nodes.extend(field_reads);
for node in non_fields_produce_prod {
field_map.insert(node, generate_reads(editor, types[&node], node));
}
}
// We add all calls and returns to the call/return list and check their
// arguments/return values later
Node::Call { .. } | Node::Return { .. } => call_return_nodes.push(*node),
// We add DataProjetion nodes that produce SROAable values
Node::DataProjection { .. } if can_sroa_type(editor, types[&node]) => {
call_return_nodes.push(*node);
}
_ => (),
}
}
// Next, we handle calls and returns. For returns, for each returned value that is a product,
// we will insert nodes that read each field of it and then write them into a new product.
// The writes we create are not put into the list of product nodes since they must remain but
// the reads are put in the list so that they will be replaced later on.
// For calls, we do a similar process for each (product) argument.
// For data projection that produce product values, we create reads for each field of that
// product and store it into our field map
for node in call_return_nodes {
match &editor.func().nodes[node.idx()] {
Node::Return { control, data } => {
let control = *control;
let data = data.to_vec();
let (new_data, changed) =
data.into_iter()
.fold((vec![], false), |(mut vals, changed), val_id| {
if !can_sroa_type(editor, types[&val_id]) {
vals.push(val_id);
(vals, changed)
} else {
vals.push(reconstruct_product(
editor,
types[&val_id],
val_id,
&mut product_nodes,
));
(vals, true)
}
});
if changed {
editor.edit(|mut edit| {
let new_return = edit.add_node(Node::Return {
control,
data: new_data.into(),
});
edit.sub_edit(node, new_return);
edit.delete_node(node)
});
}
}
Node::Call {
control,
function,
dynamic_constants,
args,
} => {
let control = *control;
let function = *function;
let dynamic_constants = dynamic_constants.clone();
let args = args.to_vec();
let (new_args, changed) =
args.into_iter()
.fold((vec![], false), |(mut vals, changed), arg| {
if !can_sroa_type(editor, types[&arg]) {
vals.push(arg);
(vals, changed)
} else {
vals.push(reconstruct_product(
editor,
types[&arg],
arg,
&mut product_nodes,
));
(vals, true)
}
});
if changed {
editor.edit(|mut edit| {
let new_call = edit.add_node(Node::Call {
control,
function,
dynamic_constants,
args: new_args.into(),
});
edit.sub_edit(node, new_call);
let edit = edit.replace_all_uses(node, new_call)?;
let edit = edit.delete_node(node)?;
Ok(edit)
});
}
}
Node::DataProjection { .. } => {
assert!(can_sroa_type(editor, types[&node]));
field_map.insert(node, generate_reads(editor, types[&node], node));
}
_ => panic!("Processing non-call or return node"),
}
}
#[derive(Debug)]
enum WorkItem {
Unhandled(NodeID),
AllocatedPhi {
control: NodeID,
data: Vec<NodeID>,
node: NodeID,
fields: IndexTree<NodeID>,
},
AllocatedReduce {
control: NodeID,
init: NodeID,
reduct: NodeID,
node: NodeID,
fields: IndexTree<NodeID>,
},
AllocatedTernary {
cond: NodeID,
thn: NodeID,
els: NodeID,
node: NodeID,
fields: IndexTree<NodeID>,
},
}
// Now, we process the other nodes that deal with products.
// The first step is to assign new NodeIDs to the nodes that will be split into multiple: phi,
// reduce, parameter, constant, and ternary.
// We do this in several steps: first we break apart parameters and constants
let mut to_delete = vec![];
let mut worklist: VecDeque<WorkItem> = VecDeque::new();
for node in product_nodes {
match editor.func().nodes[node.idx()] {
Node::Parameter { .. } => {
field_map.insert(node, generate_reads(editor, types[&node], node));
}
Node::Constant { id } => {
field_map.insert(node, generate_constant_fields(editor, id, node));
to_delete.push(node);
}
_ => {
worklist.push_back(WorkItem::Unhandled(node));
}
}
}
// Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map.
// We track the current NodeID and add nodes to a set we maintain of nodes to add (since we
// need to add nodes in a particular order we wait to do that until the end). If we don't have
// enough information to process a particular node, we add it back to the worklist
let mut next_id: usize = editor.func().nodes.len();
let mut to_insert = BTreeMap::new();
let mut to_replace: Vec<(NodeID, NodeID)> = vec![];
while let Some(mut item) = worklist.pop_front() {
if let WorkItem::Unhandled(node) = item {
match &editor.func().nodes[node.idx()] {
// For phi, reduce, and ternary, we break them apart into separate nodes for each field
Node::Phi { control, data } => {
let control = *control;
let data = data.clone();
let fields = allocate_fields(editor, types[&node], &mut next_id);
field_map.insert(node, fields.clone());
item = WorkItem::AllocatedPhi {
control,
data: data.into(),
node,
fields,
};
}
Node::Reduce {
control,
init,
reduct,
} => {
let control = *control;
let init = *init;
let reduct = *reduct;
let fields = allocate_fields(editor, types[&node], &mut next_id);
field_map.insert(node, fields.clone());
item = WorkItem::AllocatedReduce {
control,
init,
reduct,
node,
fields,
};
}
Node::Ternary {
first,
second,
third,
..
} => {
let first = *first;
let second = *second;
let third = *third;
let fields = allocate_fields(editor, types[&node], &mut next_id);
field_map.insert(node, fields.clone());
item = WorkItem::AllocatedTernary {
cond: first,
thn: second,
els: third,
node,
fields,
};
}
Node::Write {
collect,
data,
indices,
} => {
if let Some(index_map) = field_map.get(collect) {
if can_sroa_type(editor, types[&data]) {
if let Some(data_idx) = field_map.get(data) {
field_map.insert(
node,
index_map.clone().replace(indices, data_idx.clone()),
);
to_delete.push(node);
} else {
worklist.push_back(WorkItem::Unhandled(node));
}
} else {
field_map.insert(node, index_map.clone().set(indices, *data));
to_delete.push(node);
}
} else {
worklist.push_back(WorkItem::Unhandled(node));
}
}
Node::Read { collect, indices } => {
if let Some(index_map) = field_map.get(collect) {
let read_info = index_map.lookup(indices);
match read_info {
IndexTree::Leaf(field) => {
to_replace.push((node, *field));
}
_ => {}
}
field_map.insert(node, read_info.clone());
to_delete.push(node);
} else {
worklist.push_back(WorkItem::Unhandled(node));
}
}
_ => panic!("Unexpected node type"),
}
}
match item {
WorkItem::Unhandled(_) => {}
WorkItem::AllocatedPhi {
control,
data,
node,
fields,
} => {
let mut data_fields = vec![];
let mut ready = true;
for val in data.iter() {
if let Some(val_fields) = field_map.get(val) {
data_fields.push(val_fields);
} else {
ready = false;
break;
}
}
if ready {
fields.zip_list(data_fields).for_each(|_, (res, data)| {
to_insert.insert(
res.idx(),
Node::Phi {
control,
data: data.into_iter().map(|n| **n).collect::<Vec<_>>().into(),
},
);
});
to_delete.push(node);
} else {
worklist.push_back(WorkItem::AllocatedPhi {
control,
data,
node,
fields,
});
}
}
WorkItem::AllocatedReduce {
control,
init,
reduct,
node,
fields,
} => {
if let (Some(init_fields), Some(reduct_fields)) =
(field_map.get(&init), field_map.get(&reduct))
{
fields.zip(init_fields).zip(reduct_fields).for_each(
|_, ((res, init), reduct)| {
to_insert.insert(
res.idx(),
Node::Reduce {
control,
init: **init,
reduct: **reduct,
},
);
},
);
to_delete.push(node);
} else {
worklist.push_back(WorkItem::AllocatedReduce {
control,
init,
reduct,
node,
fields,
});
}
}
WorkItem::AllocatedTernary {
cond,
thn,
els,
node,
fields,
} => {
if let (Some(thn_fields), Some(els_fields)) =
(field_map.get(&thn), field_map.get(&els))
{
fields
.zip(thn_fields)
.zip(els_fields)
.for_each(|_, ((res, thn), els)| {
to_insert.insert(
res.idx(),
Node::Ternary {
first: cond,
second: **thn,
third: **els,
op: TernaryOperator::Select,
},
);
});
to_delete.push(node);
} else {
worklist.push_back(WorkItem::AllocatedTernary {
cond,
thn,
els,
node,
fields,
});
}
}
}
}
// Create new nodes nodes
editor.edit(|mut edit| {
for (node_id, node) in to_insert {
let id = edit.add_node(node);
assert_eq!(node_id, id.idx());
}
Ok(edit)
});
// Replace uses of old reads
// Because a read that is being replaced could also be the node some other read is being
// replaced by (if the first read is then written into a product that is then read from again)
// we need to track what nodes have already been replaced (and by what) so we can properly
// replace uses without leaving users of nodes that should be deleted.
// replaced_by tracks what a node has been replaced by while replaced_of tracks everything that
// maps to a particular node (which is needed to maintain the data structure efficiently)
let mut replaced_by: HashMap<NodeID, NodeID> = HashMap::new();
let mut replaced_of: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
for (old, new) in to_replace {
let new = match replaced_by.get(&new) {
Some(res) => *res,
None => new,
};
editor.edit(|mut edit| {
edit.sub_edit(old, new);
edit.replace_all_uses(old, new)
});
replaced_by.insert(old, new);
let mut replaced = vec![];
match replaced_of.get_mut(&old) {
Some(res) => {
std::mem::swap(res, &mut replaced);
}
None => {}
}
let new_of = match replaced_of.get_mut(&new) {
Some(res) => res,
None => {
replaced_of.insert(new, vec![]);
replaced_of.get_mut(&new).unwrap()
}
};
new_of.push(old);
for n in replaced {
replaced_by.insert(n, new);
new_of.push(n);
}
}
// Remove nodes
editor.edit(|mut edit| {
for node in to_delete {
edit = edit.delete_node(node)?
}
Ok(edit)
});
}
pub fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool {
match &*editor.get_type(typ) {
Type::Array(_, _) => true,
Type::Product(ts) | Type::Summation(ts) => {
ts.iter().any(|t| type_contains_array(editor, *t))
}
_ => false,
}
}
// An index tree is used to store results at many index lists
#[derive(Clone, Debug)]
pub enum IndexTree<T> {
Leaf(T),
Node(Vec<IndexTree<T>>),
}
impl<T: std::fmt::Debug> IndexTree<T> {
pub fn lookup(&self, idx: &[Index]) -> &IndexTree<T> {
self.lookup_idx(idx, 0)
}
fn lookup_idx(&self, idx: &[Index], n: usize) -> &IndexTree<T> {
if n < idx.len() {
if let Index::Field(i) = idx[n] {
match self {
IndexTree::Leaf(_) => panic!("Invalid field"),
IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1),
}
} else {
panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
}
} else {
self
}
}
pub fn set(self, idx: &[Index], val: T) -> IndexTree<T> {
self.set_idx(idx, val, 0)
}
fn set_idx(self, idx: &[Index], val: T, n: usize) -> IndexTree<T> {
if n < idx.len() {
if let Index::Field(i) = idx[n] {
match self {
IndexTree::Leaf(_) => panic!("Invalid field"),
IndexTree::Node(mut ts) => {
if i + 1 == ts.len() {
let t = ts.pop().unwrap();
ts.push(t.set_idx(idx, val, n + 1));
} else {
let mut t = ts.pop().unwrap();
std::mem::swap(&mut ts[i], &mut t);
t = t.set_idx(idx, val, n + 1);
std::mem::swap(&mut ts[i], &mut t);
ts.push(t);
}
IndexTree::Node(ts)
}
}
} else {
panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
}
} else {
IndexTree::Leaf(val)
}
}
pub fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> {
self.replace_idx(idx, val, 0)
}
fn replace_idx(self, idx: &[Index], val: IndexTree<T>, n: usize) -> IndexTree<T> {
if n < idx.len() {
if let Index::Field(i) = idx[n] {
match self {
IndexTree::Leaf(_) => panic!("Invalid field"),
IndexTree::Node(mut ts) => {
if i + 1 == ts.len() {
let t = ts.pop().unwrap();
ts.push(t.replace_idx(idx, val, n + 1));
} else {
let mut t = ts.pop().unwrap();
std::mem::swap(&mut ts[i], &mut t);
t = t.replace_idx(idx, val, n + 1);
std::mem::swap(&mut ts[i], &mut t);
ts.push(t);
}
IndexTree::Node(ts)
}
}
} else {
panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
}
} else {
val
}
}
pub fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> {
match (self, other) {
(IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)),
(IndexTree::Node(t), IndexTree::Node(a)) => {
let mut fields = vec![];
for (t, a) in t.into_iter().zip(a.iter()) {
fields.push(t.zip(a));
}
IndexTree::Node(fields)
}
_ => panic!("IndexTrees do not have the same fields, cannot zip"),
}
}
pub fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> {
match self {
IndexTree::Leaf(t) => {
let mut res = vec![];
for other in others {
match other {
IndexTree::Leaf(a) => res.push(a),
_ => panic!("IndexTrees do not have the same fields, cannot zip"),
}
}
IndexTree::Leaf((t, res))
}
IndexTree::Node(t) => {
let mut fields: Vec<Vec<&'a IndexTree<A>>> = vec![vec![]; t.len()];
for other in others {
match other {
IndexTree::Node(a) => {
for (i, a) in a.iter().enumerate() {
fields[i].push(a);
}
}
_ => panic!("IndexTrees do not have the same fields, cannot zip"),
}
}
IndexTree::Node(
t.into_iter()
.zip(fields.into_iter())
.map(|(t, f)| t.zip_list(f))
.collect(),
)
}
}
}
pub fn for_each<F>(&self, mut f: F)
where
F: FnMut(&Vec<Index>, &T),
{
self.for_each_idx(&mut vec![], &mut f);
}
fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F)
where
F: FnMut(&Vec<Index>, &T),
{
match self {
IndexTree::Leaf(t) => f(idx, t),
IndexTree::Node(ts) => {
for (i, t) in ts.iter().enumerate() {
idx.push(Index::Field(i));
t.for_each_idx(idx, f);
idx.pop();
}
}
}
}
}
// Given the editor, type of some collection, and a list of indices to access that type at, returns
// the TypeID of accessing the collection at the given indices
fn type_at_index(editor: &FunctionEditor, typ: TypeID, idx: &[Index]) -> TypeID {
let mut typ = typ;
for index in idx {
match index {
Index::Field(i) => {
let Type::Product(ref ts) = *editor.get_type(typ) else {
panic!("Accessing a field of a non-product type; did typechecking succeed?");
};
typ = ts[*i];
}
Index::Variant(i) => {
let Type::Summation(ref ts) = *editor.get_type(typ) else {
panic!(
"Accessing a variant of a non-summation type; did typechecking succeed?"
);
};
typ = ts[*i];
}
Index::Position(pos) => {
let Type::Array(elem, ref dims) = *editor.get_type(typ) else {
panic!("Accessing an array position of a non-array type; did typechecking succeed?");
};
assert!(pos.len() == dims.len(), "Read mismatch array dimensions");
typ = elem;
}
}
}
return typ;
}
// Given a product value val of type typ, constructs a copy of that value by extracting all fields
// from that value and then writing them into a new constant
// This process also adds all the read nodes that are generated into the read_list so that the
// reads can be eliminated by later parts of SROA
fn reconstruct_product(
editor: &mut FunctionEditor,
typ: TypeID,
val: NodeID,
read_list: &mut Vec<NodeID>,
) -> NodeID {
let fields = generate_reads(editor, typ, val);
let new_const = generate_constant(editor, typ);
// Create a constant node
let mut const_node = None;
editor.edit(|mut edit| {
const_node = Some(edit.add_node(Node::Constant { id: new_const }));
Ok(edit)
});
// Generate writes for each field
let mut value = const_node.expect("Add node cannot fail");
fields.for_each(|idx: &Vec<Index>, val: &NodeID| {
read_list.push(*val);
editor.edit(|mut edit| {
value = edit.add_node(Node::Write {
collect: value,
data: *val,
indices: idx.clone().into(),
});
Ok(edit)
});
});
value
}
// Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and
// returns an IndexTree that tracks the nodes reading each leaf field
pub fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
let mut result = None;
editor.edit(|mut edit| {
result = Some(generate_reads_edit(&mut edit, typ, val));
Ok(edit)
});
result.unwrap()
}
// The same as generate_reads but for if we have a FunctionEdit rather than a FunctionEditor
pub fn generate_reads_edit(edit: &mut FunctionEdit, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
generate_reads_at_index_edit(edit, typ, val, vec![])
}
// Given a node val of type which at the indices idx has type typ, construct reads of all (leaf)
// fields within this sub-value of val and return the correspondence list
fn generate_reads_at_index_edit(
edit: &mut FunctionEdit,
typ: TypeID,
val: NodeID,
idx: Vec<Index>,
) -> IndexTree<NodeID> {
let ts: Option<Vec<TypeID>> = if let Some(ts) = edit.get_type(typ).try_product() {
Some(ts.into())
} else {
None
};
if let Some(ts) = ts {
// For product values, we will recurse down each of its fields with an extended index
// and the appropriate type of that field
let mut fields = vec![];
for (i, t) in ts.into_iter().enumerate() {
let mut new_idx = idx.clone();
new_idx.push(Index::Field(i));
fields.push(generate_reads_at_index_edit(edit, t, val, new_idx));
}
IndexTree::Node(fields)
} else {
// For non-product types, we've reached a leaf so we generate the read and return it's
// information
let read_id = edit.add_node(Node::Read {
collect: val,
indices: idx.into(),
});
IndexTree::Leaf(read_id)
}
}
macro_rules! add_const {
($editor:ident, $const:expr) => {{
let mut res = None;
$editor.edit(|mut edit| {
res = Some(edit.add_constant($const));
Ok(edit)
});
res.expect("Add constant cannot fail")
}};
}
// Given a type, builds a default constant of that type
pub fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
let t = editor.get_type(typ).clone();
match t {
Type::Product(ts) => {
let mut cs = vec![];
for t in ts {
cs.push(generate_constant(editor, t));
}
add_const!(editor, Constant::Product(typ, cs.into()))
}
Type::Boolean => add_const!(editor, Constant::Boolean(false)),
Type::Integer8 => add_const!(editor, Constant::Integer8(0)),
Type::Integer16 => add_const!(editor, Constant::Integer16(0)),
Type::Integer32 => add_const!(editor, Constant::Integer32(0)),
Type::Integer64 => add_const!(editor, Constant::Integer64(0)),
Type::UnsignedInteger8 => add_const!(editor, Constant::UnsignedInteger8(0)),
Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)),
Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)),
Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)),
Type::Float8 | Type::BFloat16 => panic!(),
Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))),
Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))),
Type::Summation(ts) => {
let const_id = generate_constant(editor, ts[0]);
add_const!(editor, Constant::Summation(typ, 0, const_id))
}
Type::Array(_, _) => {
add_const!(editor, Constant::Array(typ))
}
Type::Control => panic!("Cannot create constant of control type"),
Type::MultiReturn(_) => panic!("Cannot create constant of multi-return type"),
}
}
// Given a constant cnst adds node to the function which are the constant values of each field and
// returns a list of pairs of indices and the node that holds that index
fn generate_constant_fields(
editor: &mut FunctionEditor,
cnst: ConstantID,
old_node: NodeID,
) -> IndexTree<NodeID> {
let cs: Option<Vec<ConstantID>> =
if let Some(cs) = editor.get_constant(cnst).try_product_fields() {
Some(cs.into())
} else {
None
};
if let Some(cs) = cs {
let mut fields = vec![];
for c in cs {
fields.push(generate_constant_fields(editor, c, old_node));
}
IndexTree::Node(fields)
} else {
let mut node = None;
editor.edit(|mut edit| {
node = Some(edit.add_node(Node::Constant { id: cnst }));
edit.sub_edit(old_node, node.unwrap());
Ok(edit)
});
IndexTree::Leaf(node.expect("Add node cannot fail"))
}
}
// Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the
// id provided
fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> {
let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
Some(ts.into())
} else {
None
};
if let Some(ts) = ts {
let mut fields = vec![];
for t in ts {
fields.push(allocate_fields(editor, t, id));
}
IndexTree::Node(fields)
} else {
let node = *id;
*id += 1;
IndexTree::Leaf(NodeID::new(node))
}
}