diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 3b35f73ed370984753fa8e416f3b26c24951283f..1bfc01fef2885dbb2b6b6c5fdbda97f594e32010 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -126,7 +126,7 @@ impl<'a> RTContext<'a> { // Dump the function signature. write!( w, - "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]\nasync unsafe fn {}(", + "#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]\nasync unsafe fn {}(", func.name )?; let mut first_param = true; diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index 1d089d7630ce1b90e549e5e2ddf96dcfa2678995..263fa952f834ae82ba2bd0b35ea2ff8d47a56c25 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -73,3 +73,132 @@ pub fn compute_fork_join_nesting( }) .collect() } + +/* + * Top level function to calculate reduce cycles. Returns for each reduce node + * what other nodes form a cycle with that reduce node. The strict definition of + * a reduce cycle is the iterated set of uses of the `reduct` input to the + * reduce that through iterated uses use the reduce, without going through phis + * or reduces outside the fork-join. + */ +pub fn reduce_cycles( + function: &Function, + fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, +) -> HashMap<NodeID, HashSet<NodeID>> { + let reduces = (0..function.nodes.len()) + .filter(|idx| function.nodes[*idx].is_reduce()) + .map(NodeID::new); + let mut result = HashMap::new(); + let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); + + for reduce in reduces { + let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap(); + let fork = join_fork_map[&join]; + + // DFS the uses of `reduct` until finding the reduce itself. + let mut current_visited = HashSet::new(); + let mut in_cycle = HashSet::new(); + reduce_cycle_dfs_helper( + function, + reduct, + fork, + reduce, + &mut current_visited, + &mut in_cycle, + fork_join_nest, + ); + result.insert(reduce, in_cycle); + } + + result +} + +fn reduce_cycle_dfs_helper( + function: &Function, + iter: NodeID, + fork: NodeID, + reduce: NodeID, + current_visited: &mut HashSet<NodeID>, + in_cycle: &mut HashSet<NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, +) -> bool { + let isnt_outside_fork_join = |id: NodeID| { + let node = &function.nodes[id.idx()]; + node.try_phi() + .map(|(control, _)| control) + .or(node.try_reduce().map(|(control, _, _)| control)) + .map(|control| fork_join_nest[&control].contains(&fork)) + .unwrap_or(true) + }; + + if iter == reduce || in_cycle.contains(&iter) { + return true; + } + + current_visited.insert(iter); + let found_reduce = get_uses(&function.nodes[iter.idx()]) + .as_ref() + .into_iter() + .any(|u| { + !current_visited.contains(u) + && !function.nodes[u.idx()].is_control() + && isnt_outside_fork_join(*u) + && reduce_cycle_dfs_helper( + function, + *u, + fork, + reduce, + current_visited, + in_cycle, + fork_join_nest, + ) + }); + if found_reduce { + in_cycle.insert(iter); + } + current_visited.remove(&iter); + found_reduce +} + +/* + * Top level function to calculate which data nodes are "inside" a fork-join, + * not including its reduces. + */ +pub fn data_nodes_in_fork_joins( + function: &Function, + def_use: &ImmutableDefUseMap, + fork_join_map: &HashMap<NodeID, NodeID>, +) -> HashMap<NodeID, HashSet<NodeID>> { + let mut result = HashMap::new(); + + for (fork, join) in fork_join_map { + let mut worklist = vec![*fork]; + let mut set = HashSet::new(); + + while let Some(item) = worklist.pop() { + for u in def_use.get_users(item) { + if function.nodes[u.idx()].is_control() + || function.nodes[u.idx()] + .try_reduce() + .map(|(control, _, _)| control == *join) + .unwrap_or(false) + { + // Ignore control users and reduces of the fork-join. + continue; + } + if !set.contains(u) { + set.insert(*u); + worklist.push(*u); + } + } + } + + result.insert(*fork, set); + } + + result +} diff --git a/hercules_ir/src/loops.rs b/hercules_ir/src/loops.rs index 9a964cce71828cf18f943ae55e39c9eca63215e4..a425c4428eb627f6cbba9ecf38a200023c27ef36 100644 --- a/hercules_ir/src/loops.rs +++ b/hercules_ir/src/loops.rs @@ -1,6 +1,6 @@ use std::collections::hash_map; +use std::collections::HashMap; use std::collections::VecDeque; -use std::collections::{HashMap, HashSet}; use bitvec::prelude::*; @@ -225,119 +225,3 @@ fn loop_reachability_helper( visited } } - -/* - * Top level function to calculate reduce cycles. Returns for each reduce node - * what other nodes form a cycle with that reduce node. - */ -pub fn reduce_cycles( - function: &Function, - def_use: &ImmutableDefUseMap, - fork_join_map: &HashMap<NodeID, NodeID>, - fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, -) -> HashMap<NodeID, HashSet<NodeID>> { - let reduces = (0..function.nodes.len()) - .filter(|idx| function.nodes[*idx].is_reduce()) - .map(NodeID::new); - let mut result = HashMap::new(); - let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map - .into_iter() - .map(|(fork, join)| (*join, *fork)) - .collect(); - - for reduce in reduces { - let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap(); - let fork = join_fork_map[&join]; - let isnt_outside_fork_join = |id: NodeID| { - let node = &function.nodes[id.idx()]; - node.try_phi() - .map(|(control, _)| control) - .or(node.try_reduce().map(|(control, _, _)| control)) - .map(|control| fork_join_nest[&fork].contains(&control)) - .unwrap_or(true) - }; - - // First, find all data nodes that are used by the `reduct` input of the - // reduce, including the `reduct` itself. - let mut use_reachable = HashSet::new(); - use_reachable.insert(reduct); - let mut worklist = vec![reduct]; - while let Some(item) = worklist.pop() { - for u in get_uses(&function.nodes[item.idx()]).as_ref() { - if !function.nodes[u.idx()].is_control() - && !use_reachable.contains(u) - && isnt_outside_fork_join(*u) - { - use_reachable.insert(*u); - worklist.push(*u); - } - } - } - - // Second, find all data nodes thare are users of the reduce node. - let mut user_reachable = HashSet::new(); - let mut worklist = vec![reduce]; - while let Some(item) = worklist.pop() { - for u in def_use.get_users(item) { - if !function.nodes[u.idx()].is_control() - && !user_reachable.contains(u) - && isnt_outside_fork_join(*u) - { - user_reachable.insert(*u); - worklist.push(*u); - } - } - } - - // Nodes that are both use-reachable and user-reachable by the reduce - // node are in the reduce node's cycle. - result.insert( - reduce, - use_reachable - .intersection(&user_reachable) - .map(|id| *id) - .collect(), - ); - } - - result -} - -/* - * Top level function to calculate which data nodes are "inside" a fork-join, - * not including its reduces. - */ -pub fn data_nodes_in_fork_joins( - function: &Function, - def_use: &ImmutableDefUseMap, - fork_join_map: &HashMap<NodeID, NodeID>, -) -> HashMap<NodeID, HashSet<NodeID>> { - let mut result = HashMap::new(); - - for (fork, join) in fork_join_map { - let mut worklist = vec![*fork]; - let mut set = HashSet::new(); - - while let Some(item) = worklist.pop() { - for u in def_use.get_users(item) { - if function.nodes[u.idx()].is_control() - || function.nodes[u.idx()] - .try_reduce() - .map(|(control, _, _)| control == *join) - .unwrap_or(false) - { - // Ignore control users and reduces of the fork-join. - continue; - } - if !set.contains(u) { - set.insert(*u); - worklist.push(*u); - } - } - } - - result.insert(*fork, set); - } - - result -} diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 570aa3f12f74cf4b92fa8a22045a91f543634e15..e62bc78dc4a6e4bb2e1c369fcd2fffc67255ea52 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -353,20 +353,17 @@ impl PassManager { pub fn make_reduce_cycles(&mut self) { if self.reduce_cycles.is_none() { - self.make_def_uses(); self.make_fork_join_maps(); self.make_fork_join_nests(); - let def_uses = self.def_uses.as_ref().unwrap().iter(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); self.reduce_cycles = Some( self.functions .iter() - .zip(def_uses) .zip(fork_join_maps) .zip(fork_join_nests) - .map(|(((function, def_use), fork_join_map), fork_join_nest)| { - reduce_cycles(function, def_use, fork_join_map, fork_join_nest) + .map(|((function, fork_join_map), fork_join_nest)| { + reduce_cycles(function, fork_join_map, fork_join_nest) }) .collect(), );