diff --git a/hercules_ir/src/loops.rs b/hercules_ir/src/loops.rs index 1d706c7834cf30fa3bf5e556d812917942a48d8b..f7bd54684214bd4651c5799ae3d21660d48305fe 100644 --- a/hercules_ir/src/loops.rs +++ b/hercules_ir/src/loops.rs @@ -233,14 +233,21 @@ fn loop_reachability_helper( pub fn reduce_cycles( function: &Function, def_use: &ImmutableDefUseMap, + fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, ) -> HashMap<NodeID, HashSet<NodeID>> { let reduces = (0..function.nodes.len()) .filter(|idx| function.nodes[*idx].is_reduce()) .map(NodeID::new); let mut result = HashMap::new(); + let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); for reduce in reduces { - let (_, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap(); + let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap(); + let fork = join_fork_map[&join]; // First, find all data nodes that are used by the `reduct` input of the // reduce, including the `reduct` itself. @@ -249,7 +256,13 @@ pub fn reduce_cycles( let mut worklist = vec![reduct]; while let Some(item) = worklist.pop() { for u in get_uses(&function.nodes[item.idx()]).as_ref() { - if !function.nodes[u.idx()].is_control() && !use_reachable.contains(u) { + if !function.nodes[u.idx()].is_control() + && !use_reachable.contains(u) + && function.nodes[u.idx()] + .try_phi() + .map(|(control, _)| fork_join_nest[&fork].contains(&control)) + .unwrap_or(true) + { use_reachable.insert(*u); worklist.push(*u); } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index aa54006491d9ee11b7cb92c22d9c507104a546cb..570aa3f12f74cf4b92fa8a22045a91f543634e15 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -354,10 +354,20 @@ impl PassManager { pub fn make_reduce_cycles(&mut self) { if self.reduce_cycles.is_none() { self.make_def_uses(); + self.make_fork_join_maps(); + self.make_fork_join_nests(); let def_uses = self.def_uses.as_ref().unwrap().iter(); + let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); + let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); self.reduce_cycles = Some( - zip(self.functions.iter(), def_uses) - .map(|(function, def_use)| reduce_cycles(function, def_use)) + self.functions + .iter() + .zip(def_uses) + .zip(fork_join_maps) + .zip(fork_join_nests) + .map(|(((function, def_use), fork_join_map), fork_join_nest)| { + reduce_cycles(function, def_use, fork_join_map, fork_join_nest) + }) .collect(), ); }