diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index 9bc7823ee7f5837cf49387170e548a9174340f42..10eca72e4b9c9aa6285f04cc8812d58f459d556f 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -1,6 +1,6 @@ use std::collections::{BTreeSet, HashMap, HashSet}; +use std::iter::once; -use hercules_ir::def_use::*; use hercules_ir::ir::*; use crate::*; @@ -42,6 +42,10 @@ pub fn infer_parallel_reduce( fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { + let join_fork_map: HashMap<_, _> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_reduce() { @@ -98,40 +102,11 @@ pub fn infer_parallel_reduce( && *collect == last_reduce && !reduce_cycles[&last_reduce].contains(data) { - // If there is a Write-Reduce tight cycle, get the position indices. - let positions = indices - .iter() - .filter_map(|index| { - if let Index::Position(indices) = index { - Some(indices) - } else { - None - } - }) - .flat_map(|pos| pos.iter()); - - // Get the Forks corresponding to uses of bare ThreadIDs. - let fork_thread_id_pairs = positions.filter_map(|id| { - if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] { - Some((control, dimension)) - } else { - None - } - }); - let mut forks = HashMap::<NodeID, Vec<usize>>::new(); - for (fork, dim) in fork_thread_id_pairs { - forks.entry(fork).or_default().push(dim); - } - - // Check if one of the Forks correspond to the Join associated with - // the Reduce being considered, and has all of its dimensions - // represented in the indexing. - let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| { - rep_dims.sort(); - rep_dims.dedup(); - fork_join_map[&id] == first_control.unwrap() - && func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len() - }); + let is_parallel = indices_parallel_over_forks( + editor, + indices, + once(join_fork_map[&first_control.unwrap()]), + ); if is_parallel { editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce)); @@ -145,6 +120,7 @@ pub fn infer_parallel_reduce( * operands must be the Reduce node, and all other operands must not be in the * Reduce node's cycle. */ +#[rustfmt::skip] pub fn infer_monoid_reduce( editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, diff --git a/juno_samples/rodinia/backprop/src/cpu.sch b/juno_samples/rodinia/backprop/src/cpu.sch index d59fd5f582dbb33ce5d786896b17673e0986776f..de34d660bcc5e3d95d58aa63524bffdbc0b8f67e 100644 --- a/juno_samples/rodinia/backprop/src/cpu.sch +++ b/juno_samples/rodinia/backprop/src/cpu.sch @@ -23,6 +23,7 @@ fixpoint { fork-guard-elim(*); fork-coalesce(*); } +reduce-slf(*); simpl!(*); fork-split(*);