Skip to content
Snippets Groups Projects
Commit f0eaad15 authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'fix-juno_matmul-schedule' into 'main'

Fix reduce cycles and nodes in fork joins

See merge request !180
parents dc012b1a 544024f2
No related branches found
No related tags found
1 merge request!180Fix reduce cycles and nodes in fork joins
Pipeline #201671 passed
......@@ -87,7 +87,7 @@ pub fn reduce_cycles(
function: &Function,
def_use: &ImmutableDefUseMap,
fork_join_map: &HashMap<NodeID, NodeID>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let reduces = (0..function.nodes.len())
.filter(|idx| function.nodes[*idx].is_reduce())
......@@ -101,6 +101,24 @@ pub fn reduce_cycles(
for reduce in reduces {
let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
let fork = join_fork_map[&join];
let might_be_in_fork_join = |id| {
fork_join_nest
.get(&id)
.map(|nest| nest.contains(&fork))
.unwrap_or(true)
&& function.nodes[id.idx()]
.try_phi()
.map(|(control, _)| fork_join_nest[&control].contains(&fork))
.unwrap_or(true)
&& function.nodes[id.idx()]
.try_reduce()
.map(|(control, _, _)| fork_join_nest[&control].contains(&fork))
.unwrap_or(true)
&& !function.nodes[id.idx()].is_parameter()
&& !function.nodes[id.idx()].is_constant()
&& !function.nodes[id.idx()].is_dynamic_constant()
&& !function.nodes[id.idx()].is_undef()
};
// Find nodes in the fork-join that the reduce can reach through uses.
let mut reachable_uses = HashSet::new();
......@@ -109,10 +127,7 @@ pub fn reduce_cycles(
workset.push(reduct);
while let Some(pop) = workset.pop() {
for u in get_uses(&function.nodes[pop.idx()]).as_ref() {
if !reachable_uses.contains(u)
&& nodes_in_fork_joins[&fork].contains(u)
&& *u != reduce
{
if !reachable_uses.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
reachable_uses.insert(*u);
workset.push(*u);
}
......@@ -126,10 +141,7 @@ pub fn reduce_cycles(
workset.push(reduce);
while let Some(pop) = workset.pop() {
for u in def_use.get_users(pop) {
if !reachable_users.contains(u)
&& nodes_in_fork_joins[&fork].contains(u)
&& *u != reduce
{
if !reachable_users.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
reachable_users.insert(*u);
workset.push(*u);
}
......@@ -155,6 +167,7 @@ pub fn nodes_in_fork_joins(
function: &Function,
def_use: &ImmutableDefUseMap,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let mut result = HashMap::new();
......@@ -164,6 +177,7 @@ pub fn nodes_in_fork_joins(
let mut set = HashSet::new();
set.insert(*fork);
// Iterate uses of the fork.
while let Some(item) = worklist.pop() {
for u in def_use.get_users(item) {
let terminate = *u == *join
......@@ -177,6 +191,15 @@ pub fn nodes_in_fork_joins(
set.insert(*u);
}
}
assert!(set.contains(join));
// Add all the nodes in the reduce cycle. Some of these nodes may not
// use thread IDs of the fork, so do this explicitly.
for u in def_use.get_users(*join) {
if let Some(cycle) = reduce_cycles.get(u) {
set.extend(cycle);
}
}
result.insert(*fork, set);
}
......
......@@ -292,7 +292,11 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
// gravestone.
for id in deleted_nodeids.iter() {
// Check that there are no users of deleted nodes.
assert!(editor.mut_def_use[id.idx()].is_empty());
assert!(
editor.mut_def_use[id.idx()].is_empty(),
"PANIC: Attempted to delete node {:?}, but there are still users of this node ({:?}).",
id, editor.mut_def_use[id.idx()]
);
editor.function.nodes[id.idx()] = Node::Start;
}
......
gvn(*);
phi-elim(*);
forkify(*);
fork-guard-elim(*);
dce(*);
let out = auto-outline(*);
gpu(out.matmul, out.tiled_64_matmul);
fixpoint {
reduce-slf(*);
slf(*);
infer-schedules(*);
}
fork-coalesce(*);
infer-schedules(*);
dce(*);
let out = auto-outline(*);
gpu(out.matmul);
ip-sroa(*);
sroa(*);
dce(*);
gvn(*);
phi-elim(*);
dce(*);
fixpoint panic after 20 {
forkify(out.matmul);
fork-guard-elim(out.matmul);
}
gvn(out.matmul);
phi-elim(out.matmul);
dce(out.matmul);
fixpoint panic after 20 {
reduce-slf(out.matmul);
slf(out.matmul);
infer-schedules(out.matmul);
}
fork-coalesce(out.matmul);
dce(out.matmul);
gcm(*);
float-collections(*);
dce(*);
gcm(*);
......@@ -32,11 +32,6 @@ fn main() {
.run(I as u64, J as u64, K as u64, a.clone(), b.clone())
.await;
assert_eq!(c.as_slice::<i32>(), &*correct_c);
let mut r = runner!(tiled_64_matmul);
let tiled_c = r
.run(I as u64, J as u64, K as u64, a.clone(), b.clone())
.await;
assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c);
}
#[cfg(feature = "cuda")]
{
......@@ -49,13 +44,6 @@ fn main() {
let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
c.to_cpu_ref(&mut c_cpu);
assert_eq!(&*c_cpu, &*correct_c);
let mut r = runner!(tiled_64_matmul);
let tiled_c = r
.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
.await;
let mut tiled_c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
tiled_c.to_cpu_ref(&mut tiled_c_cpu);
assert_eq!(&*tiled_c_cpu, &*correct_c);
}
});
}
......
......@@ -10,52 +10,5 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[
}
}
@exit
return res;
}
#[entry]
fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
let res : i32[n, l];
let atile : i32[64, 64];
let btile : i32[64, 64];
let ctile : i32[64, 64];
for bi = 0 to n / 64 {
for bk = 0 to l / 64 {
for ti = 0 to 64 {
for tk = 0 to 64 {
atile[ti, tk] = 0;
btile[ti, tk] = 0;
ctile[ti, tk] = 0;
}
}
for tile_idx = 0 to m / 64 {
for ti = 0 to 64 {
for tk = 0 to 64 {
atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk];
btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk];
}
}
for ti = 0 to 64 {
for tk = 0 to 64 {
let c_acc = ctile[ti, tk];
for inner_idx = 0 to 64 {
c_acc += atile[ti, inner_idx] * btile[inner_idx, tk];
}
ctile[ti, tk] = c_acc;
}
}
}
for ti = 0 to 64 {
for tk = 0 to 64 {
res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk];
}
}
}
}
return res;
}
......@@ -401,21 +401,19 @@ impl PassManager {
if self.reduce_cycles.is_none() {
self.make_def_uses();
self.make_fork_join_maps();
self.make_nodes_in_fork_joins();
self.make_fork_join_nests();
let def_uses = self.def_uses.as_ref().unwrap().iter();
let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
let nodes_in_fork_joins = self.nodes_in_fork_joins.as_ref().unwrap().iter();
let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter();
self.reduce_cycles = Some(
self.functions
.iter()
.zip(fork_join_maps)
.zip(nodes_in_fork_joins)
.zip(fork_join_nests)
.zip(def_uses)
.map(
|(((function, fork_join_map), nodes_in_fork_joins), def_use)| {
reduce_cycles(function, def_use, fork_join_map, nodes_in_fork_joins)
},
)
.map(|(((function, fork_join_map), fork_join_nests), def_use)| {
reduce_cycles(function, def_use, fork_join_map, fork_join_nests)
})
.collect(),
);
}
......@@ -425,16 +423,20 @@ impl PassManager {
if self.nodes_in_fork_joins.is_none() {
self.make_def_uses();
self.make_fork_join_maps();
self.make_reduce_cycles();
self.nodes_in_fork_joins = Some(
zip(
self.functions.iter(),
zip(
self.def_uses.as_ref().unwrap().iter(),
self.fork_join_maps.as_ref().unwrap().iter(),
zip(
self.fork_join_maps.as_ref().unwrap().iter(),
self.reduce_cycles.as_ref().unwrap().iter(),
),
),
)
.map(|(function, (def_use, fork_join_map))| {
nodes_in_fork_joins(function, def_use, fork_join_map)
.map(|(function, (def_use, (fork_join_map, reduce_cycles)))| {
nodes_in_fork_joins(function, def_use, fork_join_map, reduce_cycles)
})
.collect(),
);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment