From db19181bc14821a05828a5e76a0eb68410c22457 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Tue, 14 Jan 2025 23:28:22 -0600 Subject: [PATCH] Re-write Predication to use FunctionEditor --- hercules_opt/src/pass.rs | 60 +++- hercules_opt/src/pred.rs | 492 ++++++++++++++------------ juno_frontend/src/lib.rs | 9 + juno_samples/antideps/src/antideps.jn | 8 +- 4 files changed, 323 insertions(+), 246 deletions(-) diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 12444b36..d24b6563 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -25,6 +25,7 @@ pub enum Pass { PhiElim, Forkify, ForkGuardElim, + WritePredication, Predication, SROA, Inline, @@ -469,27 +470,58 @@ impl PassManager { } self.clear_analyses(); } + Pass::WritePredication => { + self.make_def_uses(); + let def_uses = self.def_uses.as_ref().unwrap(); + for idx in 0..self.module.functions.len() { + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( + &mut self.module.functions[idx], + FunctionID::new(idx), + &constants_ref, + &dynamic_constants_ref, + &types_ref, + &def_uses[idx], + ); + write_predication(&mut editor); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + + self.module.functions[idx].delete_gravestones(); + } + self.clear_analyses(); + } Pass::Predication => { self.make_def_uses(); - self.make_reverse_postorders(); - self.make_doms(); - self.make_fork_join_maps(); + self.make_typing(); let def_uses = self.def_uses.as_ref().unwrap(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - let doms = self.doms.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); + let typing = self.typing.as_ref().unwrap(); for idx in 0..self.module.functions.len() { - predication( + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( &mut self.module.functions[idx], + FunctionID::new(idx), + &constants_ref, + &dynamic_constants_ref, + &types_ref, &def_uses[idx], - &reverse_postorders[idx], - &doms[idx], - &fork_join_maps[idx], ); - let num_nodes = self.module.functions[idx].nodes.len(); - self.module.functions[idx] - .schedules - .resize(num_nodes, vec![]); + predication(&mut editor, &typing[idx]); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + self.module.functions[idx].delete_gravestones(); } self.clear_analyses(); diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs index be1b4a0b..cfad7d1c 100644 --- a/hercules_opt/src/pred.rs +++ b/hercules_opt/src/pred.rs @@ -1,257 +1,287 @@ -use std::collections::HashMap; -use std::collections::HashSet; -use std::collections::VecDeque; +use std::cmp::{max, min}; +use std::collections::{BTreeMap, BTreeSet}; +use std::iter::zip; -use bitvec::prelude::*; +use itertools::Itertools; -use hercules_ir::def_use::*; -use hercules_ir::dom::*; -use hercules_ir::ir::*; +use hercules_ir::*; + +use crate::*; /* - * Top level function to convert acyclic control flow in vectorized fork-joins - * into predicated data flow. + * Top level function to run predication on a function. Repeatedly looks for + * acyclic control flow that can be converted into dataflow. */ -pub fn predication( - function: &mut Function, - def_use: &ImmutableDefUseMap, - reverse_postorder: &Vec<NodeID>, - dom: &DomTree, - fork_join_map: &HashMap<NodeID, NodeID>, -) { - // Detect forks with vectorize schedules. - let vector_forks: Vec<_> = function - .nodes - .iter() - .enumerate() - //.filter(|(idx, n)| n.is_fork() && schedules[*idx].contains(&Schedule::Vectorize)) - .filter(|(_, n)| n.is_fork()) - .map(|(idx, _)| NodeID::new(idx)) - .collect(); - - // Filter forks that can't actually be vectorized, and yell at the user if - // they're being silly. - let actual_vector_forks: Vec<_> = vector_forks - .into_iter() - .filter_map(|fork_id| { - // Detect cycles in control flow between fork and join. Start at the - // join, and work backwards. - let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; - let join_id = fork_join_map[&fork_id]; - let mut stack = vec![join_id]; - while let Some(pop) = stack.pop() { - // Only detect cycles between fork and join, and don't revisit - // nodes. - if visited[pop.idx()] || function.nodes[pop.idx()].is_fork() { - continue; +pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { + // Remove branches iteratively, since predicating an inside branch may cause + // an outside branch to be available for predication. + let mut bad_branches = BTreeSet::new(); + loop { + // First, look for a branch whose projections all point to the same + // region. These are branches with no internal control flow. + let nodes = &editor.func().nodes; + let Some((region, branch, false_proj, true_proj)) = editor + .node_ids() + .filter_map(|id| { + if let Node::Region { ref preds } = nodes[id.idx()] { + // Look for two projections with the same branch. + let preds = preds.into_iter().filter_map(|id| { + nodes[id.idx()] + .try_proj() + .map(|(branch, selection)| (*id, branch, selection)) + }); + // Index projections by if branch. + let mut pred_map: BTreeMap<NodeID, Vec<(NodeID, usize)>> = BTreeMap::new(); + for (proj, branch, selection) in preds { + if nodes[branch.idx()].is_if() && !bad_branches.contains(&branch) { + pred_map.entry(branch).or_default().push((proj, selection)); + } + } + // Look for an if branch with two projections going into the + // same region. + for (branch, projs) in pred_map { + if projs.len() == 2 { + let way = projs[0].1; + assert_ne!(way, projs[1].1); + return Some((id, branch, projs[way].0, projs[1 - way].0)); + } + } } + None + }) + .next() + else { + break; + }; + let Node::Region { preds } = nodes[region.idx()].clone() else { + panic!() + }; + let Node::If { + control: if_pred, + cond, + } = nodes[branch.idx()] + else { + panic!() + }; + let phis: Vec<_> = editor + .get_users(region) + .filter(|id| nodes[id.idx()].is_phi()) + .collect(); + // Don't predicate branches where one of the phis is a collection. + // Predicating this branch would result in a clone and probably woudln't + // result in good vector code. + if phis + .iter() + .any(|id| !editor.get_type(typing[id.idx()]).is_primitive()) + { + bad_branches.insert(branch); + continue; + } + let false_pos = preds.iter().position(|id| *id == false_proj).unwrap(); + let true_pos = preds.iter().position(|id| *id == true_proj).unwrap(); - // Filter if there is a cycle, or if there is a nested fork, or - // if there is a match node. We know there is a loop if a node - // dominates one of its predecessors. - let control_uses: Vec<_> = get_uses(&function.nodes[pop.idx()]) - .as_ref() - .iter() - .filter(|id| function.nodes[id.idx()].is_control()) - .map(|x| *x) - .collect(); - if control_uses - .iter() - .any(|pred_id| dom.does_dom(pop, *pred_id)) - || (function.nodes[pop.idx()].is_join() && pop != join_id) - || function.nodes[pop.idx()].is_match() - { - eprintln!( - "WARNING: Vectorize schedule attached to fork that cannot be vectorized." - ); - return None; - } + // Second, make all the modifications: + // - Add the select nodes. + // - Replace uses in phis with the select node. + // - Remove the branch from the control flow. + // This leaves the old region in place - if the region has one + // predecessor, it may be removed by CCP. + let success = editor.edit(|mut edit| { + // Replace the branch projection predecessors of the region with the + // predecessor of the branch. + let Node::Region { preds } = edit.get_node(region).clone() else { + panic!() + }; + let mut preds = Vec::from(preds); + preds.remove(max(true_pos, false_pos)); + preds.remove(min(true_pos, false_pos)); + preds.push(if_pred); + let new_region = edit.add_node(Node::Region { + preds: preds.into_boxed_slice(), + }); + edit = edit.replace_all_uses(region, new_region)?; - // Recurse up the control subgraph. - visited.set(pop.idx(), true); - stack.extend(control_uses); + // Replace the corresponding inputs in the phi nodes with select + // nodes selecting over the old inputs to the phis. + for phi in phis { + let Node::Phi { control: _, data } = edit.get_node(phi).clone() else { + panic!() + }; + let mut data = Vec::from(data); + let select = edit.add_node(Node::Ternary { + op: TernaryOperator::Select, + first: cond, + second: data[true_pos], + third: data[false_pos], + }); + data.remove(max(true_pos, false_pos)); + data.remove(min(true_pos, false_pos)); + data.push(select); + let new_phi = edit.add_node(Node::Phi { + control: new_region, + data: data.into_boxed_slice(), + }); + edit = edit.replace_all_uses(phi, new_phi)?; + edit = edit.delete_node(phi)?; } - Some((fork_id, visited)) - }) - .collect(); - - // For each control node, collect which condition values must be true, and - // which condition values must be false to reach that node. Each phi's - // corresponding region will have at least one condition value that differs - // between the predecessors. These differing condition values anded together - // form the select condition. - let mut condition_valuations: HashMap<NodeID, (HashSet<NodeID>, HashSet<NodeID>)> = - HashMap::new(); - for (fork_id, control_in_fork_join) in actual_vector_forks.iter() { - // Within a fork-join, there are no condition requirements on the fork. - condition_valuations.insert(*fork_id, (HashSet::new(), HashSet::new())); - - // Iterate the nodes in the fork-join in reverse postorder, top-down. - let local_reverse_postorder = reverse_postorder - .iter() - .filter(|id| control_in_fork_join[id.idx()]); - for control_id in local_reverse_postorder { - match function.nodes[control_id.idx()] { - Node::If { control, cond: _ } | Node::Join { control } => { - condition_valuations - .insert(*control_id, condition_valuations[&control].clone()); - } - // Introduce condition variables into sets, as this is where - // branching occurs. - Node::Projection { - control, - ref selection, - } => { - assert!(*selection < 2); - let mut sets = condition_valuations[&control].clone(); - let condition = function.nodes[control.idx()].try_if().unwrap().1; - if *selection == 0 { - sets.0.insert(condition); - } else { - sets.1.insert(condition); - } - condition_valuations.insert(*control_id, sets); - } - // The only required conditions for a region are those required - // for all predecessors. Thus, the condition sets for a region - // are the intersections of the predecessor condition sets. - Node::Region { ref preds } => { - let (prev_true_set, prev_false_set) = condition_valuations[&preds[0]].clone(); - let int_true_set = preds[1..].iter().fold(prev_true_set, |a, b| { - a.intersection(&condition_valuations[b].0) - .map(|x| *x) - .collect::<HashSet<NodeID>>() - }); - let int_false_set = preds[1..].iter().fold(prev_false_set, |a, b| { - a.intersection(&condition_valuations[b].0) - .map(|x| *x) - .collect::<HashSet<NodeID>>() - }); + // Delete the old control nodes. + edit = edit.delete_node(region)?; + edit = edit.delete_node(branch)?; + edit = edit.delete_node(false_proj)?; + edit = edit.delete_node(true_proj)?; - condition_valuations.insert(*control_id, (int_true_set, int_false_set)); - } - _ => { - panic!() - } - } + Ok(edit) + }); + if !success { + bad_branches.insert(branch); } } +} - // Convert control flow to predicated data flow. - for (fork_id, control_in_fork_join) in actual_vector_forks.into_iter() { - // Worklist of control nodes - traverse control backwards breadth-first. - let mut queue = VecDeque::new(); - let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; - let join_id = fork_join_map[&fork_id]; - queue.push_back(join_id); +/* + * Top level function to run write predication on a function. Repeatedly looks + * for phi nodes where every data input is a write node and each write node has + * matching types of indices. These writes are coalesced into a single write + * using the phi as the `collect` input, and the phi selects over the old + * `collect` inputs to the writes. New phis are added to select over the `data` + * inputs and indices. + */ +pub fn write_predication(editor: &mut FunctionEditor) { + let mut bad_phis = BTreeSet::new(); + loop { + // First, look for phis where every input is a write with the same + // indexing structure. + let nodes = &editor.func().nodes; + let Some((phi, control, writes)) = editor + .node_ids() + .filter_map(|id| { + if let Node::Phi { + control, + ref data, + } = nodes[id.idx()] + && !bad_phis.contains(&id) + // Check that every input is a write - if this weren't true, + // we'd have to insert dummy writes that write something that + // was just read from the array. We could handle this case, but + // it's probably not needed for now. Also check that the phi is + // the only user of the write. + && data.into_iter().all(|id| nodes[id.idx()].is_write() && editor.get_users(*id).count() == 1) + // Check that every write input has equivalent indexing + // structure. + && data + .into_iter() + .filter_map(|id| nodes[id.idx()].try_write()) + .tuple_windows() + .all(|(w1, w2)| indices_structurally_equivalent(w1.2, w2.2)) + { + Some((id, control, data.clone())) + } else { + None + } + }) + .next() + else { + break; + }; + let (collects, datas, indices): (Vec<_>, Vec<_>, Vec<_>) = writes + .iter() + .filter_map(|id| nodes[id.idx()].try_write()) + .map(|(collect, data, indices)| (collect, data, indices.to_owned())) + .multiunzip(); - while let Some(pop) = queue.pop_front() { - // Stop at forks, and don't revisit nodes. - if visited[pop.idx()] || function.nodes[pop.idx()].is_fork() { - continue; + // Second, make all the modifications: + // - Replace the old phi with a phi selecting over the `collect` inputs + // of the old write inputs. + // - Add phis for the `data` and `indices` inputs to the old writes. + // - Add a write that uses the phi-ed data and indices. + // Be a little careful over how the old phi and writes get replaced, + // since the old phi itself may be used by the old writes. + let success = editor.edit(|mut edit| { + // Create all the phis. + let collect_phi = edit.add_node(Node::Phi { + control, + data: collects.into_boxed_slice(), + }); + let data_phi = edit.add_node(Node::Phi { + control, + data: datas.into_boxed_slice(), + }); + let mut phied_indices = vec![]; + for index in 0..indices[0].len() { + match indices[0][index] { + // For field and variant indices, the index is the same + // across all writes, so just take the one from the first + // set of indices. + Index::Position(ref old_pos) => { + let mut pos = vec![]; + for pos_idx in 0..old_pos.len() { + // This code is kind of messy due to three layers of + // arrays. Basically, we are collecting every + // indexing node on indices across each write. + pos.push( + edit.add_node(Node::Phi { + control, + data: indices + .iter() + .map(|indices| { + indices[index].try_position().unwrap()[pos_idx] + }) + .collect(), + }), + ); + } + phied_indices.push(Index::Position(pos.into_boxed_slice())); + } + _ => { + phied_indices.push(indices[0][index].clone()); + } + } } - // The only type of node we need to handle at this point are region - // nodes. Region nodes are what have phi users, and those phis are - // what need to get converted to select nodes. - if let Node::Region { preds } = &function.nodes[pop.idx()] { - // Get the unique true and false conditions per predecessor. - // These are the conditions attached to the predecessor that - // aren't attached to this region. - assert_eq!(preds.len(), 2); - let (region_true_conds, region_false_conds) = &condition_valuations[&pop]; - let unique_conditions = preds - .iter() - .map(|pred_id| { - let (pred_true_conds, pred_false_conds) = &condition_valuations[pred_id]; - ( - pred_true_conds - .iter() - .filter(|cond_id| !region_true_conds.contains(cond_id)) - .map(|x| *x) - .collect::<HashSet<NodeID>>(), - pred_false_conds - .iter() - .filter(|cond_id| !region_false_conds.contains(cond_id)) - .map(|x| *x) - .collect::<HashSet<NodeID>>(), - ) - }) - .collect::<Vec<_>>(); + // Create the write. + let new_write = edit.add_node(Node::Write { + collect: collect_phi, + data: data_phi, + indices: phied_indices.into_boxed_slice(), + }); - // Currently, we only handle if branching. The unique conditions - // for a region's predecessors must be exact inverses of each - // other. Given this is true, we just use unique_conditions[0] - // to calculate the select condition. - assert_eq!(unique_conditions[0].0, unique_conditions[1].1); - assert_eq!(unique_conditions[0].1, unique_conditions[1].0); - let negated_conditions = unique_conditions[0] - .1 - .iter() - .map(|cond_id| { - let id = NodeID::new(function.nodes.len()); - function.nodes.push(Node::Unary { - input: *cond_id, - op: UnaryOperator::Not, - }); - id - }) - .collect::<Vec<NodeID>>(); - let mut all_conditions = unique_conditions[0] - .0 - .iter() - .map(|x| *x) - .chain(negated_conditions.into_iter()); + // Replace the old phi with the new write. + edit = edit.replace_all_uses(phi, new_write)?; - // And together the negated negative and position conditions. - let first_cond = all_conditions.next().unwrap(); - let reduced_cond = all_conditions.into_iter().fold(first_cond, |a, b| { - let id = NodeID::new(function.nodes.len()); - function.nodes.push(Node::Binary { - left: a, - right: b, - op: BinaryOperator::And, - }); - id - }); - - // Create the select nodes, corresponding to all phi users. - for phi in def_use.get_users(pop) { - if let Node::Phi { control: _, data } = &function.nodes[phi.idx()] { - let select_id = NodeID::new(function.nodes.len()); - function.nodes.push(Node::Ternary { - first: reduced_cond, - second: data[1], - third: data[0], - op: TernaryOperator::Select, - }); - for user in def_use.get_users(*phi) { - get_uses_mut(&mut function.nodes[user.idx()]).map(*phi, select_id); - } - function.nodes[phi.idx()] = Node::Start; - } - } + // Delete the old phi and writes. + edit = edit.delete_node(phi)?; + for write in writes { + edit = edit.delete_node(write)?; } - // Add users of this control node to queue. - visited.set(pop.idx(), true); - queue.extend( - get_uses(&function.nodes[pop.idx()]) - .as_ref() - .iter() - .filter(|id| function.nodes[id.idx()].is_control() && !visited[id.idx()]), - ); + Ok(edit) + }); + if !success { + bad_phis.insert(phi); } + } +} - // Now that we've converted all the phis to selects, delete all the - // control nodes. - for control_idx in control_in_fork_join.iter_ones() { - if let Node::Join { control } = function.nodes[control_idx] { - get_uses_mut(&mut function.nodes[control_idx]).map(control, fork_id); - } else { - function.nodes[control_idx] = Node::Start; - } +/* + * Helper function to tell if two lists of indices have the same structure. + */ +fn indices_structurally_equivalent(indices1: &[Index], indices2: &[Index]) -> bool { + if indices1.len() == indices2.len() { + let mut equiv = true; + for pair in zip(indices1, indices2) { + equiv = equiv + && match pair { + (Index::Field(idx1), Index::Field(idx2)) => idx1 == idx2, + (Index::Variant(idx1), Index::Variant(idx2)) => idx1 == idx2, + (Index::Position(ref pos1), Index::Position(ref pos2)) => { + pos1.len() == pos2.len() + } + _ => false, + }; } + equiv + } else { + false } } diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index 906d7805..cfaf7a26 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -179,6 +179,15 @@ pub fn compile_ir( add_pass!(pm, verify, DCE); add_pass!(pm, verify, GVN); add_pass!(pm, verify, DCE); + add_pass!(pm, verify, WritePredication); + add_pass!(pm, verify, PhiElim); + add_pass!(pm, verify, DCE); + add_pass!(pm, verify, Predication); + add_pass!(pm, verify, DCE); + add_pass!(pm, verify, CCP); + add_pass!(pm, verify, DCE); + add_pass!(pm, verify, GVN); + add_pass!(pm, verify, DCE); if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } diff --git a/juno_samples/antideps/src/antideps.jn b/juno_samples/antideps/src/antideps.jn index 9efe71f1..6886741b 100644 --- a/juno_samples/antideps/src/antideps.jn +++ b/juno_samples/antideps/src/antideps.jn @@ -2,7 +2,13 @@ fn simple_antideps(a : usize, b : usize) -> i32 { let arr : i32[3]; let r = arr[b]; - arr[a] = 5; + let x : i32[1]; + if a == b { + x[0] = 5; + } else { + x[0] = 7; + } + arr[a] = x[0]; return r + arr[b]; } -- GitLab