diff --git a/hercules_ir/src/loops.rs b/hercules_ir/src/loops.rs index f7bd54684214bd4651c5799ae3d21660d48305fe..9a964cce71828cf18f943ae55e39c9eca63215e4 100644 --- a/hercules_ir/src/loops.rs +++ b/hercules_ir/src/loops.rs @@ -248,6 +248,14 @@ pub fn reduce_cycles( for reduce in reduces { let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap(); let fork = join_fork_map[&join]; + let isnt_outside_fork_join = |id: NodeID| { + let node = &function.nodes[id.idx()]; + node.try_phi() + .map(|(control, _)| control) + .or(node.try_reduce().map(|(control, _, _)| control)) + .map(|control| fork_join_nest[&fork].contains(&control)) + .unwrap_or(true) + }; // First, find all data nodes that are used by the `reduct` input of the // reduce, including the `reduct` itself. @@ -258,10 +266,7 @@ pub fn reduce_cycles( for u in get_uses(&function.nodes[item.idx()]).as_ref() { if !function.nodes[u.idx()].is_control() && !use_reachable.contains(u) - && function.nodes[u.idx()] - .try_phi() - .map(|(control, _)| fork_join_nest[&fork].contains(&control)) - .unwrap_or(true) + && isnt_outside_fork_join(*u) { use_reachable.insert(*u); worklist.push(*u); @@ -274,7 +279,10 @@ pub fn reduce_cycles( let mut worklist = vec![reduce]; while let Some(item) = worklist.pop() { for u in def_use.get_users(item) { - if !function.nodes[u.idx()].is_control() && !user_reachable.contains(u) { + if !function.nodes[u.idx()].is_control() + && !user_reachable.contains(u) + && isnt_outside_fork_join(*u) + { user_reachable.insert(*u); worklist.push(*u); }