From c9445e5256cab3279f3119fd2baa8bfd2e4cc8a8 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 18 Nov 2024 14:50:21 -0600
Subject: [PATCH 01/12] Start rewriting SROA

---
 hercules_ir/src/ir.rs    |   8 +
 hercules_opt/src/pass.rs |  31 ++-
 hercules_opt/src/sroa.rs | 507 ++++++++++++++++-----------------------
 juno_frontend/src/lib.rs |   6 +
 juno_samples/products.jn |   7 +
 5 files changed, 250 insertions(+), 309 deletions(-)
 create mode 100644 juno_samples/products.jn

diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 5cf549a8..1486f5ce 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -728,6 +728,14 @@ impl Type {
             None
         }
     }
+
+    pub fn try_product(&self) -> Option<&[TypeID]> {
+        if let Type::Product(ts) = self {
+            Some(ts)
+        } else {
+            None
+        }
+    }
 }
 
 impl Constant {
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index cb56f709..a7d48efb 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -494,16 +494,37 @@ impl PassManager {
                     let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
                     let typing = self.typing.as_ref().unwrap();
                     for idx in 0..self.module.functions.len() {
-                        sroa(
+                        let constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.constants));
+                        let dynamic_constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.dynamic_constants));
+                        let types_ref = RefCell::new(std::mem::take(&mut self.module.types));
+                        let mut editor = FunctionEditor::new(
                             &mut self.module.functions[idx],
+                            &constants_ref,
+                            &dynamic_constants_ref,
+                            &types_ref,
                             &def_uses[idx],
+                        );
+                        sroa(
+                            &mut editor,
                             &reverse_postorders[idx],
-                            &typing[idx],
-                            &self.module.types,
-                            &mut self.module.constants,
+                            &typing[idx]
                         );
+
+                        self.module.constants = constants_ref.take();
+                        self.module.dynamic_constants = dynamic_constants_ref.take();
+                        self.module.types = types_ref.take();
+
+                        let edits = &editor.edits();
+                        if let Some(plans) = self.plans.as_mut() {
+                            repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
+                        }
+                        let grave_mapping = self.module.functions[idx].delete_gravestones();
+                        if let Some(plans) = self.plans.as_mut() {
+                            plans[idx].fix_gravestones(&grave_mapping);
+                        }
                     }
