From 6013f4791dbc2e49765e8117365f21cfbdb31f86 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 18 Nov 2024 21:20:00 -0600
Subject: [PATCH] Acyclic SROA working (mostly)

---
 hercules_opt/src/sroa.rs | 369 +++++++++++++++++++++++++++++++++------
 juno_frontend/src/lib.rs |   2 +-
 2 files changed, 317 insertions(+), 54 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index ae5d76dc..b9380197 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -1,6 +1,6 @@
 extern crate hercules_ir;
 
-use std::collections::{HashMap, LinkedList};
+use std::collections::{HashMap, LinkedList, VecDeque};
 
 use self::hercules_ir::dataflow::*;
 use self::hercules_ir::ir::*;
@@ -41,46 +41,33 @@ use crate::*;
  *   replaced by a direct def of the field value
  */
 pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: &Vec<TypeID>) {
-    // This map stores a map from NodeID and fields to the NodeID which contains just that field of
-    // the original value
-    let mut field_map : HashMap<(NodeID, Vec<Index>), NodeID> = HashMap::new();
+    // This map stores a map from NodeID to an index tree which can be used to lookup the NodeID
+    // that contains the corresponding fields of the original value
+    let mut field_map : HashMap<NodeID, IndexTree<NodeID>> = HashMap::new();
 
     // First: determine all nodes which interact with products (as described above)
-    let mut product_nodes = vec![];
+    let mut product_nodes : Vec<NodeID> = vec![];
     // We track call and return nodes separately since they (may) require constructing new products
     // for the call's arguments or the return's value
-    let mut call_return_nodes = vec![];
-    // We track writes separately since they should be processed once their input product has been
-    // processed, so we handle them after all other nodes (though before reads)
-    let mut write_nodes = vec![];
-    // We also track reads separately because they aren't handled until all other nodes have been
-    // processed
-    let mut read_nodes = vec![];
+    let mut call_return_nodes : Vec<NodeID> = vec![];
 
     let func = editor.func();
 
     for node in reverse_postorder {
         match func.nodes[node.idx()] {
             Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. }
-            | Node::Constant { .. }
+            | Node::Constant { .. } | Node::Write { .. }
             | Node::Ternary { first: _, second: _, third: _, op: TernaryOperator::Select }
-                if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(node),
+                if editor.get_type(types[node.idx()]).is_product() => product_nodes.push(*node),
             
-            Node::Write { .. }
-                if editor.get_type(types[node.idx()]).is_product() => write_nodes.push(node),
             Node::Read { collect, .. }
-                if editor.get_type(types[collect.idx()]).is_product() => read_nodes.push(node),
+                if editor.get_type(types[collect.idx()]).is_product() => product_nodes.push(*node),
 
             // We add all calls to the call/return list and check their arguments later
-            Node::Call { .. } => {
-                call_return_nodes.push(node);
-                if editor.get_type(types[node.idx()]).is_product() {
-                    read_nodes.push(node);
-                }
-            }
+            Node::Call { .. } => call_return_nodes.push(*node),
             Node::Return { control: _, data }
                 if editor.get_type(types[data.idx()]).is_product() =>
-                call_return_nodes.push(node),
+                call_return_nodes.push(*node),
 
             _ => ()
         }
@@ -98,10 +85,10 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
             Node::Return { control, data } => {
                 assert!(editor.get_type(types[data.idx()]).is_product());
                 let control = *control;
-                let new_data = reconstruct_product(editor, types[data.idx()], *data);
+                let new_data = reconstruct_product(editor, types[data.idx()], *data, &mut product_nodes);
                 editor.edit(|mut edit| {
                     edit.add_node(Node::Return { control, data: new_data });
-                    edit.delete_node(*node)
+                    edit.delete_node(node)
                 });
             }
             Node::Call { control, function, dynamic_constants, args } => {
@@ -113,7 +100,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 // If the call returns a product, we generate reads for each field
                 let fields =
                     if editor.get_type(types[node.idx()]).is_product() {
-                        Some(generate_reads(editor, types[node.idx()], *node))
+                        Some(generate_reads(editor, types[node.idx()], node))
                     } else {
                         None
                     };
@@ -121,23 +108,21 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                 let mut new_args = vec![];
                 for arg in args {
                     if editor.get_type(types[arg.idx()]).is_product() {
-                        new_args.push(reconstruct_product(editor, types[arg.idx()], arg));
+                        new_args.push(reconstruct_product(editor, types[arg.idx()], arg, &mut product_nodes));
                     } else {
                         new_args.push(arg);
                     }
                 }
-                editor.dit(|mut edit| {
+                editor.edit(|mut edit| {
                     let new_call = edit.add_node(Node::Call { control, function, dynamic_constants, args: new_args.into() });
-                    let edit = edit.replace_all_uses(*node, new_call)?;
-                    let edit = edit.delete_node(*node)?;
+                    let edit = edit.replace_all_uses(node, new_call)?;
+                    let edit = edit.delete_node(node)?;
 
                     match fields {
                         None => {}
                         Some(fields) => {
-                            for (idx, node) in fields {
-                                field_map.insert((new_call, idx), node);
-                            }
-    }
+                            field_map.insert(new_call, fields);
+                        }
                     }
 
                     Ok(edit)
@@ -147,14 +132,243 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
         }
     }
 
-    // Now, we process all other non-read/write nodes that deal with products.
-    // The first step is to identify the NodeIDs that contain each field of each of these nodes
-    todo!()
+    enum WorkItem {
+        Unhandled(NodeID),
+        AllocatedPhi { control: NodeID, data: Vec<NodeID>, fields: IndexTree<NodeID> },
+        AllocatedReduce { control: NodeID, init: NodeID, reduct: NodeID, fields: IndexTree<NodeID> },
+        AllocatedTernary { first: NodeID, second: NodeID, third: NodeID, fields: IndexTree<NodeID> },
+    }
+
+    // Now, we process the other nodes that deal with products.
+    // The first step is to assign new NodeIDs to the nodes that will be split into multiple: phi,
+    // reduce, parameter, constant, and ternary.
+    // We do this in several steps: first we break apart parameters and constants
+    let mut to_delete = vec![];
+    let mut worklist: VecDeque<WorkItem> = VecDeque::new();
+
+    for node in product_nodes {
+        match editor.func().nodes[node.idx()] {
+            Node::Parameter { .. } => {
+                field_map.insert(node, generate_reads(editor, types[node.idx()], node));
+            }
+            Node::Constant { id } => {
+                field_map.insert(node, generate_constant_fields(editor, id));
+                to_delete.push(node);
+            }
+            _ => { worklist.push_back(WorkItem::Unhandled(node)); }
+        }
+    }
+
+    // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map.
+    // We track the current NodeID and add nodes whenever possible, otherwise we return values to
+    // the worklist
+    let mut next_id : usize = editor.func().nodes.len();
+    let mut cur_id : usize = editor.func().nodes.len();
+    while let Some(mut item) = worklist.pop_front() {
+        if let WorkItem::Unhandled(node) = item {
+            match &editor.func().nodes[node.idx()] {
+                // For phi, reduce, and ternary, we break them apart into separate nodes for each field
+                Node::Phi { control, data } => {
+                    let control = *control;
+                    let data = data.clone();
+                    let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
+                    field_map.insert(node, fields.clone());
+
+                    item = WorkItem::AllocatedPhi {
+                        control,
+                        data: data.into(),
+                        fields };
+                }
+                Node::Reduce { control, init, reduct } => {
+                    let control = *control;
+                    let init = *init;
+                    let reduct = *reduct;
+                    let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
+                    field_map.insert(node, fields.clone());
+
+                    item = WorkItem::AllocatedReduce { control, init, reduct, fields };
+                }
+                Node::Ternary { first, second, third, .. } => {
+                    let first = *first;
+                    let second = *second;
+                    let third = *third;
+                    let fields = allocate_fields(editor, types[node.idx()], &mut next_id);
+                    field_map.insert(node, fields.clone());
+
+                    item = WorkItem::AllocatedTernary { first, second, third, fields };
+                }
+
+                Node::Write { collect, data, indices } => {
+                    if let Some(index_map) = field_map.get(collect) {
+                        if editor.get_type(types[data.idx()]).is_product() {
+                            if let Some(data_idx) = field_map.get(data) {
+                                field_map.insert(node, index_map.clone().replace(indices, data_idx.clone()));
+                                to_delete.push(node);
+                            } else {
+                                worklist.push_back(WorkItem::Unhandled(node));
+                            }
+                        } else {
+                            field_map.insert(node, index_map.clone().set(indices, *data));
+                            to_delete.push(node);
+                        }
+                    } else {
+                        worklist.push_back(WorkItem::Unhandled(node));
+                    }
+                }
+                Node::Read { collect, indices } => {
+                    if let Some(index_map) = field_map.get(collect) {
+                        let read_info = index_map.lookup(indices);
+                        match read_info {
+                            IndexTree::Leaf(field) => {
+                                editor.edit(|edit| {
+                                    edit.replace_all_uses(node, *field)
+                                });
+                            }
+                            _ => {}
+                        }
+                        field_map.insert(node, read_info.clone());
+                        to_delete.push(node);
+                    } else {
+                        worklist.push_back(WorkItem::Unhandled(node));
+                    }
+                }
+
+                _ => panic!("Unexpected node type")
+            }
+        }
+        match item {
+            WorkItem::Unhandled(_) => {}
+            _ => todo!()
+        }
+    }
+
+    // Actually deleting nodes seems to break things right now
+    /*
+    println!("{:?}", to_delete);
+    editor.edit(|mut edit| {
+        for node in to_delete {
+            edit = edit.delete_node(node)?
+        }
+        Ok(edit)
+    });
+    */
+}
+
+// An index tree is used to store results at many index lists
+#[derive(Clone, Debug)]
+enum IndexTree<T> {
+    Leaf(T),
+    Node(Vec<IndexTree<T>>),
+}
+
+impl<T: std::fmt::Debug> IndexTree<T> {
+    fn lookup(&self, idx: &[Index]) -> &IndexTree<T> {
+        self.lookup_idx(idx, 0)
+    }
+
+    fn lookup_idx(&self, idx: &[Index], n: usize) -> &IndexTree<T> {
+        if n < idx.len() {
+            if let Index::Field(i) = idx[n] {
+                match self {
+                    IndexTree::Leaf(_) => panic!("Invalid field"),
+                    IndexTree::Node(ts) => ts[i].lookup_idx(idx, n+1),
+                }
+            } else {
+                // TODO: This could be hit because of an array inside of a product
+                panic!("Error handling lookup of field");
+            }
+        } else {
+            self
+        }
+    }
+
+    fn set(self, idx: &[Index], val: T) -> IndexTree<T> {
+        self.set_idx(idx, val, 0)
+    }
+
+    fn set_idx(self, idx: &[Index], val: T, n: usize) -> IndexTree<T> {
+        if n < idx.len() {
+            if let Index::Field(i) = idx[n] {
+                match self {
+                    IndexTree::Leaf(_) => panic!("Invalid field"),
+                    IndexTree::Node(mut ts) => {
+                        if i + 1 == ts.len() {
+                            let t = ts.pop().unwrap();
+                            ts.push(t.set_idx(idx, val, n+1));
+                        } else {
+                            let mut t = ts.pop().unwrap();
+                            std::mem::swap(&mut ts[i], &mut t);
+                            t = t.set_idx(idx, val, n+1);
+                            std::mem::swap(&mut ts[i], &mut t);
+                            ts.push(t);
+                        }
+                        IndexTree::Node(ts)
+                    }
+                }
+            } else {
+                panic!("Error handling set of field");
+            }
+        } else {
+            IndexTree::Leaf(val)
+        }
+    }
+
+    fn replace(self, idx: &[Index], val: IndexTree<T>) -> IndexTree<T> {
+        self.replace_idx(idx, val, 0)
+    }
+
+    fn replace_idx(self, idx: &[Index], val: IndexTree<T>, n: usize) -> IndexTree<T> {
+        if n < idx.len() {
+            if let Index::Field(i) = idx[n] {
+                match self {
+                    IndexTree::Leaf(_) => panic!("Invalid field"),
+                    IndexTree::Node(mut ts) => {
+                        if i + 1 == ts.len() {
+                            let t = ts.pop().unwrap();
+                            ts.push(t.replace_idx(idx, val, n+1));
+                        } else {
+                            let mut t = ts.pop().unwrap();
+                            std::mem::swap(&mut ts[i], &mut t);
+                            t = t.replace_idx(idx, val, n+1);
+                            std::mem::swap(&mut ts[i], &mut t);
+                            ts.push(t);
+                        }
+                        IndexTree::Node(ts)
+                    }
+                }
+            } else {
+                panic!("Error handling set of field");
+            }
+        } else {
+            val
+        }
+    }
+
+    fn for_each<F>(&self, mut f: F)
+        where F: FnMut(&Vec<Index>, &T) {
+        self.for_each_idx(&mut vec![], &mut f);
+    }
+
+    fn for_each_idx<F>(&self, idx: &mut Vec<Index>, f: &mut F)
+        where F: FnMut(&Vec<Index>, &T) {
+        match self {
+            IndexTree::Leaf(t) => f(idx, t),
+            IndexTree::Node(ts) => {
+                for (i, t) in ts.iter().enumerate() {
+                    idx.push(Index::Field(i));
+                    t.for_each_idx(idx, f);
+                    idx.pop();
+                }
+            }
+        }
+    }
 }
 
 // Given a product value val of type typ, constructs a copy of that value by extracting all fields
 // from that value and then writing them into a new constant
-fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> NodeID {
+// This process also adds all the read nodes that are generated into the read_list so that the
+// reads can be eliminated by later parts of SROA
+fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, read_list: &mut Vec<NodeID>) -> NodeID {
     let fields = generate_reads(editor, typ, val);
     let new_const = generate_constant(editor, typ);
 
@@ -167,44 +381,44 @@ fn reconstruct_product(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) ->
 
     // Generate writes for each field
     let mut value = const_node.expect("Add node cannot fail");
-    for (idx, val) in fields {
+    fields.for_each(|idx: &Vec<Index>, val: &NodeID| {
+        read_list.push(*val);
         editor.edit(|mut edit| {
-            value = edit.add_node(Node::Write { collect: value, data: val, indices: idx.into() });
+            value = edit.add_node(Node::Write { collect: value, data: *val, indices: idx.clone().into() });
             Ok(edit)
         });
-    }
+    });
 
     value
 }
 
 // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and
 // returns a list of pairs of the indices and the node that reads that index
-fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> Vec<(Vec<Index>, NodeID)> {
-    generate_reads_at_index(editor, typ, val, vec![]).into_iter().collect::<_>()
+fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> {
+    let res = generate_reads_at_index(editor, typ, val, vec![]);
+    res
 }
 
 // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf)
 // fields within this sub-value of val and return the correspondence list
-fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> LinkedList<(Vec<Index>, NodeID)> {
-    let ts : Option<Vec<TypeID>> = {
+fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID, idx: Vec<Index>) -> IndexTree<NodeID> {
+    let ts : Option<Vec<TypeID>> =
         if let Some(ts) = editor.get_type(typ).try_product() {
             Some(ts.into())
         } else {
             None
-        }
-    };
+        };
 
     if let Some(ts) = ts {
         // For product values, we will recurse down each of its fields with an extended index
         // and the appropriate type of that field
-        let mut result = LinkedList::new();
+        let mut fields = vec![];
         for (i, t) in ts.into_iter().enumerate() {
             let mut new_idx = idx.clone();
             new_idx.push(Index::Field(i));
-            result.append(&mut generate_reads_at_index(editor, t, val, new_idx));
+            fields.push(generate_reads_at_index(editor, t, val, new_idx));
         }
-
-        result
+        IndexTree::Node(fields)
     } else {
         // For non-product types, we've reached a leaf so we generate the read and return it's
         // information
@@ -214,7 +428,7 @@ fn generate_reads_at_index(editor: &mut FunctionEditor, typ: TypeID, val: NodeID
             Ok(edit)
         });
 
-        LinkedList::from([(idx, read_id.expect("Add node cannot fail"))])
+        IndexTree::Leaf(read_id.expect("Add node canont fail"))
     }
 }
 
@@ -262,3 +476,52 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID {
         Type::Control => panic!("Cannot create constant of control type")
     }
 }
