From b156fad632e6d0a7746b81c8bcf9e9dda0ef9977 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 3 Mar 2025 15:56:23 -0600
Subject: [PATCH] Parallelize backprop

---
 juno_samples/rodinia/backprop/src/backprop.jn |  4 +--
 juno_samples/rodinia/backprop/src/cpu.sch     | 27 +++++++++++++++++--
 2 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/juno_samples/rodinia/backprop/src/backprop.jn b/juno_samples/rodinia/backprop/src/backprop.jn
index 2ca57c9f..70894c17 100644
--- a/juno_samples/rodinia/backprop/src/backprop.jn
+++ b/juno_samples/rodinia/backprop/src/backprop.jn
@@ -69,8 +69,8 @@ fn adjust_weights<n, m: usize>(
   weights: f32[n + 1, m + 1],
   prev_weights: f32[n + 1, m + 1]
 ) -> f32[n + 1, m + 1], f32[n + 1, m + 1] {
-  for j in 1..=m {
-    for k in 0..=n {
+  @outer_loop for j in 1..=m {
+    @inner_loop for k in 0..=n {
       let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j];
       weights[k, j] += new_dw;
       prev_weights[k, j] = new_dw;
diff --git a/juno_samples/rodinia/backprop/src/cpu.sch b/juno_samples/rodinia/backprop/src/cpu.sch
index 661ec531..865cc1a2 100644
--- a/juno_samples/rodinia/backprop/src/cpu.sch
+++ b/juno_samples/rodinia/backprop/src/cpu.sch
@@ -28,6 +28,29 @@ simpl!(*);
 fork-interchange[0, 1](adjust_weights);
 simpl!(*);
 
-fork-split(*);
-unforkify(*);
+infer-schedules(*);
+
+fork-tile[32, 0, false, true](layer_forward@outer_loop \ layer_forward@inner_loop);
+let (forward_outer, forward_inner) = fork-reshape[[1], [0]](layer_forward@outer_loop \ layer_forward@inner_loop);
+
+fork-tile[32, 0, false, true](adjust_weights);
+let (adjust_outer, adjust_inner) = fork-reshape[[1], [0, 2]](adjust_weights);
+
+let forward_body = outline(forward_inner);
+let adjust_body = outline(adjust_inner);
+
+rename["output_error"](output_error);
+rename["hidden_error"](hidden_error);
+
+let output_error_body = auto-outline(output_error).output_error;
+let hidden_error_body = auto-outline(hidden_error).hidden_error;
+
+inline(backprop);
+delete-uncalled(*);
+const-inline(*);
+
+simpl!(*);
+fork-split(forward_body, adjust_body, output_error_body, hidden_error_body);
+unforkify(forward_body, adjust_body, output_error_body, hidden_error_body);
+
 gcm(*);
-- 
GitLab