Newer
Older
use std::collections::{BTreeSet, HashMap, HashSet};
use crate::*;
/*
* Infer parallel fork-joins. These are fork-joins with only parallel reduction
* variables.
*/
pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
for id in editor.node_ids() {
let func = editor.func();
let join_id = fork_join_map[&id];
let all_parallel_reduce = editor.get_users(join_id).all(|user| {
func.schedules[user.idx()].contains(&Schedule::ParallelReduce)
|| func.nodes[user.idx()].is_control()
});
if all_parallel_reduce {
editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelFork));
}
}
}
/*
* Infer parallel reductions consisting of a simple cycle between a Reduce node
* and a Write node, where indices of the Write are position indices using the
* ThreadID nodes attached to the corresponding Fork, and data of the Write is
* not in the Reduce node's cycle. This procedure also adds the ParallelReduce
* schedule to Reduce nodes reducing over a parallelized Reduce, as long as the
* base Write node also has position indices of the ThreadID of the outer fork.
* In other words, the complete Reduce chain is annotated with ParallelReduce,
* as long as each ThreadID dimension appears in the positional indexing of the
* original Write.
pub fn infer_parallel_reduce(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
for id in editor.node_ids() {
let func = editor.func();
if !func.nodes[id.idx()].is_reduce() {
continue;
}
let mut first_control = None;
let mut last_reduce = id;
let mut chain_id = id;
// Walk down Reduce chain until we reach the Reduce potentially looping
// with the Write. Note the control node of the first Reduce, since this
// will tell us which Thread ID to look for in the Write.
while let Node::Reduce {
control,
init: _,
reduct,
} = func.nodes[chain_id.idx()]
{
if first_control.is_none() {
first_control = Some(control);
}
last_reduce = chain_id;
chain_id = reduct;
}
// If the use is a phi that uses the reduce and a write, then we might
// want to parallelize this still. Set the chain ID to the write.
if let Node::Phi {
control: _,
ref data,
} = func.nodes[chain_id.idx()]
&& data.len()
== data
.into_iter()
.filter(|phi_use| **phi_use == last_reduce)
.count()
+ 1
{
chain_id = *data
.into_iter()
.filter(|phi_use| **phi_use != last_reduce)
.next()
.unwrap();
}
// Check for a Write-Reduce tight cycle.
if let Node::Write {
collect,
indices,
} = &func.nodes[chain_id.idx()]
&& *collect == last_reduce
&& !reduce_cycles[&last_reduce].contains(data)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
{
// If there is a Write-Reduce tight cycle, get the position indices.
let positions = indices
.iter()
.filter_map(|index| {
if let Index::Position(indices) = index {
Some(indices)
} else {
None
}
})
.flat_map(|pos| pos.iter());
// Get the Forks corresponding to uses of bare ThreadIDs.
let fork_thread_id_pairs = positions.filter_map(|id| {
if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] {
Some((control, dimension))
} else {
None
}
});
let mut forks = HashMap::<NodeID, Vec<usize>>::new();
for (fork, dim) in fork_thread_id_pairs {
forks.entry(fork).or_default().push(dim);
}
// Check if one of the Forks correspond to the Join associated with
// the Reduce being considered, and has all of its dimensions
// represented in the indexing.
let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| {
rep_dims.sort();
rep_dims.dedup();
fork_join_map[&id] == first_control.unwrap()
&& func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len()
});
if is_parallel {
editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce));
}
}
}
}
/*
* Infer monoid reduction loops. Exactly one of the associative operation's
* operands must be the Reduce node, and all other operands must not be in the
* Reduce node's cycle.
editor: &mut FunctionEditor,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
op == BinaryOperator::Add
|| op == BinaryOperator::Mul
|| op == BinaryOperator::Or
|| op == BinaryOperator::And
let is_intrinsic_monoid =
|intrinsic| intrinsic == Intrinsic::Max || intrinsic == Intrinsic::Min;
for id in editor.node_ids() {
let func = editor.func();
if let Node::Reduce {
control: _,
init: _,
reduct,
} = func.nodes[id.idx()]
&& (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op }
if ((left == id && !reduce_cycles[&id].contains(&right)) ||
(right == id && !reduce_cycles[&id].contains(&left))) &&
|| matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args }
args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg)))))
editor.edit(|edit| edit.add_schedule(id, Schedule::MonoidReduce));
/*
* From analysis result of which constants don't need to be reset, add schedules
* to those constant nodes.
*/
pub fn infer_no_reset_constants(
editor: &mut FunctionEditor,
no_reset_constants: &BTreeSet<NodeID>,
) {
for id in no_reset_constants {
editor.edit(|edit| edit.add_schedule(*id, Schedule::NoResetConstant));
}
}