From 6114f5fe4e8b9bf40e2cce6abb6be8f978d5ee7a Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 26 Nov 2024 12:18:16 -0600 Subject: [PATCH] CCP using editor --- hercules_opt/src/ccp.rs | 289 ++++++++++++++++----------------------- hercules_opt/src/lib.rs | 4 +- hercules_opt/src/pass.rs | 27 +++- 3 files changed, 146 insertions(+), 174 deletions(-) diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 57559079..dba29f23 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -1,12 +1,14 @@ extern crate hercules_ir; -use std::collections::HashMap; +use std::collections::HashSet; use std::iter::zip; use self::hercules_ir::dataflow::*; use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; +use crate::*; + /* * The ccp lattice tracks, for each node, the following information: * 1. Reachability - is it possible for this node to be reached during any @@ -160,173 +162,138 @@ macro_rules! binary_float_intrinsic { /* * Top level function to run conditional constant propagation. */ -pub fn ccp( - function: &mut Function, - constants: &mut Vec<Constant>, - def_use: &ImmutableDefUseMap, - reverse_postorder: &Vec<NodeID>, -) { +pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) { // Step 1: run ccp analysis to understand the function. - let result = dataflow_global(&function, reverse_postorder, |inputs, node_id| { - ccp_flow_function(inputs, node_id, &function, &constants) + let result = dataflow_global(editor.func(), reverse_postorder, |inputs, node_id| { + ccp_flow_function(inputs, node_id, editor) }); - // Step 2: update uses of constants. Any node that doesn't produce a - // constant value, but does use a newly found constant value, needs to be - // updated to use the newly found constant. - - // Step 2.1: assemble reverse constant map. We created a bunch of constants - // during the analysis, so we need to intern them. - let mut reverse_constant_map: HashMap<Constant, ConstantID> = constants - .iter() - .enumerate() - .map(|(idx, cons)| (cons.clone(), ConstantID::new(idx))) - .collect(); - - // Helper function for interning constants in the lattice. - let mut get_constant_id = |cons| { - if let Some(id) = reverse_constant_map.get(&cons) { - *id - } else { - let id = ConstantID::new(reverse_constant_map.len()); - reverse_constant_map.insert(cons.clone(), id); - id + // Step 2: propagate constants. For each node that was found to have a constant value, we + // create a node for that constant value, replace uses of the original node with the constant, + // and finally delete the original node + let mut unreachable: HashSet<NodeID> = HashSet::new(); + for (idx, res) in result.into_iter().enumerate() { + let old_id = NodeID::new(idx); + if let Some(cons) = res.get_constant() { + assert!(!editor.func().nodes[idx].is_control()); + editor.edit(|mut edit| { + // Get the ConstantID of this constant value. + let cons_id = edit.add_constant(cons); + // Add a constant IR node for this constant + let cons_node = edit.add_node(Node::Constant { id: cons_id }); + // Replace the original node with the constant node + edit = edit.replace_all_uses(old_id, cons_node)?; + edit.delete_node(old_id) + }); } - }; - - // Step 2.2: for every node, update uses of now constant nodes. We need to - // separately create constant nodes, since we are mutably looping over the - // function nodes separately. - let mut new_constant_nodes = vec![]; - let base_cons_node_idx = function.nodes.len(); - for node in function.nodes.iter_mut() { - for u in get_uses_mut(node).as_mut() { - let old_id = **u; - if let Some(cons) = result[old_id.idx()].get_constant() { - // Get ConstantID for this constant. - let cons_id = get_constant_id(cons); - - // Search new_constant_nodes for a constant IR node that already - // referenced this ConstantID. - if let Some(new_nodes_idx) = new_constant_nodes - .iter() - .enumerate() - .filter(|(_, id)| **id == cons_id) - .map(|(idx, _)| idx) - .next() - { - // If there is already a constant IR node, calculate what - // the NodeID will be for it, and set the use to that ID. - **u = NodeID::new(base_cons_node_idx + new_nodes_idx); - } else { - // If there is not already a constant IR node for this - // ConstantID, add this ConstantID to the new_constant_nodes - // list. Set the use to the corresponding NodeID for the new - // constant IR node. - let cons_node_id = NodeID::new(base_cons_node_idx + new_constant_nodes.len()); - new_constant_nodes.push(cons_id); - **u = cons_node_id; - } - } + if !res.is_reachable() { + unreachable.insert(old_id); } } - // Step 2.3: add new constant nodes into nodes of function. - for node in new_constant_nodes { - function.nodes.push(Node::Constant { id: node }); - } - - // Step 2.4: re-create module's constants vector from interning map. - *constants = vec![Constant::Boolean(false); reverse_constant_map.len()]; - for (cons, id) in reverse_constant_map { - constants[id.idx()] = cons; - } - // Step 3: delete dead branches. Any nodes that are unreachable should be // deleted. Any if or match nodes that now are light on users need to be // removed immediately, since if and match nodes have requirements on the // number of users. - // Step 3.1: delete unreachable nodes. Loop over the length of the dataflow - // result instead of the function's node list, since in step 2, constant - // nodes were added that don't have a corresponding lattice result. - for idx in 0..result.len() { - if !result[idx].is_reachable() { - function.nodes[idx] = Node::Start; + // Step 3.1: remove uses of data nodes in phi nodes corresponding to + // unreachable uses in corresponding region nodes. + for phi_id in (0..editor.func().nodes.len()).map(NodeID::new) { + if unreachable.contains(&phi_id) { + continue; } - } - // Step 3.2: remove uses of data nodes in phi nodes corresponding to - // unreachable uses in corresponding region nodes. - for phi_id in (0..result.len()).map(NodeID::new) { - if let Node::Phi { control, data } = &function.nodes[phi_id.idx()] { - if let Node::Region { preds } = &function.nodes[control.idx()] { + if let Node::Phi { control, data } = &editor.func().nodes[phi_id.idx()] { + if let Node::Region { preds } = &editor.func().nodes[control.idx()] { + let control = *control; let new_data = zip(preds.iter(), data.iter()) - .filter(|(pred, _)| result[pred.idx()].is_reachable()) + .filter(|(pred, _)| !unreachable.contains(pred)) .map(|(_, datum)| *datum) .collect(); - function.nodes[phi_id.idx()] = Node::Phi { - control: *control, - data: new_data, - }; + editor.edit(|mut edit| { + let new_node = edit.add_node(Node::Phi { + control, + data: new_data, + }); + edit = edit.replace_all_uses(phi_id, new_node)?; + edit.delete_node(phi_id) + }); } } } - // Step 3.3: remove uses of unreachable nodes in region nodes. - for node in function.nodes.iter_mut() { - if let Node::Region { preds } = node { - *preds = preds + // Step 3.2: remove uses of unreachable nodes in region nodes. + for region_id in (0..editor.func().nodes.len()).map(NodeID::new) { + if unreachable.contains(®ion_id) { + continue; + } + + if let Node::Region { preds } = &editor.func().nodes[region_id.idx()] { + let new_preds = preds .iter() - .filter(|pred| result[pred.idx()].is_reachable()) + .filter(|pred| !unreachable.contains(pred)) .map(|x| *x) .collect(); + editor.edit(|mut edit| { + let new_node = edit.add_node(Node::Region { preds: new_preds }); + edit = edit.replace_all_uses(region_id, new_node)?; + edit.delete_node(region_id) + }); } } - // Step 3.4: remove if and match nodes with one reachable user. - for branch_id in (0..result.len()).map(NodeID::new) { + // Step 3.3: remove if and match nodes with one reachable user. + for branch_id in (0..editor.func().nodes.len()).map(NodeID::new) { + if unreachable.contains(&branch_id) { + continue; + } + if let Node::If { control, cond: _ } | Node::Match { control, sum: _ } = - function.nodes[branch_id.idx()].clone() + &editor.func().nodes[branch_id.idx()] { - let users = def_use.get_users(branch_id); - let mut reachable_users = users + let control = *control; + + let users = editor.get_users(branch_id).collect::<Vec<NodeID>>(); + let reachable_users = users .iter() - .filter(|user| result[user.idx()].is_reachable()); - let the_reachable_user = reachable_users - .next() - .expect("During CCP, found a branch with no reachable users."); + .map(|x| *x) + .filter(|user| !unreachable.contains(user)) + .collect::<Vec<NodeID>>(); + let the_reachable_user = reachable_users[0]; // The reachable users iterator will contain one user if we need to // remove this branch node. - if let None = reachable_users.next() { + if reachable_users.len() == 1 { // The user is a Read node, which in turn has one user. assert!( - def_use.get_users(*the_reachable_user).len() == 1, + editor.get_users(the_reachable_user).len() == 1, "Control Read node doesn't have exactly one user." ); - let target = def_use.get_users(*the_reachable_user)[0]; - // For each use in the target of the reachable Read, turn it - // into a use of the node proceeding this branch node. - for u in get_uses_mut(&mut function.nodes[target.idx()]).as_mut() { - if **u == *the_reachable_user { - **u = control; + editor.edit(|mut edit| { + // Replace all uses of the single reachable user with the node preceeding the + // branch node + edit = edit.replace_all_uses(the_reachable_user, control)?; + // Delete all users and the branch node + for user in users { + edit = edit.delete_node(user)?; } - } - - // Remove this branch node, since it is malformed. Also remove - // all successor Read nodes. - function.nodes[branch_id.idx()] = Node::Start; - for user in users { - function.nodes[user.idx()] = Node::Start; - } + edit.delete_node(branch_id) + }); } } } + // Step 3.4: delete unreachable nodes. + editor.edit(|mut edit| { + for node in unreachable { + edit = edit.delete_node(node)?; + } + Ok(edit) + }); + // Step 4: collapse region chains. - collapse_region_chains(function, def_use); + collapse_region_chains(editor); } /* @@ -335,47 +302,39 @@ pub fn ccp( * deleted. The use of the head of the chain can turn into the use by the user * of the tail of the chain. */ -pub fn collapse_region_chains(function: &mut Function, def_use: &ImmutableDefUseMap) { +pub fn collapse_region_chains(editor: &mut FunctionEditor) { + let num_nodes = editor.func().nodes.len(); // Loop over all region nodes. It's fine to modify the function as we loop // over it. - for id in (0..function.nodes.len()).map(NodeID::new) { - if let Node::Region { preds } = &function.nodes[id.idx()] { - let has_call_user = def_use + for id in (0..num_nodes).map(NodeID::new) { + if let Node::Region { preds } = &editor.func().nodes[id.idx()] { + let has_call_user = editor .get_users(id) - .iter() - .any(|x| function.nodes[x.idx()].is_call()); + .any(|x| editor.func().nodes[x.idx()].is_call()); if preds.len() == 1 && !has_call_user { // Step 1: bridge gap between use and user. let predecessor = preds[0]; - let successor = def_use + let successor = editor .get_users(id) - .iter() - .filter(|x| !function.nodes[x.idx()].is_phi()) + .filter(|x| !editor.func().nodes[x.idx()].is_phi()) .next() .expect("Region node doesn't have a non-phi user."); - // Set successor's use of this region to use the region's use. - for u in get_uses_mut(&mut function.nodes[successor.idx()]).as_mut() { - if **u == id { - **u = predecessor; - } - } - - // Delete this region. - function.nodes[id.idx()] = Node::Start; + editor.edit(|edit| { + // Set successor's use of this region to use the region's use. + edit.replace_all_uses_where(id, predecessor, |n| *n == successor) + }); // Step 2: bridge gap between uses and users of corresponding // phi nodes. - let phis: Vec<NodeID> = def_use + let phis: Vec<NodeID> = editor .get_users(id) - .iter() - .map(|x| *x) - .filter(|x| function.nodes[x.idx()].is_phi()) + .filter(|x| editor.func().nodes[x.idx()].is_phi()) .collect(); for phi_id in phis { let data_uses = - if let Node::Phi { control, data } = &function.nodes[phi_id.idx()] { + if let Node::Phi { control, data } = &editor.func().nodes[phi_id.idx()] { assert!(*control == id); data } else { @@ -384,18 +343,16 @@ pub fn collapse_region_chains(function: &mut Function, def_use: &ImmutableDefUse assert!(data_uses.len() == 1, "Phi node doesn't have exactly one data use, while corresponding region had exactly one control use."); let predecessor = data_uses[0]; - // Set successors' use of this phi to use the phi's use. - for successor in def_use.get_users(phi_id) { - for u in get_uses_mut(&mut function.nodes[successor.idx()]).as_mut() { - if **u == phi_id { - **u = predecessor; - } - } - } - - // Delete this phi. - function.nodes[phi_id.idx()] = Node::Start; + editor.edit(|mut edit| { + // Set successors' use of this phi to use the phi's use. + edit = edit.replace_all_uses(phi_id, predecessor)?; + // Delete this phi. + edit.delete_node(phi_id) + }); } + + // Delete this region. + editor.edit(|edit| edit.delete_node(id)); } } } @@ -404,10 +361,9 @@ pub fn collapse_region_chains(function: &mut Function, def_use: &ImmutableDefUse fn ccp_flow_function( inputs: &[CCPLattice], node_id: NodeID, - function: &Function, - old_constants: &Vec<Constant>, + editor: &FunctionEditor, ) -> CCPLattice { - let node = &function.nodes[node_id.idx()]; + let node = &editor.func().nodes[node_id.idx()]; match node { Node::Start => CCPLattice::bottom(), Node::Region { preds } => preds.iter().fold(CCPLattice::top(), |val, id| { @@ -427,7 +383,7 @@ fn ccp_flow_function( // output. Node::Phi { control, data } => { // Get the control predecessors of the corresponding region. - let region_preds = if let Node::Region { preds } = &function.nodes[control.idx()] { + let region_preds = if let Node::Region { preds } = &editor.func().nodes[control.idx()] { preds } else { panic!("A phi's control input must be a region node.") @@ -461,15 +417,12 @@ fn ccp_flow_function( init: _, reduct: _, } => inputs[control.idx()].clone(), - Node::Return { control, data } => CCPLattice { - reachability: inputs[control.idx()].reachability.clone(), - constant: inputs[data.idx()].constant.clone(), - }, + Node::Return { control, data } => inputs[control.idx()].clone(), Node::Parameter { index: _ } => CCPLattice::bottom(), // A constant node is the "source" of concrete constant lattice values. Node::Constant { id } => CCPLattice { reachability: ReachabilityLattice::bottom(), - constant: ConstantLattice::Constant(old_constants[id.idx()].clone()), + constant: ConstantLattice::Constant(editor.get_constant(*id).clone()), }, // TODO: This should really be constant interpreted, since dynamic // constants as values are used frequently. @@ -886,7 +839,7 @@ fn ccp_flow_function( constant: ConstantLattice::bottom(), }, // Projection handles reachability when following an if or match. - Node::Projection { control, selection } => match &function.nodes[control.idx()] { + Node::Projection { control, selection } => match &editor.func().nodes[control.idx()] { Node::If { control: _, cond } => { let cond_constant = &inputs[cond.idx()].constant; let if_reachability = &inputs[control.idx()].reachability; diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index dbd66012..87a894ed 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -6,8 +6,8 @@ pub mod editor; pub mod fork_guard_elim; pub mod forkify; pub mod gvn; -pub mod interprocedural_sroa; pub mod inline; +pub mod interprocedural_sroa; pub mod pass; pub mod phi_elim; pub mod pred; @@ -19,8 +19,8 @@ pub use crate::editor::*; pub use crate::fork_guard_elim::*; pub use crate::forkify::*; pub use crate::gvn::*; -pub use crate::interprocedural_sroa::*; pub use crate::inline::*; +pub use crate::interprocedural_sroa::*; pub use crate::pass::*; pub use crate::phi_elim::*; pub use crate::pred::*; diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index f69958e2..177ce0c5 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -427,14 +427,33 @@ impl PassManager { let def_uses = self.def_uses.as_ref().unwrap(); let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); for idx in 0..self.module.functions.len() { - ccp( + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( &mut self.module.functions[idx], - &mut self.module.constants, + &constants_ref, + &dynamic_constants_ref, + &types_ref, &def_uses[idx], - &reverse_postorders[idx], ); + ccp(&mut editor, &reverse_postorders[idx]); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + + let edits = &editor.edits(); + if let Some(plans) = self.plans.as_mut() { + repair_plan(&mut plans[idx], &self.module.functions[idx], edits); + } + let grave_mapping = self.module.functions[idx].delete_gravestones(); + if let Some(plans) = self.plans.as_mut() { + plans[idx].fix_gravestones(&grave_mapping); + } } - self.legacy_repair_plan(); self.clear_analyses(); } Pass::GVN => { -- GitLab