+
+// Given a constant cnst adds node to the function which are the constant values of each field and
+// returns a list of pairs of indices and the node that holds that index
+fn generate_constant_fields(editor: &mut FunctionEditor, cnst: ConstantID) -> IndexTree<NodeID> {
+    let cs : Option<Vec<ConstantID>> =
+        if let Some(cs) = editor.get_constant(cnst).try_product_fields() {
+            Some(cs.into())
+        } else {
+            None
+        };
+
+    if let Some(cs) = cs {
+        let mut fields = vec![];
+        for c in cs {
+            fields.push(generate_constant_fields(editor, c));
+        }
+        IndexTree::Node(fields)
+    } else {
+        let mut node = None;
+        editor.edit(|mut edit| {
+            node = Some(edit.add_node(Node::Constant { id: cnst }));
+            Ok(edit)
+        });
+        IndexTree::Leaf(node.expect("Add node cannot fail"))
+    }
+}
+
+// Given a type, return a list of the fields and new NodeIDs for them, with NodeIDs starting at the
+// id provided
+fn allocate_fields(editor: &FunctionEditor, typ: TypeID, id: &mut usize) -> IndexTree<NodeID> {
+    let ts : Option<Vec<TypeID>> =
+        if let Some(ts) = editor.get_type(typ).try_product() {
+            Some(ts.into())
+        } else {
+            None
+        };
+
+    if let Some(ts) = ts {
+        let mut fields = vec![];
+        for t in ts {
+            fields.push(allocate_fields(editor, t, id));
+        }
+        IndexTree::Node(fields)
+    } else {
+        let node = *id;
+        *id += 1;
+        IndexTree::Leaf(NodeID::new(node))
+    }
+}
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index d4263fe8..c3dc262a 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -157,7 +157,7 @@ pub fn compile_ir(
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
     // TEMPORARY
-    add_pass!(pm, verify, Inline);
+    //add_pass!(pm, verify, Inline);
     add_pass!(pm, verify, CCP);
     add_pass!(pm, verify, DCE);
     add_pass!(pm, verify, GVN);
-- 
GitLab