From b8f77a5e2fe55d6008c3958563bb0d6354b740a2 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 19 Nov 2024 08:58:41 -0600
Subject: [PATCH] Start on SROA of remaining nodes

---
 hercules_opt/src/sroa.rs | 184 +++++++++++++++++++++++++++++++++++++--
 juno_frontend/src/lib.rs |   4 +
 juno_samples/products.jn |   4 +-
 3 files changed, 185 insertions(+), 7 deletions(-)

diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index d8e2021a..ccc2605c 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -1,6 +1,6 @@
 extern crate hercules_ir;
 
-use std::collections::{HashMap, LinkedList, VecDeque};
+use std::collections::{BTreeMap, HashMap, LinkedList, VecDeque};
 
 use self::hercules_ir::dataflow::*;
 use self::hercules_ir::ir::*;
@@ -165,18 +165,21 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
         AllocatedPhi {
             control: NodeID,
             data: Vec<NodeID>,
+            node: NodeID,
             fields: IndexTree<NodeID>,
         },
         AllocatedReduce {
             control: NodeID,
             init: NodeID,
             reduct: NodeID,
+            node: NodeID,
             fields: IndexTree<NodeID>,
         },
         AllocatedTernary {
             first: NodeID,
             second: NodeID,
             third: NodeID,
+            node: NodeID,
             fields: IndexTree<NodeID>,
         },
     }
@@ -204,10 +207,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
     }
 
     // Now, we process the remaining nodes, allocating NodeIDs for them and updating the field_map.
-    // We track the current NodeID and add nodes whenever possible, otherwise we return values to
-    // the worklist
+    // We track the current NodeID and add nodes to a set we maintain of nodes to add (since we
+    // need to add nodes in a particular order we wait to do that until the end). If we don't have
+    // enough information to process a particular node, we add it back to the worklist
     let mut next_id: usize = editor.func().nodes.len();
-    let mut cur_id: usize = editor.func().nodes.len();
+    let mut to_insert = BTreeMap::new();
+
     while let Some(mut item) = worklist.pop_front() {
         if let WorkItem::Unhandled(node) = item {
             match &editor.func().nodes[node.idx()] {
@@ -221,6 +226,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                     item = WorkItem::AllocatedPhi {
                         control,
                         data: data.into(),
+                        node,
                         fields,
                     };
                 }
@@ -239,6 +245,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                         control,
                         init,
                         reduct,
+                        node,
                         fields,
                     };
                 }
@@ -258,6 +265,7 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
                         first,
                         second,
                         third,
+                        node,
                         fields,
                     };
                 }
@@ -307,11 +315,127 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types:
         }
         match item {
             WorkItem::Unhandled(_) => {}
-            _ => todo!(),
+            WorkItem::AllocatedPhi {
+                control,
+                data,
+                node,
+                fields,
+            } => {
+                let mut data_fields = vec![];
+                let mut ready = true;
+                for val in data.iter() {
+                    if let Some(val_fields) = field_map.get(val) {
+                        data_fields.push(val_fields);
+                    } else {
+                        ready = false;
+                        break;
+                    }
+                }
+
+                if ready {
+                    fields.zip_list(data_fields).for_each(|idx, (res, data)| {
+                        to_insert.insert(
+                            res.idx(),
+                            Node::Phi {
+                                control,
+                                data: data.into_iter().map(|n| **n).collect::<Vec<_>>().into(),
+                            },
+                        );
+                    });
+                } else {
+                    worklist.push_back(WorkItem::AllocatedPhi {
+                        control,
+                        data,
+                        node,
+                        fields,
+                    });
+                }
+            }
+            WorkItem::AllocatedReduce {
+                control,
+                init,
+                reduct,
+                node,
+                fields,
+            } => {
+                if let (Some(init_fields), Some(reduct_fields)) =
+                    (field_map.get(&init), field_map.get(&reduct))
+                {
+                    fields.zip(init_fields).zip(reduct_fields).for_each(
+                        |idx, ((res, init), reduct)| {
+                            to_insert.insert(
+                                res.idx(),
+                                Node::Reduce {
+                                    control,
+                                    init: **init,
+                                    reduct: **reduct,
+                                },
+                            );
+                        },
+                    );
+                    to_delete.push(node);
+                } else {
+                    worklist.push_back(WorkItem::AllocatedReduce {
+                        control,
+                        init,
+                        reduct,
+                        node,
+                        fields,
+                    });
+                }
+            }
+            WorkItem::AllocatedTernary {
+                first,
+                second,
+                third,
+                node,
+                fields,
+            } => {
+                if let (Some(fst_fields), Some(snd_fields), Some(thd_fields)) = (
+                    field_map.get(&first),
+                    field_map.get(&second),
+                    field_map.get(&third),
+                ) {
+                    fields
+                        .zip(fst_fields)
+                        .zip(snd_fields)
+                        .zip(thd_fields)
+                        .for_each(|idx, (((res, fst), snd), thd)| {
+                            to_insert.insert(
+                                res.idx(),
+                                Node::Ternary {
+                                    first: **fst,
+                                    second: **snd,
+                                    third: **thd,
+                                    op: TernaryOperator::Select,
+                                },
+                            );
+                        });
+                } else {
+                    worklist.push_back(WorkItem::AllocatedTernary {
+                        first,
+                        second,
+                        third,
+                        node,
+                        fields,
+                    });
+                }
+            }
         }
     }
 
