From 5f52017944ecca82421bd035264dcea62fa40dca Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Thu, 23 Jan 2025 21:51:38 -0600
Subject: [PATCH] SROA read chains

---
 hercules_opt/src/sroa.rs              | 255 ++++++++++++++++++++++++--
 juno_samples/antideps/src/antideps.jn |  15 +-
 juno_samples/antideps/src/main.rs     |   4 +
 3 files changed, 262 insertions(+), 12 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 6461ad71..66d11d69 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -33,6 +33,8 @@ use crate::*;
  *
  * - Read: the read node reads primitive fields from product values - these get
  *   replaced by a direct use of the field value
+ *   A read can also extract a product from an array or sum; the value read out
+ *   will be broken into individual fields (by individual reads from the array)
  *
  * - Write: the write node writes primitive fields in product values - these get
  *   replaced by a direct def of the field value
@@ -54,15 +56,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
     // for the call's arguments or the return's value
     let mut call_return_nodes: Vec<NodeID> = vec![];
 
-    let func = editor.func();
-
     for node in reverse_postorder {
-        match func.nodes[node.idx()] {
+        match &editor.func().nodes[node.idx()] {
             Node::Phi { .. }
             | Node::Reduce { .. }
             | Node::Parameter { .. }
             | Node::Constant { .. }
-            | Node::Write { .. }
             | Node::Ternary {
                 first: _,
                 second: _,
@@ -70,8 +69,211 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 op: TernaryOperator::Select,
             } if editor.get_type(types[&node]).is_product() => product_nodes.push(*node),
 
-            Node::Read { collect, .. } if editor.get_type(types[&collect]).is_product() => {
-                product_nodes.push(*node)
+            Node::Write {
+                collect,
+                data,
+                indices,
+            } => {
+                let data = *data;
+                let collect = *collect;
+
+                // For a write, we may need to split it into two pieces if the it contains a mix of
+                // field and non-field indices
+                let (fields_write, write_prod_into_non) = {
+                    let mut fields = vec![];
+                    let mut remainder = vec![];
+
+                    let mut indices = indices.iter();
+                    while let Some(idx) = indices.next() {
+                        if idx.is_field() {
+                            fields.push(idx.clone());
+                        } else {
+                            remainder.push(idx.clone());
+                            remainder.extend(indices.cloned());
+                            break;
+                        }
+                    }
+
+                    if fields.is_empty() {
+                        if editor.get_type(types[&data]).is_product() {
+                            (None, Some((*node, collect, remainder)))
+                        } else {
+                            (None, None)
+                        }
+                    } else if remainder.is_empty() {
+                        (Some(*node), None)
+                    } else {
+                        // Here we perform the split into two writes
+                        // We need to find the type of the collection that will be extracted from
+                        // the collection being modified when we read it at the fields index
+                        let after_fields_type = type_at_index(editor, types[&collect], &fields);
+
+                        let mut inner_collection = None;
+                        let mut fields_write = None;
+                        let mut remainder_write = None;
+                        editor.edit(|mut edit| {
+                            let read_inner = edit.add_node(Node::Read {
+                                collect,
+                                indices: fields.clone().into(),
+                            });
+                            types.insert(read_inner, after_fields_type);
+                            product_nodes.push(read_inner);
+                            inner_collection = Some(read_inner);
+
+                            let rem_write = edit.add_node(Node::Write {
+                                collect: read_inner,
+                                data,
+                                indices: remainder.clone().into(),
+                            });
+                            types.insert(rem_write, after_fields_type);
+                            remainder_write = Some(rem_write);
+
+                            let complete_write = edit.add_node(Node::Write {
+                                collect,
+                                data: rem_write,
+                                indices: fields.into(),
+                            });
+                            types.insert(complete_write, types[&collect]);
+                            fields_write = Some(complete_write);
+
+                            edit = edit.replace_all_uses(*node, complete_write)?;
+                            edit.delete_node(*node)
+                        });
+                        let inner_collection = inner_collection.unwrap();
+                        let fields_write = fields_write.unwrap();
+                        let remainder_write = remainder_write.unwrap();
+
+                        if editor.get_type(types[&data]).is_product() {
+                            (
+                                Some(fields_write),
+                                Some((remainder_write, inner_collection, remainder)),
+                            )
+                        } else {
+                            (Some(fields_write), None)
+                        }
+                    }
+                };
+
+                if let Some(node) = fields_write {
+                    product_nodes.push(node);
+                }
+
+                if let Some((write_node, collection, index)) = write_prod_into_non {
+                    let node = write_node;
+                    // If we're writing a product into a non-product we need to replace the write
+                    // by a sequence of writes that read each field of the product and write them
+                    // into the collection, then those write nodes can be ignored for SROA but the
+                    // reads will be handled by SROA
+
+                    // The value being written must be the data and so must be a product
+                    assert!(editor.get_type(types[&data]).is_product());
+                    let fields = generate_reads(editor, types[&data], data);
+
+                    let mut collection = collection;
+                    let collection_type = types[&collection];
+
+                    fields.for_each(|field: &Vec<Index>, val: &NodeID| {
+                        product_nodes.push(*val);
+                        editor.edit(|mut edit| {
+                            collection = edit.add_node(Node::Write {
+                                collect: collection,
+                                data: *val,
+                                indices: index
+                                    .iter()
+                                    .chain(field)
+                                    .cloned()
+                                    .collect::<Vec<_>>()
+                                    .into(),
+                            });
+                            types.insert(collection, collection_type);
+                            Ok(edit)
+                        });
+                    });
+
+                    editor.edit(|mut edit| {
+                        edit = edit.replace_all_uses(node, collection)?;
+                        edit.delete_node(node)
+                    });
+                }
+            }
+            Node::Read { collect, indices } => {
+                // For a read, we split the read into a series of reads where each piece has either
+                // only field reads or no field reads. Those with fields are the only ones
+                // considered during SROA but any read whose collection is not a product but
+                // produces a product (i.e. if there's an array of products) then following the
+                // read we replace the read that produces a product by reads of each field and add
+                // that information to the node map for the rest of SROA (this produces some reads
+                // that mix types of indices, since we only read leaves but that's okay since those
+                // reads are not handled by SROA)
+                let indices = indices
+                    .chunk_by(|i, j| i.is_field() && j.is_field())
+                    .collect::<Vec<_>>();
+
+                let (field_reads, non_fields_produce_prod) = {
+                    if indices.len() == 0 {
+                        // If there are no indices then there were no indices originally, this is
+                        // only used with clones of arrays
+                        (vec![], vec![])
+                    } else if indices.len() == 1 {
+                        // If once we perform chunking there's only one set of indices, we can just
+                        // use the original node
+                        if indices[0][0].is_field() {
+                            (vec![*node], vec![])
+                        } else if editor.get_type(types[node]).is_product() {
+                            (vec![], vec![*node])
+                        } else {
+                            (vec![], vec![])
+                        }
+                    } else {
+                        let mut field_reads = vec![];
+                        let mut non_field = vec![];
+
+                        // To construct the multiple reads we need to track the current collection
+                        // and the type of that collection
+                        let mut collect = *collect;
+                        let mut typ = types[&collect];
+
+                        let indices = indices
+                            .into_iter()
+                            .map(|i| i.into_iter().cloned().collect::<Vec<_>>())
+                            .collect::<Vec<_>>();
+                        for index in indices {
+                            let is_field_read = index[0].is_field();
+                            let field_type = type_at_index(editor, typ, &index);
+
+                            editor.edit(|mut edit| {
+                                collect = edit.add_node(Node::Read {
+                                    collect,
+                                    indices: index.into(),
+                                });
+                                types.insert(collect, field_type);
+                                typ = field_type;
+                                Ok(edit)
+                            });
+
+                            if is_field_read {
+                                field_reads.push(collect);
+                            } else if editor.get_type(typ).is_product() {
+                                non_field.push(collect);
+                            }
+                        }
+
+                        // Replace all uses of the original read (with mixed indices) with the
+                        // newly constructed reads
+                        editor.edit(|mut edit| {
+                            edit = edit.replace_all_uses(*node, collect)?;
+                            edit.delete_node(*node)
+                        });
+
+                        (field_reads, non_field)
+                    }
+                };
+
+                product_nodes.extend(field_reads);
+
+                for node in non_fields_produce_prod {
+                    field_map.insert(node, generate_reads(editor, types[&node], node));
+                }
             }
 
             // We add all calls to the call/return list and check their arguments later
@@ -516,8 +718,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
                     IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1),
                 }
             } else {
-                // TODO: This could be hit because of an array inside of a product
-                panic!("Error handling lookup of field");
+                panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
             }
         } else {
             self
@@ -548,7 +749,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
                     }
                 }
             } else {
-                panic!("Error handling set of field");
+                panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
             }
         } else {
             IndexTree::Leaf(val)
@@ -579,7 +780,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
                     }
                 }
             } else {
-                panic!("Error handling set of field");
+                panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this");
             }
         } else {
             val
@@ -658,6 +859,38 @@ impl<T: std::fmt::Debug> IndexTree<T> {
     }
 }
 
