Skip to content
Snippets Groups Projects
Commit 1b78bc2e authored by Russel Arbore's avatar Russel Arbore
Browse files

Refactor infer parallelreduce

parent da2cf755
No related branches found
No related tags found
1 merge request!206More misc. rodinia opts
Pipeline #201946 canceled
This commit is part of merge request !206. Comments created here will be created in the context of that merge request.
use std::collections::{BTreeSet, HashMap, HashSet}; use std::collections::{BTreeSet, HashMap, HashSet};
use std::iter::once;
use hercules_ir::def_use::*;
use hercules_ir::ir::*; use hercules_ir::ir::*;
use crate::*; use crate::*;
...@@ -42,6 +42,10 @@ pub fn infer_parallel_reduce( ...@@ -42,6 +42,10 @@ pub fn infer_parallel_reduce(
fork_join_map: &HashMap<NodeID, NodeID>, fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) { ) {
let join_fork_map: HashMap<_, _> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
for id in editor.node_ids() { for id in editor.node_ids() {
let func = editor.func(); let func = editor.func();
if !func.nodes[id.idx()].is_reduce() { if !func.nodes[id.idx()].is_reduce() {
...@@ -98,40 +102,11 @@ pub fn infer_parallel_reduce( ...@@ -98,40 +102,11 @@ pub fn infer_parallel_reduce(
&& *collect == last_reduce && *collect == last_reduce
&& !reduce_cycles[&last_reduce].contains(data) && !reduce_cycles[&last_reduce].contains(data)
{ {
// If there is a Write-Reduce tight cycle, get the position indices. let is_parallel = indices_parallel_over_forks(
let positions = indices editor,
.iter() indices,
.filter_map(|index| { once(join_fork_map[&first_control.unwrap()]),
if let Index::Position(indices) = index { );
Some(indices)
} else {
None
}
})
.flat_map(|pos| pos.iter());
// Get the Forks corresponding to uses of bare ThreadIDs.
let fork_thread_id_pairs = positions.filter_map(|id| {
if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] {
Some((control, dimension))
} else {
None
}
});
let mut forks = HashMap::<NodeID, Vec<usize>>::new();
for (fork, dim) in fork_thread_id_pairs {
forks.entry(fork).or_default().push(dim);
}
// Check if one of the Forks correspond to the Join associated with
// the Reduce being considered, and has all of its dimensions
// represented in the indexing.
let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| {
rep_dims.sort();
rep_dims.dedup();
fork_join_map[&id] == first_control.unwrap()
&& func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len()
});
if is_parallel { if is_parallel {
editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce)); editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce));
...@@ -145,6 +120,7 @@ pub fn infer_parallel_reduce( ...@@ -145,6 +120,7 @@ pub fn infer_parallel_reduce(
* operands must be the Reduce node, and all other operands must not be in the * operands must be the Reduce node, and all other operands must not be in the
* Reduce node's cycle. * Reduce node's cycle.
*/ */
#[rustfmt::skip]
pub fn infer_monoid_reduce( pub fn infer_monoid_reduce(
editor: &mut FunctionEditor, editor: &mut FunctionEditor,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
......
...@@ -23,6 +23,7 @@ fixpoint { ...@@ -23,6 +23,7 @@ fixpoint {
fork-guard-elim(*); fork-guard-elim(*);
fork-coalesce(*); fork-coalesce(*);
} }
reduce-slf(*);
simpl!(*); simpl!(*);
fork-split(*); fork-split(*);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment