Skip to content
Snippets Groups Projects
Commit 0d27e11f authored by Russel Arbore's avatar Russel Arbore
Browse files

Merge branch 'main' into forkify-fixes

parents 6382ef42 259b6268
No related branches found
No related tags found
2 merge requests!157Fork fission bufferize,!144Forkify fixes
Pipeline #201377 failed
This commit is part of merge request !144. Comments created here will be created in the context of that merge request.
...@@ -1521,6 +1521,8 @@ extern \"C\" {} {}(", ...@@ -1521,6 +1521,8 @@ extern \"C\" {} {}(",
let collect_variable = self.get_value(*collect, false, false); let collect_variable = self.get_value(*collect, false, false);
write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?; write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?;
} }
// Undef nodes never need to be assigned to.
Node::Undef { ty: _ } => {}
_ => { _ => {
panic!( panic!(
"Unsupported data node type: {:?}", "Unsupported data node type: {:?}",
......
...@@ -181,8 +181,13 @@ fn preliminary_fixups( ...@@ -181,8 +181,13 @@ fn preliminary_fixups(
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> bool { ) -> bool {
let nodes = &editor.func().nodes; let nodes = &editor.func().nodes;
let schedules = &editor.func().schedules;
// Sequentialize non-parallel forks that contain problematic reduce cycles.
for (reduce, cycle) in reduce_cycles { for (reduce, cycle) in reduce_cycles {
if cycle.into_iter().any(|id| nodes[id.idx()].is_reduce()) { if !schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
&& cycle.into_iter().any(|id| nodes[id.idx()].is_reduce())
{
let join = nodes[reduce.idx()].try_reduce().unwrap().0; let join = nodes[reduce.idx()].try_reduce().unwrap().0;
let fork = fork_join_map let fork = fork_join_map
.into_iter() .into_iter()
...@@ -198,6 +203,31 @@ fn preliminary_fixups( ...@@ -198,6 +203,31 @@ fn preliminary_fixups(
return true; return true;
} }
} }
// Get rid of the backward edge on parallel reduces in fork-joins.
for (_, join) in fork_join_map {
let parallel_reduces: Vec<_> = editor
.get_users(*join)
.filter(|id| {
nodes[id.idx()].is_reduce()
&& schedules[id.idx()].contains(&Schedule::ParallelReduce)
})
.collect();
for reduce in parallel_reduces {
if reduce_cycles[&reduce].is_empty() {
continue;
}
let (_, init, _) = nodes[reduce.idx()].try_reduce().unwrap();
// Replace uses of the reduce in its cycle with the init.
let success = editor.edit(|edit| {
edit.replace_all_uses_where(reduce, init, |id| reduce_cycles[&reduce].contains(id))
});
assert!(success);
return true;
}
}
false false
} }
...@@ -511,7 +541,8 @@ fn basic_blocks( ...@@ -511,7 +541,8 @@ fn basic_blocks(
// outside of reduce loops. Nodes that do need to be in a reduce // outside of reduce loops. Nodes that do need to be in a reduce
// loop use the reduce node forming the loop, so the dominator chain // loop use the reduce node forming the loop, so the dominator chain
// will consist of one block, and this loop won't ever iterate. // will consist of one block, and this loop won't ever iterate.
let currently_at_join = function.nodes[location.idx()].is_join(); let currently_at_join = function.nodes[location.idx()].is_join()
&& !function.nodes[control_node.idx()].is_join();
if (!is_constant_or_undef || is_gpu_returned) if (!is_constant_or_undef || is_gpu_returned)
&& (shallower_nest || currently_at_join) && (shallower_nest || currently_at_join)
...@@ -811,7 +842,14 @@ fn spill_clones( ...@@ -811,7 +842,14 @@ fn spill_clones(
.into_iter() .into_iter()
.any(|u| *u == *b) .any(|u| *u == *b)
&& (editor.func().nodes[a.idx()].is_phi() && (editor.func().nodes[a.idx()].is_phi()
|| editor.func().nodes[a.idx()].is_reduce())) || editor.func().nodes[a.idx()].is_reduce())
&& !editor.func().nodes[a.idx()]
.try_reduce()
.map(|(_, init, _)| {
init == *b
&& editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
})
.unwrap_or(false))
}); });
// Step 3: if there is a spill edge, spill it and return true. Otherwise, // Step 3: if there is a spill edge, spill it and return true. Otherwise,
...@@ -989,15 +1027,16 @@ fn liveness_dataflow( ...@@ -989,15 +1027,16 @@ fn liveness_dataflow(
} }
let mut num_phis_reduces = vec![0; function.nodes.len()]; let mut num_phis_reduces = vec![0; function.nodes.len()];
let mut has_phi = vec![false; function.nodes.len()]; let mut has_phi = vec![false; function.nodes.len()];
let mut has_reduce = vec![false; function.nodes.len()]; let mut has_seq_reduce = vec![false; function.nodes.len()];
for (node_idx, bb) in bbs.0.iter().enumerate() { for (node_idx, bb) in bbs.0.iter().enumerate() {
let node = &function.nodes[node_idx]; let node = &function.nodes[node_idx];
if node.is_phi() || node.is_reduce() { if node.is_phi() || node.is_reduce() {
num_phis_reduces[bb.idx()] += 1; num_phis_reduces[bb.idx()] += 1;
} }
has_phi[bb.idx()] = node.is_phi(); has_phi[bb.idx()] = node.is_phi();
has_reduce[bb.idx()] = node.is_reduce(); has_seq_reduce[bb.idx()] =
assert!(!has_phi[bb.idx()] || !has_reduce[bb.idx()]); node.is_reduce() && !function.schedules[node_idx].contains(&Schedule::ParallelReduce);
assert!(!node.is_phi() || !node.is_reduce());
} }
let is_obj = |id: NodeID| !objects[&func_id].objects(id).is_empty(); let is_obj = |id: NodeID| !objects[&func_id].objects(id).is_empty();
...@@ -1009,11 +1048,14 @@ fn liveness_dataflow( ...@@ -1009,11 +1048,14 @@ fn liveness_dataflow(
let last_pt = bbs.1[bb.idx()].len(); let last_pt = bbs.1[bb.idx()].len();
let old_value = &liveness[&bb][last_pt]; let old_value = &liveness[&bb][last_pt];
let mut new_value = BTreeSet::new(); let mut new_value = BTreeSet::new();
for succ in control_subgraph.succs(*bb).chain(if has_reduce[bb.idx()] { for succ in control_subgraph
Either::Left(once(*bb)) .succs(*bb)
} else { .chain(if has_seq_reduce[bb.idx()] {
Either::Right(empty()) Either::Left(once(*bb))
}) { } else {
Either::Right(empty())
})
{
// The liveness at the bottom of a basic block is the union of: // The liveness at the bottom of a basic block is the union of:
// 1. The liveness of each succecessor right after its phis and // 1. The liveness of each succecessor right after its phis and
// reduces. // reduces.
...@@ -1041,7 +1083,9 @@ fn liveness_dataflow( ...@@ -1041,7 +1083,9 @@ fn liveness_dataflow(
assert_eq!(control, succ); assert_eq!(control, succ);
if succ == *bb { if succ == *bb {
new_value.insert(reduct); new_value.insert(reduct);
} else { } else if !function.schedules[id.idx()]
.contains(&Schedule::ParallelReduce)
{
new_value.insert(init); new_value.insert(init);
} }
} }
...@@ -1058,6 +1102,7 @@ fn liveness_dataflow( ...@@ -1058,6 +1102,7 @@ fn liveness_dataflow(
let mut new_value = liveness[&bb][pt + 1].clone(); let mut new_value = liveness[&bb][pt + 1].clone();
let id = bbs.1[bb.idx()][pt]; let id = bbs.1[bb.idx()][pt];
let uses = get_uses(&function.nodes[id.idx()]); let uses = get_uses(&function.nodes[id.idx()]);
let is_obj = |id: &NodeID| is_obj(*id);
new_value.remove(&id); new_value.remove(&id);
new_value.extend( new_value.extend(
if let Node::Write { if let Node::Write {
...@@ -1070,14 +1115,19 @@ fn liveness_dataflow( ...@@ -1070,14 +1115,19 @@ fn liveness_dataflow(
// If this write is a cloning write, the `collect` input // If this write is a cloning write, the `collect` input
// isn't actually live, because its value doesn't // isn't actually live, because its value doesn't
// matter. // matter.
Either::Left(once(data).filter(|id| is_obj(*id))) Either::Left(once(data).filter(is_obj))
} else if let Node::Reduce {
control: _,
init: _,
reduct,
} = function.nodes[id.idx()]
&& function.schedules[id.idx()].contains(&Schedule::ParallelReduce)
{
// If this reduce is a parallel reduce, the `init` input
// isn't actually live.
Either::Left(once(reduct).filter(is_obj))
} else { } else {
Either::Right( Either::Right(uses.as_ref().into_iter().map(|id| *id).filter(is_obj))
uses.as_ref()
.into_iter()
.map(|id| *id)
.filter(|id| is_obj(*id)),
)
}, },
); );
changed |= *old_value != new_value; changed |= *old_value != new_value;
......
...@@ -62,7 +62,7 @@ fn test4(input : i32) -> i32[4, 4] { ...@@ -62,7 +62,7 @@ fn test4(input : i32) -> i32[4, 4] {
#[entry] #[entry]
fn test5(input : i32) -> i32[4] { fn test5(input : i32) -> i32[4] {
let arr1 : i32[4]; @cons let arr1 : i32[4];
for i = 0 to 4 { for i = 0 to 4 {
let red = arr1[i]; let red = arr1[i];
for k = 0 to 3 { for k = 0 to 3 {
......
no-memset(test5@cons);
parallel-reduce(test5@reduce); parallel-reduce(test5@reduce);
gvn(*); gvn(*);
......
...@@ -46,6 +46,6 @@ fn main() { ...@@ -46,6 +46,6 @@ fn main() {
} }
#[test] #[test]
fn implicit_clone_test() { fn fork_join_test() {
main(); main();
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment