From 6114f5fe4e8b9bf40e2cce6abb6be8f978d5ee7a Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 26 Nov 2024 12:18:16 -0600
Subject: [PATCH] CCP using editor

---
 hercules_opt/src/ccp.rs  | 289 ++++++++++++++++-----------------------
 hercules_opt/src/lib.rs  |   4 +-
 hercules_opt/src/pass.rs |  27 +++-
 3 files changed, 146 insertions(+), 174 deletions(-)

diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs
index 57559079..dba29f23 100644
--- a/hercules_opt/src/ccp.rs
+++ b/hercules_opt/src/ccp.rs
@@ -1,12 +1,14 @@
 extern crate hercules_ir;
 
-use std::collections::HashMap;
+use std::collections::HashSet;
 use std::iter::zip;
 
 use self::hercules_ir::dataflow::*;
 use self::hercules_ir::def_use::*;
 use self::hercules_ir::ir::*;
 
+use crate::*;
+
 /*
  * The ccp lattice tracks, for each node, the following information:
  * 1. Reachability - is it possible for this node to be reached during any
@@ -160,173 +162,138 @@ macro_rules! binary_float_intrinsic {
 /*
  * Top level function to run conditional constant propagation.
  */
-pub fn ccp(
-    function: &mut Function,
-    constants: &mut Vec<Constant>,
-    def_use: &ImmutableDefUseMap,
-    reverse_postorder: &Vec<NodeID>,
-) {
+pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) {
     // Step 1: run ccp analysis to understand the function.
-    let result = dataflow_global(&function, reverse_postorder, |inputs, node_id| {
-        ccp_flow_function(inputs, node_id, &function, &constants)
+    let result = dataflow_global(editor.func(), reverse_postorder, |inputs, node_id| {
+        ccp_flow_function(inputs, node_id, editor)
     });
 
-    // Step 2: update uses of constants. Any node that doesn't produce a
-    // constant value, but does use a newly found constant value, needs to be
-    // updated to use the newly found constant.
-
-    // Step 2.1: assemble reverse constant map. We created a bunch of constants
-    // during the analysis, so we need to intern them.
-    let mut reverse_constant_map: HashMap<Constant, ConstantID> = constants
-        .iter()
-        .enumerate()
-        .map(|(idx, cons)| (cons.clone(), ConstantID::new(idx)))
-        .collect();
-
-    // Helper function for interning constants in the lattice.
-    let mut get_constant_id = |cons| {
-        if let Some(id) = reverse_constant_map.get(&cons) {
-            *id
-        } else {
-            let id = ConstantID::new(reverse_constant_map.len());
-            reverse_constant_map.insert(cons.clone(), id);
-            id
+    // Step 2: propagate constants. For each node that was found to have a constant value, we
+    // create a node for that constant value, replace uses of the original node with the constant,
+    // and finally delete the original node
+    let mut unreachable: HashSet<NodeID> = HashSet::new();
+    for (idx, res) in result.into_iter().enumerate() {
+        let old_id = NodeID::new(idx);
+        if let Some(cons) = res.get_constant() {
+            assert!(!editor.func().nodes[idx].is_control());
+            editor.edit(|mut edit| {
+                // Get the ConstantID of this constant value.
+                let cons_id = edit.add_constant(cons);
+                // Add a constant IR node for this constant
+                let cons_node = edit.add_node(Node::Constant { id: cons_id });
+                // Replace the original node with the constant node
+                edit = edit.replace_all_uses(old_id, cons_node)?;
+                edit.delete_node(old_id)
+            });
         }
-    };
-
-    // Step 2.2: for every node, update uses of now constant nodes. We need to
-    // separately create constant nodes, since we are mutably looping over the
-    // function nodes separately.
-    let mut new_constant_nodes = vec![];
-    let base_cons_node_idx = function.nodes.len();
-    for node in function.nodes.iter_mut() {
-        for u in get_uses_mut(node).as_mut() {
-            let old_id = **u;
-            if let Some(cons) = result[old_id.idx()].get_constant() {
-                // Get ConstantID for this constant.
-                let cons_id = get_constant_id(cons);
-
-                // Search new_constant_nodes for a constant IR node that already
-                // referenced this ConstantID.
-                if let Some(new_nodes_idx) = new_constant_nodes
-                    .iter()
-                    .enumerate()
-                    .filter(|(_, id)| **id == cons_id)
-                    .map(|(idx, _)| idx)
-                    .next()
-                {
-                    // If there is already a constant IR node, calculate what
-                    // the NodeID will be for it, and set the use to that ID.
-                    **u = NodeID::new(base_cons_node_idx + new_nodes_idx);
-                } else {
-                    // If there is not already a constant IR node for this
-                    // ConstantID, add this ConstantID to the new_constant_nodes
-                    // list. Set the use to the corresponding NodeID for the new
-                    // constant IR node.
-                    let cons_node_id = NodeID::new(base_cons_node_idx + new_constant_nodes.len());
-                    new_constant_nodes.push(cons_id);
-                    **u = cons_node_id;
-                }
-            }
+        if !res.is_reachable() {
+            unreachable.insert(old_id);
         }
     }
 
-    // Step 2.3: add new constant nodes into nodes of function.
-    for node in new_constant_nodes {
-        function.nodes.push(Node::Constant { id: node });
-    }
-
-    // Step 2.4: re-create module's constants vector from interning map.
-    *constants = vec![Constant::Boolean(false); reverse_constant_map.len()];
-    for (cons, id) in reverse_constant_map {
-        constants[id.idx()] = cons;
-    }
-
     // Step 3: delete dead branches. Any nodes that are unreachable should be
     // deleted. Any if or match nodes that now are light on users need to be
     // removed immediately, since if and match nodes have requirements on the
     // number of users.
 
-    // Step 3.1: delete unreachable nodes. Loop over the length of the dataflow
-    // result instead of the function's node list, since in step 2, constant
-    // nodes were added that don't have a corresponding lattice result.
-    for idx in 0..result.len() {
-        if !result[idx].is_reachable() {
-            function.nodes[idx] = Node::Start;
+    // Step 3.1: remove uses of data nodes in phi nodes corresponding to
+    // unreachable uses in corresponding region nodes.
+    for phi_id in (0..editor.func().nodes.len()).map(NodeID::new) {
+        if unreachable.contains(&phi_id) {
+            continue;
         }
-    }
 
-    // Step 3.2: remove uses of data nodes in phi nodes corresponding to
-    // unreachable uses in corresponding region nodes.
-    for phi_id in (0..result.len()).map(NodeID::new) {
-        if let Node::Phi { control, data } = &function.nodes[phi_id.idx()] {
-            if let Node::Region { preds } = &function.nodes[control.idx()] {
+        if let Node::Phi { control, data } = &editor.func().nodes[phi_id.idx()] {
+            if let Node::Region { preds } = &editor.func().nodes[control.idx()] {
+                let control = *control;
                 let new_data = zip(preds.iter(), data.iter())
-                    .filter(|(pred, _)| result[pred.idx()].is_reachable())
+                    .filter(|(pred, _)| !unreachable.contains(pred))
                     .map(|(_, datum)| *datum)
                     .collect();
-                function.nodes[phi_id.idx()] = Node::Phi {
-                    control: *control,
-                    data: new_data,
-                };
+                editor.edit(|mut edit| {
+                    let new_node = edit.add_node(Node::Phi {
+                        control,
+                        data: new_data,
+                    });
+                    edit = edit.replace_all_uses(phi_id, new_node)?;
+                    edit.delete_node(phi_id)
+                });
             }
         }
     }
 
-    // Step 3.3: remove uses of unreachable nodes in region nodes.
-    for node in function.nodes.iter_mut() {
-        if let Node::Region { preds } = node {
-            *preds = preds
+    // Step 3.2: remove uses of unreachable nodes in region nodes.
+    for region_id in (0..editor.func().nodes.len()).map(NodeID::new) {
+        if unreachable.contains(&region_id) {
+            continue;
+        }
+
+        if let Node::Region { preds } = &editor.func().nodes[region_id.idx()] {
+            let new_preds = preds
                 .iter()
-                .filter(|pred| result[pred.idx()].is_reachable())
+                .filter(|pred| !unreachable.contains(pred))
                 .map(|x| *x)
                 .collect();
+            editor.edit(|mut edit| {
+                let new_node = edit.add_node(Node::Region { preds: new_preds });
+                edit = edit.replace_all_uses(region_id, new_node)?;
+                edit.delete_node(region_id)
+            });
         }
     }
 
-    // Step 3.4: remove if and match nodes with one reachable user.
-    for branch_id in (0..result.len()).map(NodeID::new) {
+    // Step 3.3: remove if and match nodes with one reachable user.
+    for branch_id in (0..editor.func().nodes.len()).map(NodeID::new) {
+        if unreachable.contains(&branch_id) {
+            continue;
+        }
+
         if let Node::If { control, cond: _ } | Node::Match { control, sum: _ } =
-            function.nodes[branch_id.idx()].clone()
+            &editor.func().nodes[branch_id.idx()]
         {
-            let users = def_use.get_users(branch_id);
-            let mut reachable_users = users
+            let control = *control;
+
+            let users = editor.get_users(branch_id).collect::<Vec<NodeID>>();
+            let reachable_users = users
                 .iter()
-                .filter(|user| result[user.idx()].is_reachable());
-            let the_reachable_user = reachable_users
-                .next()
-                .expect("During CCP, found a branch with no reachable users.");
+                .map(|x| *x)
+                .filter(|user| !unreachable.contains(user))
+                .collect::<Vec<NodeID>>();
+            let the_reachable_user = reachable_users[0];
 
             // The reachable users iterator will contain one user if we need to
             // remove this branch node.
-            if let None = reachable_users.next() {
+            if reachable_users.len() == 1 {
                 // The user is a Read node, which in turn has one user.
                 assert!(
-                    def_use.get_users(*the_reachable_user).len() == 1,
+                    editor.get_users(the_reachable_user).len() == 1,
                     "Control Read node doesn't have exactly one user."
                 );
-                let target = def_use.get_users(*the_reachable_user)[0];
 
-                // For each use in the target of the reachable Read, turn it
-                // into a use of the node proceeding this branch node.
-                for u in get_uses_mut(&mut function.nodes[target.idx()]).as_mut() {
-                    if **u == *the_reachable_user {
-                        **u = control;
+                editor.edit(|mut edit| {
+                    // Replace all uses of the single reachable user with the node preceeding the
+                    // branch node
+                    edit = edit.replace_all_uses(the_reachable_user, control)?;
+                    // Delete all users and the branch node
+                    for user in users {
+                        edit = edit.delete_node(user)?;
                     }
-                }
-
-                // Remove this branch node, since it is malformed. Also remove
-                // all successor Read nodes.
-                function.nodes[branch_id.idx()] = Node::Start;
-                for user in users {
-                    function.nodes[user.idx()] = Node::Start;
-                }
+                    edit.delete_node(branch_id)
+                });
             }
         }
     }
 
+    // Step 3.4: delete unreachable nodes.
+    editor.edit(|mut edit| {
+        for node in unreachable {
+            edit = edit.delete_node(node)?;
+        }
+        Ok(edit)
+    });
+
     // Step 4: collapse region chains.
-    collapse_region_chains(function, def_use);
+    collapse_region_chains(editor);
 }
 
 /*
@@ -335,47 +302,39 @@ pub fn ccp(
  * deleted. The use of the head of the chain can turn into the use by the user
  * of the tail of the chain.
  */
-pub fn collapse_region_chains(function: &mut Function, def_use: &ImmutableDefUseMap) {
+pub fn collapse_region_chains(editor: &mut FunctionEditor) {
+    let num_nodes = editor.func().nodes.len();
     // Loop over all region nodes. It's fine to modify the function as we loop
     // over it.
-    for id in (0..function.nodes.len()).map(NodeID::new) {
-        if let Node::Region { preds } = &function.nodes[id.idx()] {
-            let has_call_user = def_use
+    for id in (0..num_nodes).map(NodeID::new) {
+        if let Node::Region { preds } = &editor.func().nodes[id.idx()] {
+            let has_call_user = editor
                 .get_users(id)
-                .iter()
-                .any(|x| function.nodes[x.idx()].is_call());
+                .any(|x| editor.func().nodes[x.idx()].is_call());
 
             if preds.len() == 1 && !has_call_user {
                 // Step 1: bridge gap between use and user.
                 let predecessor = preds[0];
-                let successor = def_use
+                let successor = editor
                     .get_users(id)
-                    .iter()
-                    .filter(|x| !function.nodes[x.idx()].is_phi())
+                    .filter(|x| !editor.func().nodes[x.idx()].is_phi())
                     .next()
                     .expect("Region node doesn't have a non-phi user.");
 
-                // Set successor's use of this region to use the region's use.
-                for u in get_uses_mut(&mut function.nodes[successor.idx()]).as_mut() {
-                    if **u == id {
-                        **u = predecessor;
-                    }
-                }
-
-                // Delete this region.
-                function.nodes[id.idx()] = Node::Start;
+                editor.edit(|edit| {
+                    // Set successor's use of this region to use the region's use.
+                    edit.replace_all_uses_where(id, predecessor, |n| *n == successor)
+                });
 
                 // Step 2: bridge gap between uses and users of corresponding
                 // phi nodes.
-                let phis: Vec<NodeID> = def_use
+                let phis: Vec<NodeID> = editor
                     .get_users(id)
-                    .iter()
-                    .map(|x| *x)
-                    .filter(|x| function.nodes[x.idx()].is_phi())
+                    .filter(|x| editor.func().nodes[x.idx()].is_phi())
                     .collect();
                 for phi_id in phis {
                     let data_uses =
-                        if let Node::Phi { control, data } = &function.nodes[phi_id.idx()] {
+                        if let Node::Phi { control, data } = &editor.func().nodes[phi_id.idx()] {
                             assert!(*control == id);
                             data
                         } else {
@@ -384,18 +343,16 @@ pub fn collapse_region_chains(function: &mut Function, def_use: &ImmutableDefUse
                     assert!(data_uses.len() == 1, "Phi node doesn't have exactly one data use, while corresponding region had exactly one control use.");
                     let predecessor = data_uses[0];
 
-                    // Set successors' use of this phi to use the phi's use.
-                    for successor in def_use.get_users(phi_id) {
-                        for u in get_uses_mut(&mut function.nodes[successor.idx()]).as_mut() {
-                            if **u == phi_id {
-                                **u = predecessor;
-                            }
-                        }
-                    }
-
-                    // Delete this phi.
-                    function.nodes[phi_id.idx()] = Node::Start;
+                    editor.edit(|mut edit| {
+                        // Set successors' use of this phi to use the phi's use.
+                        edit = edit.replace_all_uses(phi_id, predecessor)?;
+                        // Delete this phi.
+                        edit.delete_node(phi_id)
+                    });
                 }
+
+                // Delete this region.
+                editor.edit(|edit| edit.delete_node(id));
             }
         }
     }
@@ -404,10 +361,9 @@ pub fn collapse_region_chains(function: &mut Function, def_use: &ImmutableDefUse
 fn ccp_flow_function(
     inputs: &[CCPLattice],
     node_id: NodeID,
-    function: &Function,
-    old_constants: &Vec<Constant>,
+    editor: &FunctionEditor,
 ) -> CCPLattice {
-    let node = &function.nodes[node_id.idx()];
+    let node = &editor.func().nodes[node_id.idx()];
     match node {
         Node::Start => CCPLattice::bottom(),
         Node::Region { preds } => preds.iter().fold(CCPLattice::top(), |val, id| {
@@ -427,7 +383,7 @@ fn ccp_flow_function(
         // output.
         Node::Phi { control, data } => {
             // Get the control predecessors of the corresponding region.
-            let region_preds = if let Node::Region { preds } = &function.nodes[control.idx()] {
+            let region_preds = if let Node::Region { preds } = &editor.func().nodes[control.idx()] {
                 preds
             } else {
                 panic!("A phi's control input must be a region node.")
@@ -461,15 +417,12 @@ fn ccp_flow_function(
             init: _,
             reduct: _,
         } => inputs[control.idx()].clone(),
-        Node::Return { control, data } => CCPLattice {
-            reachability: inputs[control.idx()].reachability.clone(),
-            constant: inputs[data.idx()].constant.clone(),
-        },
+        Node::Return { control, data } => inputs[control.idx()].clone(),
         Node::Parameter { index: _ } => CCPLattice::bottom(),
         // A constant node is the "source" of concrete constant lattice values.
         Node::Constant { id } => CCPLattice {
             reachability: ReachabilityLattice::bottom(),
-            constant: ConstantLattice::Constant(old_constants[id.idx()].clone()),
+            constant: ConstantLattice::Constant(editor.get_constant(*id).clone()),
         },
         // TODO: This should really be constant interpreted, since dynamic
         // constants as values are used frequently.
@@ -886,7 +839,7 @@ fn ccp_flow_function(
             constant: ConstantLattice::bottom(),
         },
         // Projection handles reachability when following an if or match.
-        Node::Projection { control, selection } => match &function.nodes[control.idx()] {
+        Node::Projection { control, selection } => match &editor.func().nodes[control.idx()] {
             Node::If { control: _, cond } => {
                 let cond_constant = &inputs[cond.idx()].constant;
                 let if_reachability = &inputs[control.idx()].reachability;
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index dbd66012..87a894ed 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -6,8 +6,8 @@ pub mod editor;
 pub mod fork_guard_elim;
 pub mod forkify;
 pub mod gvn;
-pub mod interprocedural_sroa;
 pub mod inline;
+pub mod interprocedural_sroa;
 pub mod pass;
 pub mod phi_elim;
 pub mod pred;
@@ -19,8 +19,8 @@ pub use crate::editor::*;
 pub use crate::fork_guard_elim::*;
 pub use crate::forkify::*;
 pub use crate::gvn::*;
-pub use crate::interprocedural_sroa::*;
 pub use crate::inline::*;
+pub use crate::interprocedural_sroa::*;
 pub use crate::pass::*;
 pub use crate::phi_elim::*;
 pub use crate::pred::*;
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index f69958e2..177ce0c5 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -427,14 +427,33 @@ impl PassManager {
                     let def_uses = self.def_uses.as_ref().unwrap();
                     let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
                     for idx in 0..self.module.functions.len() {
-                        ccp(
+                        let constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.constants));
+                        let dynamic_constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.dynamic_constants));
+                        let types_ref = RefCell::new(std::mem::take(&mut self.module.types));
+                        let mut editor = FunctionEditor::new(
                             &mut self.module.functions[idx],
-                            &mut self.module.constants,
+                            &constants_ref,
+                            &dynamic_constants_ref,
+                            &types_ref,
                             &def_uses[idx],
-                            &reverse_postorders[idx],
                         );
+                        ccp(&mut editor, &reverse_postorders[idx]);
+
+                        self.module.constants = constants_ref.take();
+                        self.module.dynamic_constants = dynamic_constants_ref.take();
+                        self.module.types = types_ref.take();
+
+                        let edits = &editor.edits();
+                        if let Some(plans) = self.plans.as_mut() {
+                            repair_plan(&mut plans[idx], &self.module.functions[idx], edits);
+                        }
+                        let grave_mapping = self.module.functions[idx].delete_gravestones();
+                        if let Some(plans) = self.plans.as_mut() {
+                            plans[idx].fix_gravestones(&grave_mapping);
+                        }
                     }
-                    self.legacy_repair_plan();
                     self.clear_analyses();
                 }
                 Pass::GVN => {
-- 
GitLab