From 51e246a500198d922b424e34f6ea6b3d5f2fda77 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 18 Feb 2025 12:57:50 -0600
Subject: [PATCH] gpu schedule

---
 juno_samples/cava/src/cava.jn |   2 +-
 juno_samples/cava/src/cpu.sch |   3 +
 juno_samples/cava/src/gpu.sch | 144 ++++++++++++++++++++++++++++++----
 3 files changed, 133 insertions(+), 16 deletions(-)

diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index 8158bf0a..0c74646c 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -40,7 +40,7 @@ fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row,
 fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] {
   @res2 let res : f32[CHAN, row, col];
 
-  for r = 1 to row-1 {
+  @loop for r = 1 to row-1 {
     for c = 1 to col-1 {
       if r % 2 == 0 && c % 2 == 0 {
         let R1 = input[0, r, c-1];
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 6cd33a3b..3ae1c6bf 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -46,6 +46,9 @@ fixpoint {
 }
 predication(fuse1);
 simpl!(fuse1);
+write-predication(fuse1);
+simpl!(fuse1);
+parallel-reduce(fuse1@loop);
 
 inline(fuse2);
 no-memset(fuse2@res);
diff --git a/juno_samples/cava/src/gpu.sch b/juno_samples/cava/src/gpu.sch
index f440dacd..5772d56d 100644
--- a/juno_samples/cava/src/gpu.sch
+++ b/juno_samples/cava/src/gpu.sch
@@ -1,23 +1,137 @@
-gvn(*);
-phi-elim(*);
-dce(*);
+macro simpl!(X) {
+  ccp(X);
+  simplify-cfg(X);
+  lift-dc-math(X);
+  gvn(X);
+  phi-elim(X);
+  dce(X);
+  infer-schedules(X);
+}
+
+simpl!(*);
+
+let fuse1 = outline(cava@fuse1);
+inline(fuse1);
+gpu(fuse1);
+
+let fuse2 = outline(cava@fuse2);
+inline(fuse2);
+gpu(fuse2);
+
+let fuse3 = outline(cava@fuse3);
+inline(fuse3);
+gpu(fuse3);
 
-inline(denoise);
-gpu(scale, demosaic, denoise, transform, gamut, tone_map, descale);
+let fuse4 = outline(cava@fuse4);
+inline(fuse4);
+gpu(fuse4);
+
+let fuse5 = outline(cava@fuse5);
+inline(fuse5);
+gpu(fuse5);
 
 ip-sroa(*);
 sroa(*);
-dce(*);
-gvn(*);
-phi-elim(*);
-dce(*);
+simpl!(*);
 
-// forkify(*);
-infer-schedules(*);
+no-memset(fuse1@res1);
+no-memset(fuse1@res2);
+fixpoint {
+  forkify(fuse1);
+  fork-guard-elim(fuse1);
+  fork-coalesce(fuse1);
+}
+simpl!(fuse1);
+array-slf(fuse1);
+loop-bound-canon(fuse1);
+fixpoint {
+  forkify(fuse1);
+  fork-guard-elim(fuse1);
+  fork-coalesce(fuse1);
+}
+predication(fuse1);
+simpl!(fuse1);
+write-predication(fuse1);
+simpl!(fuse1);
+parallel-reduce(fuse1@loop);
 
-gcm(*);
+inline(fuse2);
+no-memset(fuse2@res);
+no-memset(fuse2@filter);
+no-memset(fuse2@tmp);
+fixpoint {
+  forkify(fuse2);
+  fork-guard-elim(fuse2);
+  fork-coalesce(fuse2);
+}
+simpl!(fuse2);
+predication(fuse2);
+simpl!(fuse2);
+
+let median = outline(fuse2@median);
+fork-unroll(median@medianOuter);
+simpl!(median);
+fixpoint {
+  forkify(median);
+  fork-guard-elim(median);
+}
+simpl!(median);
+fixpoint {
+  fork-unroll(median);
+}
+ccp(median);
+array-to-product(median);
+sroa(median);
+phi-elim(median);
+predication(median);
+simpl!(median);
+
+inline(fuse2);
+ip-sroa(*);
+sroa(*);
+array-slf(fuse2);
+write-predication(fuse2);
+simpl!(fuse2);
+
+no-memset(fuse3@res);
 fixpoint {
-  float-collections(*);
-  dce(*);
-  gcm(*);
+  forkify(fuse3);
+  fork-guard-elim(fuse3);
+  fork-coalesce(fuse3);
 }
+simpl!(fuse3);
+
+no-memset(fuse4@res);
+no-memset(fuse4@l2);
+fixpoint {
+  forkify(fuse4);
+  fork-guard-elim(fuse4);
+  fork-coalesce(fuse4);
+}
+simpl!(fuse4);
+fork-unroll(fuse4@channel_loop);
+simpl!(fuse4);
+fixpoint {
+  fork-fusion(fuse4@channel_loop);
+}
+simpl!(fuse4);
+array-slf(fuse4);
+simpl!(fuse4);
+
+no-memset(fuse5@res1);
+no-memset(fuse5@res2);
+fixpoint {
+  forkify(fuse5);
+  fork-guard-elim(fuse5);
+  fork-coalesce(fuse5);
+}
+simpl!(fuse5);
+array-slf(fuse5);
+simpl!(fuse5);
+
+delete-uncalled(*);
+simpl!(*);
+
+delete-uncalled(*);
+gcm(*);
+
-- 
GitLab