-                    self.legacy_repair_plan();
                     self.clear_analyses();
                 }
                 Pass::Inline => {
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index cb5ecd25..ae5d76dc 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -1,15 +1,12 @@
-extern crate bitvec;
 extern crate hercules_ir;
 
-use std::collections::HashMap;
-use std::iter::zip;
-
-use self::bitvec::prelude::*;
+use std::collections::{HashMap, LinkedList};
 
 use self::hercules_ir::dataflow::*;
-use self::hercules_ir::def_use::*;
 use self::hercules_ir::ir::*;
 
+use crate::*;
+
 /*
  * Top level function to run SROA, intraprocedurally. Product values can be used
  * and created by a relatively small number of nodes. Here are *all* of them:
@@ -20,11 +17,11 @@ use self::hercules_ir::ir::*;
  * - Reduce: similarly to phis, reduce nodes can cycle product values through
  *   reduction loops - these get broken up into reduces on the fields
  *
- * + Return: can return a product - these are untouched, and are the sinks for
- *   unbroken product values
+ * - Return: can return a product - the product values will be constructed
+ *   at the return site
  *
- * + Parameter: can introduce a product - these are untouched, and are the
- *   sources for unbroken product values
+ * - Parameter: can introduce a product - reads will be introduced for each
+ *   field
  *
  * - Constant: can introduce a product - these are broken up into constants for
  *   the individual fields
@@ -32,334 +29,236 @@ use self::hercules_ir::ir::*;
  * - Ternary: the select ternary operator can select between products - these
  *   are broken up into ternary nodes for the individual fields
  *
- * + Call: the call node can use a product value as an argument to another
- *   function, and can produce a product value as a result - these are
- *   untouched, and are the sink and source for unbroken product values
+ * - Call: the call node can use a product value as an argument to another
+ *   function, and can produce a product value as a result. Argument values
+ *   will be constructed at the call site and the return value will be broken
+ *   into individual fields
  *
  * - Read: the read node reads primitive fields from product values - these get
- *   replaced by a direct use of the field value from the broken product value,
- *   but are retained when the product value is unbroken
+ *   replaced by a direct use of the field value
  *
  * - Write: the write node writes primitive fields in product values - these get
- *   replaced by a direct def of the field value from the broken product value,
- *   but are retained when the product value is unbroken
- *
- * The nodes above with the list marker "+" are retained for maintaining API/ABI
- * compatability with other Hercules functions and the host code. These are
- * called "sink" or "source" nodes in comments below.
+ *   replaced by a direct def of the field value
  */
-pub fn sroa(
-    function: &mut Function,
-    def_use: &ImmutableDefUseMap,
-    reverse_postorder: &Vec<NodeID>,
-    typing: &Vec<TypeID>,
-    types: &Vec<Type>,
-    constants: &mut Vec<Constant>,
-) {
-    // Determine which sources of product values we want to try breaking up. We
-    // can determine easily on the soure side if a node produces a product that
-    // shouldn't be broken up by just examining the node type. However, the way
-    // that products are used is also important for determining if the product
-    // can be broken up. We backward dataflow this info to the sources of
-    // product values.
-    #[derive(PartialEq, Eq, Clone, Debug)]
-    enum ProductUseLattice {
-        // The product value used by this node is eventually used by a sink.
-        UsedBySink,
-        // This node uses multiple product values - the stored node ID indicates
-        // which is eventually used by a sink. This lattice value is produced by
-        // read and write nodes implementing partial indexing.
-        SpecificUsedBySink(NodeID),
-        // This node doesn't use a product node, or the product node it does use
-        // is not in turn used by a sink.
-        UnusedBySink,
-    }
+pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) {
+    // This map stores a map from NodeID and fields to the NodeID which contains just that field of
+    // the original value
+    let mut field_map : HashMap<(NodeID, Vec<Index>), NodeID> = HashMap::new();
 
-    impl Semilattice for ProductUseLattice {
-        fn meet(a: &Self, b: &Self) -> Self {
-            match (a, b) {
-                (Self::UsedBySink, _) | (_, Self::UsedBySink) => Self::UsedBySink,
-                (Self::SpecificUsedBySink(id1), Self::SpecificUsedBySink(id2)) => {
-                    if id1 == id2 {
-                        Self::SpecificUsedBySink(*id1)
-                    } else {
-                        Self::UsedBySink
-                    }
-                }
-                (Self::SpecificUsedBySink(id), _) | (_, Self::SpecificUsedBySink(id)) => {
-                    Self::SpecificUsedBySink(*id)
+    // First: determine all nodes which interact with products (as described above)
+    let mut product_nodes = vec![];
+    // We track call and return nodes separately since they (may) require constructing new products
+    // for the call's arguments or the return's value
+    let mut call_return_nodes = vec![];
+    // We track writes separately since they should be processed once their input product has been
+    // processed, so we handle them after all other nodes (though before reads)
+    let mut write_nodes = vec![];
+    // We also track reads separately because they aren't handled until all other nodes have been
+    // processed
+    let mut read_nodes = vec![];
+
+    let func = editor.func();
+
+    for node in reverse_postorder {
+        match func.nodes[node.idx()] {
+            Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. }
+            | Node::Constant { .. }
+            | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select }
+                if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(node),
+            
+            Node::Write { .. }
+                if editor.get_type(types[node.idx()]).is_product() => write_nodes.push(node),
+            Node::Read { collect, .. }
+                if editor.get_type(types[collect.idx()]).is_product() => read_nodes.push(node),
+
+            // We add all calls to the call/return list and check their arguments later
+            Node::Call { .. } => {
+                call_return_nodes.push(node);
+                if editor.get_type(types[node.idx()]).is_product() {
+                    read_nodes.push(node);
                 }
-                _ => Self::UnusedBySink,
             }
-        }
+            Node::Return { control: _, data }
+                if editor.get_type(types[data.idx()]).is_product() =>
+                call_return_nodes.push(node),
 
-        fn bottom() -> Self {
-            Self::UsedBySink
-        }
-
-        fn top() -> Self {
-            Self::UnusedBySink
+            _ => ()
         }
     }
 
-    // Run dataflow analysis to find which product values are used by a sink.
-    let product_uses = backward_dataflow(function, def_use, reverse_postorder, |succ_outs, id| {
-        match function.nodes[id.idx()] {
-            Node::Return {
-                control: _,
-                data: _,
-            } => {
-                if types[typing[id.idx()].idx()].is_product() {
-                    ProductUseLattice::UsedBySink
-                } else {
-                    ProductUseLattice::UnusedBySink
-                }
-            }
-            Node::Call {
-                control: _,
-                function: _,
-                dynamic_constants: _,
-                args: _,
-            } => todo!(),
-            // For reads and writes, we only want to propagate the use of the
-            // product to the collect input of the node.
-            Node::Read {
-                collect,
-                indices: _,
+    // Next, we handle calls and returns. For returns, we will insert nodes that read each field of
+    // the returned product and then write them into a new product. These writes are not put into
+    // the list of product nodes since they must remain but the reads are so that they will be
+    // replaced later on.
+    // For calls, we do a similar process for each (product) argument. Additionally, if the call
+    // returns a product, we create reads for each field in that product and store it into our
+    // field map
+    for node in call_return_nodes {
+        match &editor.func().nodes[node.idx()] {
+            Node::Return { control, data } => {
+                assert!(editor.get_type(types[data.idx()]).is_product());
+                let control = *control;
+                let new_data = reconstruct_product(editor, types[data.idx()], *data);
+                editor.edit(|mut edit| {
+                    edit.add_node(Node::Return { control, data: new_data });
+                    edit.delete_node(*node)
+                });
             }
-            | Node::Write {
-                collect,
-                data: _,
-                indices: _,
-            } => {
-                let meet = succ_outs
-                    .iter()
-                    .fold(ProductUseLattice::top(), |acc, latt| {
-                        ProductUseLattice::meet(&acc, latt)
-                    });
-                if meet == ProductUseLattice::UnusedBySink {
-                    ProductUseLattice::UnusedBySink
-                } else {
-                    ProductUseLattice::SpecificUsedBySink(collect)
-                }
-            }
-            // For non-sink nodes.
-            _ => {
-                if function.nodes[id.idx()].is_control() {
-                    return ProductUseLattice::UnusedBySink;
-                }
-                let meet = succ_outs
-                    .iter()
-                    .fold(ProductUseLattice::top(), |acc, latt| {
-                        ProductUseLattice::meet(&acc, latt)
-                    });
-                if let ProductUseLattice::SpecificUsedBySink(meet_id) = meet {
-                    if meet_id == id {
-                        ProductUseLattice::UsedBySink
+            Node::Call { control, function, dynamic_constants, args } => {
+                let control = *control;
+                let function = *function;
+                let dynamic_constants = dynamic_constants.clone();
+                let args = args.clone();
+
+                // If the call returns a product, we generate reads for each field
+                let fields =
+                    if editor.get_type(types[node.idx()]).is_product() {
+                        Some(generate_reads(editor, types[node.idx()], *node))
                     } else {
-                        ProductUseLattice::UnusedBySink
+                        None
+                    };
+
+                let mut new_args = vec![];
+                for arg in args {
+                    if editor.get_type(types[arg.idx()]).is_product() {
+                        new_args.push(reconstruct_product(editor, types[arg.idx()], arg));
+                    } else {
+                        new_args.push(arg);
                     }
-                } else {
-                    meet
                 }
+                editor.dit(|mut edit| {
+                    let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() });
+                    let edit = edit.replace_all_uses(*node, new_call)?;
+                    let edit = edit.delete_node(*node)?;
+
+                    match fields {
+                        None => {}
+                        Some(fields) => {
+                            for (idx, node) in fields {
+                                field_map.insert((new_call, idx), node);
+                            }
+    }
+                    }
+
+                    Ok(edit)
+                });
             }
+            _ => panic!("Processing non-call or return node")
         }
-    });
+    }
 
-    // Only product values introduced as constants can be replaced by scalars.
-    let to_sroa: Vec<(NodeID, ConstantID)> = product_uses
-        .into_iter()
-        .enumerate()
-        .filter_map(|(node_idx, product_use)| {
-            if ProductUseLattice::UnusedBySink == product_use
-                && types[typing[node_idx].idx()].is_product()
-            {
-                function.nodes[node_idx]
-                    .try_constant()
-                    .map(|cons_id| (NodeID::new(node_idx), cons_id))
-            } else {
-                None
-            }
-        })
-        .collect();
+    // Now, we process all other non-read/write nodes that deal with products.
+    // The first step is to identify the NodeIDs that contain each field of each of these nodes
+    todo!()
+}
 
-    // Perform SROA. TODO: repair def-use when there are multiple product
-    // constants to SROA away.
-    assert!(to_sroa.len() < 2);
-    for (constant_node_id, constant_id) in to_sroa {
-        // Get the field constants to replace the product constant with.
-        let product_constant = constants[constant_id.idx()].clone();
-        let constant_fields = product_constant.try_product_fields().unwrap();
+// Given a product value val of type typ, constructs a copy of that value by extracting all fields
+// from that value and then writing them into a new constant
+fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> NodeID {
+    let fields = generate_reads(editor, typ, val);
+    let new_const = generate_constant(editor, typ);
 
-        // DFS to find all data nodes that use the product constant.
-        let to_replace = sroa_dfs(constant_node_id, function, def_use);
+    // Create a constant node
+    let mut const_node = None;
+    editor.edit(|mut edit| {
+        const_node = Some(edit.add_node(Node::Constant { id: new_const }));
+        Ok(edit)
+    });
 
-        // Assemble a mapping from old nodes IDs acting on the product constant
-        // to new nodes IDs operating on the field constants.
-        let old_to_new_id_map: HashMap<NodeID, Vec<NodeID>> = to_replace
-            .iter()
-            .map(|old_id| match function.nodes[old_id.idx()] {
-                Node::Phi {
-                    control: _,
-                    data: _,
-                }
-                | Node::Reduce {
-                    control: _,
-                    init: _,
-                    reduct: _,
-                }
-                | Node::Constant { id: _ }
-                | Node::Ternary {
-                    op: _,
-                    first: _,
-                    second: _,
-                    third: _,
-                }
-                | Node::Write {
-                    collect: _,
-                    data: _,
-                    indices: _,
-                } => {
-                    let new_ids = (0..constant_fields.len())
-                        .map(|_| {
-                            let id = NodeID::new(function.nodes.len());
-                            function.nodes.push(Node::Start);
-                            id
-                        })
-                        .collect();
-                    (*old_id, new_ids)
-                }
-                Node::Read {
-                    collect: _,
-                    indices: _,
-                } => (*old_id, vec![]),
-                _ => panic!("PANIC: Invalid node using a constant product found during SROA."),
-            })
-            .collect();
+    // Generate writes for each field
+    let mut value = const_node.expect("Add node cannot fail");
+    for (idx, val) in fields {
+        editor.edit(|mut edit| {
+            value = edit.add_node(Node::Write { collect: value, data: val, indices: idx.into() });
+            Ok(edit)
+        });
+    }
 
-        // Replace the old nodes with the new nodes. Since we've already
-        // allocated the node IDs, at this point we can iterate through the to-
-        // replace nodes in an arbitrary order.
-        for (old_id, new_ids) in &old_to_new_id_map {
-            // First, add the new nodes to the node list.
-            let node = function.nodes[old_id.idx()].clone();
-            match node {
-                // Replace the original constant with constants for each field.
-                Node::Constant { id: _ } => {
-                    for (new_id, field_id) in zip(new_ids.iter(), constant_fields.iter()) {
-                        function.nodes[new_id.idx()] = Node::Constant { id: *field_id };
-                    }
-                }
-                // Replace writes using the constant as the data use with a
-                // series of writes writing the invidiual constant fields. TODO:
-                // handle the case where the constant is the collect use of the
-                // write node.
-                Node::Write {
-                    collect,
-                    data,
-                    ref indices,
-                } => {
-                    // Create the write chain.
-                    assert!(old_to_new_id_map.contains_key(&data), "PANIC: Can't handle case where write node depends on constant to SROA in the collect use yet.");
-                    let mut collect_def = collect;
-                    for (idx, (new_id, new_data_def)) in
-                        zip(new_ids.iter(), old_to_new_id_map[&data].iter()).enumerate()
-                    {
-                        let mut new_indices = indices.clone().into_vec();
-                        new_indices.push(Index::Field(idx));
-                        function.nodes[new_id.idx()] = Node::Write {
-                            collect: collect_def,
-                            data: *new_data_def,
-                            indices: new_indices.into_boxed_slice(),
-                        };
-                        collect_def = *new_id;
-                    }
+    value
+}
 
-                    // Replace uses of the old write with the new write.
-                    for user in def_use.get_users(*old_id) {
-                        get_uses_mut(&mut function.nodes[user.idx()]).map(*old_id, collect_def);
-                    }
-                }
-                _ => todo!(),
-            }
+// Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and
+// returns a list of pairs of the indices and the node that reads that index
+fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Vec<(Vec<Index>, NodeID)> {
+    generate_reads_at_index(editor, typ, val, vec![]).into_iter().collect::<_>()
+}
 
-            // Delete the old node.
-            function.nodes[old_id.idx()] = Node::Start;
+// Given a node val of type which at the indices idx has type typ, construct reads of all (leaf)
+// fields within this sub-value of val and return the correspondence list
+fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> LinkedList<(Vec<Index>, NodeID)> {
+    let ts : Option<Vec<TypeID>> = {
+        if let Some(ts) = editor.get_type(typ).try_product() {
+            Some(ts.into())
+        } else {
+            None
         }
+    };
+
+    if let Some(ts) = ts {
+        // For product values, we will recurse down each of its fields with an extended index
+        // and the appropriate type of that field
+        let mut result = LinkedList::new();
+        for (i, t) in ts.into_iter().enumerate() {
+            let mut new_idx = idx.clone();
+            new_idx.push(Index::Field(i));
+            result.append(&mut generate_reads_at_index(editor, t, val, new_idx));
+        }
+
+        result
+    } else {
+        // For non-product types, we've reached a leaf so we generate the read and return it's
+        // information
+        let mut read_id = None;
+        editor.edit(|mut edit| {
+            read_id = Some(edit.add_node(Node::Read { collect: val, indices: idx.clone().into() }));
+            Ok(edit)
+        });
+
+        LinkedList::from([(idx, read_id.expect("Add node cannot fail"))])
     }
 }
 
-fn sroa_dfs(src: NodeID, function: &Function, def_uses: &ImmutableDefUseMap) -> Vec<NodeID> {
-    // Initialize order vector and bitset for tracking which nodes have been
-    // visited.
-    let order = Vec::with_capacity(def_uses.num_nodes());
-    let visited = bitvec![u8, Lsb0; 0; def_uses.num_nodes()];
-
-    // Order and visited are threaded through arguments / return pair of
-    // sroa_dfs_helper for ownership reasons.
-    let (order, _) = sroa_dfs_helper(src, src, function, def_uses, order, visited);
-    order
+macro_rules! add_const {
+    ($editor:ident, $const:expr) => {{
+        let mut res = None;
+        $editor.edit(|mut edit| {
+            res = Some(edit.add_constant($const));
+            Ok(edit)
+        });
+        res.expect("Add constant cannot fail")
+    }}
 }
 
-fn sroa_dfs_helper(
-    node: NodeID,
-    def: NodeID,
-    function: &Function,
-    def_uses: &ImmutableDefUseMap,
-    mut order: Vec<NodeID>,
-    mut visited: BitVec<u8, Lsb0>,
-) -> (Vec<NodeID>, BitVec<u8, Lsb0>) {
-    if visited[node.idx()] {
-        // If already visited, return early.
-        (order, visited)
-    } else {
-        // Set visited to true.
-        visited.set(node.idx(), true);
+// Given a type, builds a default constant of that type
+fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
+    let t = editor.get_type(typ).clone();
 
-        // Before iterating users, push this node.
-        order.push(node);
-        match function.nodes[node.idx()] {
-            Node::Phi {
-                control: _,
-                data: _,
+    match t {
+        Type::Product(ts) => {
+            let mut cs = vec![];
+            for t in ts {
+                cs.push(generate_constant(editor, t));
             }
-            | Node::Reduce {
-                control: _,
-                init: _,
-                reduct: _,
-            }
-            | Node::Constant { id: _ }
-            | Node::Ternary {
-                op: _,
-                first: _,
-                second: _,
-                third: _,
-            } => {}
-            Node::Read {
-                collect,
-                indices: _,
-            } => {
-                assert_eq!(def, collect);
-                return (order, visited);
-            }
-            Node::Write {
-                collect,
-                data,
-                indices: _,
-            } => {
-                if def == data {
-                    return (order, visited);
-                }
-                assert_eq!(def, collect);
-            }
-            _ => panic!("PANIC: Invalid node using a constant product found during SROA."),
+            add_const!(editor, Constant::Product(typ, cs.into()))
         }
-
-        // Iterate over users, if we shouldn't stop here.
-        for user in def_uses.get_users(node) {
-            (order, visited) = sroa_dfs_helper(*user, node, function, def_uses, order, visited);
+        Type::Boolean => add_const!(editor, Constant::Boolean(false)),
+        Type::Integer8 => add_const!(editor, Constant::Integer8(0)),
+        Type::Integer16 => add_const!(editor, Constant::Integer16(0)),
+        Type::Integer32 => add_const!(editor, Constant::Integer32(0)),
+        Type::Integer64 => add_const!(editor, Constant::Integer64(0)),
+        Type::UnsignedInteger8 => add_const!(editor, Constant::UnsignedInteger8(0)),
+        Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)),
+        Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)),
+        Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)),
+        Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))),
+        Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))),
+        Type::Summation(ts) => {
+            let const_id = generate_constant(editor, ts[0]);
+            add_const!(editor, Constant::Summation(typ, 0, const_id))
         }
-
-        (order, visited)
+        Type::Array(elem, _) => {
+            add_const!(editor, Constant::Array(typ))
+        }
+        Type::Control => panic!("Cannot create constant of control type")
     }
 }
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index cccadbdd..d4263fe8 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -151,6 +151,12 @@ pub fn compile_ir(
     if x_dot {
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
+    // TEMPORARY
+    add_pass!(pm, verify, SROA);
+    if x_dot {
+        pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
+    }
+    // TEMPORARY
     add_pass!(pm, verify, Inline);
     add_pass!(pm, verify, CCP);
     add_pass!(pm, verify, DCE);
diff --git a/juno_samples/products.jn b/juno_samples/products.jn
new file mode 100644
index 00000000..a6dd8862
--- /dev/null
+++ b/juno_samples/products.jn
@@ -0,0 +1,7 @@
+fn test_call(x : i32, y : f32) -> (i32, f32) {
+  return (x, y);
+}
+
+fn test(x : i32, y : f32) -> (i32, f32) {
+  return test_call(x, y);
+}
-- 
GitLab


From 6013f4791dbc2e49765e8117365f21cfbdb31f86 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 18 Nov 2024 21:20:00 -0600
Subject: [PATCH 02/12] Acyclic SROA working (mostly)

---
 hercules_opt/src/sroa.rs | 369 +++++++++++++++++++++++++++++++++------
 juno_frontend/src/lib.rs |   2 +-
 2 files changed, 317 insertions(+), 54 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index ae5d76dc..b9380197 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -1,6 +1,6 @@
 extern crate hercules_ir;
 
-use std::collections::{HashMap, LinkedList};
+use std::collections::{HashMap, LinkedList, VecDeque};
 
 use self::hercules_ir::dataflow::*;
 use self::hercules_ir::ir::*;
@@ -41,46 +41,33 @@ use crate::*;
  *   replaced by a direct def of the field value
  */
 pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) {
-    // This map stores a map from NodeID and fields to the NodeID which contains just that field of
-    // the original value
-    let mut field_map : HashMap<(NodeID, Vec<Index>), NodeID> = HashMap::new();
+    // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID
+    // that contains the corresponding fields of the original value
+    let mut field_map : HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();
 
     // First: determine all nodes which interact with products (as described above)
-    let mut product_nodes = vec![];
+    let mut product_nodes : Vec<NodeID> = vec![];
     // We track call and return nodes separately since they (may) require constructing new products
     // for the call's arguments or the return's value
-    let mut call_return_nodes = vec![];
-    // We track writes separately since they should be processed once their input product has been
-    // processed, so we handle them after all other nodes (though before reads)
-    let mut write_nodes = vec![];
-    // We also track reads separately because they aren't handled until all other nodes have been
-    // processed
-    let mut read_nodes = vec![];
+    let mut call_return_nodes : Vec<NodeID> = vec![];
 
     let func = editor.func();
 
     for node in reverse_postorder {
         match func.nodes[node.idx()] {
             Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. }
-            | Node::Constant { .. }
+            | Node::Constant { .. } | Node::Write { .. }
             | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select }
-                if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(node),
+                if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node),
             
-            Node::Write { .. }
-                if editor.get_type(types[node.idx()]).is_product() => write_nodes.push(node),
             Node::Read { collect, .. }
-                if editor.get_type(types[collect.idx()]).is_product() => read_nodes.push(node),
+                if editor.get_type(types[collect.idx()]).is_product() => product_nodes.push(*node),
 
             // We add all calls to the call/return list and check their arguments later
-            Node::Call { .. } => {
-                call_return_nodes.push(node);
-                if editor.get_type(types[node.idx()]).is_product() {
-                    read_nodes.push(node);
-                }
-            }
+            Node::Call { .. } => call_return_nodes.push(*node),
             Node::Return { control: _, data }
                 if editor.get_type(types[data.idx()]).is_product() =>
-                call_return_nodes.push(node),
+                call_return_nodes.push(*node),
 
             _ => ()
         }
@@ -98,10 +85,10 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
             Node::Return { control, data } => {
                 assert!(editor.get_type(types[data.idx()]).is_product());
                 let control = *control;
-                let new_data = reconstruct_product(editor, types[data.idx()], *data);
+                let new_data = reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes);
                 editor.edit(|mut edit| {
                     edit.add_node(Node::Return { control, data: new_data });
-                    edit.delete_node(*node)
+                    edit.delete_node(node)
                 });
             }
             Node::Call { control, function, dynamic_constants, args } => {
@@ -113,7 +100,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 // If the call returns a product, we generate reads for each field
                 let fields =
                     if editor.get_type(types[node.idx()]).is_product() {
-                        Some(generate_reads(editor, types[node.idx()], *node))
+                        Some(generate_reads(editor, types[node.idx()], node))
                     } else {
                         None
                     };
@@ -121,23 +108,21 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 let mut new_args = vec![];
                 for arg in args {
                     if editor.get_type(types[arg.idx()]).is_product() {
-                        new_args.push(reconstruct_product(editor, types[arg.idx()], arg));
+                        new_args.push(reconstruct_product(editor, types[arg.idx()], arg, &mut product_nodes));
                     } else {
                         new_args.push(arg);
                     }
                 }
-                editor.dit(|mut edit| {
+                editor.edit(|mut edit| {
                     let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() });
-                    let edit = edit.replace_all_uses(*node, new_call)?;
-                    let edit = edit.delete_node(*node)?;
+                    let edit = edit.replace_all_uses(node, new_call)?;
+                    let edit = edit.delete_node(node)?;
 
                     match fields {
                         None => {}
                         Some(fields) => {
-                            for (idx, node) in fields {
-                                field_map.insert((new_call, idx), node);
-                            }
-    }
+                            field_map.insert(new_call, fields);
+                        }
                     }
 
                     Ok(edit)
@@ -147,14 +132,243 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
         }
     }
 
