From be6c3a90a2d3f35f251915e216c81bb9a85f2449 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 3 Mar 2025 16:36:31 -0600
Subject: [PATCH] New backprop schedule

---
 juno_samples/rodinia/backprop/src/backprop.jn |  8 ++++----
 juno_samples/rodinia/backprop/src/cpu.sch     | 15 ++++++++++++++-
 2 files changed, 18 insertions(+), 5 deletions(-)

diff --git a/juno_samples/rodinia/backprop/src/backprop.jn b/juno_samples/rodinia/backprop/src/backprop.jn
index 70894c17..7851cf47 100644
--- a/juno_samples/rodinia/backprop/src/backprop.jn
+++ b/juno_samples/rodinia/backprop/src/backprop.jn
@@ -91,15 +91,15 @@ fn backprop<input_n, hidden_n, output_n: usize>(
 ) -> f32, f32,
      f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1],
      f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] {
-  let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
-  let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
+  @forward_input let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights);
+  @forward_hidden let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights);
 
   @output_error let out_err, out_delta = output_error::<output_n>(target, output_vals);
   @hidden_error let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals);
 
-  let hidden_weights, hidden_prev_weights
+  @adjust_hidden let hidden_weights, hidden_prev_weights
     = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights);
-  let input_weights, input_prev_weights
+  @adjust_input let input_weights, input_prev_weights
     = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights);
 
   return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights;
diff --git a/juno_samples/rodinia/backprop/src/cpu.sch b/juno_samples/rodinia/backprop/src/cpu.sch
index 6899523e..9a5f4d75 100644
--- a/juno_samples/rodinia/backprop/src/cpu.sch
+++ b/juno_samples/rodinia/backprop/src/cpu.sch
@@ -30,11 +30,24 @@ simpl!(*);
 
 infer-schedules(*);
 
+// The first call to layer_forward can be parallelized by 16 (the size of the
+// hidden layer) and the second can't be parallelized at all (the size of the
+// output layer is 1)
+inline(backprop@forward_input, backprop@forward_hidden);
+let forward_input = outline(backprop@forward_input);
+let forward_hidden = outline(backprop@forward_hidden);
+
+fork-tile[16, 0, false, true](forward_input@outer_loop \ forward_input@inner_loop);
+let (outer, inner) = fork-reshape[[1], [0]](forward_input@outer_loop \ forward_input@inner_loop);
+let forward_input = outline(inner);
+inline(backprop@forward_input);
+
 delete-uncalled(*);
 const-inline(*);
 
 simpl!(*);
 fork-split(*);
-unforkify(*);
+unforkify(output_error, hidden_error, adjust_weights, forward_hidden, forward_input);
+simpl!(*);
 
 gcm(*);
-- 
GitLab