From e7d5bde00fc5dc4aa258f2686a03496e63fa3ba3 Mon Sep 17 00:00:00 2001
From: Praneet Rathi <prrathi10@gmail.com>
Date: Sun, 5 Jan 2025 13:42:02 -0600
Subject: [PATCH 1/3] associative intrinsic

---
 hercules_opt/src/pass.rs                   |  1 +
 hercules_opt/src/schedule.rs               | 15 ++++--
 juno_samples/matmul/src/matmul_indented.jn | 57 ++++++++++++++++++++++
 3 files changed, 68 insertions(+), 5 deletions(-)
 create mode 100644 juno_samples/matmul/src/matmul_indented.jn

diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index bcaba374..60a6ee24 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -867,6 +867,7 @@ impl PassManager {
                         infer_parallel_fork(&mut editor, &fork_join_maps[idx]);
                         infer_vectorizable(&mut editor, &fork_join_maps[idx]);
                         infer_associative(&mut editor);
+                        infer_tight_associative(&mut editor);
 
                         self.module.constants = constants_ref.take();
                         self.module.dynamic_constants = dynamic_constants_ref.take();
diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs
index b65b4d04..9d73b11d 100644
--- a/hercules_opt/src/schedule.rs
+++ b/hercules_opt/src/schedule.rs
@@ -149,14 +149,18 @@ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<N
  * Infer associative reduction loops.
  */
 pub fn infer_associative(editor: &mut FunctionEditor) {
-    let is_associative = |op| match op {
+    let is_binop_associative = |op| match op {
         BinaryOperator::Add
-        | BinaryOperator::Mul
         | BinaryOperator::Or
         | BinaryOperator::And
         | BinaryOperator::Xor => true,
         _ => false,
     };
+    let is_intrinsic_associative = |intrinsic| match intrinsic {
+        Intrinsic::Max
+        | Intrinsic::Min => true,
+        _ => false,
+    };
 
     for id in editor.node_ids() {
         let func = editor.func();
@@ -165,9 +169,10 @@ pub fn infer_associative(editor: &mut FunctionEditor) {
             init: _,
             reduct,
         } = func.nodes[id.idx()]
-            && let Node::Binary { left, right, op } = func.nodes[reduct.idx()]
-            && (left == id || right == id)
-            && is_associative(op)
+            && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } 
+                if (left == id || right == id) && is_binop_associative(op)) || 
+            matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args }
+                if (args.contains(&id) && is_intrinsic_associative(*intrinsic))))
         {
             editor.edit(|edit| edit.add_schedule(id, Schedule::Associative));
         }
diff --git a/juno_samples/matmul/src/matmul_indented.jn b/juno_samples/matmul/src/matmul_indented.jn
new file mode 100644
index 00000000..12039f52
--- /dev/null
+++ b/juno_samples/matmul/src/matmul_indented.jn
@@ -0,0 +1,57 @@
+#[entry]
+fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
+  let res : i32[n, l];
+
+  @outer for i = 0 to n {
+    @middle for j = 0 to l {
+      @inner for k = 0 to m {
+        res[i, j] += a[i, k] * b[k, j];
+      }
+    }
+  }
+
+  @exit
+  return res;
+}
+
+/*
+#[entry]
+fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
+  let res : i32[n, l];
+  
+  for bi = 0 to n / 64 {
+    for bk = 0 to l / 64 {
+      let atile : i32[66, 64];
+      let btile : i32[65, 64];
+      let ctile : i32[64, 64];
+
+      for tile_idx = 0 to m / 64 {
+        for ti = 0 to 64 {
+          for tk = 0 to 64 {
+            atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk];
+            btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk];
+            ctile[ti, tk] = 0;
+          }
+        }
+        for ti = 0 to 64 {
+          for tk = 0 to 64 {
+            let c_acc = ctile[ti, tk];
+            for inner_idx = 0 to 64 {
+              c_acc += atile[ti, inner_idx] * btile[inner_idx, tk];
+            } 
+            ctile[ti, tk] = c_acc;
+          }
+        }
+      }
+
+      for ti = 0 to 64 {
+        for tk = 0 to 64 {
+          res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk];
+        }
+      }
+    }
+  }
+
+  return res;
+}
+*/
\ No newline at end of file
-- 
GitLab


From d3ebaeed8b5a5b00d318a6e0a9b2435dcd24c5d0 Mon Sep 17 00:00:00 2001
From: Praneet Rathi <prrathi10@gmail.com>
Date: Sun, 5 Jan 2025 14:17:49 -0600
Subject: [PATCH 2/3] add tightness check

---
 hercules_ir/src/ir.rs        |  2 +-
 hercules_opt/src/pass.rs     |  7 +++--
 hercules_opt/src/schedule.rs | 59 +++++++++++++++++-------------------
 3 files changed, 32 insertions(+), 36 deletions(-)

diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index b46e2dda..11d23e61 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -322,7 +322,7 @@ pub enum Schedule {
     Vectorizable,
     // This reduce can be re-associated. This may lower a sequential dependency
     // chain into a reduction tree.
-    Associative,
+    TightAssociative,
 }
 
 /*
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index 60a6ee24..3b7c81ed 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -848,8 +848,10 @@ impl PassManager {
                 Pass::InferSchedules => {
                     self.make_def_uses();
                     self.make_fork_join_maps();
+                    self.make_reduce_cycles();
                     let def_uses = self.def_uses.as_ref().unwrap();
                     let fork_join_maps = self.fork_join_maps.as_ref().unwrap();
+                    let reduce_cycles = self.reduce_cycles.as_ref().unwrap();
                     for idx in 0..self.module.functions.len() {
                         let constants_ref =
                             RefCell::new(std::mem::take(&mut self.module.constants));
@@ -863,11 +865,10 @@ impl PassManager {
                             &types_ref,
                             &def_uses[idx],
                         );
-                        infer_parallel_reduce(&mut editor, &fork_join_maps[idx]);
+                        infer_parallel_reduce(&mut editor, &fork_join_maps[idx], &reduce_cycles[idx]);
                         infer_parallel_fork(&mut editor, &fork_join_maps[idx]);
                         infer_vectorizable(&mut editor, &fork_join_maps[idx]);
-                        infer_associative(&mut editor);
-                        infer_tight_associative(&mut editor);
+                        infer_tight_associative(&mut editor, &reduce_cycles[idx]);
 
                         self.module.constants = constants_ref.take();
                         self.module.dynamic_constants = dynamic_constants_ref.take();
diff --git a/hercules_opt/src/schedule.rs b/hercules_opt/src/schedule.rs
index 9d73b11d..ff895b16 100644
--- a/hercules_opt/src/schedule.rs
+++ b/hercules_opt/src/schedule.rs
@@ -1,6 +1,6 @@
 extern crate hercules_ir;
 
-use std::collections::HashMap;
+use std::collections::{HashMap, HashSet};
 
 use self::hercules_ir::def_use::*;
 use self::hercules_ir::ir::*;
@@ -14,13 +14,9 @@ use crate::*;
 pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
     for id in editor.node_ids() {
         let func = editor.func();
-        let Node::Fork {
-            control: _,
-            factors: _,
-        } = func.nodes[id.idx()]
-        else {
+        if !func.nodes[id.idx()].is_fork() {
             continue;
-        };
+        }
         let join_id = fork_join_map[&id];
         let all_parallel_reduce = editor.get_users(join_id).all(|user| {
             func.schedules[user.idx()].contains(&Schedule::ParallelReduce)
@@ -35,14 +31,15 @@ pub fn infer_parallel_fork(editor: &mut FunctionEditor, fork_join_map: &HashMap<
 /*
  * Infer parallel reductions consisting of a simple cycle between a Reduce node
  * and a Write node, where indices of the Write are position indices using the
- * ThreadID nodes attached to the corresponding Fork. This procedure also adds
- * the ParallelReduce schedule to Reduce nodes reducing over a parallelized
- * Reduce, as long as the base Write node also has position indices of the
- * ThreadID of the outer fork. In other words, the complete Reduce chain is
- * annotated with ParallelReduce, as long as each ThreadID dimension appears in
- * the positional indexing of the original Write.
+ * ThreadID nodes attached to the corresponding Fork, and data of the Write is 
+ * not in the Reduce node's cycle. This procedure also adds the ParallelReduce
+ * schedule to Reduce nodes reducing over a parallelized Reduce, as long as the
+ * base Write node also has position indices of the ThreadID of the outer fork.
+ * In other words, the complete Reduce chain is annotated with ParallelReduce,
+ * as long as each ThreadID dimension appears in the positional indexing of the
+ * original Write.
  */
-pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) {
+pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) {
     for id in editor.node_ids() {
         let func = editor.func();
         if !func.nodes[id.idx()].is_reduce() {
@@ -73,10 +70,11 @@ pub fn infer_parallel_reduce(editor: &mut FunctionEditor, fork_join_map: &HashMa
         // Check for a Write-Reduce tight cycle.
         if let Node::Write {
             collect,
-            data: _,
+            data,
             indices,
         } = &func.nodes[chain_id.idx()]
             && *collect == last_reduce
+            && !reduce_cycles[&last_reduce].contains(data)
         {
             // If there is a Write-Reduce tight cycle, get the position indices.
             let positions = indices
@@ -146,21 +144,15 @@ pub fn infer_vectorizable(editor: &mut FunctionEditor, fork_join_map: &HashMap<N
 }
 
 /*
- * Infer associative reduction loops.
+ * Infer tight associative reduction loops. Exactly one of the associative
+ * operation's operands must be the Reduce node, and all other operands must
+ * not be in the Reduce node's cycle.
  */
-pub fn infer_associative(editor: &mut FunctionEditor) {
-    let is_binop_associative = |op| match op {
-        BinaryOperator::Add
-        | BinaryOperator::Or
-        | BinaryOperator::And
-        | BinaryOperator::Xor => true,
-        _ => false,
-    };
-    let is_intrinsic_associative = |intrinsic| match intrinsic {
-        Intrinsic::Max
-        | Intrinsic::Min => true,
-        _ => false,
-    };
+pub fn infer_tight_associative(editor: &mut FunctionEditor, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>) {
+    let is_binop_associative = |op| matches!(op,
+        BinaryOperator::Add | BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor);
+    let is_intrinsic_associative = |intrinsic| matches!(intrinsic, 
+        Intrinsic::Max | Intrinsic::Min);
 
     for id in editor.node_ids() {
         let func = editor.func();
@@ -170,11 +162,14 @@ pub fn infer_associative(editor: &mut FunctionEditor) {
             reduct,
         } = func.nodes[id.idx()]
             && (matches!(func.nodes[reduct.idx()], Node::Binary { left, right, op } 
-                if (left == id || right == id) && is_binop_associative(op)) || 
+                if ((left == id && !reduce_cycles[&id].contains(&right)) || 
+                    (right == id && !reduce_cycles[&id].contains(&left))) && 
+                    is_binop_associative(op)) || 
             matches!(&func.nodes[reduct.idx()], Node::IntrinsicCall { intrinsic, args }
-                if (args.contains(&id) && is_intrinsic_associative(*intrinsic))))
+                if (args.contains(&id) && is_intrinsic_associative(*intrinsic) && 
+                    args.iter().filter(|arg| **arg != id).all(|arg| !reduce_cycles[&id].contains(arg)))))
         {
-            editor.edit(|edit| edit.add_schedule(id, Schedule::Associative));
+            editor.edit(|edit| edit.add_schedule(id, Schedule::TightAssociative));
         }
     }
 }
-- 
GitLab


From 05ca328518e2201db0337dc78fb68b9b5706d512 Mon Sep 17 00:00:00 2001
From: prathi3 <prathi3@illinois.edu>
Date: Sun, 5 Jan 2025 20:50:26 -0600
Subject: [PATCH 3/3] Delete matmul_indented.jn

---
 juno_samples/matmul/src/matmul_indented.jn | 57 ----------------------
 1 file changed, 57 deletions(-)
 delete mode 100644 juno_samples/matmul/src/matmul_indented.jn

diff --git a/juno_samples/matmul/src/matmul_indented.jn b/juno_samples/matmul/src/matmul_indented.jn
deleted file mode 100644
index 12039f52..00000000
--- a/juno_samples/matmul/src/matmul_indented.jn
+++ /dev/null
@@ -1,57 +0,0 @@
-#[entry]
-fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
-  let res : i32[n, l];
-
-  @outer for i = 0 to n {
-    @middle for j = 0 to l {
-      @inner for k = 0 to m {
-        res[i, j] += a[i, k] * b[k, j];
-      }
-    }
-  }
-
-  @exit
-  return res;
-}
-
-/*
-#[entry]
-fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
-  let res : i32[n, l];
-  
-  for bi = 0 to n / 64 {
-    for bk = 0 to l / 64 {
-      let atile : i32[66, 64];
-      let btile : i32[65, 64];
-      let ctile : i32[64, 64];
-
-      for tile_idx = 0 to m / 64 {
-        for ti = 0 to 64 {
-          for tk = 0 to 64 {
-            atile[ti, tk] = a[bi * 64 + ti, tile_idx * 64 + tk];
-            btile[ti, tk] = b[tile_idx * 64 + ti, bk * 64 + tk];
-            ctile[ti, tk] = 0;
-          }
-        }
-        for ti = 0 to 64 {
-          for tk = 0 to 64 {
-            let c_acc = ctile[ti, tk];
-            for inner_idx = 0 to 64 {
-              c_acc += atile[ti, inner_idx] * btile[inner_idx, tk];
-            } 
-            ctile[ti, tk] = c_acc;
-          }
-        }
-      }
-
-      for ti = 0 to 64 {
-        for tk = 0 to 64 {
-          res[bi * 64 + ti, bk * 64 + tk] = ctile[ti, tk];
-        }
-      }
-    }
-  }
-
-  return res;
-}
-*/
\ No newline at end of file
-- 
GitLab