From 996018c62c6a921a88d1bfa98534a78478c89eab Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 2 Jun 2024 20:59:41 -0700
Subject: [PATCH] Get the products that are / aren't used in a sink

---
 hercules_ir/src/ir.rs     |   8 +++
 hercules_opt/src/pass.rs  |  14 ++++-
 hercules_opt/src/sroa.rs  | 124 +++++++++++++++++++++++++++++++++++++-
 juno_frontend/src/main.rs |   2 +-
 4 files changed, 144 insertions(+), 4 deletions(-)

diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 6f44bff7..7c892117 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -659,6 +659,14 @@ impl Type {
         self.is_bool() || self.is_fixed() || self.is_float()
     }
 
+    pub fn is_product(&self) -> bool {
+        if let Type::Product(_) = self {
+            true
+        } else {
+            false
+        }
+    }
+
     pub fn is_array(&self) -> bool {
         if let Type::Array(_, _) = self {
             true
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index d6b8aca5..909b6295 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -366,8 +366,20 @@ impl PassManager {
                     }
                 }
                 Pass::SROA => {
+                    self.make_def_uses();
+                    self.make_reverse_postorders();
+                    self.make_typing();
+                    let def_uses = self.def_uses.as_ref().unwrap();
+                    let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
+                    let typing = self.typing.as_ref().unwrap();
                     for idx in 0..self.module.functions.len() {
-                        sroa(&mut self.module.functions[idx])
+                        sroa(
+                            &mut self.module.functions[idx],
+                            &def_uses[idx],
+                            &reverse_postorders[idx],
+                            &typing[idx],
+                            &self.module.types,
+                        );
                     }
                 }
                 Pass::Verify => {
diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs
index 2f1ba1a0..a62209dc 100644
--- a/hercules_opt/src/sroa.rs
+++ b/hercules_opt/src/sroa.rs
@@ -1,5 +1,7 @@
 extern crate hercules_ir;
 
+use self::hercules_ir::dataflow::*;
+use self::hercules_ir::def_use::*;
 use self::hercules_ir::ir::*;
 
 /*
@@ -37,6 +39,124 @@ use self::hercules_ir::ir::*;
  *   but are retained when the product value is unbroken
  *
  * The nodes above with the list marker "+" are retained for maintaining API/ABI
- * compatability with other Hercules functions and the host code.
+ * compatability with other Hercules functions and the host code. These are
+ * called "sink" or "source" nodes in comments below.
  */
-pub fn sroa(function: &mut Function) {}
+pub fn sroa(
+    function: &mut Function,
+    def_use: &ImmutableDefUseMap,
+    reverse_postorder: &Vec<NodeID>,
+    typing: &Vec<TypeID>,
+    types: &Vec<Type>,
+) {
+    // Step 1: determine which sources of product values we want to try breaking
+    // up. We can determine easily on the soure side if a node produces a
+    // product that shouldn't be broken up by just examining the node type.
+    // However, the way that products are used is also important for determining
+    // if the product can be broken up. We backward dataflow this info to the
+    // sources of product values.
+    #[derive(PartialEq, Eq, Clone, Debug)]
+    enum ProductUseLattice {
+        // The product value used by this node is eventually used by a sink.
+        UsedBySink,
+        // This node uses multiple product values - the stored node ID indicates
+        // which is eventually used by a sink. This lattice value is produced by
+        // read and write nodes implementing partial indexing. We handle this
+        // specially since we want to optimize away partial indexing completely.
+        SpecificUsedBySink(NodeID),
+        // This node doesn't use a product node, or the product node it does use
+        // is not in turn used by a sink.
+        UnusedBySink,
+    }
+
+    impl Semilattice for ProductUseLattice {
+        fn meet(a: &Self, b: &Self) -> Self {
+            match (a, b) {
+                (Self::UsedBySink, _) | (_, Self::UsedBySink) => Self::UsedBySink,
+                (Self::SpecificUsedBySink(id1), Self::SpecificUsedBySink(id2)) => {
+                    if id1 == id2 {
+                        Self::SpecificUsedBySink(*id1)
+                    } else {
+                        Self::UsedBySink
+                    }
+                }
+                (Self::SpecificUsedBySink(id), _) | (_, Self::SpecificUsedBySink(id)) => {
+                    Self::SpecificUsedBySink(*id)
+                }
+                _ => Self::UnusedBySink,
+            }
+        }
+
+        fn bottom() -> Self {
+            Self::UsedBySink
+        }
+
+        fn top() -> Self {
+            Self::UnusedBySink
+        }
+    }
+
+    let product_uses = backward_dataflow(function, def_use, reverse_postorder, |succ_outs, id| {
+        match function.nodes[id.idx()] {
+            Node::Return {
+                control: _,
+                data: _,
+            } => {
+                if types[typing[id.idx()].idx()].is_product() {
+                    ProductUseLattice::UsedBySink
+                } else {
+                    ProductUseLattice::UnusedBySink
+                }
+            }
+            Node::Call {
+                function: _,
+                dynamic_constants: _,
+                args: _,
+            } => todo!(),
+            // For reads and writes, we only want to propagate the use of the
+            // product to the collect input of the node.
+            Node::Read {
+                collect,
+                indices: _,
+            }
+            | Node::Write {
+                collect,
+                data: _,
+                indices: _,
+            } => {
+                let meet = succ_outs
+                    .iter()
+                    .fold(ProductUseLattice::top(), |acc, latt| {
+                        ProductUseLattice::meet(&acc, latt)
+                    });
+                if meet == ProductUseLattice::UnusedBySink {
+                    ProductUseLattice::UnusedBySink
+                } else {
+                    ProductUseLattice::SpecificUsedBySink(collect)
+                }
+            }
+            // For non-sink nodes.
+            _ => {
+                if function.nodes[id.idx()].is_control() {
+                    return ProductUseLattice::UnusedBySink;
+                }
+                let meet = succ_outs
+                    .iter()
+                    .fold(ProductUseLattice::top(), |acc, latt| {
+                        ProductUseLattice::meet(&acc, latt)
+                    });
+                if let ProductUseLattice::SpecificUsedBySink(meet_id) = meet {
+                    if meet_id == id {
+                        ProductUseLattice::UsedBySink
+                    } else {
+                        ProductUseLattice::UnusedBySink
+                    }
+                } else {
+                    meet
+                }
+            }
+        }
+    });
+
+    println!("{:?}", product_uses);
+}
diff --git a/juno_frontend/src/main.rs b/juno_frontend/src/main.rs
index 0428d73a..72855c0d 100644
--- a/juno_frontend/src/main.rs
+++ b/juno_frontend/src/main.rs
@@ -61,11 +61,11 @@ fn main() {
                 pm.add_pass(hercules_opt::pass::Pass::Verify);
             }
             add_verified_pass!(pm, args, PhiElim);
-            add_pass!(pm, args, SROA);
             add_pass!(pm, args, CCP);
             add_pass!(pm, args, DCE);
             add_pass!(pm, args, GVN);
             add_pass!(pm, args, DCE);
+            add_pass!(pm, args, SROA);
             if args.x_dot {
                 pm.add_pass(hercules_opt::pass::Pass::Xdot(true));
             }
-- 
GitLab