diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index b46e2dda12b41fe5b2b3905d53a8bd11387eda4d..11d23e618c7b872e43ef04e6c02e1b18348e473c 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -322,7 +322,7 @@ pub enum Schedule { Vectorizable, // This reduce can be re-associated. This may lower a sequential dependency // chain into a reduction tree. - Associative, + TightAssociative, } /* diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index bcaba374464355a8ee31254afb2bb3b1f6fa9551..3b7c81eda174a27fa544fe6dd136fcfc24cce695 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -848,8 +848,10 @@ impl PassManager { Pass::InferSchedules => { self.make_def_uses(); self.make_fork_join_maps(); + self.make_reduce_cycles(); let def_uses = self.def_uses.as_ref().unwrap(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); + let reduce_cycles = self.reduce_cycles.as_ref().unwrap(); for idx in 0..self.module.functions.len() { let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); @@ -863,10 +865,10 @@ impl PassManager { &types_ref, &def_uses[idx], ); - infer_parallel_reduce(&mut editor, &fork_join_maps[idx]); + infer_parallel_reduce(&mut editor, &fork_join_maps[idx], &reduce_cycles[idx]); infer_parallel_fork(&mut editor, &fork_join_maps[idx]); infer_vectorizable(&mut editor, &fork_join_maps[idx]); - infer_associative(&mut editor); + infer_tight_associative(&mut editor, &reduce_cycles[idx]); self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index b65b4d0481c66e6eb0a5307f9119633a6d35e3cd..ff895b1651b298cc59a8b0afbd921da2aaf04f82 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -1,6 +1,6 @@ extern crate hercules_ir; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; @@ -14,13 +14,9 @@ use crate::*; pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { for id in editor.node_ids() { let func = editor.func(); - let Node::Fork { - control: _, - factors: _, - } = func.nodes[id.idx()] - else { + if !func.nodes[id.idx()].is_fork() { continue; - }; + } let join_id = fork_join_map[&id]; let all_parallel_reduce = editor.get_users(join_id).all(|user| { func.schedules[user.idx()].contains(&Schedule::ParallelReduce) @@ -35,14 +31,15 @@ pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap< /* * Infer parallel reductions consisting of a simple cycle between a Reduce node * and a Write node, where indices of the Write are position indices using the - * ThreadID nodes attached to the corresponding Fork. This procedure also adds - * the ParallelReduce schedule to Reduce nodes reducing over a parallelized - * Reduce, as long as the base Write node also has position indices of the - * ThreadID of the outer fork. In other words, the complete Reduce chain is - * annotated with ParallelReduce, as long as each ThreadID dimension appears in - * the positional indexing of the original Write. + * ThreadID nodes attached to the corresponding Fork, and data of the Write is + * not in the Reduce node's cycle. This procedure also adds the ParallelReduce + * schedule to Reduce nodes reducing over a parallelized Reduce, as long as the + * base Write node also has position indices of the ThreadID of the outer fork. + * In other words, the complete Reduce chain is annotated with ParallelReduce, + * as long as each ThreadID dimension appears in the positional indexing of the + * original Write. */ -pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { +pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) { for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_reduce() { @@ -73,10 +70,11 @@ pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMa // Check for a Write-Reduce tight cycle. if let Node::Write { collect, - data: _, + data, indices, } = &func.nodes[chain_id.idx()] && *collect == last_reduce + && !reduce_cycles[&last_reduce].contains(data) { // If there is a Write-Reduce tight cycle, get the position indices. let positions = indices @@ -146,17 +144,15 @@ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<N } /* - * Infer associative reduction loops. + * Infer tight associative reduction loops. Exactly one of the associative + * operation's operands must be the Reduce node, and all other operands must + * not be in the Reduce node's cycle. */ -pub fn infer_associative(editor: &mut FunctionEditor) { - let is_associative = |op| match op { - BinaryOperator::Add - | BinaryOperator::Mul - | BinaryOperator::Or - | BinaryOperator::And - | BinaryOperator::Xor => true, - _ => false, - }; +pub fn infer_tight_associative(editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) { + let is_binop_associative = |op| matches!(op, + BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor); + let is_intrinsic_associative = |intrinsic| matches!(intrinsic, + Intrinsic::Max | Intrinsic::Min); for id in editor.node_ids() { let func = editor.func(); @@ -165,11 +161,15 @@ pub fn infer_associative(editor: &mut FunctionEditor) { init: _, reduct, } = func.nodes[id.idx()] - && let Node::Binary { left, right, op } = func.nodes[reduct.idx()] - && (left == id || right == id) - && is_associative(op) + && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } + if ((left == id && !reduce_cycles[&id].contains(&right)) || + (right == id && !reduce_cycles[&id].contains(&left))) && + is_binop_associative(op)) || + matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } + if (args.contains(&id) && is_intrinsic_associative(*intrinsic) && + args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg))))) { - editor.edit(|edit| edit.add_schedule(id, Schedule::Associative)); + editor.edit(|edit| edit.add_schedule(id, Schedule::TightAssociative)); } } }