From 5f52017944ecca82421bd035264dcea62fa40dca Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 23 Jan 2025 21:51:38 -0600 Subject: [PATCH] SROA read chains --- hercules_opt/src/sroa.rs | 255 ++++++++++++++++++++++++-- juno_samples/antideps/src/antideps.jn | 15 +- juno_samples/antideps/src/main.rs | 4 + 3 files changed, 262 insertions(+), 12 deletions(-) diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 6461ad71..66d11d69 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -33,6 +33,8 @@ use crate::*; * * - Read: the read node reads primitive fields from product values - these get * replaced by a direct use of the field value + * A read can also extract a product from an array or sum; the value read out + * will be broken into individual fields (by individual reads from the array) * * - Write: the write node writes primitive fields in product values - these get * replaced by a direct def of the field value @@ -54,15 +56,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: // for the call's arguments or the return's value let mut call_return_nodes: Vec<NodeID> = vec![]; - let func = editor.func(); - for node in reverse_postorder { - match func.nodes[node.idx()] { + match &editor.func().nodes[node.idx()] { Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. } | Node::Constant { .. } - | Node::Write { .. } | Node::Ternary { first: _, second: _, @@ -70,8 +69,211 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: op: TernaryOperator::Select, } if editor.get_type(types[&node]).is_product() => product_nodes.push(*node), - Node::Read { collect, .. } if editor.get_type(types[&collect]).is_product() => { - product_nodes.push(*node) + Node::Write { + collect, + data, + indices, + } => { + let data = *data; + let collect = *collect; + + // For a write, we may need to split it into two pieces if the it contains a mix of + // field and non-field indices + let (fields_write, write_prod_into_non) = { + let mut fields = vec![]; + let mut remainder = vec![]; + + let mut indices = indices.iter(); + while let Some(idx) = indices.next() { + if idx.is_field() { + fields.push(idx.clone()); + } else { + remainder.push(idx.clone()); + remainder.extend(indices.cloned()); + break; + } + } + + if fields.is_empty() { + if editor.get_type(types[&data]).is_product() { + (None, Some((*node, collect, remainder))) + } else { + (None, None) + } + } else if remainder.is_empty() { + (Some(*node), None) + } else { + // Here we perform the split into two writes + // We need to find the type of the collection that will be extracted from + // the collection being modified when we read it at the fields index + let after_fields_type = type_at_index(editor, types[&collect], &fields); + + let mut inner_collection = None; + let mut fields_write = None; + let mut remainder_write = None; + editor.edit(|mut edit| { + let read_inner = edit.add_node(Node::Read { + collect, + indices: fields.clone().into(), + }); + types.insert(read_inner, after_fields_type); + product_nodes.push(read_inner); + inner_collection = Some(read_inner); + + let rem_write = edit.add_node(Node::Write { + collect: read_inner, + data, + indices: remainder.clone().into(), + }); + types.insert(rem_write, after_fields_type); + remainder_write = Some(rem_write); + + let complete_write = edit.add_node(Node::Write { + collect, + data: rem_write, + indices: fields.into(), + }); + types.insert(complete_write, types[&collect]); + fields_write = Some(complete_write); + + edit = edit.replace_all_uses(*node, complete_write)?; + edit.delete_node(*node) + }); + let inner_collection = inner_collection.unwrap(); + let fields_write = fields_write.unwrap(); + let remainder_write = remainder_write.unwrap(); + + if editor.get_type(types[&data]).is_product() { + ( + Some(fields_write), + Some((remainder_write, inner_collection, remainder)), + ) + } else { + (Some(fields_write), None) + } + } + }; + + if let Some(node) = fields_write { + product_nodes.push(node); + } + + if let Some((write_node, collection, index)) = write_prod_into_non { + let node = write_node; + // If we're writing a product into a non-product we need to replace the write + // by a sequence of writes that read each field of the product and write them + // into the collection, then those write nodes can be ignored for SROA but the + // reads will be handled by SROA + + // The value being written must be the data and so must be a product + assert!(editor.get_type(types[&data]).is_product()); + let fields = generate_reads(editor, types[&data], data); + + let mut collection = collection; + let collection_type = types[&collection]; + + fields.for_each(|field: &Vec<Index>, val: &NodeID| { + product_nodes.push(*val); + editor.edit(|mut edit| { + collection = edit.add_node(Node::Write { + collect: collection, + data: *val, + indices: index + .iter() + .chain(field) + .cloned() + .collect::<Vec<_>>() + .into(), + }); + types.insert(collection, collection_type); + Ok(edit) + }); + }); + + editor.edit(|mut edit| { + edit = edit.replace_all_uses(node, collection)?; + edit.delete_node(node) + }); + } + } + Node::Read { collect, indices } => { + // For a read, we split the read into a series of reads where each piece has either + // only field reads or no field reads. Those with fields are the only ones + // considered during SROA but any read whose collection is not a product but + // produces a product (i.e. if there's an array of products) then following the + // read we replace the read that produces a product by reads of each field and add + // that information to the node map for the rest of SROA (this produces some reads + // that mix types of indices, since we only read leaves but that's okay since those + // reads are not handled by SROA) + let indices = indices + .chunk_by(|i, j| i.is_field() && j.is_field()) + .collect::<Vec<_>>(); + + let (field_reads, non_fields_produce_prod) = { + if indices.len() == 0 { + // If there are no indices then there were no indices originally, this is + // only used with clones of arrays + (vec![], vec![]) + } else if indices.len() == 1 { + // If once we perform chunking there's only one set of indices, we can just + // use the original node + if indices[0][0].is_field() { + (vec![*node], vec![]) + } else if editor.get_type(types[node]).is_product() { + (vec![], vec![*node]) + } else { + (vec![], vec![]) + } + } else { + let mut field_reads = vec![]; + let mut non_field = vec![]; + + // To construct the multiple reads we need to track the current collection + // and the type of that collection + let mut collect = *collect; + let mut typ = types[&collect]; + + let indices = indices + .into_iter() + .map(|i| i.into_iter().cloned().collect::<Vec<_>>()) + .collect::<Vec<_>>(); + for index in indices { + let is_field_read = index[0].is_field(); + let field_type = type_at_index(editor, typ, &index); + + editor.edit(|mut edit| { + collect = edit.add_node(Node::Read { + collect, + indices: index.into(), + }); + types.insert(collect, field_type); + typ = field_type; + Ok(edit) + }); + + if is_field_read { + field_reads.push(collect); + } else if editor.get_type(typ).is_product() { + non_field.push(collect); + } + } + + // Replace all uses of the original read (with mixed indices) with the + // newly constructed reads + editor.edit(|mut edit| { + edit = edit.replace_all_uses(*node, collect)?; + edit.delete_node(*node) + }); + + (field_reads, non_field) + } + }; + + product_nodes.extend(field_reads); + + for node in non_fields_produce_prod { + field_map.insert(node, generate_reads(editor, types[&node], node)); + } } // We add all calls to the call/return list and check their arguments later @@ -516,8 +718,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1), } } else { - // TODO: This could be hit because of an array inside of a product - panic!("Error handling lookup of field"); + panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this"); } } else { self @@ -548,7 +749,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } } else { - panic!("Error handling set of field"); + panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this"); } } else { IndexTree::Leaf(val) @@ -579,7 +780,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } } else { - panic!("Error handling set of field"); + panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this"); } } else { val @@ -658,6 +859,38 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } +// Given the editor, type of some collection, and a list of indices to access that type at, returns +// the TypeID of accessing the collection at the given indices +fn type_at_index(editor: &FunctionEditor, typ: TypeID, idx: &[Index]) -> TypeID { + let mut typ = typ; + for index in idx { + match index { + Index::Field(i) => { + let Type::Product(ref ts) = *editor.get_type(typ) else { + panic!("Accessing a field of a non-product type; did typechecking succeed?"); + }; + typ = ts[*i]; + } + Index::Variant(i) => { + let Type::Summation(ref ts) = *editor.get_type(typ) else { + panic!( + "Accessing a variant of a non-summation type; did typechecking succeed?" + ); + }; + typ = ts[*i]; + } + Index::Position(pos) => { + let Type::Array(elem, ref dims) = *editor.get_type(typ) else { + panic!("Accessing an array position of a non-array type; did typechecking succeed?"); + }; + assert!(pos.len() == dims.len(), "Read mismatch array dimensions"); + typ = elem; + } + } + } + return typ; +} + // Given a product value val of type typ, constructs a copy of that value by extracting all fields // from that value and then writing them into a new constant // This process also adds all the read nodes that are generated into the read_list so that the @@ -696,7 +929,7 @@ fn reconstruct_product( } // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and -// returns a list of pairs of the indices and the node that reads that index +// returns an IndexTree that tracks the nodes reading each leaf field fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { let res = generate_reads_at_index(editor, typ, val, vec![]); res diff --git a/juno_samples/antideps/src/antideps.jn b/juno_samples/antideps/src/antideps.jn index 738ee6da..f40640d2 100644 --- a/juno_samples/antideps/src/antideps.jn +++ b/juno_samples/antideps/src/antideps.jn @@ -121,4 +121,17 @@ fn read_chains(input : i32) -> i32 { sub[1] = 99; arrs.0[1] = 99; return result + sub[1] - arrs.0[1]; -} \ No newline at end of file +} + +#[entry] +fn array_of_structs(input: i32) -> i32 { + let arr : (i32, i32)[2]; + let sub = arr[0]; + sub.1 = input + 7; + arr[0] = sub; + arr[0].1 = input + 3; + let result = sub.1 + arr[0].1; + sub.1 = 99; + arr[0].1 = 99; + return result + sub.1 - arr[0].1; +} diff --git a/juno_samples/antideps/src/main.rs b/juno_samples/antideps/src/main.rs index 6e5ed7a3..2f1e8efc 100644 --- a/juno_samples/antideps/src/main.rs +++ b/juno_samples/antideps/src/main.rs @@ -27,6 +27,10 @@ fn main() { let output = read_chains(2).await; println!("{}", output); assert_eq!(output, 14); + + let output = array_of_structs(2).await; + println!("{}", output); + assert_eq!(output, 14); }); } -- GitLab