schedule.rs 6.05 KiB
use std::collections::{BTreeSet, HashMap, HashSet};
use hercules_ir::def_use::*;
use hercules_ir::ir::*;
use crate::*;
/*
* Infer parallel fork-joins. These are fork-joins with only parallel reduction
* variables.
*/
pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
for id in editor.node_ids() {
let func = editor.func();
if !func.nodes[id.idx()].is_fork() {
continue;
}
let join_id = fork_join_map[&id];
let all_parallel_reduce = editor.get_users(join_id).all(|user| {
func.schedules[user.idx()].contains(&Schedule::ParallelReduce)
|| func.nodes[user.idx()].is_control()
});
if all_parallel_reduce {
editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelFork));
}
}
}
/*
* Infer parallel reductions consisting of a simple cycle between a Reduce node
* and a Write node, where indices of the Write are position indices using the
* ThreadID nodes attached to the corresponding Fork, and data of the Write is
* not in the Reduce node's cycle. This procedure also adds the ParallelReduce
* schedule to Reduce nodes reducing over a parallelized Reduce, as long as the
* base Write node also has position indices of the ThreadID of the outer fork.
* In other words, the complete Reduce chain is annotated with ParallelReduce,
* as long as each ThreadID dimension appears in the positional indexing of the
* original Write.
*/
pub fn infer_parallel_reduce(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
for id in editor.node_ids() {
let func = editor.func();
if !func.nodes[id.idx()].is_reduce() {
continue;
}
let mut first_control = None;
let mut last_reduce = id;
let mut chain_id = id;
// Walk down Reduce chain until we reach the Reduce potentially looping
// with the Write. Note the control node of the first Reduce, since this
// will tell us which Thread ID to look for in the Write.
while let Node::Reduce {
control,
init: _,
reduct,
} = func.nodes[chain_id.idx()]
{
if first_control.is_none() {
first_control = Some(control);
}
last_reduce = chain_id;
chain_id = reduct;
}
// Check for a Write-Reduce tight cycle.
if let Node::Write {
collect,
data,
indices,
} = &func.nodes[chain_id.idx()]
&& *collect == last_reduce
&& !reduce_cycles[&last_reduce].contains(data)
{
// If there is a Write-Reduce tight cycle, get the position indices.
let positions = indices
.iter()
.filter_map(|index| {
if let Index::Position(indices) = index {
Some(indices)
} else {
None
}
})
.flat_map(|pos| pos.iter());
// Get the Forks corresponding to uses of bare ThreadIDs.
let fork_thread_id_pairs = positions.filter_map(|id| {
if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] {
Some((control, dimension))
} else {
None
}
});
let mut forks = HashMap::<NodeID, Vec<usize>>::new();
for (fork, dim) in fork_thread_id_pairs {
forks.entry(fork).or_default().push(dim);
}
// Check if one of the Forks correspond to the Join associated with
// the Reduce being considered, and has all of its dimensions
// represented in the indexing.
let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| {
rep_dims.sort();
rep_dims.dedup();
fork_join_map[&id] == first_control.unwrap()
&& func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len()
});
if is_parallel {
editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce));
}
}
}
}
/*
* Infer monoid reduction loops. Exactly one of the associative operation's
* operands must be the Reduce node, and all other operands must not be in the
* Reduce node's cycle.
*/
pub fn infer_monoid_reduce(
editor: &mut FunctionEditor,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) {
let is_binop_monoid = |op| {
matches!(
op,
BinaryOperator::Add | BinaryOperator::Mul | BinaryOperator::Or | BinaryOperator::And
)
};
let is_intrinsic_monoid = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min);
for id in editor.node_ids() {
let func = editor.func();
if let Node::Reduce {
control: _,
init: _,
reduct,
} = func.nodes[id.idx()]
&& (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op }
if ((left == id && !reduce_cycles[&id].contains(&right)) ||
(right == id && !reduce_cycles[&id].contains(&left))) &&
is_binop_monoid(op))
|| matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args }
if (args.contains(&id) && is_intrinsic_monoid(*intrinsic) &&
args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg)))))
{
editor.edit(|edit| edit.add_schedule(id, Schedule::MonoidReduce));
}
}
}
/*
* From analysis result of which constants don't need to be reset, add schedules
* to those constant nodes.
*/
pub fn infer_no_reset_constants(
editor: &mut FunctionEditor,
no_reset_constants: &BTreeSet<NodeID>,
) {
for id in no_reset_constants {
editor.edit(|edit| edit.add_schedule(*id, Schedule::NoResetConstant));
}
}