Skip to content
Snippets Groups Projects
Commit 1b78bc2e authored by Russel Arbore's avatar Russel Arbore
Browse files

Refactor infer parallelreduce

parent da2cf755
No related branches found
No related tags found
1 merge request!206More misc. rodinia opts
Pipeline #201946 canceled
use std::collections::{BTreeSet, HashMap, HashSet};
use std::iter::once;
use hercules_ir::def_use::*;
use hercules_ir::ir::*;
use crate::*;
......@@ -42,6 +42,10 @@ pub fn infer_parallel_reduce(
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
let join_fork_map: HashMap<_, _> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
for id in editor.node_ids() {
let func = editor.func();
if !func.nodes[id.idx()].is_reduce() {
......@@ -98,40 +102,11 @@ pub fn infer_parallel_reduce(
&& *collect == last_reduce
&& !reduce_cycles[&last_reduce].contains(data)
{
// If there is a Write-Reduce tight cycle, get the position indices.
let positions = indices
.iter()
.filter_map(|index| {
if let Index::Position(indices) = index {
Some(indices)
} else {
None
}
})
.flat_map(|pos| pos.iter());
// Get the Forks corresponding to uses of bare ThreadIDs.
let fork_thread_id_pairs = positions.filter_map(|id| {
if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] {
Some((control, dimension))
} else {
None
}
});
let mut forks = HashMap::<NodeID, Vec<usize>>::new();
for (fork, dim) in fork_thread_id_pairs {
forks.entry(fork).or_default().push(dim);
}
// Check if one of the Forks correspond to the Join associated with
// the Reduce being considered, and has all of its dimensions
// represented in the indexing.
let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| {
rep_dims.sort();
rep_dims.dedup();
fork_join_map[&id] == first_control.unwrap()
&& func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len()
});
let is_parallel = indices_parallel_over_forks(
editor,
indices,
once(join_fork_map[&first_control.unwrap()]),
);
if is_parallel {
editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce));
......@@ -145,6 +120,7 @@ pub fn infer_parallel_reduce(
* operands must be the Reduce node, and all other operands must not be in the
* Reduce node's cycle.
*/
#[rustfmt::skip]
pub fn infer_monoid_reduce(
editor: &mut FunctionEditor,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
......
......@@ -23,6 +23,7 @@ fixpoint {
fork-guard-elim(*);
fork-coalesce(*);
}
reduce-slf(*);
simpl!(*);
fork-split(*);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment