use std::collections::{BTreeSet, HashMap, HashSet}; use std::iter::once; use hercules_ir::ir::*; use crate::*; /* * Infer parallel fork-joins. These are fork-joins with only parallel reduction * variables. */ pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_fork() { continue; } let join_id = fork_join_map[&id]; let all_parallel_reduce = editor.get_users(join_id).all(|user| { func.schedules[user.idx()].contains(&Schedule::ParallelReduce) || func.nodes[user.idx()].is_control() }); if all_parallel_reduce { editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelFork)); } } } /* * Infer parallel reductions consisting of a simple cycle between a Reduce node * and a Write node, where indices of the Write are position indices using the * ThreadID nodes attached to the corresponding Fork, and data of the Write is * not in the Reduce node's cycle. This procedure also adds the ParallelReduce * schedule to Reduce nodes reducing over a parallelized Reduce, as long as the * base Write node also has position indices of the ThreadID of the outer fork. * In other words, the complete Reduce chain is annotated with ParallelReduce, * as long as each ThreadID dimension appears in the positional indexing of the * original Write. */ pub fn infer_parallel_reduce( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { let join_fork_map: HashMap<_, _> = fork_join_map .into_iter() .map(|(fork, join)| (*join, *fork)) .collect(); for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_reduce() { continue; } let mut first_control = None; let mut last_reduce = id; let mut chain_id = id; // Walk down Reduce chain until we reach the Reduce potentially looping // with the Write. Note the control node of the first Reduce, since this // will tell us which Thread ID to look for in the Write. while let Node::Reduce { control, init: _, reduct, } = func.nodes[chain_id.idx()] { if first_control.is_none() { first_control = Some(control); } last_reduce = chain_id; chain_id = reduct; } // If the use is a phi that uses the reduce and a write, then we might // want to parallelize this still. Set the chain ID to the write. if let Node::Phi { control: _, ref data, } = func.nodes[chain_id.idx()] && data.len() == data .into_iter() .filter(|phi_use| **phi_use == last_reduce) .count() + 1 { chain_id = *data .into_iter() .filter(|phi_use| **phi_use != last_reduce) .next() .unwrap(); } // Check for a Write-Reduce tight cycle. if let Node::Write { collect, data, indices, } = &func.nodes[chain_id.idx()] && *collect == last_reduce && !reduce_cycles[&last_reduce].contains(data) { let is_parallel = indices_parallel_over_forks( editor, indices, once(join_fork_map[&first_control.unwrap()]), ); if is_parallel { editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce)); } } } } /* * Infer monoid reduction loops. Exactly one of the associative operation's * operands must be the Reduce node, and all other operands must not be in the * Reduce node's cycle. */ #[rustfmt::skip] pub fn infer_monoid_reduce( editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { let is_binop_monoid = |op| { op == BinaryOperator::Add || op == BinaryOperator::Mul || op == BinaryOperator::Or || op == BinaryOperator::And }; let is_intrinsic_monoid = |intrinsic| intrinsic == Intrinsic::Max || intrinsic == Intrinsic::Min; for id in editor.node_ids() { let func = editor.func(); if let Node::Reduce { control: _, init: _, reduct, } = func.nodes[id.idx()] && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } if ((left == id && !reduce_cycles[&id].contains(&right)) || (right == id && !reduce_cycles[&id].contains(&left))) && is_binop_monoid(op)) || matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } if (args.contains(&id) && is_intrinsic_monoid(*intrinsic) && args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg))))) { editor.edit(|edit| edit.add_schedule(id, Schedule::MonoidReduce)); } } } /* * From analysis result of which constants don't need to be reset, add schedules * to those constant nodes. */ pub fn infer_no_reset_constants( editor: &mut FunctionEditor, no_reset_constants: &BTreeSet<NodeID>, ) { for id in no_reset_constants { editor.edit(|edit| edit.add_schedule(*id, Schedule::NoResetConstant)); } }