Skip to content
Snippets Groups Projects
Commit 8ac51ec6 authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'predication' into 'main'

Predication pass

See merge request !17
parents fc79696c 97b62c4c
No related branches found
No related tags found
1 merge request!17Predication pass
......@@ -206,6 +206,7 @@ dependencies = [
name = "hercules_opt"
version = "0.1.0"
dependencies = [
"bitvec",
"hercules_ir",
"ordered-float",
]
......
......@@ -61,7 +61,8 @@ pub fn gcm(
let highest =
dom.lowest_amongst(immediate_control_uses[idx].nodes(function.nodes.len() as u32));
let lowest = dom
.common_ancestor(immediate_control_users[idx].nodes(function.nodes.len() as u32));
.common_ancestor(immediate_control_users[idx].nodes(function.nodes.len() as u32))
.unwrap_or(highest);
// Collect into vector to reverse, since we want to traverse down
// the dom tree, not up it.
......
......@@ -80,13 +80,16 @@ impl DomTree {
.1
}
pub fn common_ancestor<I>(&self, x: I) -> NodeID
pub fn common_ancestor<I>(&self, x: I) -> Option<NodeID>
where
I: Iterator<Item = NodeID>,
{
let mut positions: HashMap<NodeID, u32> = x
.map(|x| (x, if x == self.root { 0 } else { self.idom[&x].0 }))
.collect();
if positions.len() == 0 {
return None;
}
let mut current_level = *positions.iter().map(|(_, level)| level).max().unwrap();
while positions.len() > 1 {
let at_current_level: Vec<NodeID> = positions
......@@ -102,7 +105,7 @@ impl DomTree {
}
current_level -= 1;
}
positions.into_iter().next().unwrap().0
Some(positions.into_iter().next().unwrap().0)
}
pub fn chain<'a>(&'a self, bottom: NodeID, top: NodeID) -> DomChainIterator<'a> {
......
......@@ -565,7 +565,7 @@ impl Function {
let old_id = **u;
let new_id = node_mapping[old_id.idx()];
if new_id == NodeID::new(0) && old_id != NodeID::new(0) {
panic!("While deleting gravestones, came across a use of a gravestoned node. The user has ID {} and was using {}.", idx, old_id.idx());
panic!("While deleting gravestones, came across a use of a gravestoned node. The user has ID {} and was using ID {}. Here's the user: {:?}", idx, old_id.idx(), node);
}
**u = new_id;
}
......@@ -766,6 +766,14 @@ impl Index {
}
}
pub fn try_control(&self) -> Option<usize> {
if let Index::Control(val) = self {
Some(*val)
} else {
None
}
}
pub fn lower_case_name(&self) -> &'static str {
match self {
Index::Field(_) => "field",
......@@ -834,6 +842,14 @@ impl Node {
);
define_pattern_predicate!(is_match, Node::Match { control: _, sum: _ });
pub fn try_region(&self) -> Option<&[NodeID]> {
if let Node::Region { preds } = self {
Some(preds)
} else {
None
}
}
pub fn try_if(&self) -> Option<(NodeID, NodeID)> {
if let Node::If { control, cond } = self {
Some((*control, *cond))
......
......@@ -8,7 +8,7 @@ use crate::*;
* consideration at some point during the compilation pipeline. Each schedule is
* associated with a single node.
*/
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Schedule {
ParallelReduce,
Vectorize,
......
......@@ -5,4 +5,5 @@ authors = ["Russel Arbore <rarbore2@illinois.edu>"]
[dependencies]
ordered-float = "*"
bitvec = "*"
hercules_ir = { path = "../hercules_ir" }
......@@ -224,6 +224,9 @@ pub fn forkify(
function.nodes[idx_phi.idx()] = Node::Start;
// Delete old loop control nodes;
for user in def_use.get_users(*header) {
get_uses_mut(&mut function.nodes[user.idx()]).map(*header, fork_id);
}
function.nodes[header.idx()] = Node::Start;
function.nodes[loop_end.idx()] = Node::Start;
function.nodes[loop_true_read.idx()] = Node::Start;
......
......@@ -3,9 +3,11 @@ pub mod dce;
pub mod forkify;
pub mod gvn;
pub mod pass;
pub mod pred;
pub use crate::ccp::*;
pub use crate::dce::*;
pub use crate::forkify::*;
pub use crate::gvn::*;
pub use crate::pass::*;
pub use crate::pred::*;
......@@ -24,6 +24,7 @@ pub enum Pass {
CCP,
GVN,
Forkify,
Predication,
Verify,
Xdot,
}
......@@ -219,6 +220,26 @@ impl PassManager {
)
}
}
Pass::Predication => {
self.make_def_uses();
self.make_reverse_postorders();
self.make_doms();
self.make_fork_join_maps();
let def_uses = self.def_uses.as_ref().unwrap();
let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
let doms = self.doms.as_ref().unwrap();
let fork_join_maps = self.fork_join_maps.as_ref().unwrap();
for idx in 0..self.module.functions.len() {
predication(
&mut self.module.functions[idx],
&def_uses[idx],
&reverse_postorders[idx],
&doms[idx],
&fork_join_maps[idx],
&vec![],
)
}
}
Pass::Verify => {
let (
def_uses,
......
extern crate bitvec;
extern crate hercules_ir;
use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;
use self::bitvec::prelude::*;
use self::hercules_ir::def_use::*;
use self::hercules_ir::dom::*;
use self::hercules_ir::ir::*;
use self::hercules_ir::schedule::*;
/*
* Top level function to convert acyclic control flow in vectorized fork-joins
* into predicated data flow.
*/
pub fn predication(
function: &mut Function,
def_use: &ImmutableDefUseMap,
reverse_postorder: &Vec<NodeID>,
dom: &DomTree,
fork_join_map: &HashMap<NodeID, NodeID>,
schedules: &Vec<Vec<Schedule>>,
) {
// Detect forks with vectorize schedules.
let vector_forks: Vec<_> = function
.nodes
.iter()
.enumerate()
//.filter(|(idx, n)| n.is_fork() && schedules[*idx].contains(&Schedule::Vectorize))
.filter(|(_, n)| n.is_fork())
.map(|(idx, _)| NodeID::new(idx))
.collect();
// Filter forks that can't actually be vectorized, and yell at the user if
// they're being silly.
let actual_vector_forks: Vec<_> = vector_forks
.into_iter()
.filter_map(|fork_id| {
// Detect cycles in control flow between fork and join. Start at the
// join, and work backwards.
let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()];
let join_id = fork_join_map[&fork_id];
let mut stack = vec![join_id];
while let Some(pop) = stack.pop() {
// Only detect cycles between fork and join, and don't revisit
// nodes.
if visited[pop.idx()] || function.nodes[pop.idx()].is_fork() {
continue;
}
// Filter if there is a cycle, or if there is a nested fork, or
// if there is a match node. We know there is a loop if a node
// dominates one of its predecessors.
let control_uses: Vec<_> = get_uses(&function.nodes[pop.idx()])
.as_ref()
.iter()
.filter(|id| function.nodes[id.idx()].is_control())
.map(|x| *x)
.collect();
if control_uses
.iter()
.any(|pred_id| dom.does_dom(pop, *pred_id))
|| (function.nodes[pop.idx()].is_join() && pop != join_id)
|| function.nodes[pop.idx()].is_match()
{
eprintln!(
"WARNING: Vectorize schedule attached to fork that cannot be vectorized."
);
return None;
}
// Recurse up the control subgraph.
visited.set(pop.idx(), true);
stack.extend(control_uses);
}
Some((fork_id, visited))
})
.collect();
// For each control node, collect which condition values must be true, and
// which condition values must be false to reach that node. Each phi's
// corresponding region will have at least one condition value that differs
// between the predecessors. These differing condition values anded together
// form the select condition.
let mut condition_valuations: HashMap<NodeID, (HashSet<NodeID>, HashSet<NodeID>)> =
HashMap::new();
for (fork_id, control_in_fork_join) in actual_vector_forks.iter() {
// Within a fork-join, there are no condition requirements on the fork.
condition_valuations.insert(*fork_id, (HashSet::new(), HashSet::new()));
// Iterate the nodes in the fork-join in reverse postorder, top-down.
let local_reverse_postorder = reverse_postorder
.iter()
.filter(|id| control_in_fork_join[id.idx()]);
for control_id in local_reverse_postorder {
match function.nodes[control_id.idx()] {
Node::If { control, cond: _ } | Node::Join { control } => {
condition_valuations
.insert(*control_id, condition_valuations[&control].clone());
}
// Introduce condition variables into sets, as this is where
// branching occurs.
Node::Read {
collect,
ref indices,
} => {
assert_eq!(indices.len(), 1);
let truth_value = indices[0].try_control().unwrap();
assert!(truth_value < 2);
let mut sets = condition_valuations[&collect].clone();
let condition = function.nodes[collect.idx()].try_if().unwrap().1;
if truth_value == 0 {
sets.0.insert(condition);
} else {
sets.1.insert(condition);
}
condition_valuations.insert(*control_id, sets);
}
// The only required conditions for a region are those required
// for all predecessors. Thus, the condition sets for a region
// are the intersections of the predecessor condition sets.
Node::Region { ref preds } => {
let (prev_true_set, prev_false_set) = condition_valuations[&preds[0]].clone();
let int_true_set = preds[1..].iter().fold(prev_true_set, |a, b| {
a.intersection(&condition_valuations[b].0)
.map(|x| *x)
.collect::<HashSet<NodeID>>()
});
let int_false_set = preds[1..].iter().fold(prev_false_set, |a, b| {
a.intersection(&condition_valuations[b].0)
.map(|x| *x)
.collect::<HashSet<NodeID>>()
});
condition_valuations.insert(*control_id, (int_true_set, int_false_set));
}
_ => {
panic!()
}
}
}
}
// Convert control flow to predicated data flow.
for (fork_id, control_in_fork_join) in actual_vector_forks.into_iter() {
// Worklist of control nodes - traverse control backwards breadth-first.
let mut queue = VecDeque::new();
let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()];
let join_id = fork_join_map[&fork_id];
queue.push_back(join_id);
while let Some(pop) = queue.pop_front() {
// Stop at forks, and don't revisit nodes.
if visited[pop.idx()] || function.nodes[pop.idx()].is_fork() {
continue;
}
// The only type of node we need to handle at this point are region
// nodes. Region nodes are what have phi users, and those phis are
// what need to get converted to select nodes.
if let Node::Region { preds } = &function.nodes[pop.idx()] {
// Get the unique true and false conditions per predecessor.
// These are the conditions attached to the predecessor that
// aren't attached to this region.
assert_eq!(preds.len(), 2);
let (region_true_conds, region_false_conds) = &condition_valuations[&pop];
let unique_conditions = preds
.iter()
.map(|pred_id| {
let (pred_true_conds, pred_false_conds) = &condition_valuations[pred_id];
(
pred_true_conds
.iter()
.filter(|cond_id| !region_true_conds.contains(cond_id))
.map(|x| *x)
.collect::<HashSet<NodeID>>(),
pred_false_conds
.iter()
.filter(|cond_id| !region_false_conds.contains(cond_id))
.map(|x| *x)
.collect::<HashSet<NodeID>>(),
)
})
.collect::<Vec<_>>();
// Currently, we only handle if branching. The unique conditions
// for a region's predecessors must be exact inverses of each
// other. Given this is true, we just use unique_conditions[0]
// to calculate the select condition.
assert_eq!(unique_conditions[0].0, unique_conditions[1].1);
assert_eq!(unique_conditions[0].1, unique_conditions[1].0);
let negated_conditions = unique_conditions[0]
.1
.iter()
.map(|cond_id| {
let id = NodeID::new(function.nodes.len());
function.nodes.push(Node::Unary {
input: *cond_id,
op: UnaryOperator::Not,
});
id
})
.collect::<Vec<NodeID>>();
let mut all_conditions = unique_conditions[0]
.0
.iter()
.map(|x| *x)
.chain(negated_conditions.into_iter());
// And together the negated negative and position conditions.
let first_cond = all_conditions.next().unwrap();
let reduced_cond = all_conditions.into_iter().fold(first_cond, |a, b| {
let id = NodeID::new(function.nodes.len());
function.nodes.push(Node::Binary {
left: a,
right: b,
op: BinaryOperator::And,
});
id
});
// Create the select nodes, corresponding to all phi users.
for phi in def_use.get_users(pop) {
if let Node::Phi { control: _, data } = &function.nodes[phi.idx()] {
let select_id = NodeID::new(function.nodes.len());
function.nodes.push(Node::Ternary {
first: reduced_cond,
second: data[1],
third: data[0],
op: TernaryOperator::Select,
});
for user in def_use.get_users(*phi) {
get_uses_mut(&mut function.nodes[user.idx()]).map(*phi, select_id);
}
function.nodes[phi.idx()] = Node::Start;
}
}
}
// Add users of this control node to queue.
visited.set(pop.idx(), true);
queue.extend(
get_uses(&function.nodes[pop.idx()])
.as_ref()
.iter()
.filter(|id| function.nodes[id.idx()].is_control() && !visited[id.idx()]),
);
}
// Now that we've converted all the phis to selects, delete all the
// control nodes.
for control_idx in control_in_fork_join.iter_ones() {
if let Node::Join { control } = function.nodes[control_idx] {
get_uses_mut(&mut function.nodes[control_idx]).map(control, fork_id);
} else {
function.nodes[control_idx] = Node::Start;
}
}
}
}
......@@ -14,3 +14,29 @@ fn sum(a: array(f32, 16)) -> f32
if_false = read(if, control(0))
if_true = read(if, control(1))
r = return(if_false, red_add)
fn alt_sum<1>(a: array(f32, #0)) -> f32
zero_idx = constant(u64, 0)
one_idx = constant(u64, 1)
two_idx = constant(u64, 2)
zero_inc = constant(f32, 0)
bound = dynamic_constant(#0)
loop = region(start, if_true)
idx = phi(loop, zero_idx, idx_inc)
idx_inc = add(idx, one_idx)
red = phi(loop, zero_inc, red_add)
rem = rem(idx, two_idx)
odd = eq(rem, one_idx)
negate_if = if(loop, odd)
negate_if_false = read(negate_if, control(0))
negate_if_true = read(negate_if, control(1))
negate_bottom = region(negate_if_false, negate_if_true)
read = read(a, position(idx))
read_neg = neg(read)
read_phi = phi(negate_bottom, read, read_neg)
red_add = add(red, read_phi)
in_bounds = lt(idx_inc, bound)
if = if(negate_bottom, in_bounds)
if_false = read(if, control(0))
if_true = read(if, control(1))
r = return(if_false, red_add)
\ No newline at end of file
......@@ -36,6 +36,8 @@ fn main() {
pm.add_pass(hercules_opt::pass::Pass::DCE);
pm.add_pass(hercules_opt::pass::Pass::Forkify);
pm.add_pass(hercules_opt::pass::Pass::DCE);
pm.add_pass(hercules_opt::pass::Pass::Predication);
pm.add_pass(hercules_opt::pass::Pass::DCE);
let mut module = pm.run_passes();
let (def_uses, reverse_postorders, typing, subgraphs, doms, _postdoms, fork_join_maps) =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment