use std::collections::{BTreeSet, HashMap, HashSet}; use hercules_ir::def_use::*; use hercules_ir::ir::*; use crate::*; /* * Infer parallel fork-joins. These are fork-joins with only parallel reduction * variables. */ pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_fork() { continue; } let join_id = fork_join_map[&id]; let all_parallel_reduce = editor.get_users(join_id).all(|user| { func.schedules[user.idx()].contains(&Schedule::ParallelReduce) || func.nodes[user.idx()].is_control() }); if all_parallel_reduce { editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelFork)); } } } /* * Infer parallel reductions consisting of a simple cycle between a Reduce node * and a Write node, where indices of the Write are position indices using the * ThreadID nodes attached to the corresponding Fork, and data of the Write is * not in the Reduce node's cycle. This procedure also adds the ParallelReduce * schedule to Reduce nodes reducing over a parallelized Reduce, as long as the * base Write node also has position indices of the ThreadID of the outer fork. * In other words, the complete Reduce chain is annotated with ParallelReduce, * as long as each ThreadID dimension appears in the positional indexing of the * original Write. */ pub fn infer_parallel_reduce( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_reduce() { continue; } let mut first_control = None; let mut last_reduce = id; let mut chain_id = id; // Walk down Reduce chain until we reach the Reduce potentially looping // with the Write. Note the control node of the first Reduce, since this // will tell us which Thread ID to look for in the Write. while let Node::Reduce { control, init: _, reduct, } = func.nodes[chain_id.idx()] { if first_control.is_none() { first_control = Some(control); } last_reduce = chain_id; chain_id = reduct; } // Check for a Write-Reduce tight cycle. if let Node::Write { collect, data, indices, } = &func.nodes[chain_id.idx()] && *collect == last_reduce && !reduce_cycles[&last_reduce].contains(data) { // If there is a Write-Reduce tight cycle, get the position indices. let positions = indices .iter() .filter_map(|index| { if let Index::Position(indices) = index { Some(indices) } else { None } }) .flat_map(|pos| pos.iter()); // Get the Forks corresponding to uses of bare ThreadIDs. let fork_thread_id_pairs = positions.filter_map(|id| { if let Node::ThreadID { control, dimension } = func.nodes[id.idx()] { Some((control, dimension)) } else { None } }); let mut forks = HashMap::<NodeID, Vec<usize>>::new(); for (fork, dim) in fork_thread_id_pairs { forks.entry(fork).or_default().push(dim); } // Check if one of the Forks correspond to the Join associated with // the Reduce being considered, and has all of its dimensions // represented in the indexing. let is_parallel = forks.into_iter().any(|(id, mut rep_dims)| { rep_dims.sort(); rep_dims.dedup(); fork_join_map[&id] == first_control.unwrap() && func.nodes[id.idx()].try_fork().unwrap().1.len() == rep_dims.len() }); if is_parallel { editor.edit(|edit| edit.add_schedule(id, Schedule::ParallelReduce)); } } } } /* * Infer vectorizable fork-joins. Just check that there are no control nodes * between a fork and its join and the factor is a constant. */ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_join() { continue; } let u = get_uses(&func.nodes[id.idx()]).as_ref()[0]; if let Some(join) = fork_join_map.get(&u) && *join == id { let factors = func.nodes[u.idx()].try_fork().unwrap().1; if factors.len() == 1 && evaluate_dynamic_constant(factors[0], &editor.get_dynamic_constants()).is_some() { editor.edit(|edit| edit.add_schedule(u, Schedule::Vectorizable)); } } } } /* * Infer tight associative reduction loops. Exactly one of the associative * operation's operands must be the Reduce node, and all other operands must * not be in the Reduce node's cycle. */ pub fn infer_tight_associative( editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) { let is_binop_associative = |op| { matches!( op, BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor ) }; let is_intrinsic_associative = |intrinsic| matches!(intrinsic, Intrinsic::Max | Intrinsic::Min); for id in editor.node_ids() { let func = editor.func(); if let Node::Reduce { control: _, init: _, reduct, } = func.nodes[id.idx()] && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } if ((left == id && !reduce_cycles[&id].contains(&right)) || (right == id && !reduce_cycles[&id].contains(&left))) && is_binop_associative(op)) || matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } if (args.contains(&id) && is_intrinsic_associative(*intrinsic) && args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg))))) { editor.edit(|edit| edit.add_schedule(id, Schedule::TightAssociative)); } } } /* * From analysis result of which constants don't need to be reset, add schedules * to those constant nodes. */ pub fn infer_no_reset_constants( editor: &mut FunctionEditor, no_reset_constants: &BTreeSet<NodeID>, ) { for id in no_reset_constants { editor.edit(|edit| edit.add_schedule(*id, Schedule::NoResetConstant)); } }