From b336327e69c03752217bcc613ac3374e880367b1 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 27 Feb 2025 10:52:10 -0600
Subject: [PATCH] Skeleton and analysis for const inline pass

---
 hercules_opt/src/inline.rs           | 134 ++++++++++++++++++++++++++-
 juno_samples/rodinia/bfs/src/cpu.sch |   1 +
 juno_scheduler/src/compile.rs        |   1 +
 juno_scheduler/src/ir.rs             |   3 +-
 juno_scheduler/src/pm.rs             |  18 ++++
 5 files changed, 153 insertions(+), 4 deletions(-)

diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs
index 99187dd2..c94ad962 100644
--- a/hercules_opt/src/inline.rs
+++ b/hercules_opt/src/inline.rs
@@ -1,8 +1,7 @@
+use std::cell::Ref;
 use std::collections::HashMap;
 
-use hercules_ir::callgraph::*;
-use hercules_ir::def_use::*;
-use hercules_ir::ir::*;
+use hercules_ir::*;
 
 use crate::*;
 
@@ -235,3 +234,132 @@ fn inline_func(
         });
     }
 }
+
+#[derive(Clone, Debug, Copy)]
+enum ParameterLattice {
+    Top,
+    Constant(ConstantID),
+    // Dynamic constant
+    DynamicConstant(DynamicConstantID, FunctionID),
+    Bottom,
+}
+
+impl ParameterLattice {
+    fn from_node(node: &Node, func_id: FunctionID) -> Self {
+        use ParameterLattice::*;
+        match node {
+            Node::Undef { ty: _ } => Top,
+            Node::Constant { id } => Constant(*id),
+            Node::DynamicConstant { id } => DynamicConstant(*id, func_id),
+            _ => Bottom,
+        }
+    }
+
+    fn try_const_dc(self, dcs: Ref<'_, Vec<DynamicConstant>>) -> Option<usize> {
+        if let ParameterLattice::DynamicConstant(id, _) = self
+            && let DynamicConstant::Constant(val) = &dcs[id.idx()]
+        {
+            Some(*val)
+        } else {
+            None
+        }
+    }
+
+    fn meet(&mut self, b: Self, cons: Ref<'_, Vec<Constant>>, dcs: Ref<'_, Vec<DynamicConstant>>) {
+        use ParameterLattice::*;
+        *self = match (*self, b) {
+            (Top, b) => b,
+            (a, Top) => a,
+            (Bottom, _) | (_, Bottom) => Bottom,
+            (Constant(id_a), Constant(id_b)) => {
+                if id_a == id_b {
+                    Constant(id_a)
+                } else {
+                    Bottom
+                }
+            }
+            (DynamicConstant(dc_a, f_a), DynamicConstant(dc_b, f_b)) => {
+                if dc_a == dc_b && f_a == f_b {
+                    DynamicConstant(dc_a, f_a)
+                } else if let (
+                    ir::DynamicConstant::Constant(dcv_a),
+                    ir::DynamicConstant::Constant(dcv_b),
+                ) = (&dcs[dc_a.idx()], &dcs[dc_b.idx()])
+                    && *dcv_a == *dcv_b
+                {
+                    DynamicConstant(dc_a, f_a)
+                } else {
+                    Bottom
+                }
+            }
+            (DynamicConstant(dc, _), Constant(con)) | (Constant(con), DynamicConstant(dc, _)) => {
+                match (&cons[con.idx()], &dcs[dc.idx()]) {
+                    (ir::Constant::UnsignedInteger64(conv), ir::DynamicConstant::Constant(dcv))
+                        if *conv as usize == *dcv =>
+                    {
+                        Constant(con)
+                    }
+                    _ => Bottom,
+                }
+            }
+        }
+    }
+}
+
+/*
+ * Top level function to inline constant parameters and constant dynamic
+ * constant parameters. Identifies functions that are:
+ *
+ * 1. Not marked as entry.
+ * 2. At every call site, a particular parameter is always a specific constant
+ *    or dynamic constant OR a particular dynamic constant parameter is always a
+ *    specific constant.
+ *
+ * These functions can have that constant "inlined" - the parameter is removed
+ * and all uses of the parameter becomes uses of the constant directly.
+ */
+pub fn const_inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) {
+    // Run const inlining on each function, starting at the most shallow
+    // function first, since we want to propagate constants down the call graph.
+    for func_id in callgraph.topo().into_iter().rev() {
+        let func = editors[func_id.idx()].func();
+        if func.entry || callgraph.num_callers(func_id) == 0 {
+            continue;
+        }
+
+        // Figure out what we know about the parameters (both normal and dynamic
+        // constant) to this function.
+        let mut param_lattice = vec![ParameterLattice::Top; func.param_types.len()];
+        let mut dc_param_lattice = vec![ParameterLattice::Top; func.num_dynamic_constants as usize];
+        for caller in callgraph.get_callers(func_id) {
+            let editor = &editors[caller.idx()];
+            let nodes = &editor.func().nodes;
+            for id in editor.node_ids() {
+                if let Some((_, callee, dc_args, args)) = nodes[id.idx()].try_call()
+                    && callee == func_id
+                {
+                    for (idx, id) in args.into_iter().enumerate() {
+                        let lattice = ParameterLattice::from_node(&nodes[id.idx()], callee);
+                        param_lattice[idx].meet(
+                            lattice,
+                            editor.get_constants(),
+                            editor.get_dynamic_constants(),
+                        );
+                    }
+
+                    for (idx, id) in dc_args.into_iter().enumerate() {
+                        let lattice = ParameterLattice::DynamicConstant(*id, func_id);
+                        dc_param_lattice[idx].meet(
+                            lattice,
+                            editor.get_constants(),
+                            editor.get_dynamic_constants(),
+                        );
+                    }
+                }
+            }
+        }
+        println!("{}:", func.name);
+        println!("{:?}", param_lattice);
+        println!("{:?}", dc_param_lattice);
+    }
+}
diff --git a/juno_samples/rodinia/bfs/src/cpu.sch b/juno_samples/rodinia/bfs/src/cpu.sch
index 1a94ea55..d201765e 100644
--- a/juno_samples/rodinia/bfs/src/cpu.sch
+++ b/juno_samples/rodinia/bfs/src/cpu.sch
@@ -22,5 +22,6 @@ fixpoint {
   fork-guard-elim(*);
   fork-coalesce(*);
 }
