From 28ea01a1afcbb630b570e360693defb3e76c70d9 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 28 Feb 2025 13:24:48 -0600
Subject: [PATCH] Infer more indices as parallel

---
 hercules_opt/src/utils.rs                     | 18 ++++++++++++++++++
 juno_samples/rodinia/backprop/src/backprop.jn |  4 ++--
 juno_samples/rodinia/backprop/src/cpu.sch     | 10 +++-------
 3 files changed, 23 insertions(+), 9 deletions(-)

diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index e962b81d..b910a128 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -532,6 +532,24 @@ where
     let fork_thread_id_pairs = node_indices(indices).filter_map(|id| {
         if let Node::ThreadID { control, dimension } = nodes[id.idx()] {
             Some((control, dimension))
+        } else if let Node::Binary {
+            op: BinaryOperator::Add,
+            left: tid,
+            right: cons,
+        } = nodes[id.idx()]
+            && let Node::ThreadID { control, dimension } = nodes[tid.idx()]
+            && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant())
+        {
+            Some((control, dimension))
+        } else if let Node::Binary {
+            op: BinaryOperator::Add,
+            left: cons,
+            right: tid,
+        } = nodes[id.idx()]
+            && let Node::ThreadID { control, dimension } = nodes[tid.idx()]
+            && (nodes[cons.idx()].is_constant() || nodes[cons.idx()].is_dynamic_constant())
+        {
+            Some((control, dimension))
         } else {
             None
         }
diff --git a/juno_samples/rodinia/backprop/src/backprop.jn b/juno_samples/rodinia/backprop/src/backprop.jn
index 356bb3d9..94c4334c 100644
--- a/juno_samples/rodinia/backprop/src/backprop.jn
+++ b/juno_samples/rodinia/backprop/src/backprop.jn
@@ -7,9 +7,9 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f
   @res let result : f32[m + 1];
   result[0] = 1.0;
 
-  for j in 1..=m {
+  @outer_loop for j in 1..=m {
     let sum = 0.0;
-    for k in 0..=n {
+    @inner_loop for k in 0..=n {
       sum += weights[k, j] * vals[k];
     }
     result[j] = squash(sum);
diff --git a/juno_samples/rodinia/backprop/src/cpu.sch b/juno_samples/rodinia/backprop/src/cpu.sch
index d1fe8953..d59fd5f5 100644
--- a/juno_samples/rodinia/backprop/src/cpu.sch
+++ b/juno_samples/rodinia/backprop/src/cpu.sch
@@ -15,20 +15,16 @@ delete-uncalled(*);
 no-memset(layer_forward@res);
 lift-dc-math(*);
 loop-bound-canon(*);
-dce(*);
+simpl!(*);
 lift-dc-math(*);
+slf(*);
 fixpoint {
   forkify(*);
   fork-guard-elim(*);
   fork-coalesce(*);
 }
+simpl!(*);
 
 fork-split(*);
-gvn(*);
-phi-elim(*);
-dce(*);
 unforkify(*);
-gvn(*);
-phi-elim(*);
-dce(*);
 gcm(*);
-- 
GitLab