From 1aceb18f344505c2433cc54ab40e0a4538318f5e Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 24 Feb 2025 20:35:12 -0600
Subject: [PATCH] parallelize gamut in cava on cpu

---
 hercules_opt/src/fork_transforms.rs |  4 +++-
 juno_samples/cava/src/cava.jn       |  2 +-
 juno_samples/cava/src/cpu.sch       | 12 ++++++++++--
 3 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 0e943973..ff0f0283 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -916,7 +916,9 @@ pub fn chunk_all_forks_unguarded(
     };
 
     for (fork, _) in fork_join_map {
-        chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order);
+        if editor.is_mutable(*fork) {
+            chunk_fork_unguarded(editor, *fork, dim_idx, dc_id, order);
+        }
     }
 }
 // Splits a dimension of a single fork join into multiple.
diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn
index dbe799f9..4d02b2cd 100644
--- a/juno_samples/cava/src/cava.jn
+++ b/juno_samples/cava/src/cava.jn
@@ -142,7 +142,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>(
 ) -> f32[CHAN, row, col] {
   @res let result : f32[CHAN, row, col];
 
-  for r = 0 to row {
+  @image_loop for r = 0 to row {
     for c = 0 to col {
       @l2 let l2_dist : f32[num_ctrl_pts];
       for cp = 0 to num_ctrl_pts {
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index efa7302e..8f22b37d 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -113,6 +113,14 @@ fixpoint {
 simpl!(fuse4);
 array-slf(fuse4);
 simpl!(fuse4);
+let par = fuse4@image_loop \ fuse4@channel_loop;
+fork-tile[4, 1, false, false](par);
+fork-tile[4, 0, false, false](par);
+fork-interchange[1, 2](par);
+let split = fork-split(par);
+let fuse4_body = outline(split.cava_3.fj2);
+fork-coalesce(fuse4, fuse4_body);
+simpl!(fuse4, fuse4_body);
 
 no-memset(fuse5@res1);
 no-memset(fuse5@res2);
@@ -128,8 +136,8 @@ simpl!(fuse5);
 delete-uncalled(*);
 simpl!(*);
 
-fork-split(fuse1, fuse2, fuse3, fuse4, fuse5);
-unforkify(fuse1, fuse2, fuse3, fuse4, fuse5);
+fork-split(fuse1, fuse2, fuse3, fuse4_body, fuse5);
+unforkify(fuse1, fuse2, fuse3, fuse4_body, fuse5);
 
 simpl!(*);
 
-- 
GitLab