From ff03fe68238c360455e1240766a77bd3ce494052 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 28 Feb 2025 18:06:10 -0600
Subject: [PATCH] ip-sroa params first attempt

---
 hercules_ir/src/typecheck.rs             |   8 +-
 hercules_opt/src/interprocedural_sroa.rs | 123 ++++++++++++++++++-----
 2 files changed, 106 insertions(+), 25 deletions(-)

diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index 2a3f9fb1..b2567b8f 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -716,8 +716,10 @@ fn typeflow(
             // Check number of run-time arguments.
             if inputs.len() - 1 != callee.param_types.len() {
                 return Error(format!(
-                    "Call node has {} inputs, but calls a function with {} parameters.",
+                    "Call node in {} has {} inputs, but calls a function ({}) with {} parameters.",
+                    function.name,
                     inputs.len() - 1,
+                    callee.name,
                     callee.param_types.len(),
                 ));
             }
@@ -725,8 +727,10 @@ fn typeflow(
             // Check number of dynamic constant arguments.
             if dc_args.len() != callee.num_dynamic_constants as usize {
                 return Error(format!(
-                    "Call node references {} dynamic constants, but calls a function expecting {} dynamic constants.",
+                    "Call node in {} references {} dynamic constants, but calls a function ({}) expecting {} dynamic constants.",
+                    function.name,
                     dc_args.len(),
+                    callee.name,
                     callee.num_dynamic_constants
                 ));
             }
diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs
index ad4ce19e..044d3590 100644
--- a/hercules_opt/src/interprocedural_sroa.rs
+++ b/hercules_opt/src/interprocedural_sroa.rs
@@ -39,46 +39,91 @@ pub fn interprocedural_sroa(
         }
 
         let editor: &mut FunctionEditor = &mut editors[func_id.idx()];
+        let param_types = &editor.func().param_types.to_vec();
         let return_types = &editor.func().return_types.to_vec();
 
-        // We determine the new return types of the function and track a map
-        // that tells us how the old return values are constructed from the
-        // new ones
+        // We determine the new param/return types of the function and track a
+        // map that tells us how the old param/return values are constructed
+        // from the new ones.
+        let mut new_param_types = vec![];
+        let mut old_param_type_map = vec![];
         let mut new_return_types = vec![];
         let mut old_return_type_map = vec![];
         let mut changed = false;
 
-        for ret_typ in return_types.iter() {
+        for ret_typ in param_types.iter() {
             if !can_sroa_type(editor, *ret_typ) {
+                old_param_type_map.push(IndexTree::Leaf(new_param_types.len()));
+                new_param_types.push(*ret_typ);
+            } else {
+                let (types, index) = sroa_type(editor, *ret_typ, new_param_types.len());
+                old_param_type_map.push(index);
+                new_param_types.extend(types);
+                changed = true;
+            }
+        }
+
+        for par_typ in return_types.iter() {
+            if !can_sroa_type(editor, *par_typ) {
                 old_return_type_map.push(IndexTree::Leaf(new_return_types.len()));
-                new_return_types.push(*ret_typ);
+                new_return_types.push(*par_typ);
             } else {
-                let (types, index) = sroa_type(editor, *ret_typ, new_return_types.len());
+                let (types, index) = sroa_type(editor, *par_typ, new_return_types.len());
                 old_return_type_map.push(index);
                 new_return_types.extend(types);
                 changed = true;
             }
         }
 
-        // If the return type is not changed by IP SROA, skip to the next function
+        // If the param/return types aren't changed by IP SROA, skip to the next
+        // function.
         if !changed {
             continue;
         }
 
-        // Now, modify each return in the current function and the return type
-        let return_nodes = editor
-            .func()
-            .nodes
-            .iter()
-            .enumerate()
-            .filter_map(|(idx, node)| {
-                if node.try_return().is_some() {
-                    Some(NodeID::new(idx))
+        // Modify each parameter in the current function and the param types.
+        let mut param_nodes: Vec<_> = vec![vec![]; param_types.len()];
+        for id in editor.node_ids() {
+            if let Some(idx) = editor.func().nodes[id.idx()].try_parameter() {
+                param_nodes[idx].push(id);
+            }
+        }
+        println!("{}", editor.func().name);
+        let success = editor.edit(|mut edit| {
+            for (idx, ids) in param_nodes.into_iter().enumerate() {
+                let new_indices = &old_param_type_map[idx];
+                let built = if let IndexTree::Leaf(new_idx) = new_indices {
+                    edit.add_node(Node::Parameter { index: *new_idx })
                 } else {
-                    None
+                    let prod_ty = param_types[idx];
+                    let cons = edit.add_zero_constant(prod_ty);
+                    let mut cons = edit.add_node(Node::Constant { id: cons });
+                    new_indices.for_each(|idx: &Vec<Index>, param_idx: &usize| {
+                        let param = edit.add_node(Node::Parameter { index: *param_idx });
+                        cons = edit.add_node(Node::Write {
+                            collect: cons,
+                            data: param,
+                            indices: idx.clone().into_boxed_slice(),
+                        });
+                    });
+                    cons
+                };
+                for id in ids {
+                    edit = edit.replace_all_uses(id, built)?;
+                    edit = edit.delete_node(id)?;
                 }
-            })
-            .collect::<Vec<_>>();
+            }
+
+            edit.set_param_types(new_param_types);
+            Ok(edit)
+        });
+        assert!(success, "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument");
+
+        // Modify each return in the current function and the return types.
+        let return_nodes: Vec<_> = editor
+            .node_ids()
+            .filter(|id| editor.func().nodes[id.idx()].is_return())
+            .collect();
         let success = editor.edit(|mut edit| {
             for node in return_nodes {
                 let Node::Return { control, data } = edit.get_node(node) else {
@@ -114,17 +159,15 @@ pub fn interprocedural_sroa(
             }
 
             edit.set_return_types(new_return_types);
-
             Ok(edit)
         });
         assert!(success, "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument");
 
-        // Finally, update calls of this function
-        // In particular, we actually don't have to update the call node at all but have to update
-        // its DataProjection users
+        // Finally, update calls of this function.
         for (caller, callsite) in callsites {
             let editor = &mut editors[caller.idx()];
             assert!(editor.func_id() == caller);
+
             let projs = editor.get_users(callsite).collect::<Vec<_>>();
             for proj_id in projs {
                 let Node::DataProjection { data: _, selection } = editor.node(proj_id) else {
@@ -134,6 +177,40 @@ pub fn interprocedural_sroa(
                 let typ = types[caller.idx()][proj_id.idx()];
                 replace_returned_value(editor, proj_id, typ, new_return_info, callsite);
             }
+
+            let (control, callee, dc_args, args) =
+                editor.func().nodes[callsite.idx()].try_call().unwrap();
+            let dc_args = dc_args.clone();
+            let args = args.clone();
+            let success = editor.edit(|mut edit| {
+                let mut new_args = vec![];
+                for (idx, (data_id, update_info)) in
+                    args.iter().zip(old_param_type_map.iter()).enumerate()
+                {
+                    if let IndexTree::Leaf(new_idx) = update_info {
+                        // Unchanged parameter value
+                        assert!(new_args.len() == *new_idx);
+                        new_args.push(*data_id);
+                    } else {
+                        // SROA'd parameter value
+                        let reads = generate_reads_edit(&mut edit, param_types[idx], *data_id);
+                        reads.zip(update_info).for_each(|_, (read_id, ret_idx)| {
+                            assert!(new_args.len() == **ret_idx);
+                            new_args.push(*read_id);
+                        });
+                    }
+                }
+                let new_call = edit.add_node(Node::Call {
+                    control,
+                    function: callee,
+                    dynamic_constants: dc_args,
+                    args: new_args.into_boxed_slice(),
+                });
+                edit = edit.replace_all_uses(callsite, new_call)?;
+                edit = edit.delete_node(callsite)?;
+                Ok(edit)
+            });
+            assert!(success);
         }
     }
 }
-- 
GitLab