From e7d5bde00fc5dc4aa258f2686a03496e63fa3ba3 Mon Sep 17 00:00:00 2001 From: Praneet Rathi <prrathi10@gmail.com> Date: Sun, 5 Jan 2025 13:42:02 -0600 Subject: [PATCH 1/3] associative intrinsic --- hercules_opt/src/pass.rs | 1 + hercules_opt/src/schedule.rs | 15 ++++-- juno_samples/matmul/src/matmul_indented.jn | 57 ++++++++++++++++++++++ 3 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 juno_samples/matmul/src/matmul_indented.jn diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index bcaba374..60a6ee24 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -867,6 +867,7 @@ impl PassManager { infer_parallel_fork(&mut editor, &fork_join_maps[idx]); infer_vectorizable(&mut editor, &fork_join_maps[idx]); infer_associative(&mut editor); + infer_tight_associative(&mut editor); self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index b65b4d04..9d73b11d 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -149,14 +149,18 @@ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<N * Infer associative reduction loops. */ pub fn infer_associative(editor: &mut FunctionEditor) { - let is_associative = |op| match op { + let is_binop_associative = |op| match op { BinaryOperator::Add - | BinaryOperator::Mul | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor => true, _ => false, }; + let is_intrinsic_associative = |intrinsic| match intrinsic { + Intrinsic::Max + | Intrinsic::Min => true, + _ => false, + }; for id in editor.node_ids() { let func = editor.func(); @@ -165,9 +169,10 @@ pub fn infer_associative(editor: &mut FunctionEditor) { init: _, reduct, } = func.nodes[id.idx()] - && let Node::Binary { left, right, op } = func.nodes[reduct.idx()] - && (left == id || right == id) - && is_associative(op) + && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } + if (left == id || right == id) && is_binop_associative(op)) || + matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } + if (args.contains(&id) && is_intrinsic_associative(*intrinsic)))) { editor.edit(|edit| edit.add_schedule(id, Schedule::Associative)); } diff --git a/juno_samples/matmul/src/matmul_indented.jn b/juno_samples/matmul/src/matmul_indented.jn new file mode 100644 index 00000000..12039f52 --- /dev/null +++ b/juno_samples/matmul/src/matmul_indented.jn @@ -0,0 +1,57 @@ +#[entry] +fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { + let res : i32[n, l]; + + @outer for i = 0 to n { + @middle for j = 0 to l { + @inner for k = 0 to m { + res[i, j] += a[i, k] * b[k, j]; + } + } + } + + @exit + return res; +} + +/* +#[entry] +fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { + let res : i32[n, l]; + + for bi = 0 to n / 64 { + for bk = 0 to l / 64 { + let atile : i32[66, 64]; + let btile : i32[65, 64]; + let ctile : i32[64, 64]; + + for tile_idx = 0 to m / 64 { + for ti = 0 to 64 { + for tk = 0 to 64 { + atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk]; + btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk]; + ctile[ti, tk] = 0; + } + } + for ti = 0 to 64 { + for tk = 0 to 64 { + let c_acc = ctile[ti, tk]; + for inner_idx = 0 to 64 { + c_acc += atile[ti, inner_idx] * btile[inner_idx, tk]; + } + ctile[ti, tk] = c_acc; + } + } + } + + for ti = 0 to 64 { + for tk = 0 to 64 { + res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk]; + } + } + } + } + + return res; +} +*/ \ No newline at end of file -- GitLab From d3ebaeed8b5a5b00d318a6e0a9b2435dcd24c5d0 Mon Sep 17 00:00:00 2001 From: Praneet Rathi <prrathi10@gmail.com> Date: Sun, 5 Jan 2025 14:17:49 -0600 Subject: [PATCH 2/3] add tightness check --- hercules_ir/src/ir.rs | 2 +- hercules_opt/src/pass.rs | 7 +++-- hercules_opt/src/schedule.rs | 59 +++++++++++++++++------------------- 3 files changed, 32 insertions(+), 36 deletions(-) diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index b46e2dda..11d23e61 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -322,7 +322,7 @@ pub enum Schedule { Vectorizable, // This reduce can be re-associated. This may lower a sequential dependency // chain into a reduction tree. - Associative, + TightAssociative, } /* diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 60a6ee24..3b7c81ed 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -848,8 +848,10 @@ impl PassManager { Pass::InferSchedules => { self.make_def_uses(); self.make_fork_join_maps(); + self.make_reduce_cycles(); let def_uses = self.def_uses.as_ref().unwrap(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); + let reduce_cycles = self.reduce_cycles.as_ref().unwrap(); for idx in 0..self.module.functions.len() { let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); @@ -863,11 +865,10 @@ impl PassManager { &types_ref, &def_uses[idx], ); - infer_parallel_reduce(&mut editor, &fork_join_maps[idx]); + infer_parallel_reduce(&mut editor, &fork_join_maps[idx], &reduce_cycles[idx]); infer_parallel_fork(&mut editor, &fork_join_maps[idx]); infer_vectorizable(&mut editor, &fork_join_maps[idx]); - infer_associative(&mut editor); - infer_tight_associative(&mut editor); + infer_tight_associative(&mut editor, &reduce_cycles[idx]); self.module.constants = constants_ref.take(); self.module.dynamic_constants = dynamic_constants_ref.take(); diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs index 9d73b11d..ff895b16 100644 --- a/hercules_opt/src/schedule.rs +++ b/hercules_opt/src/schedule.rs @@ -1,6 +1,6 @@ extern crate hercules_ir; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use self::hercules_ir::def_use::*; use self::hercules_ir::ir::*; @@ -14,13 +14,9 @@ use crate::*; pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { for id in editor.node_ids() { let func = editor.func(); - let Node::Fork { - control: _, - factors: _, - } = func.nodes[id.idx()] - else { + if !func.nodes[id.idx()].is_fork() { continue; - }; + } let join_id = fork_join_map[&id]; let all_parallel_reduce = editor.get_users(join_id).all(|user| { func.schedules[user.idx()].contains(&Schedule::ParallelReduce) @@ -35,14 +31,15 @@ pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap< /* * Infer parallel reductions consisting of a simple cycle between a Reduce node * and a Write node, where indices of the Write are position indices using the - * ThreadID nodes attached to the corresponding Fork. This procedure also adds - * the ParallelReduce schedule to Reduce nodes reducing over a parallelized - * Reduce, as long as the base Write node also has position indices of the - * ThreadID of the outer fork. In other words, the complete Reduce chain is - * annotated with ParallelReduce, as long as each ThreadID dimension appears in - * the positional indexing of the original Write. + * ThreadID nodes attached to the corresponding Fork, and data of the Write is + * not in the Reduce node's cycle. This procedure also adds the ParallelReduce + * schedule to Reduce nodes reducing over a parallelized Reduce, as long as the + * base Write node also has position indices of the ThreadID of the outer fork. + * In other words, the complete Reduce chain is annotated with ParallelReduce, + * as long as each ThreadID dimension appears in the positional indexing of the + * original Write. */ -pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { +pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) { for id in editor.node_ids() { let func = editor.func(); if !func.nodes[id.idx()].is_reduce() { @@ -73,10 +70,11 @@ pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMa // Check for a Write-Reduce tight cycle. if let Node::Write { collect, - data: _, + data, indices, } = &func.nodes[chain_id.idx()] && *collect == last_reduce + && !reduce_cycles[&last_reduce].contains(data) { // If there is a Write-Reduce tight cycle, get the position indices. let positions = indices @@ -146,21 +144,15 @@ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<N } /* - * Infer associative reduction loops. + * Infer tight associative reduction loops. Exactly one of the associative + * operation's operands must be the Reduce node, and all other operands must + * not be in the Reduce node's cycle. */ -pub fn infer_associative(editor: &mut FunctionEditor) { - let is_binop_associative = |op| match op { - BinaryOperator::Add - | BinaryOperator::Or - | BinaryOperator::And - | BinaryOperator::Xor => true, - _ => false, - }; - let is_intrinsic_associative = |intrinsic| match intrinsic { - Intrinsic::Max - | Intrinsic::Min => true, - _ => false, - }; +pub fn infer_tight_associative(editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) { + let is_binop_associative = |op| matches!(op, + BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor); + let is_intrinsic_associative = |intrinsic| matches!(intrinsic, + Intrinsic::Max | Intrinsic::Min); for id in editor.node_ids() { let func = editor.func(); @@ -170,11 +162,14 @@ pub fn infer_associative(editor: &mut FunctionEditor) { reduct, } = func.nodes[id.idx()] && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } - if (left == id || right == id) && is_binop_associative(op)) || + if ((left == id && !reduce_cycles[&id].contains(&right)) || + (right == id && !reduce_cycles[&id].contains(&left))) && + is_binop_associative(op)) || matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args } - if (args.contains(&id) && is_intrinsic_associative(*intrinsic)))) + if (args.contains(&id) && is_intrinsic_associative(*intrinsic) && + args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg))))) { - editor.edit(|edit| edit.add_schedule(id, Schedule::Associative)); + editor.edit(|edit| edit.add_schedule(id, Schedule::TightAssociative)); } } } -- GitLab From 05ca328518e2201db0337dc78fb68b9b5706d512 Mon Sep 17 00:00:00 2001 From: prathi3 <prathi3@illinois.edu> Date: Sun, 5 Jan 2025 20:50:26 -0600 Subject: [PATCH 3/3] Delete matmul_indented.jn --- juno_samples/matmul/src/matmul_indented.jn | 57 ---------------------- 1 file changed, 57 deletions(-) delete mode 100644 juno_samples/matmul/src/matmul_indented.jn diff --git a/juno_samples/matmul/src/matmul_indented.jn b/juno_samples/matmul/src/matmul_indented.jn deleted file mode 100644 index 12039f52..00000000 --- a/juno_samples/matmul/src/matmul_indented.jn +++ /dev/null @@ -1,57 +0,0 @@ -#[entry] -fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { - let res : i32[n, l]; - - @outer for i = 0 to n { - @middle for j = 0 to l { - @inner for k = 0 to m { - res[i, j] += a[i, k] * b[k, j]; - } - } - } - - @exit - return res; -} - -/* -#[entry] -fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { - let res : i32[n, l]; - - for bi = 0 to n / 64 { - for bk = 0 to l / 64 { - let atile : i32[66, 64]; - let btile : i32[65, 64]; - let ctile : i32[64, 64]; - - for tile_idx = 0 to m / 64 { - for ti = 0 to 64 { - for tk = 0 to 64 { - atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk]; - btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk]; - ctile[ti, tk] = 0; - } - } - for ti = 0 to 64 { - for tk = 0 to 64 { - let c_acc = ctile[ti, tk]; - for inner_idx = 0 to 64 { - c_acc += atile[ti, inner_idx] * btile[inner_idx, tk]; - } - ctile[ti, tk] = c_acc; - } - } - } - - for ti = 0 to 64 { - for tk = 0 to 64 { - res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk]; - } - } - } - } - - return res; -} -*/ \ No newline at end of file -- GitLab