Skip to content
Snippets Groups Projects

Fix reduce cycles and nodes in fork joins

Merged rarbore2 requested to merge fix-juno_matmul-schedule into main
6 files
+ 63
104
Compare changes
  • Side-by-side
  • Inline
Files
6
@@ -87,7 +87,7 @@ pub fn reduce_cycles(
function: &Function,
def_use: &ImmutableDefUseMap,
fork_join_map: &HashMap<NodeID, NodeID>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let reduces = (0..function.nodes.len())
.filter(|idx| function.nodes[*idx].is_reduce())
@@ -101,6 +101,24 @@ pub fn reduce_cycles(
for reduce in reduces {
let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap();
let fork = join_fork_map[&join];
let might_be_in_fork_join = |id| {
fork_join_nest
.get(&id)
.map(|nest| nest.contains(&fork))
.unwrap_or(true)
&& function.nodes[id.idx()]
.try_phi()
.map(|(control, _)| fork_join_nest[&control].contains(&fork))
.unwrap_or(true)
&& function.nodes[id.idx()]
.try_reduce()
.map(|(control, _, _)| fork_join_nest[&control].contains(&fork))
.unwrap_or(true)
&& !function.nodes[id.idx()].is_parameter()
&& !function.nodes[id.idx()].is_constant()
&& !function.nodes[id.idx()].is_dynamic_constant()
&& !function.nodes[id.idx()].is_undef()
};
// Find nodes in the fork-join that the reduce can reach through uses.
let mut reachable_uses = HashSet::new();
@@ -109,10 +127,7 @@ pub fn reduce_cycles(
workset.push(reduct);
while let Some(pop) = workset.pop() {
for u in get_uses(&function.nodes[pop.idx()]).as_ref() {
if !reachable_uses.contains(u)
&& nodes_in_fork_joins[&fork].contains(u)
&& *u != reduce
{
if !reachable_uses.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
reachable_uses.insert(*u);
workset.push(*u);
}
@@ -126,10 +141,7 @@ pub fn reduce_cycles(
workset.push(reduce);
while let Some(pop) = workset.pop() {
for u in def_use.get_users(pop) {
if !reachable_users.contains(u)
&& nodes_in_fork_joins[&fork].contains(u)
&& *u != reduce
{
if !reachable_users.contains(u) && might_be_in_fork_join(*u) && *u != reduce {
reachable_users.insert(*u);
workset.push(*u);
}
@@ -155,6 +167,7 @@ pub fn nodes_in_fork_joins(
function: &Function,
def_use: &ImmutableDefUseMap,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> HashMap<NodeID, HashSet<NodeID>> {
let mut result = HashMap::new();
@@ -164,6 +177,7 @@ pub fn nodes_in_fork_joins(
let mut set = HashSet::new();
set.insert(*fork);
// Iterate uses of the fork.
while let Some(item) = worklist.pop() {
for u in def_use.get_users(item) {
let terminate = *u == *join
@@ -177,6 +191,15 @@ pub fn nodes_in_fork_joins(
set.insert(*u);
}
}
assert!(set.contains(join));
// Add all the nodes in the reduce cycle. Some of these nodes may not
// use thread IDs of the fork, so do this explicitly.
for u in def_use.get_users(*join) {
if let Some(cycle) = reduce_cycles.get(u) {
set.extend(cycle);
}
}
result.insert(*fork, set);
}
Loading