From d4280e531b346797827cc23653b23b196cdc08f4 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 24 Feb 2025 14:22:44 -0600 Subject: [PATCH] Simple conv test --- juno_samples/fork_join_tests/src/cpu.sch | 9 ++++--- .../fork_join_tests/src/fork_join_tests.jn | 24 +++++++++++++++++++ juno_samples/fork_join_tests/src/gpu.sch | 6 +++-- juno_samples/fork_join_tests/src/main.rs | 19 ++++++++++++++- 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 76dcbdf6..f46c91d6 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -3,7 +3,7 @@ gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9); cpu(auto.test1); cpu(auto.test2); cpu(auto.test3); @@ -11,6 +11,7 @@ cpu(auto.test4); cpu(auto.test5); cpu(auto.test7); cpu(auto.test8); +cpu(auto.test9); let test1_cpu = auto.test1; rename["test1_cpu"](test1_cpu); @@ -51,11 +52,11 @@ fixpoint panic after 20 { unroll(auto.test1); } -fork-split(auto.test2, auto.test3, auto.test4, auto.test5); +fork-split(auto.test2, auto.test3, auto.test4, auto.test5, auto.test9); gvn(*); phi-elim(*); dce(*); -unforkify(auto.test2, auto.test3, auto.test4, auto.test5); +unforkify(auto.test2, auto.test3, auto.test4, auto.test5, auto.test9); ccp(*); gvn(*); phi-elim(*); @@ -93,4 +94,6 @@ dce(auto.test8); simplify-cfg(auto.test8); dce(auto.test8); +no-memset(test9@const); + gcm(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index bfb5564b..3b7c7833 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -122,3 +122,27 @@ fn test8(input : i32) -> i32[8] { } return out; } + +#[entry] +fn test9<r, c, z : usize>(input : i32[r, c]) -> i32[r, c] { + const rad = z / 2; + @const let out : i32[r, c]; + + for ir = 0 to r { + for ic = 0 to c { + let acc = 0; + @filter_loop for sr = 0 to z { + for sc = 0 to z { + acc += if ir + sr < rad then 0 + else if ir + sr - rad > r - 1 then 0 + else if ic + sc < rad then 0 + else if ic + sc - rad > c - 1 then 0 + else input[ir + sr - rad, ic + sc - rad]; + } + } + out[ir, ic] = acc; + } + } + + return out; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 364673cd..c554fd50 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -7,12 +7,13 @@ no-memset(test3@const3); no-memset(test6@const); no-memset(test8@const1); no-memset(test8@const2); +no-memset(test9@const); gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8); +let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8, test9); gpu(auto.test1); gpu(auto.test2); gpu(auto.test3); @@ -20,6 +21,7 @@ gpu(auto.test4); gpu(auto.test5); gpu(auto.test7); gpu(auto.test8); +gpu(auto.test9); ip-sroa(*); sroa(*); @@ -34,7 +36,7 @@ fixpoint panic after 20 { } fixpoint panic after 20 { - fork-coalesce(auto.test1, auto.test3, auto.test4, auto.test5, auto.test7, auto.test8); + fork-coalesce(auto.test1, auto.test3, auto.test4, auto.test5, auto.test7, auto.test8, auto.test9); } gvn(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index cd715cac..fa99f759 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -1,6 +1,6 @@ #![feature(concat_idents)] -use hercules_rt::runner; +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; juno_build::juno!("fork_join_tests"); @@ -57,6 +57,23 @@ fn main() { let output = r.run(0).await; let correct = vec![10, 17, 24, 31, 38, 45, 52, 59]; assert(&correct, output); + + let mut r = runner!(test9); + let input = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; + let input = HerculesImmBox::from(&input as &[i32]); + let output = r.run(3, 3, 3, input.to()).await; + let correct = vec![ + 1 + 2 + 4 + 5, + 1 + 2 + 3 + 4 + 5 + 6, + 2 + 3 + 5 + 6, + 1 + 2 + 4 + 5 + 7 + 8, + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9, + 2 + 3 + 5 + 6 + 8 + 9, + 4 + 5 + 7 + 8, + 4 + 5 + 6 + 7 + 8 + 9, + 5 + 6 + 8 + 9, + ]; + assert(&correct, output); }); } -- GitLab