-    // Actually deleting nodes seems to break things right now
+    // Create new nodes nodes
+    for (node_id, node) in to_insert {
+        assert!(node_id == editor.func().nodes.len());
+        println!("Inserting {:?} : {:?}", node_id, node);
+        editor.edit(|mut edit| {
+            let id = edit.add_node(node);
+            assert!(node_id == id.idx());
+            Ok(edit)
+        });
+    }
+
+    // Remove nodes
     editor.edit(|mut edit| {
         for node in to_delete {
             edit = edit.delete_node(node)?
@@ -410,6 +534,54 @@ impl<T: std::fmt::Debug> IndexTree<T> {
         }
     }
 
+    fn zip<'a, A>(self, other: &'a IndexTree<A>) -> IndexTree<(T, &'a A)> {
+        match (self, other) {
+            (IndexTree::Leaf(t), IndexTree::Leaf(a)) => IndexTree::Leaf((t, a)),
+            (IndexTree::Node(t), IndexTree::Node(a)) => {
+                let mut fields = vec![];
+                for (t, a) in t.into_iter().zip(a.iter()) {
+                    fields.push(t.zip(a));
+                }
+                IndexTree::Node(fields)
+            }
+            _ => panic!("IndexTrees do not have the same fields, cannot zip"),
+        }
+    }
+
+    fn zip_list<'a, A>(self, others: Vec<&'a IndexTree<A>>) -> IndexTree<(T, Vec<&'a A>)> {
+        match self {
+            IndexTree::Leaf(t) => {
+                let mut res = vec![];
+                for other in others {
+                    match other {
+                        IndexTree::Leaf(a) => res.push(a),
+                        _ => panic!("IndexTrees do not have the same fields, cannot zip"),
+                    }
+                }
+                IndexTree::Leaf((t, res))
+            }
+            IndexTree::Node(t) => {
+                let mut fields: Vec<Vec<&'a IndexTree<A>>> = vec![vec![]; t.len()];
+                for other in others {
+                    match other {
+                        IndexTree::Node(a) => {
+                            for (i, a) in a.iter().enumerate() {
+                                fields[i].push(a);
+                            }
+                        }
+                        _ => panic!("IndexTrees do not have the same fields, cannot zip"),
+                    }
+                }
+                IndexTree::Node(
+                    t.into_iter()
+                        .zip(fields.into_iter())
+                        .map(|(t, f)| t.zip_list(f))
+                        .collect(),
+                )
+            }
+        }
+    }
+
     fn for_each<F>(&self, mut f: F)
     where
         F: FnMut(&Vec<Index>, &T),
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index 49249fc7..e54e88fc 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -151,6 +151,10 @@ pub fn compile_ir(
     if x_dot {
         pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
     }
+    add_pass!(pm, verify, SROA);
+    if x_dot {
+        pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
+    }
     add_pass!(pm, verify, Inline);
     // Run SROA pretty early (though after inlining which can make SROA more effective) so that
     // CCP, GVN, etc. can work on the result of SROA
diff --git a/juno_samples/products.jn b/juno_samples/products.jn
index a6dd8862..00b66aab 100644
--- a/juno_samples/products.jn
+++ b/juno_samples/products.jn
@@ -1,5 +1,7 @@
 fn test_call(x : i32, y : f32) -> (i32, f32) {
-  return (x, y);
+  let res = (x, y);
+  if x < 13 { res = (x + 1, y); }
+  return res;
 }
 
 fn test(x : i32, y : f32) -> (i32, f32) {
-- 
GitLab