diff --git a/Cargo.lock b/Cargo.lock index a3d1a27af3e4a9faead9499168cff5879d3ef895..ed83f6be73b90235394dc221fa6f32222c8d1f4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -206,6 +206,7 @@ dependencies = [ name = "hercules_opt" version = "0.1.0" dependencies = [ + "bitvec", "hercules_ir", "ordered-float", ] diff --git a/hercules_cg/src/gcm.rs b/hercules_cg/src/gcm.rs index baaeee330c41d1098d4c1546ddb60732d1ee6dbc..1f840a5cfeb04722273d18365a1eea45a436d676 100644 --- a/hercules_cg/src/gcm.rs +++ b/hercules_cg/src/gcm.rs @@ -61,7 +61,8 @@ pub fn gcm( let highest = dom.lowest_amongst(immediate_control_uses[idx].nodes(function.nodes.len() as u32)); let lowest = dom - .common_ancestor(immediate_control_users[idx].nodes(function.nodes.len() as u32)); + .common_ancestor(immediate_control_users[idx].nodes(function.nodes.len() as u32)) + .unwrap_or(highest); // Collect into vector to reverse, since we want to traverse down // the dom tree, not up it. diff --git a/hercules_ir/src/dom.rs b/hercules_ir/src/dom.rs index 8e11efacbde3521f4dc307ad8a132bca03668333..e9cb07cd428fa006e3588076b9b94633c2c1e24a 100644 --- a/hercules_ir/src/dom.rs +++ b/hercules_ir/src/dom.rs @@ -80,13 +80,16 @@ impl DomTree { .1 } - pub fn common_ancestor<I>(&self, x: I) -> NodeID + pub fn common_ancestor<I>(&self, x: I) -> Option<NodeID> where I: Iterator<Item = NodeID>, { let mut positions: HashMap<NodeID, u32> = x .map(|x| (x, if x == self.root { 0 } else { self.idom[&x].0 })) .collect(); + if positions.len() == 0 { + return None; + } let mut current_level = *positions.iter().map(|(_, level)| level).max().unwrap(); while positions.len() > 1 { let at_current_level: Vec<NodeID> = positions @@ -102,7 +105,7 @@ impl DomTree { } current_level -= 1; } - positions.into_iter().next().unwrap().0 + Some(positions.into_iter().next().unwrap().0) } pub fn chain<'a>(&'a self, bottom: NodeID, top: NodeID) -> DomChainIterator<'a> { diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 63d8831318f9817878c560056f04ef0b22f2b9bb..7e3b03a06018f7e30c77d496b80a52a05cb6db35 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -565,7 +565,7 @@ impl Function { let old_id = **u; let new_id = node_mapping[old_id.idx()]; if new_id == NodeID::new(0) && old_id != NodeID::new(0) { - panic!("While deleting gravestones, came across a use of a gravestoned node. The user has ID {} and was using {}.", idx, old_id.idx()); + panic!("While deleting gravestones, came across a use of a gravestoned node. The user has ID {} and was using ID {}. Here's the user: {:?}", idx, old_id.idx(), node); } **u = new_id; } @@ -766,6 +766,14 @@ impl Index { } } + pub fn try_control(&self) -> Option<usize> { + if let Index::Control(val) = self { + Some(*val) + } else { + None + } + } + pub fn lower_case_name(&self) -> &'static str { match self { Index::Field(_) => "field", @@ -834,6 +842,14 @@ impl Node { ); define_pattern_predicate!(is_match, Node::Match { control: _, sum: _ }); + pub fn try_region(&self) -> Option<&[NodeID]> { + if let Node::Region { preds } = self { + Some(preds) + } else { + None + } + } + pub fn try_if(&self) -> Option<(NodeID, NodeID)> { if let Node::If { control, cond } = self { Some((*control, *cond)) diff --git a/hercules_ir/src/schedule.rs b/hercules_ir/src/schedule.rs index 2151b1320a8cf85a934f8edec1bd5315962e5999..08f8103dac29219850558f16cae91d37cd98f970 100644 --- a/hercules_ir/src/schedule.rs +++ b/hercules_ir/src/schedule.rs @@ -8,7 +8,7 @@ use crate::*; * consideration at some point during the compilation pipeline. Each schedule is * associated with a single node. */ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Schedule { ParallelReduce, Vectorize, diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml index bc30540595ff660080e21de3b76b2d7287adbd90..b1f2b468c5f0a81646187ce5ff4ccd8ef2454437 100644 --- a/hercules_opt/Cargo.toml +++ b/hercules_opt/Cargo.toml @@ -5,4 +5,5 @@ authors = ["Russel Arbore <rarbore2@illinois.edu>"] [dependencies] ordered-float = "*" +bitvec = "*" hercules_ir = { path = "../hercules_ir" } diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index 5b0920c57cc5d1e92cf7b6016bfeacf1d6087268..22496b20f2b2d5d3f2d0fe005bacdf5320d94e54 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -224,6 +224,9 @@ pub fn forkify( function.nodes[idx_phi.idx()] = Node::Start; // Delete old loop control nodes; + for user in def_use.get_users(*header) { + get_uses_mut(&mut function.nodes[user.idx()]).map(*header, fork_id); + } function.nodes[header.idx()] = Node::Start; function.nodes[loop_end.idx()] = Node::Start; function.nodes[loop_true_read.idx()] = Node::Start; diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 2cdf4c14965d53e07c5872d0835c4415e44a8d1f..53ebc1da7f99639e93ab57f192a3d5dc389aadd6 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -3,9 +3,11 @@ pub mod dce; pub mod forkify; pub mod gvn; pub mod pass; +pub mod pred; pub use crate::ccp::*; pub use crate::dce::*; pub use crate::forkify::*; pub use crate::gvn::*; pub use crate::pass::*; +pub use crate::pred::*; diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index fcb0e5116d96858868771bc3ddf199adcaea61a8..d94de66e8dc022a48f5b69374205063c00a339d2 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -24,6 +24,7 @@ pub enum Pass { CCP, GVN, Forkify, + Predication, Verify, Xdot, } @@ -219,6 +220,26 @@ impl PassManager { ) } } + Pass::Predication => { + self.make_def_uses(); + self.make_reverse_postorders(); + self.make_doms(); + self.make_fork_join_maps(); + let def_uses = self.def_uses.as_ref().unwrap(); + let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); + let doms = self.doms.as_ref().unwrap(); + let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); + for idx in 0..self.module.functions.len() { + predication( + &mut self.module.functions[idx], + &def_uses[idx], + &reverse_postorders[idx], + &doms[idx], + &fork_join_maps[idx], + &vec![], + ) + } + } Pass::Verify => { let ( def_uses, diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs new file mode 100644 index 0000000000000000000000000000000000000000..0bd7171def8fbf2e55be16fb1ddc1fc929e3f971 --- /dev/null +++ b/hercules_opt/src/pred.rs @@ -0,0 +1,264 @@ +extern crate bitvec; +extern crate hercules_ir; + +use std::collections::HashMap; +use std::collections::HashSet; +use std::collections::VecDeque; + +use self::bitvec::prelude::*; + +use self::hercules_ir::def_use::*; +use self::hercules_ir::dom::*; +use self::hercules_ir::ir::*; +use self::hercules_ir::schedule::*; + +/* + * Top level function to convert acyclic control flow in vectorized fork-joins + * into predicated data flow. + */ +pub fn predication( + function: &mut Function, + def_use: &ImmutableDefUseMap, + reverse_postorder: &Vec<NodeID>, + dom: &DomTree, + fork_join_map: &HashMap<NodeID, NodeID>, + schedules: &Vec<Vec<Schedule>>, +) { + // Detect forks with vectorize schedules. + let vector_forks: Vec<_> = function + .nodes + .iter() + .enumerate() + //.filter(|(idx, n)| n.is_fork() && schedules[*idx].contains(&Schedule::Vectorize)) + .filter(|(_, n)| n.is_fork()) + .map(|(idx, _)| NodeID::new(idx)) + .collect(); + + // Filter forks that can't actually be vectorized, and yell at the user if + // they're being silly. + let actual_vector_forks: Vec<_> = vector_forks + .into_iter() + .filter_map(|fork_id| { + // Detect cycles in control flow between fork and join. Start at the + // join, and work backwards. + let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; + let join_id = fork_join_map[&fork_id]; + let mut stack = vec![join_id]; + while let Some(pop) = stack.pop() { + // Only detect cycles between fork and join, and don't revisit + // nodes. + if visited[pop.idx()] || function.nodes[pop.idx()].is_fork() { + continue; + } + + // Filter if there is a cycle, or if there is a nested fork, or + // if there is a match node. We know there is a loop if a node + // dominates one of its predecessors. + let control_uses: Vec<_> = get_uses(&function.nodes[pop.idx()]) + .as_ref() + .iter() + .filter(|id| function.nodes[id.idx()].is_control()) + .map(|x| *x) + .collect(); + if control_uses + .iter() + .any(|pred_id| dom.does_dom(pop, *pred_id)) + || (function.nodes[pop.idx()].is_join() && pop != join_id) + || function.nodes[pop.idx()].is_match() + { + eprintln!( + "WARNING: Vectorize schedule attached to fork that cannot be vectorized." + ); + return None; + } + + // Recurse up the control subgraph. + visited.set(pop.idx(), true); + stack.extend(control_uses); + } + + Some((fork_id, visited)) + }) + .collect(); + + // For each control node, collect which condition values must be true, and + // which condition values must be false to reach that node. Each phi's + // corresponding region will have at least one condition value that differs + // between the predecessors. These differing condition values anded together + // form the select condition. + let mut condition_valuations: HashMap<NodeID, (HashSet<NodeID>, HashSet<NodeID>)> = + HashMap::new(); + for (fork_id, control_in_fork_join) in actual_vector_forks.iter() { + // Within a fork-join, there are no condition requirements on the fork. + condition_valuations.insert(*fork_id, (HashSet::new(), HashSet::new())); + + // Iterate the nodes in the fork-join in reverse postorder, top-down. + let local_reverse_postorder = reverse_postorder + .iter() + .filter(|id| control_in_fork_join[id.idx()]); + for control_id in local_reverse_postorder { + match function.nodes[control_id.idx()] { + Node::If { control, cond: _ } | Node::Join { control } => { + condition_valuations + .insert(*control_id, condition_valuations[&control].clone()); + } + // Introduce condition variables into sets, as this is where + // branching occurs. + Node::Read { + collect, + ref indices, + } => { + assert_eq!(indices.len(), 1); + let truth_value = indices[0].try_control().unwrap(); + assert!(truth_value < 2); + let mut sets = condition_valuations[&collect].clone(); + let condition = function.nodes[collect.idx()].try_if().unwrap().1; + if truth_value == 0 { + sets.0.insert(condition); + } else { + sets.1.insert(condition); + } + condition_valuations.insert(*control_id, sets); + } + // The only required conditions for a region are those required + // for all predecessors. Thus, the condition sets for a region + // are the intersections of the predecessor condition sets. + Node::Region { ref preds } => { + let (prev_true_set, prev_false_set) = condition_valuations[&preds[0]].clone(); + let int_true_set = preds[1..].iter().fold(prev_true_set, |a, b| { + a.intersection(&condition_valuations[b].0) + .map(|x| *x) + .collect::<HashSet<NodeID>>() + }); + let int_false_set = preds[1..].iter().fold(prev_false_set, |a, b| { + a.intersection(&condition_valuations[b].0) + .map(|x| *x) + .collect::<HashSet<NodeID>>() + }); + + condition_valuations.insert(*control_id, (int_true_set, int_false_set)); + } + _ => { + panic!() + } + } + } + } + + // Convert control flow to predicated data flow. + for (fork_id, control_in_fork_join) in actual_vector_forks.into_iter() { + // Worklist of control nodes - traverse control backwards breadth-first. + let mut queue = VecDeque::new(); + let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; + let join_id = fork_join_map[&fork_id]; + queue.push_back(join_id); + + while let Some(pop) = queue.pop_front() { + // Stop at forks, and don't revisit nodes. + if visited[pop.idx()] || function.nodes[pop.idx()].is_fork() { + continue; + } + + // The only type of node we need to handle at this point are region + // nodes. Region nodes are what have phi users, and those phis are + // what need to get converted to select nodes. + if let Node::Region { preds } = &function.nodes[pop.idx()] { + // Get the unique true and false conditions per predecessor. + // These are the conditions attached to the predecessor that + // aren't attached to this region. + assert_eq!(preds.len(), 2); + let (region_true_conds, region_false_conds) = &condition_valuations[&pop]; + let unique_conditions = preds + .iter() + .map(|pred_id| { + let (pred_true_conds, pred_false_conds) = &condition_valuations[pred_id]; + ( + pred_true_conds + .iter() + .filter(|cond_id| !region_true_conds.contains(cond_id)) + .map(|x| *x) + .collect::<HashSet<NodeID>>(), + pred_false_conds + .iter() + .filter(|cond_id| !region_false_conds.contains(cond_id)) + .map(|x| *x) + .collect::<HashSet<NodeID>>(), + ) + }) + .collect::<Vec<_>>(); + + // Currently, we only handle if branching. The unique conditions + // for a region's predecessors must be exact inverses of each + // other. Given this is true, we just use unique_conditions[0] + // to calculate the select condition. + assert_eq!(unique_conditions[0].0, unique_conditions[1].1); + assert_eq!(unique_conditions[0].1, unique_conditions[1].0); + let negated_conditions = unique_conditions[0] + .1 + .iter() + .map(|cond_id| { + let id = NodeID::new(function.nodes.len()); + function.nodes.push(Node::Unary { + input: *cond_id, + op: UnaryOperator::Not, + }); + id + }) + .collect::<Vec<NodeID>>(); + let mut all_conditions = unique_conditions[0] + .0 + .iter() + .map(|x| *x) + .chain(negated_conditions.into_iter()); + + // And together the negated negative and position conditions. + let first_cond = all_conditions.next().unwrap(); + let reduced_cond = all_conditions.into_iter().fold(first_cond, |a, b| { + let id = NodeID::new(function.nodes.len()); + function.nodes.push(Node::Binary { + left: a, + right: b, + op: BinaryOperator::And, + }); + id + }); + + // Create the select nodes, corresponding to all phi users. + for phi in def_use.get_users(pop) { + if let Node::Phi { control: _, data } = &function.nodes[phi.idx()] { + let select_id = NodeID::new(function.nodes.len()); + function.nodes.push(Node::Ternary { + first: reduced_cond, + second: data[1], + third: data[0], + op: TernaryOperator::Select, + }); + for user in def_use.get_users(*phi) { + get_uses_mut(&mut function.nodes[user.idx()]).map(*phi, select_id); + } + function.nodes[phi.idx()] = Node::Start; + } + } + } + + // Add users of this control node to queue. + visited.set(pop.idx(), true); + queue.extend( + get_uses(&function.nodes[pop.idx()]) + .as_ref() + .iter() + .filter(|id| function.nodes[id.idx()].is_control() && !visited[id.idx()]), + ); + } + + // Now that we've converted all the phis to selects, delete all the + // control nodes. + for control_idx in control_in_fork_join.iter_ones() { + if let Node::Join { control } = function.nodes[control_idx] { + get_uses_mut(&mut function.nodes[control_idx]).map(control, fork_id); + } else { + function.nodes[control_idx] = Node::Start; + } + } + } +} diff --git a/hercules_samples/sum_sample.hir b/hercules_samples/sum_sample.hir index 55852e7f008115a0f86f4b68dfc8daa08710283d..8b8c0024fb548ac8c0d1dde4e029f76230d71886 100644 --- a/hercules_samples/sum_sample.hir +++ b/hercules_samples/sum_sample.hir @@ -14,3 +14,29 @@ fn sum(a: array(f32, 16)) -> f32 if_false = read(if, control(0)) if_true = read(if, control(1)) r = return(if_false, red_add) + +fn alt_sum<1>(a: array(f32, #0)) -> f32 + zero_idx = constant(u64, 0) + one_idx = constant(u64, 1) + two_idx = constant(u64, 2) + zero_inc = constant(f32, 0) + bound = dynamic_constant(#0) + loop = region(start, if_true) + idx = phi(loop, zero_idx, idx_inc) + idx_inc = add(idx, one_idx) + red = phi(loop, zero_inc, red_add) + rem = rem(idx, two_idx) + odd = eq(rem, one_idx) + negate_if = if(loop, odd) + negate_if_false = read(negate_if, control(0)) + negate_if_true = read(negate_if, control(1)) + negate_bottom = region(negate_if_false, negate_if_true) + read = read(a, position(idx)) + read_neg = neg(read) + read_phi = phi(negate_bottom, read, read_neg) + red_add = add(red, read_phi) + in_bounds = lt(idx_inc, bound) + if = if(negate_bottom, in_bounds) + if_false = read(if, control(0)) + if_true = read(if, control(1)) + r = return(if_false, red_add) \ No newline at end of file diff --git a/hercules_tools/hercules_dot/src/main.rs b/hercules_tools/hercules_dot/src/main.rs index dfe8db49c5acdcdfdb65a7c782fea820bcc373f4..425b9478399cea4a753abb5726b527bb03c271f6 100644 --- a/hercules_tools/hercules_dot/src/main.rs +++ b/hercules_tools/hercules_dot/src/main.rs @@ -36,6 +36,8 @@ fn main() { pm.add_pass(hercules_opt::pass::Pass::DCE); pm.add_pass(hercules_opt::pass::Pass::Forkify); pm.add_pass(hercules_opt::pass::Pass::DCE); + pm.add_pass(hercules_opt::pass::Pass::Predication); + pm.add_pass(hercules_opt::pass::Pass::DCE); let mut module = pm.run_passes(); let (def_uses, reverse_postorders, typing, subgraphs, doms, _postdoms, fork_join_maps) =