-    // Now, we process all other non-read/write nodes that deal with products.
-    // The first step is to identify the NodeIDs that contain each field of each of these nodes
-    todo!()
+    enum WorkItem {
+        Unhandled(NodeID),
+        AllocatedPhi { control: NodeID, data: Vec<NodeID>, fields: IndexTree<NodeID> },
+        AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, fields: IndexTree<NodeID> },
+        AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, fields: IndexTree<NodeID> },
+    }
+
+    // Now, we process the other nodes that deal with products.
+    // The first step is to assign new NodeIDs to the nodes that will be split into multiple: phi,
+    // reduce, parameter, constant, and ternary.
+    // We do this in several steps: first we break apart parameters and constants
+    let mut to_delete = vec![];
+    let mut worklist: VecDeque<WorkItem> = VecDeque::new();
+
+    for node in product_nodes {
+        match editor.func().nodes[node.idx()] {
+            Node::Parameter { .. } => {
+                field_map.insert(node, generate_reads(editor, types[node.idx()], node));
+            }
+            Node::Constant { id } => {
+                field_map.insert(node, generate_constant_fields(editor, id));
+                to_delete.push(node);
+            }
+            _ => { worklist.push_back(WorkItem::Unhandled(node)); }
+        }
+    }
+
+    // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map.
+    // We track the current NodeID and add nodes whenever possible, otherwise we return values to
+    // the worklist
+    let mut next_id : usize = editor.func().nodes.len();
+    let mut cur_id : usize = editor.func().nodes.len();
+    while let Some(mut item) = worklist.pop_front() {
+        if let WorkItem::Unhandled(node) = item {
+            match &editor.func().nodes[node.idx()] {
+                // For phi, reduce, and ternary, we break them apart into separate nodes for each field
+                Node::Phi { control, data } => {
+                    let control = *control;
+                    let data = data.clone();
+                    let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
+                    field_map.insert(node, fields.clone());
+
+                    item = WorkItem::AllocatedPhi {
+                        control,
+                        data: data.into(),
+                        fields };
+                }
+                Node::Reduce { control, init, reduct } => {
+                    let control = *control;
+                    let init = *init;
+                    let reduct = *reduct;
+                    let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
+                    field_map.insert(node, fields.clone());
+
+                    item = WorkItem::AllocatedReduce { control, init, reduct, fields };
+                }
+                Node::Ternary { first, second, third, .. } => {
+                    let first = *first;
+                    let second = *second;
+                    let third = *third;
+                    let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
+                    field_map.insert(node, fields.clone());
+
+                    item = WorkItem::AllocatedTernary { first, second, third, fields };
+                }
+
+                Node::Write { collect, data, indices } => {
+                    if let Some(index_map) = field_map.get(collect) {
+                        if editor.get_type(types[data.idx()]).is_product() {
+                            if let Some(data_idx) = field_map.get(data) {
+                                field_map.insert(node, index_map.clone().replace(indices, data_idx.clone()));
+                                to_delete.push(node);
+                            } else {
+                                worklist.push_back(WorkItem::Unhandled(node));
+                            }
+                        } else {
+                            field_map.insert(node, index_map.clone().set(indices, *data));
+                            to_delete.push(node);
+                        }
+                    } else {
+                        worklist.push_back(WorkItem::Unhandled(node));
+                    }
+                }
+                Node::Read { collect, indices } => {
+                    if let Some(index_map) = field_map.get(collect) {
+                        let read_info = index_map.lookup(indices);
+                        match read_info {
+                            IndexTree::Leaf(field) => {
+                                editor.edit(|edit| {
+                                    edit.replace_all_uses(node, *field)
+                                });
+                            }
+                            _ => {}
+                        }
+                        field_map.insert(node, read_info.clone());
+                        to_delete.push(node);
+                    } else {
+                        worklist.push_back(WorkItem::Unhandled(node));
+                    }
+                }
+
+                _ => panic!("Unexpected node type")
+            }
+        }
+        match item {
+            WorkItem::Unhandled(_) => {}
+            _ => todo!()
+        }
+    }
+
+    // Actually deleting nodes seems to break things right now
+    /*
+    println!("{:?}", to_delete);
+    editor.edit(|mut edit| {
+        for node in to_delete {
+            edit = edit.delete_node(node)?
+        }
+        Ok(edit)
+    });
+    */
+}
+
+// An index tree is used to store results at many index lists
+#[derive(Clone, Debug)]
+enum IndexTree<T> {
+    Leaf(T),
+    Node(Vec<IndexTree<T>>),
+}
+
+impl<T: std::fmt::Debug> IndexTree<T> {
+    fn lookup(&self, idx: &[Index]) -> &IndexTree<T> {
+        self.lookup_idx(idx, 0)
+    }
+
+    fn lookup_idx(&self, idx: &[Index], n: usize) -> &IndexTree<T> {
+        if n < idx.len() {
+            if let Index::Field(i) = idx[n] {
+                match self {
+                    IndexTree::Leaf(_) => panic!("Invalid field"),
+                    IndexTree::Node(ts) => ts[i].lookup_idx(idx, n+1),
+                }
+            } else {
+                // TODO: This could be hit because of an array inside of a product
+                panic!("Error handling lookup of field");
+            }
+        } else {
+            self
+        }
+    }
+
+    fn set(self, idx: &[Index], val: T) -> IndexTree<T> {
+        self.set_idx(idx, val, 0)
+    }
+
+    fn set_idx(self, idx: &[Index], val: T, n: usize) -> IndexTree<T> {
+        if n < idx.len() {
+            if let Index::Field(i) = idx[n] {
+                match self {
+                    IndexTree::Leaf(_) => panic!("Invalid field"),
+                    IndexTree::Node(mut ts) => {
+                        if i + 1 == ts.len() {
+                            let t = ts.pop().unwrap();
+                            ts.push(t.set_idx(idx, val, n+1));
+                        } else {
+                            let mut t = ts.pop().unwrap();
+                            std::mem::swap(&mut ts[i], &mut t);
+                            t = t.set_idx(idx, val, n+1);
+                            std::mem::swap(&mut ts[i], &mut t);
+                            ts.push(t);
+                        }
+                        IndexTree::Node(ts)
+                    }
+                }
+            } else {
+                panic!("Error handling set of field");
+            }
+        } else {
+            IndexTree::Leaf(val)
+        }
+    }
+
+    fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> {
+        self.replace_idx(idx, val, 0)
+    }
+
+    fn replace_idx(self, idx: &[Index], val: IndexTree<T>, n: usize) -> IndexTree<T> {
+        if n < idx.len() {
+            if let Index::Field(i) = idx[n] {
+                match self {
+                    IndexTree::Leaf(_) => panic!("Invalid field"),
+                    IndexTree::Node(mut ts) => {
+                        if i + 1 == ts.len() {
+                            let t = ts.pop().unwrap();
+                            ts.push(t.replace_idx(idx, val, n+1));
+                        } else {
+                            let mut t = ts.pop().unwrap();
+                            std::mem::swap(&mut ts[i], &mut t);
+                            t = t.replace_idx(idx, val, n+1);
+                            std::mem::swap(&mut ts[i], &mut t);
+                            ts.push(t);
+                        }
+                        IndexTree::Node(ts)
+                    }
+                }
+            } else {
+                panic!("Error handling set of field");
+            }
+        } else {
+            val
+        }
+    }
+
+    fn for_each<F>(&self, mut f: F)
+        where F: FnMut(&Vec<Index>, &T) {
+        self.for_each_idx(&mut vec![], &mut f);
+    }
+
+    fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F)
+        where F: FnMut(&Vec<Index>, &T) {
+        match self {
+            IndexTree::Leaf(t) => f(idx, t),
+            IndexTree::Node(ts) => {
+                for (i, t) in ts.iter().enumerate() {
+                    idx.push(Index::Field(i));
+                    t.for_each_idx(idx, f);
+                    idx.pop();
+                }
+            }
+        }
+    }
 }
 
 // Given a product value val of type typ, constructs a copy of that value by extracting all fields
 // from that value and then writing them into a new constant
-fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> NodeID {
+// This process also adds all the read nodes that are generated into the read_list so that the
+// reads can be eliminated by later parts of SROA
+fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, read_list: &mut Vec<NodeID>) -> NodeID {
     let fields = generate_reads(editor, typ, val);
     let new_const = generate_constant(editor, typ);
 
@@ -167,44 +381,44 @@ fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) ->
 
     // Generate writes for each field
     let mut value = const_node.expect("Add node cannot fail");
-    for (idx, val) in fields {
+    fields.for_each(|idx: &Vec<Index>, val: &NodeID| {
+        read_list.push(*val);
         editor.edit(|mut edit| {
-            value = edit.add_node(Node::Write { collect: value, data: val, indices: idx.into() });
+            value = edit.add_node(Node::Write { collect: value, data: *val, indices: idx.clone().into() });
             Ok(edit)
         });
-    }
+    });
 
     value
 }
 
 // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and
 // returns a list of pairs of the indices and the node that reads that index
-fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Vec<(Vec<Index>, NodeID)> {
-    generate_reads_at_index(editor, typ, val, vec![]).into_iter().collect::<_>()
+fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
+    let res = generate_reads_at_index(editor, typ, val, vec![]);
+    res
 }
 
 // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf)
 // fields within this sub-value of val and return the correspondence list
-fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> LinkedList<(Vec<Index>, NodeID)> {
-    let ts : Option<Vec<TypeID>> = {
+fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> IndexTree<NodeID> {
+    let ts : Option<Vec<TypeID>> =
         if let Some(ts) = editor.get_type(typ).try_product() {
             Some(ts.into())
         } else {
             None
-        }
-    };
+        };
 
     if let Some(ts) = ts {
         // For product values, we will recurse down each of its fields with an extended index
         // and the appropriate type of that field
-        let mut result = LinkedList::new();
+        let mut fields = vec![];
         for (i, t) in ts.into_iter().enumerate() {
             let mut new_idx = idx.clone();
             new_idx.push(Index::Field(i));
-            result.append(&mut generate_reads_at_index(editor, t, val, new_idx));
+            fields.push(generate_reads_at_index(editor, t, val, new_idx));
         }
-
-        result
+        IndexTree::Node(fields)
     } else {
         // For non-product types, we've reached a leaf so we generate the read and return it's
         // information
@@ -214,7 +428,7 @@ fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID
             Ok(edit)
         });
 
-        LinkedList::from([(idx, read_id.expect("Add node cannot fail"))])
+        IndexTree::Leaf(read_id.expect("Add node canont fail"))
     }
 }
 
@@ -262,3 +476,52 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
         Type::Control => panic!("Cannot create constant of control type")
     }
 }