+const-inline(*);
 
 gcm(*);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 3c288ca7..bd27350a 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -112,6 +112,7 @@ impl FromStr for Appliable {
             "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)),
             "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)),
             "clean-monoid-reduces" => Ok(Appliable::Pass(ir::Pass::CleanMonoidReduces)),
+            "const-inline" => Ok(Appliable::Pass(ir::Pass::ConstInline)),
             "dce" => Ok(Appliable::Pass(ir::Pass::DCE)),
             "delete-uncalled" => Ok(Appliable::DeleteUncalled),
             "float-collections" | "collections" => Ok(Appliable::Pass(ir::Pass::FloatCollections)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 3a087c0d..a0db8844 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -8,8 +8,9 @@ pub enum Pass {
     ArrayToProduct,
     AutoOutline,
     CCP,
-    CleanMonoidReduces,
     CRC,
+    CleanMonoidReduces,
+    ConstInline,
     DCE,
     FloatCollections,
     ForkChunk,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 5f2fa4cc..3c142700 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1833,6 +1833,24 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ConstInline => {
+            assert!(args.is_empty());
+            pm.make_callgraph();
+            let callgraph = pm.callgraph.take().unwrap();
+
+            let mut editors: Vec<_> = build_selection(pm, selection, true)
+                .into_iter()
+                .map(|editor| editor.unwrap())
+                .collect();
+            const_inline(&mut editors, &callgraph);
+
+            for func in editors {
+                changed |= func.modified();
+            }
+
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::CRC => {
             assert!(args.is_empty());
             for func in build_selection(pm, selection, false) {
-- 
GitLab