+// Given the editor, type of some collection, and a list of indices to access that type at, returns
+// the TypeID of accessing the collection at the given indices
+fn type_at_index(editor: &FunctionEditor, typ: TypeID, idx: &[Index]) -> TypeID {
+    let mut typ = typ;
+    for index in idx {
+        match index {
+            Index::Field(i) => {
+                let Type::Product(ref ts) = *editor.get_type(typ) else {
+                    panic!("Accessing a field of a non-product type; did typechecking succeed?");
+                };
+                typ = ts[*i];
+            }
+            Index::Variant(i) => {
+                let Type::Summation(ref ts) = *editor.get_type(typ) else {
+                    panic!(
+                        "Accessing a variant of a non-summation type; did typechecking succeed?"
+                    );
+                };
+                typ = ts[*i];
+            }
+            Index::Position(pos) => {
+                let Type::Array(elem, ref dims) = *editor.get_type(typ) else {
+                    panic!("Accessing an array position of a non-array type; did typechecking succeed?");
+                };
+                assert!(pos.len() == dims.len(), "Read mismatch array dimensions");
+                typ = elem;
+            }
+        }
+    }
+    return typ;
+}
+
 // Given a product value val of type typ, constructs a copy of that value by extracting all fields
 // from that value and then writing them into a new constant
 // This process also adds all the read nodes that are generated into the read_list so that the