+
+// Given a constant cnst adds node to the function which are the constant values of each field and
+// returns a list of pairs of indices and the node that holds that index
+fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> {
+    let cs : Option<Vec<ConstantID>> =
+        if let Some(cs) = editor.get_constant(cnst).try_product_fields() {
+            Some(cs.into())
+        } else {
+            None
+        };
+
+    if let Some(cs) = cs {
+        let mut fields = vec![];
+        for c in cs {
+            fields.push(generate_constant_fields(editor, c));
+        }
+        IndexTree::Node(fields)
+    } else {
+        let mut node = None;
+        editor.edit(|mut edit| {
+            node = Some(edit.add_node(Node::Constant { id: cnst }));
+            Ok(edit)
+        });
+        IndexTree::Leaf(node.expect("Add node cannot fail"))
+    }
+}
+
+// Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the
+// id provided
+fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> {
+    let ts : Option<Vec<TypeID>> =
+        if let Some(ts) = editor.get_type(typ).try_product() {
+            Some(ts.into())
+        } else {
+            None
+        };
+
+    if let Some(ts) = ts {
+        let mut fields = vec![];
+        for t in ts {
+            fields.push(allocate_fields(editor, t, id));
+        }
+        IndexTree::Node(fields)
+    } else {
+        let node = *id;
+        *id += 1;
+        IndexTree::Leaf(NodeID::new(node))
+    }
+}
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index d4263fe8..c3dc262a 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -157,7 +157,7 @@ pub fn compile_ir(
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
     // TEMPORARY
-    add_pass!(pm, verify, Inline);
+    //add_pass!(pm, verify, Inline);
     add_pass!(pm, verify, CCP);
     add_pass!(pm, verify, DCE);
     add_pass!(pm, verify, GVN);
-- 
GitLab


From b44f450c359dd94899de5d700943cb65f379d044 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 18 Nov 2024 21:23:10 -0600
Subject: [PATCH 03/12] Re-enabled inlining

---
 juno_frontend/src/lib.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index c3dc262a..d4263fe8 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -157,7 +157,7 @@ pub fn compile_ir(
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
     // TEMPORARY
-    //add_pass!(pm, verify, Inline);
+    add_pass!(pm, verify, Inline);
     add_pass!(pm, verify, CCP);
     add_pass!(pm, verify, DCE);
     add_pass!(pm, verify, GVN);
-- 
GitLab


From 881248849fda7222741aca6bb4c37ccb8864fa1b Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 19 Nov 2024 08:10:43 -0600
Subject: [PATCH 04/12] Fix editor bug when deleting created node

---
 hercules_opt/src/editor.rs | 4 +++-
 hercules_opt/src/sroa.rs   | 3 ---
 juno_frontend/src/lib.rs   | 7 +++----
 3 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index 95fd1669..bf7cede2 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -505,7 +505,9 @@ pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) {
     // Step 2: drop schedules for deleted nodes and create empty schedule lists
     // for added nodes.
     for deleted in total_edit.0.iter() {
-        plan.schedules[deleted.idx()] = vec![];
+        if deleted.idx() < plan.schedules.len() {
+            plan.schedules[deleted.idx()] = vec![];
+        }
     }
     if !total_edit.1.is_empty() {
         assert_eq!(
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index b9380197..2dd0049b 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -243,15 +243,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
     }
 
     // Actually deleting nodes seems to break things right now
-    /*
-    println!("{:?}", to_delete);
     editor.edit(|mut edit| {
         for node in to_delete {
             edit = edit.delete_node(node)?
         }
         Ok(edit)
     });
-    */
 }
 
 // An index tree is used to store results at many index lists
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index d4263fe8..49249fc7 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -151,18 +151,17 @@ pub fn compile_ir(
     if x_dot {
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
-    // TEMPORARY
+    add_pass!(pm, verify, Inline);
+    // Run SROA pretty early (though after inlining which can make SROA more effective) so that
+    // CCP, GVN, etc. can work on the result of SROA
     add_pass!(pm, verify, SROA);
     if x_dot {
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
-    // TEMPORARY
-    add_pass!(pm, verify, Inline);
     add_pass!(pm, verify, CCP);
     add_pass!(pm, verify, DCE);
     add_pass!(pm, verify, GVN);
     add_pass!(pm, verify, DCE);
-    //add_pass!(pm, verify, SROA);
     if x_dot {
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
-- 
GitLab


From 2021ea3b0b5f45add9409b4c649470ac7a642ac6 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 19 Nov 2024 08:14:04 -0600
Subject: [PATCH 05/12] Comment change in editor and format

---
 hercules_opt/src/editor.rs |   2 +
 hercules_opt/src/pass.rs   |   6 +-
 hercules_opt/src/sroa.rs   | 224 ++++++++++++++++++++++++++-----------
 3 files changed, 159 insertions(+), 73 deletions(-)

diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index bf7cede2..2410a8ef 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -505,6 +505,8 @@ pub fn repair_plan(plan: &mut Plan, new_function: &Function, edits: &[Edit]) {
     // Step 2: drop schedules for deleted nodes and create empty schedule lists
     // for added nodes.
     for deleted in total_edit.0.iter() {
+        // Nodes that were created and deleted using the same editor don't have
+        // an existing schedule, so ignore them
         if deleted.idx() < plan.schedules.len() {
             plan.schedules[deleted.idx()] = vec![];
         }
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index a7d48efb..f8310ab7 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -506,11 +506,7 @@ impl PassManager {
                             &types_ref,
                             &def_uses[idx],
                         );
-                        sroa(
-                            &mut editor,
-                            &reverse_postorders[idx],
-                            &typing[idx]
-                        );
+                        sroa(&mut editor, &reverse_postorders[idx], &typing[idx]);
 
                         self.module.constants = constants_ref.take();
                         self.module.dynamic_constants = dynamic_constants_ref.take();
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 2dd0049b..d8e2021a 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -43,33 +43,43 @@ use crate::*;
 pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) {
     // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID
     // that contains the corresponding fields of the original value
-    let mut field_map : HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();
+    let mut field_map: HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();
 
     // First: determine all nodes which interact with products (as described above)
-    let mut product_nodes : Vec<NodeID> = vec![];
+    let mut product_nodes: Vec<NodeID> = vec![];
     // We track call and return nodes separately since they (may) require constructing new products
     // for the call's arguments or the return's value
-    let mut call_return_nodes : Vec<NodeID> = vec![];
+    let mut call_return_nodes: Vec<NodeID> = vec![];
 
     let func = editor.func();
 
     for node in reverse_postorder {
         match func.nodes[node.idx()] {
-            Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. }
-            | Node::Constant { .. } | Node::Write { .. }
-            | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select }
-                if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node),
-            
-            Node::Read { collect, .. }
-                if editor.get_type(types[collect.idx()]).is_product() => product_nodes.push(*node),
+            Node::Phi { .. }
+            | Node::Reduce { .. }
+            | Node::Parameter { .. }
+            | Node::Constant { .. }
+            | Node::Write { .. }
+            | Node::Ternary {
+                first: _,
+                second: _,
+                third: _,
+                op: TernaryOperator::Select,
+            } if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node),
+
+            Node::Read { collect, .. } if editor.get_type(types[collect.idx()]).is_product() => {
+                product_nodes.push(*node)
+            }
 
             // We add all calls to the call/return list and check their arguments later
             Node::Call { .. } => call_return_nodes.push(*node),
             Node::Return { control: _, data }
                 if editor.get_type(types[data.idx()]).is_product() =>
-                call_return_nodes.push(*node),
+            {
+                call_return_nodes.push(*node)
+            }
 
-            _ => ()
+            _ => (),
         }
     }
 
@@ -85,36 +95,54 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
             Node::Return { control, data } => {
                 assert!(editor.get_type(types[data.idx()]).is_product());
                 let control = *control;
-                let new_data = reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes);
+                let new_data =
+                    reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes);
                 editor.edit(|mut edit| {
-                    edit.add_node(Node::Return { control, data: new_data });
+                    edit.add_node(Node::Return {
+                        control,
+                        data: new_data,
+                    });
                     edit.delete_node(node)
                 });
             }
-            Node::Call { control, function, dynamic_constants, args } => {
+            Node::Call {
+                control,
+                function,
+                dynamic_constants,
+                args,
+            } => {
                 let control = *control;
                 let function = *function;
                 let dynamic_constants = dynamic_constants.clone();
                 let args = args.clone();
 
                 // If the call returns a product, we generate reads for each field
-                let fields =
-                    if editor.get_type(types[node.idx()]).is_product() {
-                        Some(generate_reads(editor, types[node.idx()], node))
-                    } else {
-                        None
-                    };
+                let fields = if editor.get_type(types[node.idx()]).is_product() {
+                    Some(generate_reads(editor, types[node.idx()], node))
+                } else {
+                    None
+                };
 
                 let mut new_args = vec![];
                 for arg in args {
                     if editor.get_type(types[arg.idx()]).is_product() {
-                        new_args.push(reconstruct_product(editor, types[arg.idx()], arg, &mut product_nodes));
+                        new_args.push(reconstruct_product(
+                            editor,
+                            types[arg.idx()],
+                            arg,
+                            &mut product_nodes,
+                        ));
                     } else {
                         new_args.push(arg);
                     }
                 }
                 editor.edit(|mut edit| {
-                    let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() });
+                    let new_call = edit.add_node(Node::Call {
+                        control,
+                        function,
+                        dynamic_constants,
+                        args: new_args.into(),
+                    });
                     let edit = edit.replace_all_uses(node, new_call)?;
                     let edit = edit.delete_node(node)?;
 
@@ -128,15 +156,29 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     Ok(edit)
                 });
             }
-            _ => panic!("Processing non-call or return node")
+            _ => panic!("Processing non-call or return node"),
         }
     }
 
     enum WorkItem {
         Unhandled(NodeID),
-        AllocatedPhi { control: NodeID, data: Vec<NodeID>, fields: IndexTree<NodeID> },
-        AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, fields: IndexTree<NodeID> },
-        AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, fields: IndexTree<NodeID> },
+        AllocatedPhi {
+            control: NodeID,
+            data: Vec<NodeID>,
+            fields: IndexTree<NodeID>,
+        },
+        AllocatedReduce {
+            control: NodeID,
+            init: NodeID,
+            reduct: NodeID,
+            fields: IndexTree<NodeID>,
+        },
+        AllocatedTernary {
+            first: NodeID,
+            second: NodeID,
+            third: NodeID,
+            fields: IndexTree<NodeID>,
+        },
     }
 
     // Now, we process the other nodes that deal with products.
@@ -155,15 +197,17 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 field_map.insert(node, generate_constant_fields(editor, id));
                 to_delete.push(node);
             }
-            _ => { worklist.push_back(WorkItem::Unhandled(node)); }
+            _ => {
+                worklist.push_back(WorkItem::Unhandled(node));
+            }
         }
     }
 
     // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map.
     // We track the current NodeID and add nodes whenever possible, otherwise we return values to
     // the worklist
