From 555cd43175109a332d966fb124977d3ad00a0334 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Mon, 3 Feb 2025 20:12:25 -0600 Subject: [PATCH] Fix incorrect inference of no-memset --- hercules_ir/src/collections.rs | 16 ++-------- hercules_ir/src/einsum.rs | 29 +++++++++++++++++++ hercules_samples/matmul/src/gpu.sch | 2 ++ .../fork_join_tests/src/fork_join_tests.jn | 12 ++++---- juno_samples/fork_join_tests/src/gpu.sch | 8 +++-- juno_scheduler/src/pm.rs | 2 +- 6 files changed, 46 insertions(+), 23 deletions(-) diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index d9d6c8f7..1bc650e9 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::iter::{once, repeat, zip}; use either::Either; @@ -432,20 +432,8 @@ pub fn no_reset_constant_collections( init: _, reduct: _, } => { - // If the einsum for this reduce node is a full array - // comprehension, then every array element is written to, and - // the empty indices set (the whole collection) is considered as - // written to. - let (env, exprs) = reduce_einsum; - if let Some(expr) = exprs.get(&id) - && let MathExpr::Comprehension(_, _) = env[expr.idx()] - { - ZeroLattice::top() - } // Otherwise, meet the `init` and `reduct` inputs. - else { - ZeroLattice::meet(&inputs[0], &inputs[1]) - } + ZeroLattice::meet(&inputs[0], &inputs[1]) } Node::Write { collect: _, diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index d1c276e5..8d3bec3a 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -400,6 +400,35 @@ impl<'a> EinsumContext<'a> { } } +pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> { + let mut set = HashSet::new(); + let mut stack = vec![id]; + while let Some(id) = stack.pop() { + match env[id.idx()] { + MathExpr::Zero(_) | MathExpr::One(_) | MathExpr::ThreadID(_) => {} + MathExpr::OpaqueNode(id) => { + set.insert(id); + } + MathExpr::SumReduction(id, _) | MathExpr::Comprehension(id, _) => { + stack.push(id); + } + MathExpr::Read(id, ref ids) => { + stack.push(id); + stack.extend(ids); + } + MathExpr::Add(left, right) + | MathExpr::Sub(left, right) + | MathExpr::Mul(left, right) + | MathExpr::Div(left, right) + | MathExpr::Rem(left, right) => { + stack.push(left); + stack.push(right); + } + } + } + set +} + fn representable(op: BinaryOperator) -> bool { match op { BinaryOperator::Add diff --git a/hercules_samples/matmul/src/gpu.sch b/hercules_samples/matmul/src/gpu.sch index 4303d376..e5d3bbde 100644 --- a/hercules_samples/matmul/src/gpu.sch +++ b/hercules_samples/matmul/src/gpu.sch @@ -1,3 +1,5 @@ +no-memset(matmul@c); + gvn(*); phi-elim(*); dce(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index fdfa51a8..6e5db4cb 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -1,6 +1,6 @@ #[entry] fn test1(input : i32) -> i32[4, 4] { - let arr : i32[4, 4]; + @const let arr : i32[4, 4]; for i = 0 to 4 { for j = 0 to 4 { arr[i, j] = input; @@ -24,19 +24,19 @@ fn test2(input : i32) -> i32[4, 4] { #[entry] fn test3(input : i32) -> i32[3, 3] { - let arr1 : i32[3, 3]; + @const1 let arr1 : i32[3, 3]; for i = 0 to 3 { for j = 0 to 3 { arr1[i, j] = (i + j) as i32 + input; } } - let arr2 : i32[3, 3]; + @const2 let arr2 : i32[3, 3]; for i = 0 to 3 { for j = 0 to 3 { arr2[i, j] = arr1[2 - i, 2 - j]; } } - let arr3 : i32[3, 3]; + @const3 let arr3 : i32[3, 3]; for i = 0 to 3 { for j = 0 to 3 { arr3[i, j] = arr2[i, j] + 7; @@ -54,7 +54,7 @@ fn test4(input : i32) -> i32[4, 4] { for k = 0 to 7 { acc += input; } - arr[i, j] = acc; + @reduce arr[i, j] = acc; } } return arr; @@ -62,7 +62,7 @@ fn test4(input : i32) -> i32[4, 4] { #[entry] fn test5(input : i32) -> i32[4] { - @cons let arr1 : i32[4]; + let arr1 : i32[4]; for i = 0 to 4 { let red = arr1[i]; for k = 0 to 3 { diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 0987083e..f096ea50 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -1,5 +1,9 @@ -no-memset(test5@cons); +parallel-reduce(test4@reduce); parallel-reduce(test5@reduce); +no-memset(test1@const); +no-memset(test3@const1); +no-memset(test3@const2); +no-memset(test3@const3); gvn(*); phi-elim(*); @@ -33,5 +37,5 @@ fixpoint panic after 20 { infer-schedules(*); } -float-collections(test2, out.test2, test4, out.test4); +float-collections(test2, out.test2, test4, out.test4, test5, out.test5); gcm(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 20a3dba8..9c445519 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1329,7 +1329,7 @@ fn run_pass( pm: &mut PassManager, pass: Pass, args: Vec<Value>, - mut selection: Option<Vec<CodeLocation>>, + selection: Option<Vec<CodeLocation>>, ) -> Result<(Value, bool), SchedulerError> { let mut result = Value::Record { fields: HashMap::new(), -- GitLab