@@ -696,7 +929,7 @@ fn reconstruct_product(
 }
 
 // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and
-// returns a list of pairs of the indices and the node that reads that index
+// returns an IndexTree that tracks the nodes reading each leaf field
 fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
     let res = generate_reads_at_index(editor, typ, val, vec![]);
     res
diff --git a/juno_samples/antideps/src/antideps.jn b/juno_samples/antideps/src/antideps.jn
index 738ee6da..f40640d2 100644
--- a/juno_samples/antideps/src/antideps.jn
+++ b/juno_samples/antideps/src/antideps.jn
@@ -121,4 +121,17 @@ fn read_chains(input : i32) -> i32 {
   sub[1] = 99;
   arrs.0[1] = 99;
   return result + sub[1] - arrs.0[1];
-}
\ No newline at end of file
+}
+
+#[entry]
+fn array_of_structs(input: i32) -> i32 {
+  let arr : (i32, i32)[2];
+  let sub = arr[0];
+  sub.1 = input + 7;
+  arr[0] = sub;
+  arr[0].1 = input + 3;
+  let result = sub.1 + arr[0].1;
+  sub.1 = 99;
+  arr[0].1 = 99;
+  return result + sub.1 - arr[0].1;
+}
diff --git a/juno_samples/antideps/src/main.rs b/juno_samples/antideps/src/main.rs
index 6e5ed7a3..2f1e8efc 100644
--- a/juno_samples/antideps/src/main.rs
+++ b/juno_samples/antideps/src/main.rs
@@ -27,6 +27,10 @@ fn main() {
         let output = read_chains(2).await;
         println!("{}", output);
         assert_eq!(output, 14);
+
+        let output = array_of_structs(2).await;
+        println!("{}", output);
+        assert_eq!(output, 14);
     });
 }
 
-- 
GitLab