From 4572a7ec7ba13913b86127bdd3fa2edfa73f26ac Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 27 Feb 2025 12:58:53 -0600
Subject: [PATCH] Inline constants

---
 hercules_opt/src/editor.rs           |  12 +++
 hercules_opt/src/inline.rs           | 126 +++++++++++++++++++++++----
 juno_samples/rodinia/bfs/src/cpu.sch |   1 +
 3 files changed, 121 insertions(+), 18 deletions(-)

diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs
index 17cea325..0e332a00 100644
--- a/hercules_opt/src/editor.rs
+++ b/hercules_opt/src/editor.rs
@@ -880,6 +880,18 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
         }
     }
 
+    pub fn get_param_types(&self) -> &Vec<TypeID> {
+        self.updated_param_types
+            .as_ref()
+            .unwrap_or(&self.editor.function.param_types)
+    }
+
+    pub fn get_return_types(&self) -> &Vec<TypeID> {
+        self.updated_return_types
+            .as_ref()
+            .unwrap_or(&self.editor.function.return_types)
+    }
+
     pub fn set_param_types(&mut self, tys: Vec<TypeID>) {
         self.updated_param_types = Some(tys);
     }
diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs
index c94ad962..895f1f73 100644
--- a/hercules_opt/src/inline.rs
+++ b/hercules_opt/src/inline.rs
@@ -235,7 +235,7 @@ fn inline_func(
     }
 }
 
-#[derive(Clone, Debug, Copy)]
+#[derive(Clone, Debug, Copy, PartialEq, Eq)]
 enum ParameterLattice {
     Top,
     Constant(ConstantID),
@@ -331,6 +331,7 @@ pub fn const_inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) {
         // constant) to this function.
         let mut param_lattice = vec![ParameterLattice::Top; func.param_types.len()];
         let mut dc_param_lattice = vec![ParameterLattice::Top; func.num_dynamic_constants as usize];
+        let mut callers = vec![];
         for caller in callgraph.get_callers(func_id) {
             let editor = &editors[caller.idx()];
             let nodes = &editor.func().nodes;
@@ -338,28 +339,117 @@ pub fn const_inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) {
                 if let Some((_, callee, dc_args, args)) = nodes[id.idx()].try_call()
                     && callee == func_id
                 {
-                    for (idx, id) in args.into_iter().enumerate() {
-                        let lattice = ParameterLattice::from_node(&nodes[id.idx()], callee);
-                        param_lattice[idx].meet(
-                            lattice,
-                            editor.get_constants(),
-                            editor.get_dynamic_constants(),
-                        );
+                    if editor.is_mutable(id) {
+                        for (idx, id) in args.into_iter().enumerate() {
+                            let lattice = ParameterLattice::from_node(&nodes[id.idx()], callee);
+                            param_lattice[idx].meet(
+                                lattice,
+                                editor.get_constants(),
+                                editor.get_dynamic_constants(),
+                            );
+                        }
+
+                        for (idx, id) in dc_args.into_iter().enumerate() {
+                            let lattice = ParameterLattice::DynamicConstant(*id, func_id);
+                            dc_param_lattice[idx].meet(
+                                lattice,
+                                editor.get_constants(),
+                                editor.get_dynamic_constants(),
+                            );
+                        }
+                    } else {
+                        // If we can't modify the call node in the caller, then
+                        // we can't perform the inlining.
+                        param_lattice = vec![ParameterLattice::Bottom; func.param_types.len()];
+                        dc_param_lattice =
+                            vec![ParameterLattice::Bottom; func.num_dynamic_constants as usize];
                     }
+                    callers.push((caller, id));
+                }
+            }
+        }
+        if param_lattice.iter().all(|v| *v == ParameterLattice::Bottom)
+            && dc_param_lattice
+                .iter()
+                .all(|v| *v == ParameterLattice::Bottom)
+        {
+            continue;
+        }
 
-                    for (idx, id) in dc_args.into_iter().enumerate() {
-                        let lattice = ParameterLattice::DynamicConstant(*id, func_id);
-                        dc_param_lattice[idx].meet(
-                            lattice,
-                            editor.get_constants(),
-                            editor.get_dynamic_constants(),
-                        );
+        // Replace the arguments.
+        let editor = &mut editors[func_id.idx()];
+        let mut param_idx_to_ids: HashMap<usize, Vec<NodeID>> = HashMap::new();
+        for id in editor.node_ids() {
+            if let Some(idx) = editor.func().nodes[id.idx()].try_parameter() {
+                param_idx_to_ids.entry(idx).or_default().push(id);
+            }
+        }
+        let mut params_to_remove = vec![];
+        let success = editor.edit(|mut edit| {
+            let mut param_tys = edit.get_param_types().clone();
+            let mut decrement_index_by = 0;
+            for idx in 0..param_tys.len() {
+                if let Some(node) = match param_lattice[idx] {
+                    ParameterLattice::Top => Some(Node::Undef { ty: param_tys[idx] }),
+                    ParameterLattice::Constant(id) => Some(Node::Constant { id }),
+                    _ => None,
+                } && let Some(ids) = param_idx_to_ids.get(&idx)
+                {
+                    let node = edit.add_node(node);
+                    for id in ids {
+                        edit = edit.replace_all_uses(*id, node)?;
+                        edit = edit.delete_node(*id)?;
+                    }
+                    param_tys.remove(idx - decrement_index_by);
+                    params_to_remove.push(idx);
+                    decrement_index_by += 1;
+                } else if decrement_index_by != 0
+                    && let Some(ids) = param_idx_to_ids.get(&idx)
+                {
+                    let node = edit.add_node(Node::Parameter {
+                        index: idx - decrement_index_by,
+                    });
+                    for id in ids {
+                        edit = edit.replace_all_uses(*id, node)?;
+                        edit = edit.delete_node(*id)?;
                     }
                 }
             }
+            edit.set_param_types(param_tys);
+            Ok(edit)
+        });
+        params_to_remove.reverse();
+
+        // Update callers.
+        if success {
+            for (caller, call) in callers {
+                let editor = &mut editors[caller.idx()];
+                let success = editor.edit(|mut edit| {
+                    let Node::Call {
+                        control,
+                        function,
+                        dynamic_constants,
+                        args,
+                    } = edit.get_node(call).clone()
+                    else {
+                        panic!();
+                    };
+                    let mut args = args.into_vec();
+                    for idx in params_to_remove.iter() {
+                        args.remove(*idx);
+                    }
+                    let node = edit.add_node(Node::Call {
+                        control,
+                        function,
+                        dynamic_constants,
+                        args: args.into_boxed_slice(),
+                    });
+                    edit = edit.replace_all_uses(call, node)?;
+                    edit = edit.delete_node(call)?;
+                    Ok(edit)
+                });
+                assert!(success);
+            }
         }
-        println!("{}:", func.name);
-        println!("{:?}", param_lattice);
-        println!("{:?}", dc_param_lattice);
     }
 }
diff --git a/juno_samples/rodinia/bfs/src/cpu.sch b/juno_samples/rodinia/bfs/src/cpu.sch
index d201765e..04ffeaa6 100644
--- a/juno_samples/rodinia/bfs/src/cpu.sch
+++ b/juno_samples/rodinia/bfs/src/cpu.sch
@@ -23,5 +23,6 @@ fixpoint {
   fork-coalesce(*);
 }
 const-inline(*);
+simpl!(*);
 
 gcm(*);
-- 
GitLab