use std::cmp::{max, min}; use std::collections::{BTreeMap, BTreeSet}; use itertools::Itertools; use hercules_ir::*; use crate::*; /* * Top level function to run predication on a function. Repeatedly looks for * acyclic control flow that can be converted into dataflow. */ pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { // Remove branches iteratively, since predicating an inside branch may cause // an outside branch to be available for predication. let mut bad_branches = BTreeSet::new(); loop { // First, look for a branch whose projections all point to the same // region. These are branches with no internal control flow. let nodes = &editor.func().nodes; let Some((region, branch, false_proj, true_proj)) = editor .node_ids() .filter_map(|id| { if let Node::Region { ref preds } = nodes[id.idx()] { // Look for two projections with the same branch. let preds = preds.into_iter().filter_map(|id| { nodes[id.idx()] .try_control_proj() .map(|(branch, selection)| (*id, branch, selection)) }); // Index projections by if branch. let mut pred_map: BTreeMap<NodeID, Vec<(NodeID, usize)>> = BTreeMap::new(); for (proj, branch, selection) in preds { if nodes[branch.idx()].is_if() && !bad_branches.contains(&branch) { pred_map.entry(branch).or_default().push((proj, selection)); } } // Look for an if branch with two projections going into the // same region. for (branch, projs) in pred_map { if projs.len() == 2 { let way = projs[0].1; assert_ne!(way, projs[1].1); return Some((id, branch, projs[way].0, projs[1 - way].0)); } } } None }) .next() else { break; }; let Node::Region { preds } = nodes[region.idx()].clone() else { panic!() }; let Node::If { control: if_pred, cond, } = nodes[branch.idx()] else { panic!() }; let phis: Vec<_> = editor .get_users(region) .filter(|id| nodes[id.idx()].is_phi()) .collect(); // Don't predicate branches where one of the phis is a collection. // Predicating this branch would result in a clone and probably woudln't // result in good vector code. if phis .iter() .any(|id| !editor.get_type(typing[id.idx()]).is_primitive()) { bad_branches.insert(branch); continue; } let false_pos = preds.iter().position(|id| *id == false_proj).unwrap(); let true_pos = preds.iter().position(|id| *id == true_proj).unwrap(); // Second, make all the modifications: // - Add the select nodes. // - Replace uses in phis with the select node. // - Remove the branch from the control flow. // This leaves the old region in place - if the region has one // predecessor, it may be removed by CCP. let success = editor.edit(|mut edit| { // Replace the branch projection predecessors of the region with the // predecessor of the branch. let Node::Region { preds } = edit.get_node(region).clone() else { panic!() }; let mut preds = Vec::from(preds); preds.remove(max(true_pos, false_pos)); preds.remove(min(true_pos, false_pos)); preds.push(if_pred); let new_region = edit.add_node(Node::Region { preds: preds.into_boxed_slice(), }); edit = edit.replace_all_uses(region, new_region)?; // Replace the corresponding inputs in the phi nodes with select // nodes selecting over the old inputs to the phis. for phi in phis { let Node::Phi { control: _, data } = edit.get_node(phi).clone() else { panic!() }; let mut data = Vec::from(data); let select = edit.add_node(Node::Ternary { op: TernaryOperator::Select, first: cond, second: data[true_pos], third: data[false_pos], }); data.remove(max(true_pos, false_pos)); data.remove(min(true_pos, false_pos)); data.push(select); let new_phi = edit.add_node(Node::Phi { control: new_region, data: data.into_boxed_slice(), }); edit = edit.replace_all_uses(phi, new_phi)?; edit = edit.delete_node(phi)?; } // Delete the old control nodes. edit = edit.delete_node(region)?; edit = edit.delete_node(branch)?; edit = edit.delete_node(false_proj)?; edit = edit.delete_node(true_proj)?; Ok(edit) }); if !success { bad_branches.insert(branch); } } // Do a quick and dirty rewrite to convert select(a, b, false) to a && b and // select(a, b, true) to a || b. for id in editor.node_ids() { let nodes = &editor.func().nodes; if let Node::Ternary { op: TernaryOperator::Select, first, second, third, } = nodes[id.idx()] { if let Some(cons) = nodes[second.idx()].try_constant() && editor.get_constant(cons).is_false() { editor.edit(|mut edit| { let inv = edit.add_node(Node::Unary { op: UnaryOperator::Not, input: first, }); let node = edit.add_node(Node::Binary { op: BinaryOperator::And, left: inv, right: third, }); edit = edit.replace_all_uses(id, node)?; edit.delete_node(id) }); } else if let Some(cons) = nodes[third.idx()].try_constant() && editor.get_constant(cons).is_false() { editor.edit(|mut edit| { let node = edit.add_node(Node::Binary { op: BinaryOperator::And, left: first, right: second, }); edit = edit.replace_all_uses(id, node)?; edit.delete_node(id) }); } else if let Some(cons) = nodes[second.idx()].try_constant() && editor.get_constant(cons).is_true() { editor.edit(|mut edit| { let node = edit.add_node(Node::Binary { op: BinaryOperator::Or, left: first, right: third, }); edit = edit.replace_all_uses(id, node)?; edit.delete_node(id) }); } else if let Some(cons) = nodes[third.idx()].try_constant() && editor.get_constant(cons).is_true() { editor.edit(|mut edit| { let inv = edit.add_node(Node::Unary { op: UnaryOperator::Not, input: first, }); let node = edit.add_node(Node::Binary { op: BinaryOperator::Or, left: inv, right: second, }); edit = edit.replace_all_uses(id, node)?; edit.delete_node(id) }); } } } } /* * Top level function to run write predication on a function. Repeatedly looks * for phi nodes where every data input is a write node and each write node has * matching types of indices. These writes are coalesced into a single write * using the phi as the `collect` input, and the phi selects over the old * `collect` inputs to the writes. New phis are added to select over the `data` * inputs and indices. */ pub fn write_predication(editor: &mut FunctionEditor) { let mut bad_phis = BTreeSet::new(); loop { // First, look for phis where every input is a write with the same // indexing structure. let nodes = &editor.func().nodes; let Some((phi, control, writes)) = editor .node_ids() .filter_map(|id| { if let Node::Phi { control, ref data, } = nodes[id.idx()] && !bad_phis.contains(&id) // Check that every input is a write - if this weren't true, // we'd have to insert dummy writes that write something that // was just read from the array. We could handle this case, but // it's probably not needed for now. Also check that the phi is // the only user of the write. && data.into_iter().all(|id| nodes[id.idx()].is_write() && editor.get_users(*id).count() == 1) // Check that every write input has equivalent indexing // structure. && data .into_iter() .filter_map(|id| nodes[id.idx()].try_write()) .tuple_windows() .all(|(w1, w2)| indices_structurally_equivalent(w1.2, w2.2)) { Some((id, control, data.clone())) } else { None } }) .next() else { break; }; let (collects, datas, indices): (Vec<_>, Vec<_>, Vec<_>) = writes .iter() .filter_map(|id| nodes[id.idx()].try_write()) .map(|(collect, data, indices)| (collect, data, indices.to_owned())) .multiunzip(); // Second, make all the modifications: // - Replace the old phi with a phi selecting over the `collect` inputs // of the old write inputs. // - Add phis for the `data` and `indices` inputs to the old writes. // - Add a write that uses the phi-ed data and indices. // Be a little careful over how the old phi and writes get replaced, // since the old phi itself may be used by the old writes. let success = editor.edit(|mut edit| { // Create all the phis. let collect_phi = edit.add_node(Node::Phi { control, data: collects.into_boxed_slice(), }); let data_phi = edit.add_node(Node::Phi { control, data: datas.into_boxed_slice(), }); let mut phied_indices = vec![]; for index in 0..indices[0].len() { match indices[0][index] { // For field and variant indices, the index is the same // across all writes, so just take the one from the first // set of indices. Index::Position(ref old_pos) => { let mut pos = vec![]; for pos_idx in 0..old_pos.len() { // This code is kind of messy due to three layers of // arrays. Basically, we are collecting every // indexing node on indices across each write. pos.push( edit.add_node(Node::Phi { control, data: indices .iter() .map(|indices| { indices[index].try_position().unwrap()[pos_idx] }) .collect(), }), ); } phied_indices.push(Index::Position(pos.into_boxed_slice())); } _ => { phied_indices.push(indices[0][index].clone()); } } } // Create the write. let new_write = edit.add_node(Node::Write { collect: collect_phi, data: data_phi, indices: phied_indices.into_boxed_slice(), }); // Replace the old phi with the new write. edit = edit.replace_all_uses(phi, new_write)?; // Delete the old phi and writes. edit = edit.delete_node(phi)?; for write in writes { edit = edit.delete_node(write)?; } Ok(edit) }); if !success { bad_phis.insert(phi); } } }