diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index 3fcc6af029a50839c0b382be371db6fa593e1119..bba6ac42a4a2479f8309b575fe1fb1030f5b5a21 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -85,8 +85,9 @@ pub fn compute_fork_join_nesting( */ pub fn reduce_cycles( function: &Function, + def_use: &ImmutableDefUseMap, fork_join_map: &HashMap<NodeID, NodeID>, - fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> HashMap<NodeID, HashSet<NodeID>> { let reduces = (0..function.nodes.len()) .filter(|idx| function.nodes[*idx].is_reduce()) @@ -101,69 +102,50 @@ pub fn reduce_cycles( let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap(); let fork = join_fork_map[&join]; - // DFS the uses of `reduct` until finding the reduce itself. - let mut current_visited = HashSet::new(); - let mut in_cycle = HashSet::new(); - reduce_cycle_dfs_helper( - function, - reduct, - fork, - reduce, - &mut current_visited, - &mut in_cycle, - fork_join_nest, - ); - result.insert(reduce, in_cycle); - } + // Find nodes in the fork-join that the reduce can reach through uses. + let mut reachable_uses = HashSet::new(); + let mut workset = vec![]; + reachable_uses.insert(reduct); + workset.push(reduct); + while let Some(pop) = workset.pop() { + for u in get_uses(&function.nodes[pop.idx()]).as_ref() { + if !reachable_uses.contains(u) + && nodes_in_fork_joins[&fork].contains(u) + && *u != reduce + { + reachable_uses.insert(*u); + workset.push(*u); + } + } + } - result -} + // Find nodes in the fork-join that the reduce can reach through users. + let mut reachable_users = HashSet::new(); + workset.clear(); + reachable_users.insert(reduce); + workset.push(reduce); + while let Some(pop) = workset.pop() { + for u in def_use.get_users(pop) { + if !reachable_users.contains(u) + && nodes_in_fork_joins[&fork].contains(u) + && *u != reduce + { + reachable_users.insert(*u); + workset.push(*u); + } + } + } -fn reduce_cycle_dfs_helper( - function: &Function, - iter: NodeID, - fork: NodeID, - reduce: NodeID, - current_visited: &mut HashSet<NodeID>, - in_cycle: &mut HashSet<NodeID>, - fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, -) -> bool { - let isnt_outside_fork_join = |id: NodeID| { - let node = &function.nodes[id.idx()]; - node.try_phi() - .map(|(control, _)| control) - .or(node.try_reduce().map(|(control, _, _)| control)) - .map(|control| fork_join_nest[&control].contains(&fork)) - .unwrap_or(true) - }; - - if iter == reduce || in_cycle.contains(&iter) { - return true; + // The reduce cycle is the insersection of nodes reachable through uses + // and users. + let intersection = reachable_uses + .intersection(&reachable_users) + .map(|id| *id) + .collect(); + result.insert(reduce, intersection); } - current_visited.insert(iter); - let mut found_reduce = false; - - // This doesn't short circuit on purpose. - for u in get_uses(&function.nodes[iter.idx()]).as_ref() { - found_reduce |= !current_visited.contains(u) - && !function.nodes[u.idx()].is_control() - && isnt_outside_fork_join(*u) - && reduce_cycle_dfs_helper( - function, - *u, - fork, - reduce, - current_visited, - in_cycle, - fork_join_nest, - ) - } - if found_reduce { - in_cycle.insert(iter); - } - current_visited.remove(&iter); - found_reduce + result } /* diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 9c51276b82aa52fde5aa8dd13295b69aea7fb064..8c2ecb1969c1dd8a340a0f4bb54fcb61b6fc852e 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -399,18 +399,23 @@ impl PassManager { pub fn make_reduce_cycles(&mut self) { if self.reduce_cycles.is_none() { + self.make_def_uses(); self.make_fork_join_maps(); - self.make_fork_join_nests(); + self.make_nodes_in_fork_joins(); + let def_uses = self.def_uses.as_ref().unwrap().iter(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); - let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); + let nodes_in_fork_joins = self.nodes_in_fork_joins.as_ref().unwrap().iter(); self.reduce_cycles = Some( self.functions .iter() .zip(fork_join_maps) - .zip(fork_join_nests) - .map(|((function, fork_join_map), fork_join_nest)| { - reduce_cycles(function, fork_join_map, fork_join_nest) - }) + .zip(nodes_in_fork_joins) + .zip(def_uses) + .map( + |(((function, fork_join_map), nodes_in_fork_joins), def_use)| { + reduce_cycles(function, def_use, fork_join_map, nodes_in_fork_joins) + }, + ) .collect(), ); }