Skip to content
Snippets Groups Projects
Commit 423235ed authored by rarbore2's avatar rarbore2
Browse files

Incorporate anti-dependencies into GCM

parent 595ab192
No related branches found
No related tags found
1 merge request!92Incorporate anti-dependencies into GCM
......@@ -43,24 +43,14 @@ pub fn antideps(
objects: &FunctionCollectionObjects,
) -> Vec<(NodeID, NodeID)> {
// First, we analyze "generations" of collections as they are mutated.
// Originating nodes and mutating nodes start a new generation of a
// Originating, mutating, phi, and reduce nodes start a new generation of a
// collection. Generations are not ordered due to loops, but are rather just
// node IDs of the originating or mutating node (parameter, constant, call,
// undef, write). Other nodes operating on collections mean reads / writes
// can operate on potentially different generations of multiple collections
// (phi, reduce, select).
// node IDs of node (parameter, constant, call, undef, write, phi, reduce).
// Other nodes operating on collections mean reads / writes can operate on
// potentially different generations of multiple collections (select).
let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| {
match function.nodes[id.idx()] {
Node::Phi {
control: _,
data: _,
}
| Node::Reduce {
control: _,
init: _,
reduct: _,
}
| Node::Ternary {
Node::Ternary {
op: TernaryOperator::Select,
first: _,
second: _,
......@@ -121,12 +111,21 @@ pub fn antideps(
collect: _,
indices: _,
} => inputs[0].clone(),
Node::Write {
Node::Phi {
control: _,
data: _,
}
| Node::Reduce {
control: _,
init: _,
reduct: _,
}
| Node::Write {
collect: _,
data: _,
indices: _,
} => {
// Writes update the generation to the write.
// Phis, reduces, and writes update the generation to the write.
let objs = inputs[0].objs.iter().map(|(obj, _)| (*obj, id)).collect();
GenerationLattice { objs }
}
......@@ -169,7 +168,7 @@ pub fn antideps(
}
Node::Write {
collect,
data: _,
data,
indices: _,
} => {
for pair in lattice[collect.idx()].objs.iter() {
......@@ -179,6 +178,19 @@ pub fn antideps(
.1
.push(id);
}
// When a write takes a collection on its `data` input, it
// memcpys that collection into the mutated collection. This is
// a read.
if !objects.objects(*data).is_empty() {
for pair in lattice[collect.idx()].objs.iter() {
reads_writes_calls_mut_calls_per_pair
.entry(*pair)
.or_default()
.0
.push(id);
}
}
}
Node::Call {
control: _,
......@@ -219,7 +231,7 @@ pub fn antideps(
// Case 1:
for read in reads.iter() {
for write in writes.iter() {
if *write != gen {
if *write != gen && *read != *write {
antideps.push((*read, *write));
}
}
......@@ -228,7 +240,7 @@ pub fn antideps(
// Case 2:
for read in reads.iter() {
for mut_call in mut_calls.iter() {
if *mut_call != gen {
if *mut_call != gen && *read != *mut_call {
antideps.push((*read, *mut_call));
}
}
......@@ -237,7 +249,7 @@ pub fn antideps(
// Case 3:
for call in calls.iter().chain(mut_calls.iter()) {
for write in writes.iter() {
if *write != gen {
if *write != gen && *call != *write {
antideps.push((*call, *write));
}
}
......
......@@ -30,12 +30,18 @@ pub fn gcm(
function: &Function,
def_use: &ImmutableDefUseMap,
reverse_postorder: &Vec<NodeID>,
control_subgraph: &Subgraph,
dom: &DomTree,
antideps: &Vec<(NodeID, NodeID)>,
loops: &LoopTree,
fork_join_map: &HashMap<NodeID, NodeID>,
) -> BasicBlocks {
let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];
let back_edges = control_subgraph.back_edges(NodeID::new(0));
let no_loop_reachability =
control_subgraph.pairwise_reachability(|src, dst| !back_edges.contains(&(src, dst)));
let antideps_users = map_antideps(antideps);
let antideps_uses = flip_antideps(antideps);
// Step 1: assign the basic block locations of all nodes that must be in a
// specific block. This includes control nodes as well as some special data
......@@ -103,7 +109,7 @@ pub fn gcm(
.map(|(fork, join)| (*join, *fork))
.collect();
let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev());
while let Some(id) = worklist.pop_front() {
'worklist: while let Some(id) = worklist.pop_front() {
if bbs[id.idx()].is_some() {
continue;
}
......@@ -169,13 +175,38 @@ pub fn gcm(
continue;
};
// Check if all anti-dependency users have been placed. If one of them
// hasn't, then add this node back on to the worklist. We need to know
// where the anti-dependency users are, so that we can place this
// read "above" mutators that anti-depend on it. The condition for a
// potential placement location is that in the CFG *without loop back-
// edges* the mutator cannot reach the read. Ask Russel about why this
// works, hopefully I'll have a convincing argument by then ;).
let mut antidep_user_locations = vec![];
for antidep_user in antideps_users.get(&id).unwrap_or(&vec![]) {
if let Some(location) = bbs[antidep_user.idx()] {
antidep_user_locations.push(location);
} else {
worklist.push_back(id);
continue 'worklist;
}
}
// Look between the LCA and the schedule early location to place the
// node.
let schedule_early = schedule_early[id.idx()].unwrap();
let mut chain = dom
// If the node has no users, then it doesn't really matter where we
// place it - just place it at the early placement.
.chain(lca.unwrap_or(schedule_early), schedule_early);
.chain(lca.unwrap_or(schedule_early), schedule_early)
// Only allow locations that don't violate the anti-depence property
// listed above.
.filter(|location| {
!antidep_user_locations.iter().any(|antidep_user_location| {
antidep_user_location != location
&& no_loop_reachability[antidep_user_location.idx()][location.idx()]
})
});
let mut location = chain.next().unwrap();
while let Some(control_node) = chain.next() {
// If the next node further up the dominator tree is in a shallower
......@@ -211,7 +242,6 @@ pub fn gcm(
.filter(|id| !function.nodes[id.idx()].is_control()),
);
let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()];
let antideps = flip_antideps(&antideps);
while let Some(id) = worklist.pop_front() {
let node = &function.nodes[id.idx()];
if node.is_phi()
......@@ -219,7 +249,7 @@ pub fn gcm(
|| get_uses(node)
.as_ref()
.into_iter()
.chain(antideps.get(&id).into_iter().flatten())
.chain(antideps_uses.get(&id).into_iter().flatten())
.all(|u| {
function.nodes[u.idx()].is_control()
|| bbs[u.idx()] != bbs[id.idx()]
......
use crate::*;
extern crate bitvec;
use std::collections::{HashMap, HashSet};
use std::mem::take;
use self::bitvec::prelude::*;
use std::collections::HashMap;
use crate::*;
/*
* In various parts of the compiler, we want to consider a subset of a complete
......@@ -166,6 +171,68 @@ impl Subgraph {
original_num_nodes,
}
}
pub fn back_edges(&self, root: NodeID) -> HashSet<(NodeID, NodeID)> {
let mut edges = HashSet::new();
let mut stack = vec![];
let mut visited = bitvec![u8, Lsb0; 0; self.original_num_nodes as usize];
let mut in_stack = bitvec![u8, Lsb0; 0; self.original_num_nodes as usize];
stack.push(root);
visited.set(root.idx(), true);
in_stack.set(root.idx(), true);
while let Some(pop) = stack.pop() {
if self.succs(pop).any(|succ| !visited[succ.idx()]) {
stack.push(pop);
} else {
in_stack.set(pop.idx(), false);
}
for succ in self.succs(pop) {
if in_stack[succ.idx()] {
edges.insert((pop, succ));
} else if !visited[succ.idx()] {
visited.set(succ.idx(), true);
stack.push(succ);
in_stack.set(succ.idx(), true);
break;
}
}
}
edges
}
pub fn pairwise_reachability<P>(&self, p: P) -> Vec<BitVec<u8, Lsb0>>
where
P: Fn(NodeID, NodeID) -> bool,
{
let mut reachability = vec![
bitvec![u8, Lsb0; 0; self.original_num_nodes as usize];
self.original_num_nodes as usize
];
// Every control node can reach itself.
for node in self.iter() {
reachability[node.idx()].set(node.idx(), true);
}
// Propagate N - 1 times, as that is the length of the longest path.
for _ in 0..self.original_num_nodes {
for node in self.iter() {
for succ in self.succs(*node) {
assert_ne!(*node, succ);
if p(*node, succ) {
let mut node_reachability = take(&mut reachability[node.idx()]);
node_reachability |= &reachability[succ.idx()];
reachability[node.idx()] = node_reachability;
}
}
}
}
reachability
}
}
/*
......
......@@ -12,6 +12,10 @@ use crate::*;
* in the backend that collections have reference, rather than value, semantics.
* The pass calling this function is mandatory for correctness.
*/
pub fn materialize_clones(editor: &mut FunctionEditor, objects: &FunctionCollectionObjects) {
pub fn materialize_clones(
editor: &mut FunctionEditor,
objects: &FunctionCollectionObjects,
bbs: &BasicBlocks,
) {
todo!()
}
......@@ -284,12 +284,14 @@ impl PassManager {
if self.bbs.is_none() {
self.make_def_uses();
self.make_reverse_postorders();
self.make_control_subgraphs();
self.make_doms();
self.make_antideps();
self.make_loops();
self.make_fork_join_maps();
let def_uses = self.def_uses.as_ref().unwrap().iter();
let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter();
let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter();
let doms = self.doms.as_ref().unwrap().iter();
let antideps = self.antideps.as_ref().unwrap().iter();
let loops = self.loops.as_ref().unwrap().iter();
......@@ -301,19 +303,29 @@ impl PassManager {
def_uses,
zip(
reverse_postorders,
zip(doms, zip(antideps, zip(loops, fork_join_maps))),
zip(
control_subgraphs,
zip(doms, zip(antideps, zip(loops, fork_join_maps))),
),
),
),
)
.map(
|(
function,
(def_use, (reverse_postorder, (dom, (antideps, (loops, fork_join_map))))),
(
def_use,
(
reverse_postorder,
(control_subgraph, (dom, (antideps, (loops, fork_join_map)))),
),
),
)| {
gcm(
function,
def_use,
reverse_postorder,
control_subgraph,
dom,
antideps,
loops,
......@@ -802,8 +814,10 @@ impl PassManager {
Pass::MaterializeClones => {
self.make_def_uses();
self.make_collection_objects();
self.make_bbs();
let def_uses = self.def_uses.as_ref().unwrap();
let collection_objects = self.collection_objects.as_ref().unwrap();
let bbs = self.bbs.as_ref().unwrap();
for idx in 0..self.module.functions.len() {
let constants_ref =
RefCell::new(std::mem::take(&mut self.module.constants));
......@@ -817,7 +831,11 @@ impl PassManager {
&types_ref,
&def_uses[idx],
);
materialize_clones(&mut editor, &collection_objects[&FunctionID::new(idx)]);
materialize_clones(
&mut editor,
&collection_objects[&FunctionID::new(idx)],
&bbs[idx],
);
self.module.constants = constants_ref.take();
self.module.dynamic_constants = dynamic_constants_ref.take();
......
......@@ -9,16 +9,68 @@ fn simple_antideps(a : usize, b : usize) -> i32 {
#[entry]
fn complex_antideps(x : i32) -> i32 {
let arr : i32[4];
arr[1] = 7;
let r = 0;
let arr2 : i32[12];
arr[1] = 7 + arr2[0];
let r = arr2[1];
arr2[1] = 99;
while x > 6 {
arr2[10] = 98;
if x > 5 {
r = arr[1];
r = arr[1] + arr2[(x + 1) as usize];
} else {
arr2[11] = 97;
arr[1] = 8;
r = arr[1];
}
arr2[x as usize] = 13;
x -= 1;
}
return r;
}
#[entry]
fn very_complex_antideps(x: usize) -> usize {
let arr1 : usize[203];
let arr2 : usize[202];
let arr3 : usize[201];
let arr4 : usize[200];
let y = x;
while y < 10 {
arr1[y] = y;
arr2[y] = arr1[x];
y += 4;
continue;
}
let z = arr1[x] + arr1[x + 4] + arr2[x + 4];
if z == 13 {
z = 12;
}
while true {
if arr1[x] == 3 {
z = 11;
break;
} else {
z = 10;
continue;
}
}
if z == 11 {
arr3[z - 10] = 99;
arr4[arr3[z - 10] - 98] = 4;
arr4[arr4[1]] = 42;
} else {
arr1[203] = 99999999;
arr1[204] = 99999999;
arr1[1000] = 99999999;
}
let w = arr4[4] + arr2[7];
while z > 0 {
arr2[w] = arr3[1];
arr4[w] = arr2[w];
z += z;
if z > 100 {
break;
}
}
return arr4[w] + w;
}
......@@ -13,7 +13,11 @@ fn main() {
let output = complex_antideps(9).await;
println!("{}", output);
assert_eq!(output, 7);
assert_eq!(output, 20);
let output = very_complex_antideps(3).await;
println!("{}", output);
assert_eq!(output, 144);
});
}
......
......@@ -9,7 +9,7 @@ fn main() {
async_std::task::block_on(async {
let output = implicit_clone(3).await;
println!("{}", output);
assert_eq!(output, 9);
assert_eq!(output, 11);
});
}
......
......@@ -31,7 +31,14 @@ fn main() {
J * K * 4,
);
};
let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
let c_bytes = matmul(
I as u64,
J as u64,
K as u64,
a_bytes.clone(),
b_bytes.clone(),
)
.await;
let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
unsafe {
copy_nonoverlapping(
......@@ -40,6 +47,15 @@ fn main() {
I * K * 4,
);
};
let tiled_c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
let mut tiled_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
unsafe {
copy_nonoverlapping(
Box::as_ptr(&tiled_c_bytes) as *const u8,
Box::as_mut_ptr(&mut tiled_c) as *mut u8,
I * K * 4,
);
};
let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
for i in 0..I {
for k in 0..K {
......@@ -49,6 +65,7 @@ fn main() {
}
}
assert_eq!(c, correct_c);
assert_eq!(tiled_c, correct_c);
});
}
......
......@@ -14,7 +14,6 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[
return res;
}
/*
#[entry]
fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
let res : i32[n, l];
......@@ -60,4 +59,3 @@ fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]
return res;
}
*/
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment