use std::collections::{BTreeMap, HashMap, VecDeque}; use hercules_ir::ir::*; use crate::*; /* * Top level function to run SROA, intraprocedurally. Product values can be used * and created by a relatively small number of nodes. Here are *all* of them: * * - Phi: can merge SSA values of products - these get broken up into phis on * the individual fields * * - Reduce: similarly to phis, reduce nodes can cycle product values through * reduction loops - these get broken up into reduces on the fields * * - Return: can return a product - the product values will be constructed * at the return site * * - Parameter: can introduce a product - reads will be introduced for each * field * * - Constant: can introduce a product - these are broken up into constants for * the individual fields * * - Ternary: the select ternary operator can select between products - these * are broken up into ternary nodes for the individual fields * * - Call: the call node can use a product value as an argument to another * function, and can produce a product value as a result. Argument values * will be constructed at the call site and the return value will be broken * into individual fields * * - Read: the read node reads primitive fields from product values - these get * replaced by a direct use of the field value * A read can also extract a product from an array or sum; the value read out * will be broken into individual fields (by individual reads from the array) * * - Write: the write node writes primitive fields in product values - these get * replaced by a direct def of the field value * * The allow_sroa_arrays variable controls whether products that contain arrays * will be broken into pieces. This option is useful to have since breaking * these products up can be expensive if it requires destructing and * reconstructing the product at any point. * * TODO: Handle partial selections (i.e. immutable nodes). This will involve * actually tracking each source and use of a product and verifying that all of * the nodes involved are mutable. */ pub fn sroa( editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>, allow_sroa_arrays: bool, ) { let mut types: HashMap<NodeID, TypeID> = types .iter() .enumerate() .map(|(i, t)| (NodeID::new(i), *t)) .collect(); let can_sroa_type = |editor: &FunctionEditor, typ: TypeID| { editor.get_type(typ).is_product() && (allow_sroa_arrays || !type_contains_array(editor, typ)) }; // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID // that contains the corresponding fields of the original value let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new(); // First: determine all nodes which interact with products (as described above) let mut product_nodes: Vec<NodeID> = vec![]; // We track call and return nodes separately since they (may) require constructing new products // for the call's arguments or the return's value let mut call_return_nodes: Vec<NodeID> = vec![]; for node in reverse_postorder { match &editor.func().nodes[node.idx()] { Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. } | Node::Constant { .. } | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select, } if can_sroa_type(editor, types[&node]) => product_nodes.push(*node), Node::Write { collect, data, indices, } => { let data = *data; let collect = *collect; // For a write, we may need to split it into two pieces if the it contains a mix of // field and non-field indices let (fields_write, write_prod_into_non) = { let mut fields = vec![]; let mut remainder = vec![]; if can_sroa_type(editor, types[&node]) { let mut indices = indices.iter(); while let Some(idx) = indices.next() { if idx.is_field() { fields.push(idx.clone()); } else { remainder.push(idx.clone()); remainder.extend(indices.cloned()); break; } } } else { remainder.extend_from_slice(indices); } if fields.is_empty() { if can_sroa_type(editor, types[&data]) { (None, Some((*node, collect, remainder))) } else { (None, None) } } else if remainder.is_empty() { (Some(*node), None) } else { // Here we perform the split into two writes // We need to find the type of the collection that will be extracted from // the collection being modified when we read it at the fields index let after_fields_type = type_at_index(editor, types[&collect], &fields); let mut inner_collection = None; let mut fields_write = None; let mut remainder_write = None; editor.edit(|mut edit| { let read_inner = edit.add_node(Node::Read { collect, indices: fields.clone().into(), }); types.insert(read_inner, after_fields_type); product_nodes.push(read_inner); inner_collection = Some(read_inner); let rem_write = edit.add_node(Node::Write { collect: read_inner, data, indices: remainder.clone().into(), }); types.insert(rem_write, after_fields_type); remainder_write = Some(rem_write); let complete_write = edit.add_node(Node::Write { collect, data: rem_write, indices: fields.into(), }); types.insert(complete_write, types[&collect]); fields_write = Some(complete_write); edit = edit.replace_all_uses(*node, complete_write)?; edit.delete_node(*node) }); let inner_collection = inner_collection.unwrap(); let fields_write = fields_write.unwrap(); let remainder_write = remainder_write.unwrap(); if editor.get_type(types[&data]).is_product() { ( Some(fields_write), Some((remainder_write, inner_collection, remainder)), ) } else { (Some(fields_write), None) } } }; if let Some(node) = fields_write { product_nodes.push(node); } if let Some((write_node, collection, index)) = write_prod_into_non { let node = write_node; // If we're writing a product into a non-product we need to replace the write // by a sequence of writes that read each field of the product and write them // into the collection, then those write nodes can be ignored for SROA but the // reads will be handled by SROA // The value being written must be the data and so must be a product assert!(editor.get_type(types[&data]).is_product()); let fields = generate_reads(editor, types[&data], data); let mut collection = collection; let collection_type = types[&collection]; fields.for_each(|field: &Vec<Index>, val: &NodeID| { product_nodes.push(*val); editor.edit(|mut edit| { collection = edit.add_node(Node::Write { collect: collection, data: *val, indices: index .iter() .chain(field) .cloned() .collect::<Vec<_>>() .into(), }); types.insert(collection, collection_type); Ok(edit) }); }); editor.edit(|mut edit| { edit = edit.replace_all_uses(node, collection)?; edit.delete_node(node) }); } } Node::Read { collect, indices } => { // For a read, we split the read into a series of reads where each piece has either // only field reads or no field reads. Those with fields are the only ones // considered during SROA but any read whose collection is not a product but // produces a product (i.e. if there's an array of products) then following the // read we replace the read that produces a product by reads of each field and add // that information to the node map for the rest of SROA (this produces some reads // that mix types of indices, since we only read leaves but that's okay since those // reads are not handled by SROA) let indices = if can_sroa_type(editor, types[collect]) { indices .chunk_by(|i, j| i.is_field() == j.is_field()) .collect::<Vec<_>>() } else { vec![indices.as_ref()] }; let (field_reads, non_fields_produce_prod) = { if indices.len() == 0 { // If there are no indices then there were no indices originally, this is // only used with clones of arrays (vec![], vec![]) } else if indices.len() == 1 { // If once we perform chunking there's only one set of indices, we can just // use the original node if can_sroa_type(editor, types[collect]) { (vec![*node], vec![]) } else if can_sroa_type(editor, types[node]) { (vec![], vec![*node]) } else { (vec![], vec![]) } } else { let mut field_reads = vec![]; let mut non_field = vec![]; // To construct the multiple reads we need to track the current collection // and the type of that collection let mut collect = *collect; let mut typ = types[&collect]; let indices = indices .into_iter() .map(|i| i.into_iter().cloned().collect::<Vec<_>>()) .collect::<Vec<_>>(); for index in indices { let is_field_read = index[0].is_field(); let field_type = type_at_index(editor, typ, &index); editor.edit(|mut edit| { collect = edit.add_node(Node::Read { collect, indices: index.into(), }); types.insert(collect, field_type); typ = field_type; Ok(edit) }); if is_field_read { field_reads.push(collect); } else if editor.get_type(typ).is_product() { non_field.push(collect); } } // Replace all uses of the original read (with mixed indices) with the // newly constructed reads editor.edit(|mut edit| { edit = edit.replace_all_uses(*node, collect)?; edit.delete_node(*node) }); (field_reads, non_field) } }; product_nodes.extend(field_reads); for node in non_fields_produce_prod { field_map.insert(node, generate_reads(editor, types[&node], node)); } } // We add all calls to the call/return list and check their arguments later Node::Call { .. } => call_return_nodes.push(*node), Node::Return { control: _, data } if can_sroa_type(editor, types[&data]) => { call_return_nodes.push(*node) } _ => (), } } // Next, we handle calls and returns. For returns, we will insert nodes that read each field of // the returned product and then write them into a new product. These writes are not put into // the list of product nodes since they must remain but the reads are so that they will be // replaced later on. // For calls, we do a similar process for each (product) argument. Additionally, if the call // returns a product, we create reads for each field in that product and store it into our // field map for node in call_return_nodes { match &editor.func().nodes[node.idx()] { Node::Return { control, data } => { assert!(can_sroa_type(editor, types[&data])); let control = *control; let new_data = reconstruct_product(editor, types[&data], *data, &mut product_nodes); editor.edit(|mut edit| { let new_return = edit.add_node(Node::Return { control, data: new_data, }); edit.sub_edit(node, new_return); edit.delete_node(node) }); } Node::Call { control, function, dynamic_constants, args, } => { let control = *control; let function = *function; let dynamic_constants = dynamic_constants.clone(); let args = args.clone(); // If the call returns a product that we can sroa, we generate reads for each field let fields = if can_sroa_type(editor, types[&node]) { Some(generate_reads(editor, types[&node], node)) } else { None }; let mut new_args = vec![]; for arg in args { if can_sroa_type(editor, types[&arg]) { new_args.push(reconstruct_product( editor, types[&arg], arg, &mut product_nodes, )); } else { new_args.push(arg); } } editor.edit(|mut edit| { let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into(), }); edit.sub_edit(node, new_call); let edit = edit.replace_all_uses(node, new_call)?; let edit = edit.delete_node(node)?; // Since we've replaced uses of calls with the new node, we update the type // information so that we can retrieve the type of the new call if needed // Because the other nodes we've created so far are only used in very // particular ways (i.e. are not used by arbitrary nodes) we don't need their // type information but do for the new calls types.insert(new_call, types[&node]); match fields { None => {} Some(fields) => { field_map.insert(new_call, fields); } } Ok(edit) }); } _ => panic!("Processing non-call or return node"), } } #[derive(Debug)] enum WorkItem { Unhandled(NodeID), AllocatedPhi { control: NodeID, data: Vec<NodeID>, node: NodeID, fields: IndexTree<NodeID>, }, AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, node: NodeID, fields: IndexTree<NodeID>, }, AllocatedTernary { cond: NodeID, thn: NodeID, els: NodeID, node: NodeID, fields: IndexTree<NodeID>, }, } // Now, we process the other nodes that deal with products. // The first step is to assign new NodeIDs to the nodes that will be split into multiple: phi, // reduce, parameter, constant, and ternary. // We do this in several steps: first we break apart parameters and constants let mut to_delete = vec![]; let mut worklist: VecDeque<WorkItem> = VecDeque::new(); for node in product_nodes { match editor.func().nodes[node.idx()] { Node::Parameter { .. } => { field_map.insert(node, generate_reads(editor, types[&node], node)); } Node::Constant { id } => { field_map.insert(node, generate_constant_fields(editor, id)); to_delete.push(node); } _ => { worklist.push_back(WorkItem::Unhandled(node)); } } } // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map. // We track the current NodeID and add nodes to a set we maintain of nodes to add (since we // need to add nodes in a particular order we wait to do that until the end). If we don't have // enough information to process a particular node, we add it back to the worklist let mut next_id: usize = editor.func().nodes.len(); let mut to_insert = BTreeMap::new(); let mut to_replace: Vec<(NodeID, NodeID)> = vec![]; while let Some(mut item) = worklist.pop_front() { if let WorkItem::Unhandled(node) = item { match &editor.func().nodes[node.idx()] { // For phi, reduce, and ternary, we break them apart into separate nodes for each field Node::Phi { control, data } => { let control = *control; let data = data.clone(); let fields = allocate_fields(editor, types[&node], &mut next_id); field_map.insert(node, fields.clone()); item = WorkItem::AllocatedPhi { control, data: data.into(), node, fields, }; } Node::Reduce { control, init, reduct, } => { let control = *control; let init = *init; let reduct = *reduct; let fields = allocate_fields(editor, types[&node], &mut next_id); field_map.insert(node, fields.clone()); item = WorkItem::AllocatedReduce { control, init, reduct, node, fields, }; } Node::Ternary { first, second, third, .. } => { let first = *first; let second = *second; let third = *third; let fields = allocate_fields(editor, types[&node], &mut next_id); field_map.insert(node, fields.clone()); item = WorkItem::AllocatedTernary { cond: first, thn: second, els: third, node, fields, }; } Node::Write { collect, data, indices, } => { if let Some(index_map) = field_map.get(collect) { if can_sroa_type(editor, types[&data]) { if let Some(data_idx) = field_map.get(data) { field_map.insert( node, index_map.clone().replace(indices, data_idx.clone()), ); to_delete.push(node); } else { worklist.push_back(WorkItem::Unhandled(node)); } } else { field_map.insert(node, index_map.clone().set(indices, *data)); to_delete.push(node); } } else { worklist.push_back(WorkItem::Unhandled(node)); } } Node::Read { collect, indices } => { if let Some(index_map) = field_map.get(collect) { let read_info = index_map.lookup(indices); match read_info { IndexTree::Leaf(field) => { to_replace.push((node, *field)); } _ => {} } field_map.insert(node, read_info.clone()); to_delete.push(node); } else { worklist.push_back(WorkItem::Unhandled(node)); } } _ => panic!("Unexpected node type"), } } match item { WorkItem::Unhandled(_) => {} WorkItem::AllocatedPhi { control, data, node, fields, } => { let mut data_fields = vec![]; let mut ready = true; for val in data.iter() { if let Some(val_fields) = field_map.get(val) { data_fields.push(val_fields); } else { ready = false; break; } } if ready { fields.zip_list(data_fields).for_each(|_, (res, data)| { to_insert.insert( res.idx(), Node::Phi { control, data: data.into_iter().map(|n| **n).collect::<Vec<_>>().into(), }, ); }); to_delete.push(node); } else { worklist.push_back(WorkItem::AllocatedPhi { control, data, node, fields, }); } } WorkItem::AllocatedReduce { control, init, reduct, node, fields, } => { if let (Some(init_fields), Some(reduct_fields)) = (field_map.get(&init), field_map.get(&reduct)) { fields.zip(init_fields).zip(reduct_fields).for_each( |_, ((res, init), reduct)| { to_insert.insert( res.idx(), Node::Reduce { control, init: **init, reduct: **reduct, }, ); }, ); to_delete.push(node); } else { worklist.push_back(WorkItem::AllocatedReduce { control, init, reduct, node, fields, }); } } WorkItem::AllocatedTernary { cond, thn, els, node, fields, } => { if let (Some(thn_fields), Some(els_fields)) = (field_map.get(&thn), field_map.get(&els)) { fields .zip(thn_fields) .zip(els_fields) .for_each(|_, ((res, thn), els)| { to_insert.insert( res.idx(), Node::Ternary { first: cond, second: **thn, third: **els, op: TernaryOperator::Select, }, ); }); to_delete.push(node); } else { worklist.push_back(WorkItem::AllocatedTernary { cond, thn, els, node, fields, }); } } } } // Create new nodes nodes editor.edit(|mut edit| { for (node_id, node) in to_insert { let id = edit.add_node(node); assert_eq!(node_id, id.idx()); } Ok(edit) }); // Replace uses of old reads // Because a read that is being replaced could also be the node some other read is being // replaced by (if the first read is then written into a product that is then read from again) // we need to track what nodes have already been replaced (and by what) so we can properly // replace uses without leaving users of nodes that should be deleted. // replaced_by tracks what a node has been replaced by while replaced_of tracks everything that // maps to a particular node (which is needed to maintain the data structure efficiently) let mut replaced_by: HashMap<NodeID, NodeID> = HashMap::new(); let mut replaced_of: HashMap<NodeID, Vec<NodeID>> = HashMap::new(); for (old, new) in to_replace { let new = match replaced_by.get(&new) { Some(res) => *res, None => new, }; editor.edit(|mut edit| { edit.sub_edit(old, new); edit.replace_all_uses(old, new) }); replaced_by.insert(old, new); let mut replaced = vec![]; match replaced_of.get_mut(&old) { Some(res) => { std::mem::swap(res, &mut replaced); } None => {} } let new_of = match replaced_of.get_mut(&new) { Some(res) => res, None => { replaced_of.insert(new, vec![]); replaced_of.get_mut(&new).unwrap() } }; new_of.push(old); for n in replaced { replaced_by.insert(n, new); new_of.push(n); } } // Remove nodes editor.edit(|mut edit| { for node in to_delete { edit = edit.delete_node(node)? } Ok(edit) }); } fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool { match &*editor.get_type(typ) { Type::Array(_, _) => true, Type::Product(ts) | Type::Summation(ts) => { ts.iter().any(|t| type_contains_array(editor, *t)) } _ => false, } } // An index tree is used to store results at many index lists #[derive(Clone, Debug)] pub enum IndexTree<T> { Leaf(T), Node(Vec<IndexTree<T>>), } impl<T: std::fmt::Debug> IndexTree<T> { pub fn lookup(&self, idx: &[Index]) -> &IndexTree<T> { self.lookup_idx(idx, 0) } fn lookup_idx(&self, idx: &[Index], n: usize) -> &IndexTree<T> { if n < idx.len() { if let Index::Field(i) = idx[n] { match self { IndexTree::Leaf(_) => panic!("Invalid field"), IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1), } } else { panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this"); } } else { self } } pub fn set(self, idx: &[Index], val: T) -> IndexTree<T> { self.set_idx(idx, val, 0) } fn set_idx(self, idx: &[Index], val: T, n: usize) -> IndexTree<T> { if n < idx.len() { if let Index::Field(i) = idx[n] { match self { IndexTree::Leaf(_) => panic!("Invalid field"), IndexTree::Node(mut ts) => { if i + 1 == ts.len() { let t = ts.pop().unwrap(); ts.push(t.set_idx(idx, val, n + 1)); } else { let mut t = ts.pop().unwrap(); std::mem::swap(&mut ts[i], &mut t); t = t.set_idx(idx, val, n + 1); std::mem::swap(&mut ts[i], &mut t); ts.push(t); } IndexTree::Node(ts) } } } else { panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this"); } } else { IndexTree::Leaf(val) } } pub fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> { self.replace_idx(idx, val, 0) } fn replace_idx(self, idx: &[Index], val: IndexTree<T>, n: usize) -> IndexTree<T> { if n < idx.len() { if let Index::Field(i) = idx[n] { match self { IndexTree::Leaf(_) => panic!("Invalid field"), IndexTree::Node(mut ts) => { if i + 1 == ts.len() { let t = ts.pop().unwrap(); ts.push(t.replace_idx(idx, val, n + 1)); } else { let mut t = ts.pop().unwrap(); std::mem::swap(&mut ts[i], &mut t); t = t.replace_idx(idx, val, n + 1); std::mem::swap(&mut ts[i], &mut t); ts.push(t); } IndexTree::Node(ts) } } } else { panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this"); } } else { val } } pub fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> { match (self, other) { (IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)), (IndexTree::Node(t), IndexTree::Node(a)) => { let mut fields = vec![]; for (t, a) in t.into_iter().zip(a.iter()) { fields.push(t.zip(a)); } IndexTree::Node(fields) } _ => panic!("IndexTrees do not have the same fields, cannot zip"), } } pub fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> { match self { IndexTree::Leaf(t) => { let mut res = vec![]; for other in others { match other { IndexTree::Leaf(a) => res.push(a), _ => panic!("IndexTrees do not have the same fields, cannot zip"), } } IndexTree::Leaf((t, res)) } IndexTree::Node(t) => { let mut fields: Vec<Vec<&'a IndexTree<A>>> = vec![vec![]; t.len()]; for other in others { match other { IndexTree::Node(a) => { for (i, a) in a.iter().enumerate() { fields[i].push(a); } } _ => panic!("IndexTrees do not have the same fields, cannot zip"), } } IndexTree::Node( t.into_iter() .zip(fields.into_iter()) .map(|(t, f)| t.zip_list(f)) .collect(), ) } } } pub fn for_each<F>(&self, mut f: F) where F: FnMut(&Vec<Index>, &T), { self.for_each_idx(&mut vec![], &mut f); } fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F) where F: FnMut(&Vec<Index>, &T), { match self { IndexTree::Leaf(t) => f(idx, t), IndexTree::Node(ts) => { for (i, t) in ts.iter().enumerate() { idx.push(Index::Field(i)); t.for_each_idx(idx, f); idx.pop(); } } } } } // Given the editor, type of some collection, and a list of indices to access that type at, returns // the TypeID of accessing the collection at the given indices fn type_at_index(editor: &FunctionEditor, typ: TypeID, idx: &[Index]) -> TypeID { let mut typ = typ; for index in idx { match index { Index::Field(i) => { let Type::Product(ref ts) = *editor.get_type(typ) else { panic!("Accessing a field of a non-product type; did typechecking succeed?"); }; typ = ts[*i]; } Index::Variant(i) => { let Type::Summation(ref ts) = *editor.get_type(typ) else { panic!( "Accessing a variant of a non-summation type; did typechecking succeed?" ); }; typ = ts[*i]; } Index::Position(pos) => { let Type::Array(elem, ref dims) = *editor.get_type(typ) else { panic!("Accessing an array position of a non-array type; did typechecking succeed?"); }; assert!(pos.len() == dims.len(), "Read mismatch array dimensions"); typ = elem; } } } return typ; } // Given a product value val of type typ, constructs a copy of that value by extracting all fields // from that value and then writing them into a new constant // This process also adds all the read nodes that are generated into the read_list so that the // reads can be eliminated by later parts of SROA fn reconstruct_product( editor: &mut FunctionEditor, typ: TypeID, val: NodeID, read_list: &mut Vec<NodeID>, ) -> NodeID { let fields = generate_reads(editor, typ, val); let new_const = generate_constant(editor, typ); // Create a constant node let mut const_node = None; editor.edit(|mut edit| { const_node = Some(edit.add_node(Node::Constant { id: new_const })); Ok(edit) }); // Generate writes for each field let mut value = const_node.expect("Add node cannot fail"); fields.for_each(|idx: &Vec<Index>, val: &NodeID| { read_list.push(*val); editor.edit(|mut edit| { value = edit.add_node(Node::Write { collect: value, data: *val, indices: idx.clone().into(), }); Ok(edit) }); }); value } // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and // returns an IndexTree that tracks the nodes reading each leaf field fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { let res = generate_reads_at_index(editor, typ, val, vec![]); res } // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) // fields within this sub-value of val and return the correspondence list fn generate_reads_at_index( editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>, ) -> IndexTree<NodeID> { let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { Some(ts.into()) } else { None }; if let Some(ts) = ts { // For product values, we will recurse down each of its fields with an extended index // and the appropriate type of that field let mut fields = vec![]; for (i, t) in ts.into_iter().enumerate() { let mut new_idx = idx.clone(); new_idx.push(Index::Field(i)); fields.push(generate_reads_at_index(editor, t, val, new_idx)); } IndexTree::Node(fields) } else { // For non-product types, we've reached a leaf so we generate the read and return it's // information let mut read_id = None; editor.edit(|mut edit| { read_id = Some(edit.add_node(Node::Read { collect: val, indices: idx.clone().into(), })); Ok(edit) }); IndexTree::Leaf(read_id.expect("Add node canont fail")) } } macro_rules! add_const { ($editor:ident, $const:expr) => {{ let mut res = None; $editor.edit(|mut edit| { res = Some(edit.add_constant($const)); Ok(edit) }); res.expect("Add constant cannot fail") }}; } // Given a type, builds a default constant of that type fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { let t = editor.get_type(typ).clone(); match t { Type::Product(ts) => { let mut cs = vec![]; for t in ts { cs.push(generate_constant(editor, t)); } add_const!(editor, Constant::Product(typ, cs.into())) } Type::Boolean => add_const!(editor, Constant::Boolean(false)), Type::Integer8 => add_const!(editor, Constant::Integer8(0)), Type::Integer16 => add_const!(editor, Constant::Integer16(0)), Type::Integer32 => add_const!(editor, Constant::Integer32(0)), Type::Integer64 => add_const!(editor, Constant::Integer64(0)), Type::UnsignedInteger8 => add_const!(editor, Constant::UnsignedInteger8(0)), Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)), Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)), Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)), Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))), Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))), Type::Summation(ts) => { let const_id = generate_constant(editor, ts[0]); add_const!(editor, Constant::Summation(typ, 0, const_id)) } Type::Array(_, _) => { add_const!(editor, Constant::Array(typ)) } Type::Control => panic!("Cannot create constant of control type"), } } // Given a constant cnst adds node to the function which are the constant values of each field and // returns a list of pairs of indices and the node that holds that index fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> { let cs: Option<Vec<ConstantID>> = if let Some(cs) = editor.get_constant(cnst).try_product_fields() { Some(cs.into()) } else { None }; if let Some(cs) = cs { let mut fields = vec![]; for c in cs { fields.push(generate_constant_fields(editor, c)); } IndexTree::Node(fields) } else { let mut node = None; editor.edit(|mut edit| { node = Some(edit.add_node(Node::Constant { id: cnst })); Ok(edit) }); IndexTree::Leaf(node.expect("Add node cannot fail")) } } // Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the // id provided fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> { let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { Some(ts.into()) } else { None }; if let Some(ts) = ts { let mut fields = vec![]; for t in ts { fields.push(allocate_fields(editor, t, id)); } IndexTree::Node(fields) } else { let node = *id; *id += 1; IndexTree::Leaf(NodeID::new(node)) } }