-    let mut next_id : usize = editor.func().nodes.len();
-    let mut cur_id : usize = editor.func().nodes.len();
+    let mut next_id: usize = editor.func().nodes.len();
+    let mut cur_id: usize = editor.func().nodes.len();
     while let Some(mut item) = worklist.pop_front() {
         if let WorkItem::Unhandled(node) = item {
             match &editor.func().nodes[node.idx()] {
@@ -177,32 +221,59 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     item = WorkItem::AllocatedPhi {
                         control,
                         data: data.into(),
-                        fields };
+                        fields,
+                    };
                 }
-                Node::Reduce { control, init, reduct } => {
+                Node::Reduce {
+                    control,
+                    init,
+                    reduct,
+                } => {
                     let control = *control;
                     let init = *init;
                     let reduct = *reduct;
                     let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
                     field_map.insert(node, fields.clone());
 
-                    item = WorkItem::AllocatedReduce { control, init, reduct, fields };
+                    item = WorkItem::AllocatedReduce {
+                        control,
+                        init,
+                        reduct,
+                        fields,
+                    };
                 }
-                Node::Ternary { first, second, third, .. } => {
+                Node::Ternary {
+                    first,
+                    second,
+                    third,
+                    ..
+                } => {
                     let first = *first;
                     let second = *second;
                     let third = *third;
                     let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
                     field_map.insert(node, fields.clone());
 
-                    item = WorkItem::AllocatedTernary { first, second, third, fields };
+                    item = WorkItem::AllocatedTernary {
+                        first,
+                        second,
+                        third,
+                        fields,
+                    };
                 }
 
-                Node::Write { collect, data, indices } => {
+                Node::Write {
+                    collect,
+                    data,
+                    indices,
+                } => {
                     if let Some(index_map) = field_map.get(collect) {
                         if editor.get_type(types[data.idx()]).is_product() {
                             if let Some(data_idx) = field_map.get(data) {
-                                field_map.insert(node, index_map.clone().replace(indices, data_idx.clone()));
+                                field_map.insert(
+                                    node,
+                                    index_map.clone().replace(indices, data_idx.clone()),
+                                );
                                 to_delete.push(node);
                             } else {
                                 worklist.push_back(WorkItem::Unhandled(node));
@@ -220,9 +291,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                         let read_info = index_map.lookup(indices);
                         match read_info {
                             IndexTree::Leaf(field) => {
-                                editor.edit(|edit| {
-                                    edit.replace_all_uses(node, *field)
-                                });
+                                editor.edit(|edit| edit.replace_all_uses(node, *field));
                             }
                             _ => {}
                         }
@@ -233,12 +302,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     }
                 }
 
-                _ => panic!("Unexpected node type")
+                _ => panic!("Unexpected node type"),
             }
         }
         match item {
             WorkItem::Unhandled(_) => {}
-            _ => todo!()
+            _ => todo!(),
         }
     }
 
@@ -268,7 +337,7 @@ impl<T: std::fmt::Debug> IndexTree<T> {
             if let Index::Field(i) = idx[n] {
                 match self {
                     IndexTree::Leaf(_) => panic!("Invalid field"),
-                    IndexTree::Node(ts) => ts[i].lookup_idx(idx, n+1),
+                    IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1),
                 }
             } else {
                 // TODO: This could be hit because of an array inside of a product
@@ -291,11 +360,11 @@ impl<T: std::fmt::Debug> IndexTree<T> {
                     IndexTree::Node(mut ts) => {
                         if i + 1 == ts.len() {
                             let t = ts.pop().unwrap();
-                            ts.push(t.set_idx(idx, val, n+1));
+                            ts.push(t.set_idx(idx, val, n + 1));
                         } else {
                             let mut t = ts.pop().unwrap();
                             std::mem::swap(&mut ts[i], &mut t);
-                            t = t.set_idx(idx, val, n+1);
+                            t = t.set_idx(idx, val, n + 1);
                             std::mem::swap(&mut ts[i], &mut t);
                             ts.push(t);
                         }
@@ -322,11 +391,11 @@ impl<T: std::fmt::Debug> IndexTree<T> {
                     IndexTree::Node(mut ts) => {
                         if i + 1 == ts.len() {
                             let t = ts.pop().unwrap();
-                            ts.push(t.replace_idx(idx, val, n+1));
+                            ts.push(t.replace_idx(idx, val, n + 1));
                         } else {
                             let mut t = ts.pop().unwrap();
                             std::mem::swap(&mut ts[i], &mut t);
-                            t = t.replace_idx(idx, val, n+1);
+                            t = t.replace_idx(idx, val, n + 1);
                             std::mem::swap(&mut ts[i], &mut t);
                             ts.push(t);
                         }
@@ -342,12 +411,16 @@ impl<T: std::fmt::Debug> IndexTree<T> {
     }
 
     fn for_each<F>(&self, mut f: F)
-        where F: FnMut(&Vec<Index>, &T) {
+    where
+        F: FnMut(&Vec<Index>, &T),
+    {
         self.for_each_idx(&mut vec![], &mut f);
     }
 
     fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F)
-        where F: FnMut(&Vec<Index>, &T) {
+    where
+        F: FnMut(&Vec<Index>, &T),
+    {
         match self {
             IndexTree::Leaf(t) => f(idx, t),
             IndexTree::Node(ts) => {
@@ -365,7 +438,12 @@ impl<T: std::fmt::Debug> IndexTree<T> {
 // from that value and then writing them into a new constant
 // This process also adds all the read nodes that are generated into the read_list so that the
 // reads can be eliminated by later parts of SROA
-fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, read_list: &mut Vec<NodeID>) -> NodeID {
+fn reconstruct_product(
+    editor: &mut FunctionEditor,
+    typ: TypeID,
+    val: NodeID,
+    read_list: &mut Vec<NodeID>,
+) -> NodeID {
     let fields = generate_reads(editor, typ, val);
     let new_const = generate_constant(editor, typ);
 
@@ -381,7 +459,11 @@ fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, re
     fields.for_each(|idx: &Vec<Index>, val: &NodeID| {
         read_list.push(*val);
         editor.edit(|mut edit| {
-            value = edit.add_node(Node::Write { collect: value, data: *val, indices: idx.clone().into() });
+            value = edit.add_node(Node::Write {
+                collect: value,
+                data: *val,
+                indices: idx.clone().into(),
+            });
             Ok(edit)
         });
     });
@@ -398,13 +480,17 @@ fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Inde
 
 // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf)
 // fields within this sub-value of val and return the correspondence list
-fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> IndexTree<NodeID> {
-    let ts : Option<Vec<TypeID>> =
-        if let Some(ts) = editor.get_type(typ).try_product() {
-            Some(ts.into())
-        } else {
-            None
-        };
+fn generate_reads_at_index(
+    editor: &mut FunctionEditor,
+    typ: TypeID,
+    val: NodeID,
+    idx: Vec<Index>,
+) -> IndexTree<NodeID> {
+    let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
+        Some(ts.into())
+    } else {
+        None
+    };
 
     if let Some(ts) = ts {
         // For product values, we will recurse down each of its fields with an extended index
@@ -421,7 +507,10 @@ fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID
         // information
         let mut read_id = None;
         editor.edit(|mut edit| {
-            read_id = Some(edit.add_node(Node::Read { collect: val, indices: idx.clone().into() }));
+            read_id = Some(edit.add_node(Node::Read {
+                collect: val,
+                indices: idx.clone().into(),
+            }));
             Ok(edit)
         });
 
@@ -437,7 +526,7 @@ macro_rules! add_const {
             Ok(edit)
         });
         res.expect("Add constant cannot fail")
-    }}
+    }};
 }
 
 // Given a type, builds a default constant of that type
@@ -470,14 +559,14 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
         Type::Array(elem, _) => {
             add_const!(editor, Constant::Array(typ))
         }
-        Type::Control => panic!("Cannot create constant of control type")
+        Type::Control => panic!("Cannot create constant of control type"),
     }
 }
 
 // Given a constant cnst adds node to the function which are the constant values of each field and
 // returns a list of pairs of indices and the node that holds that index
 fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> {
-    let cs : Option<Vec<ConstantID>> =
+    let cs: Option<Vec<ConstantID>> =
         if let Some(cs) = editor.get_constant(cnst).try_product_fields() {
             Some(cs.into())
         } else {
@@ -503,12 +592,11 @@ fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> In
 // Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the
 // id provided
 fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> {
-    let ts : Option<Vec<TypeID>> =
-        if let Some(ts) = editor.get_type(typ).try_product() {
-            Some(ts.into())
-        } else {
-            None
-        };
+    let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() {
+        Some(ts.into())
+    } else {
+        None
+    };
 
     if let Some(ts) = ts {
         let mut fields = vec![];
-- 
GitLab


From b8f77a5e2fe55d6008c3958563bb0d6354b740a2 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 19 Nov 2024 08:58:41 -0600
Subject: [PATCH 06/12] Start on SROA of remaining nodes

---
 hercules_opt/src/sroa.rs | 184 +++++++++++++++++++++++++++++++++++++--
 juno_frontend/src/lib.rs |   4 +
 juno_samples/products.jn |   4 +-
 3 files changed, 185 insertions(+), 7 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index d8e2021a..ccc2605c 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -1,6 +1,6 @@
 extern crate hercules_ir;
 
-use std::collections::{HashMap, LinkedList, VecDeque};
+use std::collections::{BTreeMap, HashMap, LinkedList, VecDeque};
 
 use self::hercules_ir::dataflow::*;
 use self::hercules_ir::ir::*;
@@ -165,18 +165,21 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
         AllocatedPhi {
             control: NodeID,
             data: Vec<NodeID>,
+            node: NodeID,
             fields: IndexTree<NodeID>,
         },
         AllocatedReduce {
             control: NodeID,
             init: NodeID,
             reduct: NodeID,
+            node: NodeID,
             fields: IndexTree<NodeID>,
         },
         AllocatedTernary {
             first: NodeID,
             second: NodeID,
             third: NodeID,
+            node: NodeID,
             fields: IndexTree<NodeID>,
         },
     }
@@ -204,10 +207,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
     }
 
     // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map.
-    // We track the current NodeID and add nodes whenever possible, otherwise we return values to
-    // the worklist
+    // We track the current NodeID and add nodes to a set we maintain of nodes to add (since we
+    // need to add nodes in a particular order we wait to do that until the end). If we don't have
+    // enough information to process a particular node, we add it back to the worklist
     let mut next_id: usize = editor.func().nodes.len();
-    let mut cur_id: usize = editor.func().nodes.len();
+    let mut to_insert = BTreeMap::new();
+
     while let Some(mut item) = worklist.pop_front() {
         if let WorkItem::Unhandled(node) = item {
             match &editor.func().nodes[node.idx()] {
@@ -221,6 +226,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     item = WorkItem::AllocatedPhi {
                         control,
                         data: data.into(),
+                        node,
                         fields,
                     };
                 }
@@ -239,6 +245,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                         control,
                         init,
                         reduct,
+                        node,
                         fields,
                     };
                 }
@@ -258,6 +265,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                         first,
                         second,
                         third,
+                        node,
                         fields,
                     };
                 }
@@ -307,11 +315,127 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
         }
         match item {
             WorkItem::Unhandled(_) => {}
-            _ => todo!(),
+            WorkItem::AllocatedPhi {
+                control,
+                data,
+                node,
+                fields,
+            } => {
+                let mut data_fields = vec![];
+                let mut ready = true;
+                for val in data.iter() {
+                    if let Some(val_fields) = field_map.get(val) {
+                        data_fields.push(val_fields);
+                    } else {
+                        ready = false;
+                        break;
+                    }
+                }
+
+                if ready {
+                    fields.zip_list(data_fields).for_each(|idx, (res, data)| {
+                        to_insert.insert(
+                            res.idx(),
+                            Node::Phi {
+                                control,
+                                data: data.into_iter().map(|n| **n).collect::<Vec<_>>().into(),
+                            },
+                        );
+                    });
+                } else {
+                    worklist.push_back(WorkItem::AllocatedPhi {
+                        control,
+                        data,
+                        node,
+                        fields,
+                    });
+                }
+            }
+            WorkItem::AllocatedReduce {
+                control,
+                init,
+                reduct,
+                node,
+                fields,
+            } => {
+                if let (Some(init_fields), Some(reduct_fields)) =
+                    (field_map.get(&init), field_map.get(&reduct))
+                {
+                    fields.zip(init_fields).zip(reduct_fields).for_each(
+                        |idx, ((res, init), reduct)| {
+                            to_insert.insert(
+                                res.idx(),
+                                Node::Reduce {
+                                    control,
+                                    init: **init,
+                                    reduct: **reduct,
+                                },
+                            );
+                        },
+                    );
+                    to_delete.push(node);
+                } else {
+                    worklist.push_back(WorkItem::AllocatedReduce {
+                        control,
+                        init,
+                        reduct,
+                        node,
+                        fields,
+                    });
+                }
+            }
+            WorkItem::AllocatedTernary {
+                first,
+                second,
+                third,
+                node,
+                fields,
+            } => {
+                if let (Some(fst_fields), Some(snd_fields), Some(thd_fields)) = (
+                    field_map.get(&first),
+                    field_map.get(&second),
+                    field_map.get(&third),
+                ) {
+                    fields
+                        .zip(fst_fields)
+                        .zip(snd_fields)
+                        .zip(thd_fields)
+                        .for_each(|idx, (((res, fst), snd), thd)| {
+                            to_insert.insert(
+                                res.idx(),
+                                Node::Ternary {
+                                    first: **fst,
+                                    second: **snd,
+                                    third: **thd,
+                                    op: TernaryOperator::Select,
+                                },
+                            );
+                        });
+                } else {
+                    worklist.push_back(WorkItem::AllocatedTernary {
+                        first,
+                        second,
+                        third,
+                        node,
+                        fields,
+                    });
+                }
+            }
         }
     }
 
-    // Actually deleting nodes seems to break things right now
+    // Create new nodes nodes
+    for (node_id, node) in to_insert {
+        assert!(node_id == editor.func().nodes.len());
+        println!("Inserting {:?} : {:?}", node_id, node);
+        editor.edit(|mut edit| {
+            let id = edit.add_node(node);
+            assert!(node_id == id.idx());
+            Ok(edit)
+        });
+    }
+
+    // Remove nodes
     editor.edit(|mut edit| {
         for node in to_delete {
             edit = edit.delete_node(node)?
@@ -410,6 +534,54 @@ impl<T: std::fmt::Debug> IndexTree<T> {
         }
     }
 
+    fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> {
+        match (self, other) {
+            (IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)),
+            (IndexTree::Node(t), IndexTree::Node(a)) => {
+                let mut fields = vec![];
+                for (t, a) in t.into_iter().zip(a.iter()) {
+                    fields.push(t.zip(a));
+                }
+                IndexTree::Node(fields)
+            }
+            _ => panic!("IndexTrees do not have the same fields, cannot zip"),
+        }
+    }
+
+    fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> {
+        match self {
+            IndexTree::Leaf(t) => {
+                let mut res = vec![];
+                for other in others {
+                    match other {
+                        IndexTree::Leaf(a) => res.push(a),
+                        _ => panic!("IndexTrees do not have the same fields, cannot zip"),
+                    }
+                }
+                IndexTree::Leaf((t, res))
+            }
+            IndexTree::Node(t) => {
+                let mut fields: Vec<Vec<&'a IndexTree<A>>> = vec![vec![]; t.len()];
+                for other in others {
+                    match other {
+                        IndexTree::Node(a) => {
+                            for (i, a) in a.iter().enumerate() {
+                                fields[i].push(a);
+                            }
+                        }
+                        _ => panic!("IndexTrees do not have the same fields, cannot zip"),
+                    }
+                }
+                IndexTree::Node(
+                    t.into_iter()
+                        .zip(fields.into_iter())
+                        .map(|(t, f)| t.zip_list(f))
+                        .collect(),
+                )
+            }
+        }
+    }
+
     fn for_each<F>(&self, mut f: F)
     where
         F: FnMut(&Vec<Index>, &T),
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index 49249fc7..e54e88fc 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -151,6 +151,10 @@ pub fn compile_ir(
     if x_dot {
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
+    add_pass!(pm, verify, SROA);
+    if x_dot {
+        pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
+    }
     add_pass!(pm, verify, Inline);
     // Run SROA pretty early (though after inlining which can make SROA more effective) so that
     // CCP, GVN, etc. can work on the result of SROA
diff --git a/juno_samples/products.jn b/juno_samples/products.jn
index a6dd8862..00b66aab 100644
--- a/juno_samples/products.jn
+++ b/juno_samples/products.jn
@@ -1,5 +1,7 @@
 fn test_call(x : i32, y : f32) -> (i32, f32) {
-  return (x, y);
+  let res = (x, y);
+  if x < 13 { res = (x + 1, y); }
+  return res;
 }
 
 fn test(x : i32, y : f32) -> (i32, f32) {
-- 
GitLab


From 8df0ef85e304465d964011d74bb8c6f120681e0a Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 19 Nov 2024 11:39:40 -0600
Subject: [PATCH 07/12] Fix: delete old phi nodes

---
 hercules_opt/src/sroa.rs | 3 +--
 juno_frontend/src/lib.rs | 7 +++----
 2 files changed, 4 insertions(+), 6 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index ccc2605c..efa7bd90 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -342,6 +342,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                             },
                         );
                     });
+                    to_delete.push(node);
                 } else {
                     worklist.push_back(WorkItem::AllocatedPhi {
                         control,
@@ -426,8 +427,6 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
 
     // Create new nodes nodes
     for (node_id, node) in to_insert {
-        assert!(node_id == editor.func().nodes.len());
-        println!("Inserting {:?} : {:?}", node_id, node);
         editor.edit(|mut edit| {
             let id = edit.add_node(node);
             assert!(node_id == id.idx());
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index e54e88fc..59240fe9 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -151,14 +151,13 @@ pub fn compile_ir(
     if x_dot {
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
-    add_pass!(pm, verify, SROA);
-    if x_dot {
-        pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
-    }
     add_pass!(pm, verify, Inline);
     // Run SROA pretty early (though after inlining which can make SROA more effective) so that
     // CCP, GVN, etc. can work on the result of SROA
     add_pass!(pm, verify, SROA);
+    // We run phi-elim again because SROA can introduce new phis that might be able to be
+    // simplified
+    add_verified_pass!(pm, verify, PhiElim);
     if x_dot {
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
-- 
GitLab


From 09c58d7553f34d92233fd6c8d75c3a9f44a0ab6c Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 19 Nov 2024 12:36:23 -0600
Subject: [PATCH 08/12] Fix SROA edit issues

---
 hercules_opt/src/sroa.rs | 18 ++++++++++++------
 juno_frontend/src/lib.rs |  4 ++--
 juno_samples/products.jn |  8 +++++++-
 3 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index efa7bd90..924c8d93 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -212,6 +212,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
     // enough information to process a particular node, we add it back to the worklist
     let mut next_id: usize = editor.func().nodes.len();
     let mut to_insert = BTreeMap::new();
+    let mut to_replace : Vec<(NodeID, NodeID)> = vec![];
 
     while let Some(mut item) = worklist.pop_front() {
         if let WorkItem::Unhandled(node) = item {
@@ -299,7 +300,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                         let read_info = index_map.lookup(indices);
                         match read_info {
                             IndexTree::Leaf(field) => {
-                                editor.edit(|edit| edit.replace_all_uses(node, *field));
+                                to_replace.push((node, *field));
                             }
                             _ => {}
                         }
@@ -426,12 +427,17 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
     }
 
     // Create new nodes nodes
-    for (node_id, node) in to_insert {
-        editor.edit(|mut edit| {
+    editor.edit(|mut edit| {
+        for (node_id, node) in to_insert {
             let id = edit.add_node(node);
-            assert!(node_id == id.idx());
-            Ok(edit)
-        });
+            assert_eq!(node_id, id.idx());
+        }
+        Ok(edit)
+    });
+
+    // Replace uses of old reads
+    for (old, new) in to_replace {
+        editor.edit(|edit| edit.replace_all_uses(old, new));
     }
 
     // Remove nodes
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index 59240fe9..7e6b61fa 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -69,11 +69,11 @@ impl fmt::Display for ErrorMessage {
         match self {
             ErrorMessage::SemanticError(errs) => {
                 for err in errs {
-                    write!(f, "{}", err)?;
+                    write!(f, "{}\n", err)?;
                 }
             }
             ErrorMessage::SchedulingError(msg) => {
-                write!(f, "{}", msg)?;
+                write!(f, "{}\n", msg)?;
             }
         }
         Ok(())
diff --git a/juno_samples/products.jn b/juno_samples/products.jn
index 00b66aab..6c14e67a 100644
--- a/juno_samples/products.jn
+++ b/juno_samples/products.jn
@@ -1,6 +1,12 @@
 fn test_call(x : i32, y : f32) -> (i32, f32) {
   let res = (x, y);
-  if x < 13 { res = (x + 1, y); }
+  for i = 0 to 10 {
+    if i % 2 == 0 {
+      res.0 += 1;
+    } else {
+      res.1 *= 2.0;
+    }
+  }
   return res;
 }
 
-- 
GitLab


From fb0e5c58edc4b755b3ecb238e751c1d8315b5cc1 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 20 Nov 2024 19:04:00 -0600
Subject: [PATCH 09/12] Fix SROA handling of select

---
 hercules_opt/src/sroa.rs      | 45 +++++++++++++++++------------------
 hercules_samples/products.hir |  3 +++
 juno_samples/products.jn      |  6 +----
 3 files changed, 26 insertions(+), 28 deletions(-)
 create mode 100644 hercules_samples/products.hir

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 924c8d93..6205421f 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -160,6 +160,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
         }
     }
 
+    #[derive(Debug)]
     enum WorkItem {
         Unhandled(NodeID),
         AllocatedPhi {
@@ -176,9 +177,9 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
             fields: IndexTree<NodeID>,
         },
         AllocatedTernary {
-            first: NodeID,
-            second: NodeID,
-            third: NodeID,
+            cond: NodeID,
+            thn: NodeID,
+            els: NodeID,
             node: NodeID,
             fields: IndexTree<NodeID>,
         },
@@ -263,9 +264,9 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     field_map.insert(node, fields.clone());
 
                     item = WorkItem::AllocatedTernary {
-                        first,
-                        second,
-                        third,
+                        cond: first,
+                        thn: second,
+                        els: third,
                         node,
                         fields,
                     };
@@ -387,37 +388,35 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 }
             }
             WorkItem::AllocatedTernary {
-                first,
-                second,
-                third,
+                cond,
+                thn,
+                els,
                 node,
                 fields,
             } => {
-                if let (Some(fst_fields), Some(snd_fields), Some(thd_fields)) = (
-                    field_map.get(&first),
-                    field_map.get(&second),
-                    field_map.get(&third),
+                if let (Some(thn_fields), Some(els_fields)) = (
+                    field_map.get(&thn),
+                    field_map.get(&els),
                 ) {
                     fields
-                        .zip(fst_fields)
-                        .zip(snd_fields)
-                        .zip(thd_fields)
-                        .for_each(|idx, (((res, fst), snd), thd)| {
+                        .zip(thn_fields)
+                        .zip(els_fields)
+                        .for_each(|idx, ((res, thn), els)| {
                             to_insert.insert(
                                 res.idx(),
                                 Node::Ternary {
-                                    first: **fst,
-                                    second: **snd,
-                                    third: **thd,
+                                    first: cond,
+                                    second: **thn,
+                                    third: **els,
                                     op: TernaryOperator::Select,
                                 },
                             );
                         });
                 } else {
                     worklist.push_back(WorkItem::AllocatedTernary {
-                        first,
-                        second,
-                        third,
+                        cond,
+                        thn,
+                        els,
                         node,
                         fields,
                     });
diff --git a/hercules_samples/products.hir b/hercules_samples/products.hir
new file mode 100644
index 00000000..d09bb0fa
--- /dev/null
+++ b/hercules_samples/products.hir
@@ -0,0 +1,3 @@
+fn test(x : prod(i32, f32), y: prod(i32, f32), b: bool) -> prod(i32, f32)
+  res = select(b, x, y)
+  r = return(start, res)
diff --git a/juno_samples/products.jn b/juno_samples/products.jn
index 6c14e67a..b97f1088 100644
--- a/juno_samples/products.jn
+++ b/juno_samples/products.jn
@@ -1,11 +1,7 @@
 fn test_call(x : i32, y : f32) -> (i32, f32) {
   let res = (x, y);
   for i = 0 to 10 {
-    if i % 2 == 0 {
-      res.0 += 1;
-    } else {
-      res.1 *= 2.0;
-    }
+    res.0 += 1;
   }
   return res;
 }
-- 
GitLab


From 0bcdf93009233ea2522ce6d4c36013bd89d81532 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 20 Nov 2024 19:28:57 -0600
Subject: [PATCH 10/12] Delete selects

---
 hercules_opt/src/sroa.rs      |  1 +
 hercules_samples/products.hir | 26 +++++++++++++++++++++++---
 2 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 6205421f..afbc775b 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -412,6 +412,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                                 },
                             );
                         });
+                    to_delete.push(node);
                 } else {
                     worklist.push_back(WorkItem::AllocatedTernary {
                         cond,
diff --git a/hercules_samples/products.hir b/hercules_samples/products.hir
index d09bb0fa..9d191beb 100644
--- a/hercules_samples/products.hir
+++ b/hercules_samples/products.hir
@@ -1,3 +1,23 @@
-fn test(x : prod(i32, f32), y: prod(i32, f32), b: bool) -> prod(i32, f32)
-  res = select(b, x, y)
-  r = return(start, res)
+fn test(x : prod(i32, f32), b: bool) -> prod(i32, f32)
+  zero = constant(u64, 0)
+  one = constant(i32, 1)
+  two = constant(u64, 2)
+  three = constant(f32, 3.0)
+
+  f_ctrl = fork(start, 10)
+  idx = thread_id(f_ctrl, 0)
+  
+  mod2 = rem(idx, two)
+  is_even = eq(mod2, zero)
+  field0 = read(res, field(0))
+  field1 = read(res, field(1))
+  add = add(field0, one)
+  mul = mul(field1, three)
+  upd0 = write(res, add, field(0))
+  upd1 = write(res, mul, field(1))
+  select = select(is_even, upd0, upd1)
+  
+  j_ctrl = join(f_ctrl)
+  res = reduce(j_ctrl, x, select)
+
+  r = return(j_ctrl, res)
-- 
GitLab


From 97c4bbcfd3f4a7e4d58a6d7d887b98267156ff2a Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Thu, 21 Nov 2024 09:02:13 -0600
Subject: [PATCH 11/12] Formatting

---
 hercules_opt/src/sroa.rs | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index afbc775b..7430dbca 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -213,7 +213,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
     // enough information to process a particular node, we add it back to the worklist
     let mut next_id: usize = editor.func().nodes.len();
     let mut to_insert = BTreeMap::new();
-    let mut to_replace : Vec<(NodeID, NodeID)> = vec![];
+    let mut to_replace: Vec<(NodeID, NodeID)> = vec![];
 
     while let Some(mut item) = worklist.pop_front() {
         if let WorkItem::Unhandled(node) = item {
@@ -394,10 +394,9 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 node,
                 fields,
             } => {
-                if let (Some(thn_fields), Some(els_fields)) = (
-                    field_map.get(&thn),
-                    field_map.get(&els),
-                ) {
+                if let (Some(thn_fields), Some(els_fields)) =
+                    (field_map.get(&thn), field_map.get(&els))
+                {
                     fields
                         .zip(thn_fields)
                         .zip(els_fields)
-- 
GitLab


From cb4634e4cbebb396795ae8d2ceb9f94b408ebe89 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Fri, 22 Nov 2024 09:47:17 -0600
Subject: [PATCH 12/12] Fix bug when read later used in write

- Need to track nodes that are replaced as we perform replacements
---
 hercules_opt/src/sroa.rs | 36 ++++++++++++++++++++++++++++++++++++
 juno_samples/products.jn |  5 +++--
 2 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 7430dbca..59ee4a8a 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -435,8 +435,44 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
     });
 
     // Replace uses of old reads
+    // Because a read that is being replaced could also be the node some other read is being
+    // replaced by (if the first read is then written into a product that is then read from again)
+    // we need to track what nodes have already been replaced (and by what) so we can properly
+    // replace uses without leaving users of nodes that should be deleted.
+    // replaced_by tracks what a node has been replaced by while replaced_of tracks everything that
+    // maps to a particular node (which is needed to maintain the data structure efficiently)
+    let mut replaced_by: HashMap<NodeID, NodeID> = HashMap::new();
+    let mut replaced_of: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
     for (old, new) in to_replace {
+        let new = match replaced_by.get(&new) {
+            Some(res) => *res,
+            None => new,
+        };
+
         editor.edit(|edit| edit.replace_all_uses(old, new));
+        replaced_by.insert(old, new);
+
+        let mut replaced = vec![];
+        match replaced_of.get_mut(&old) {
+            Some(res) => {
+                std::mem::swap(res, &mut replaced);
+            }
+            None => {}
+        }
+
+        let new_of = match replaced_of.get_mut(&new) {
+            Some(res) => res,
+            None => {
+                replaced_of.insert(new, vec![]);
+                replaced_of.get_mut(&new).unwrap()
+            }
+        };
+        new_of.push(old);
+
+        for n in replaced {
+            replaced_by.insert(n, new);
+            new_of.push(n);
+        }
     }
 
     // Remove nodes
diff --git a/juno_samples/products.jn b/juno_samples/products.jn
index b97f1088..d39ca246 100644
--- a/juno_samples/products.jn
+++ b/juno_samples/products.jn
@@ -6,6 +6,7 @@ fn test_call(x : i32, y : f32) -> (i32, f32) {
   return res;
 }
 
-fn test(x : i32, y : f32) -> (i32, f32) {
-  return test_call(x, y);
+fn test(x : i32, y : f32) -> (f32, i32) {
+  let res = test_call(x, y);
+  return (res.1, res.0);
 }
-- 
GitLab