From 595b9fbfd028045db3090e279dd2dbe27a883907 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 11 Feb 2025 13:52:43 -0600
Subject: [PATCH 1/7] add interchange pass

---
 hercules_opt/src/fork_transforms.rs           | 56 ++++++++++++++++++-
 juno_samples/fork_join_tests/src/cpu.sch      |  2 +-
 .../fork_join_tests/src/fork_join_tests.jn    |  8 +--
 juno_samples/fork_join_tests/src/gpu.sch      |  4 +-
 juno_samples/fork_join_tests/src/main.rs      |  2 +-
 juno_scheduler/src/ir.rs                      | 15 ++---
 juno_scheduler/src/pm.rs                      | 34 ++++++++++-
 7 files changed, 103 insertions(+), 18 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 7c423892..33c4e68c 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1,5 +1,4 @@
 use std::collections::{HashMap, HashSet};
-use std::hash::Hash;
 use std::iter::zip;
 
 use bimap::BiMap;
@@ -1076,3 +1075,58 @@ pub fn fork_dim_merge(
 
     new_fork
 }
+
+/*
+ * Run fork interchange on all fork-joins that are mutable in an editor.
+ */
+pub fn fork_interchange_all_forks(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    first_dim: usize,
+    second_dim: usize,
+) {
+    for (fork, join) in fork_join_map {
+        if editor.is_mutable(*fork) {
+            fork_interchange(editor, *fork, *join, first_dim, second_dim);
+        }
+    }
+}
+
+fn fork_interchange(
+    editor: &mut FunctionEditor,
+    fork: NodeID,
+    join: NodeID,
+    first_dim: usize,
+    second_dim: usize,
+) {
+    // Check that every reduce on the join is parallel or tight associative.
+    let nodes = &editor.func().nodes;
+    let schedules = &editor.func().schedules;
+    if !editor
+        .get_users(join)
+        .filter(|id| nodes[id.idx()].is_reduce())
+        .all(|id| {
+            schedules[id.idx()].contains(&Schedule::ParallelReduce)
+                || schedules[id.idx()].contains(&Schedule::TightAssociative)
+        })
+    {
+        // If not, we can't necessarily do interchange.
+        return;
+    }
+
+    let Node::Fork {
+        control,
+        ref factors,
+    } = nodes[fork.idx()]
+    else {
+        panic!()
+    };
+    let mut factors = factors.clone();
+    factors.swap(first_dim, second_dim);
+    let new_fork = Node::Fork { control, factors };
+    editor.edit(|mut edit| {
+        let new_fork = edit.add_node(new_fork);
+        edit = edit.replace_all_uses(fork, new_fork)?;
+        edit.delete_node(fork)
+    });
+}
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index 38a38c2b..fe0a8802 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -50,7 +50,7 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-fork-tile[32, 0, true](test6@loop);
+fork-tile[32, 0, false](test6@loop);
 let out = fork-split(test6@loop);
 let out = outline(out.test6.fj1);
 cpu(out);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 886ab13b..31752ec3 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -10,11 +10,11 @@ fn test1(input : i32) -> i32[4, 4] {
 }
 
 #[entry]
-fn test2(input : i32) -> i32[4, 4] {
-  let arr : i32[4, 4];
+fn test2(input : i32) -> i32[4, 5] {
+  let arr : i32[4, 5];
   for i = 0 to 8 {
-    for j = 0 to 4 {
-      for k = 0 to 4 {
+    for k = 0 to 5 {
+      for j = 0 to 4 {
         arr[j, k] += input;
       }
     }
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 117cf37e..8e1cd687 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -42,7 +42,7 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
-fork-tile[32, 0, true](test6@loop);
+fork-tile[32, 0, false](test6@loop);
 let out = fork-split(test6@loop);
 let out = auto-outline(test6);
 gpu(out.test6);
@@ -69,5 +69,7 @@ phi-elim(*);
 dce(*);
 gcm(*);
 
+xdot[true](*);
+
 float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5);
 gcm(*);
diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs
index 1013a1f0..21ccd7c4 100644
--- a/juno_samples/fork_join_tests/src/main.rs
+++ b/juno_samples/fork_join_tests/src/main.rs
@@ -25,7 +25,7 @@ fn main() {
 
         let mut r = runner!(test2);
         let output = r.run(3).await;
-        let correct = vec![24i32; 16];
+        let correct = vec![24i32; 20];
         assert(&correct, output);
 
         let mut r = runner!(test3);
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 0ec6f52c..881c63a3 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -9,13 +9,14 @@ pub enum Pass {
     DCE,
     DeleteUncalled,
     FloatCollections,
+    ForkChunk,
+    ForkCoalesce,
+    ForkDimMerge,
+    ForkFissionBufferize,
     ForkGuardElim,
+    ForkInterchange,
     ForkSplit,
-    ForkCoalesce,
     Forkify,
-    ForkFissionBufferize,
-    ForkDimMerge,
-    ForkChunk,
     GCM,
     GVN,
     InferSchedules,
@@ -25,14 +26,14 @@ pub enum Pass {
     Outline,
     PhiElim,
     Predication,
-    SimplifyCFG,
     SLF,
     SROA,
+    Serialize,
+    SimplifyCFG,
     Unforkify,
-    WritePredication,
     Verify,
+    WritePredication,
     Xdot,
-    Serialize,
 }
 
 impl Pass {
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 20825c54..8581ee23 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1635,6 +1635,36 @@ fn run_pass(
                 fields: new_fork_joins,
             };
         }
+        Pass::ForkInterchange => {
+            assert_eq!(args.len(), 2);
+            let Some(Value::Integer { val: first_dim }) = args.get(0) else {
+                return Err(SchedulerError::PassError {
+                    pass: "forkInterchange".to_string(),
+                    error: "expected integer argument".to_string(),
+                });
+            };
+            let Some(Value::Integer { val: second_dim }) = args.get(1) else {
+                return Err(SchedulerError::PassError {
+                    pass: "forkInterchange".to_string(),
+                    error: "expected integer argument".to_string(),
+                });
+            };
+
+            pm.make_fork_join_maps();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            for (func, fork_join_map) in build_selection(pm, selection, false)
+                .into_iter()
+                .zip(fork_join_maps.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                fork_interchange_all_forks(&mut func, fork_join_map, *first_dim, *second_dim);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::Forkify => {
             assert!(args.is_empty());
             loop {
@@ -2042,7 +2072,7 @@ fn run_pass(
                 });
             };
 
-            assert_eq!(*guarded_flag, true);
+            assert!(!*guarded_flag);
             pm.make_fork_join_maps();
             let fork_join_maps = pm.fork_join_maps.take().unwrap();
             for (func, fork_join_map) in build_selection(pm, selection, false)
@@ -2059,8 +2089,6 @@ fn run_pass(
             pm.clear_analyses();
         }
         Pass::ForkFissionBufferize => {
-            pm.make_fork_join_maps();
-
             assert_eq!(args.len(), 2);
             let Some(Value::Label {
                 labels: fork_labels,
-- 
GitLab


From 4bc008359a15444c63e0c2bd3e58c91e44e85d44 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 11 Feb 2025 14:09:12 -0600
Subject: [PATCH 2/7] setup to do reduce slf

---
 juno_samples/fork_join_tests/src/fork_join_tests.jn | 6 +++---
 juno_samples/fork_join_tests/src/gpu.sch            | 9 ++++++---
 2 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 31752ec3..8f569cfb 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -12,9 +12,9 @@ fn test1(input : i32) -> i32[4, 4] {
 #[entry]
 fn test2(input : i32) -> i32[4, 5] {
   let arr : i32[4, 5];
-  for i = 0 to 8 {
-    for k = 0 to 5 {
-      for j = 0 to 4 {
+  @loop1 for i = 0 to 8 {
+    @loop2 for k = 0 to 5 {
+      @loop3 for j = 0 to 4 {
         arr[j, k] += input;
       }
     }
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 8e1cd687..123313a0 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -31,7 +31,10 @@ dce(*);
 fixpoint panic after 20 {
   forkify(*);
   fork-guard-elim(*);
-  fork-coalesce(*);
+}
+
+fixpoint panic after 20 {
+  fork-coalesce(auto.test1, auto.test3, auto.test4, auto.test5, auto.test7, auto.test8);
 }
 
 gvn(*);
@@ -42,6 +45,8 @@ fixpoint panic after 20 {
   infer-schedules(*);
 }
 
+fork-coalesce(auto.test2@loop2);
+
 fork-tile[32, 0, false](test6@loop);
 let out = fork-split(test6@loop);
 let out = auto-outline(test6);
@@ -69,7 +74,5 @@ phi-elim(*);
 dce(*);
 gcm(*);
 
-xdot[true](*);
-
 float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5);
 gcm(*);
-- 
GitLab


From 0e36e8b01951fa9ea1c13ca741fb6b10b1919d45 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 11 Feb 2025 14:47:42 -0600
Subject: [PATCH 3/7] Figure out how to do reduce SLF

---
 hercules_opt/src/slf.rs | 63 +++++++++++++++++++++++++++++++++++++++++
 1 file changed, 63 insertions(+)

diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs
index 024a442e..f8c29e7b 100644
--- a/hercules_opt/src/slf.rs
+++ b/hercules_opt/src/slf.rs
@@ -229,3 +229,66 @@ pub fn array_slf(
         }
     }
 }
+
+/*
+ * Top level function to run reduce store-to-load forwarding on a function.
+ * There are two variants of reduce SLF. One optimizes parallel reductions and
+ * the other optimizes scalar reductions. This pass just runs one after the
+ * other - it's up to the user to potentially wrap this pass in a fixpoint.
+ *
+ * The parallel variant looks for reductions on collections with the following
+ * form:
+ *
+ * 1. A write `reduct` use.
+ * 2. A single read user in the reduction cycle.
+ * 3. The write use and read user have identical indices.
+ * 4. The indices set directly refers to at least every thread ID produced by
+ *    the fork corresponding to the reduce.
+ *
+ * Condition #4 roughly corresponds to the same condition needed to infer the
+ * ParallelReduce schedule - however, in this scenario, it can't be inferred
+ * because the data being written is in the reduce cycle, since the data being
+ * written is derived from the read user of the reduce. However, since the
+ * indices changes each iteration, the read could equivalently read from the
+ * `init` input of the reduce, rather than the reduce itself. This optimization
+ * replaces the use of the reduce in the read with the `init` input of the
+ * reduce, nominally so that ParallelReduce can get inferred.
+ *
+ * The scalar variant looks for reductions on collections with the following
+ * form:
+ *
+ * 1. A write `reduct` use.
+ * 2. A single read user in the reduction cycle.
+ * 3. The write use and read user have identical indices.
+ * 4. The indices set doesn't reference any nodes in the fork-join of the reduce
+ *    (doesn't depend on the thread IDs of the fork-join).
+ *
+ * Instead of repeatedly reading / writing the same collection item each
+ * iteration, the reduction can reduce over the scalar value directly, and do a
+ * single write into the collection after the fork-join (note technically, the
+ * "scalar" may itself be a collection in the case of nested collections, but
+ * this optimization most often handles scalars).
+ */
+pub fn reduce_slf(
+    editor: &mut FunctionEditor,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) {
+    // Helper to get write use and read user of reduce with identical indices.
+    let read_write_helper = |reduce: NodeID, editor: &FunctionEditor| -> Option<(NodeID, NodeID)> {
+        let nodes = &editor.func().nodes;
+        let reduct = nodes[reduce.idx()].try_reduce().unwrap().2;
+        if !nodes[reduct.idx()].is_write() {
+            return None;
+        }
+        let mut users = editor
+            .get_users(reduce)
+            .filter(|id| reduce_cycles[&reduce].contains(id));
+        let read = users.next()?;
+        if users.next().is_some() || !nodes[read.idx()].is_read() {
+            return None;
+        }
+        Some((read, reduct))
+    };
+}
-- 
GitLab


From 6210141d55decb42df595ec87c45ce0f5d4f8f0e Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 11 Feb 2025 16:05:32 -0600
Subject: [PATCH 4/7] Parallel reduce slf

---
 hercules_opt/src/fork_transforms.rs      | 34 +++++++++++++
 hercules_opt/src/slf.rs                  | 34 ++++++++++++-
 hercules_opt/src/utils.rs                | 62 ++++++++++++++++++++++--
 juno_samples/fork_join_tests/src/gpu.sch |  5 ++
 juno_scheduler/src/compile.rs            |  3 +-
 juno_scheduler/src/ir.rs                 |  2 +
 juno_scheduler/src/pm.rs                 | 23 +++++++++
 7 files changed, 155 insertions(+), 8 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 33c4e68c..fd6747d7 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -1121,10 +1121,44 @@ fn fork_interchange(
     else {
         panic!()
     };
+    let fix_tids: Vec<(NodeID, Node)> = editor
+        .get_users(fork)
+        .filter_map(|id| {
+            nodes[id.idx()]
+                .try_thread_id()
+                .map(|(_, dim)| {
+                    if dim == first_dim {
+                        Some((
+                            id,
+                            Node::ThreadID {
+                                control: fork,
+                                dimension: second_dim,
+                            },
+                        ))
+                    } else if dim == second_dim {
+                        Some((
+                            id,
+                            Node::ThreadID {
+                                control: fork,
+                                dimension: first_dim,
+                            },
+                        ))
+                    } else {
+                        None
+                    }
+                })
+                .flatten()
+        })
+        .collect();
     let mut factors = factors.clone();
     factors.swap(first_dim, second_dim);
     let new_fork = Node::Fork { control, factors };
     editor.edit(|mut edit| {
+        for (old_id, new_tid) in fix_tids {
+            let new_id = edit.add_node(new_tid);
+            edit = edit.replace_all_uses(old_id, new_id)?;
+            edit = edit.delete_node(old_id)?;
+        }
         let new_fork = edit.add_node(new_fork);
         edit = edit.replace_all_uses(fork, new_fork)?;
         edit.delete_node(fork)
diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs
index f8c29e7b..b10f58ce 100644
--- a/hercules_opt/src/slf.rs
+++ b/hercules_opt/src/slf.rs
@@ -1,4 +1,5 @@
 use std::collections::{BTreeMap, HashMap, HashSet};
+use std::iter::once;
 
 use hercules_ir::*;
 
@@ -284,11 +285,40 @@ pub fn reduce_slf(
         }
         let mut users = editor
             .get_users(reduce)
-            .filter(|id| reduce_cycles[&reduce].contains(id));
+            .filter(|id| reduce_cycles[&reduce].contains(id) && *id != reduct);
         let read = users.next()?;
-        if users.next().is_some() || !nodes[read.idx()].is_read() {
+        if users.next().is_some()
+            || !nodes[read.idx()].is_read()
+            || nodes[read.idx()].try_read().unwrap().1 != nodes[reduct.idx()].try_write().unwrap().2
+        {
             return None;
         }
         Some((read, reduct))
     };
+
+    // First, optimize parallel reductions.
+    for (fork, join) in fork_join_map {
+        let reduces: Vec<_> = editor
+            .get_users(*join)
+            .filter(|id| editor.func().nodes[id.idx()].is_reduce())
+            .collect();
+        for reduce in reduces {
+            let Some((read, _)) = read_write_helper(reduce, editor) else {
+                continue;
+            };
+            let indices = editor.func().nodes[read.idx()].try_read().unwrap().1;
+            if indices_parallel_over_forks(editor, indices, once(*fork)) {
+                let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
+                let new_read = Node::Read {
+                    collect: init,
+                    indices: indices.to_vec().into_boxed_slice(),
+                };
+                editor.edit(|mut edit| {
+                    let new_read = edit.add_node(new_read);
+                    edit = edit.replace_all_uses(read, new_read)?;
+                    edit.delete_node(read)
+                });
+            }
+        }
+    }
 }
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index 022da629..1482a7f0 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -10,7 +10,7 @@ use crate::*;
  * Substitute all uses of dynamic constants in a type that are keys in the substs map with the
  * dynamic constant value for that key. Return the substituted version of the type, once memoized.
  */
-pub(crate) fn substitute_dynamic_constants_in_type(
+pub fn substitute_dynamic_constants_in_type(
     substs: &HashMap<DynamicConstantID, DynamicConstantID>,
     ty: TypeID,
     edit: &mut FunctionEdit,
@@ -62,7 +62,7 @@ pub(crate) fn substitute_dynamic_constants_in_type(
  * version of dc, once memoized. Takes a mutable edit instead of an editor since this may create
  * new dynamic constants, which can only be done inside an edit.
  */
-pub(crate) fn substitute_dynamic_constants(
+pub fn substitute_dynamic_constants(
     substs: &HashMap<DynamicConstantID, DynamicConstantID>,
     dc: DynamicConstantID,
     edit: &mut FunctionEdit,
@@ -154,7 +154,7 @@ pub(crate) fn substitute_dynamic_constants(
  * Substitute all uses of the dynamic constants specified by the subst map in a constant. Return
  * the substituted version of the constant, once memozied.
  */
-pub(crate) fn substitute_dynamic_constants_in_constant(
+pub fn substitute_dynamic_constants_in_constant(
     substs: &HashMap<DynamicConstantID, DynamicConstantID>,
     cons: ConstantID,
     edit: &mut FunctionEdit,
@@ -198,7 +198,7 @@ pub(crate) fn substitute_dynamic_constants_in_constant(
 /*
  * Substitute all uses of the dynamic constants specified by the subst map in a node.
  */
-pub(crate) fn substitute_dynamic_constants_in_node(
+pub fn substitute_dynamic_constants_in_node(
     substs: &HashMap<DynamicConstantID, DynamicConstantID>,
     node: &mut Node,
     edit: &mut FunctionEdit,
@@ -275,7 +275,7 @@ pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
     new_return
 }
 
-pub(crate) fn contains_between_control_flow(func: &Function) -> bool {
+pub fn contains_between_control_flow(func: &Function) -> bool {
     let num_control = func.nodes.iter().filter(|node| node.is_control()).count();
     assert!(num_control >= 2, "PANIC: A Hercules function must have at least two control nodes: a start node and at least one return node.");
     num_control > 2
@@ -477,3 +477,55 @@ pub fn materialize_simple_einsum_expr(
         _ => panic!(),
     }
 }
+
+/*
+ * Checks if a set of indices is fully parallel over a set of forks - that is,
+ * every thread ID from every fork appears at least once in positions in the
+ * indices set.
+ */
+pub fn indices_parallel_over_forks<I>(
+    editor: &FunctionEditor,
+    indices: &[Index],
+    mut forks: I,
+) -> bool
+where
+    I: Iterator<Item = NodeID>,
+{
+    // Get the node ids in the indices set.
+    let positions = indices
+        .iter()
+        .filter_map(|index| {
+            if let Index::Position(indices) = index {
+                Some(indices)
+            } else {
+                None
+            }
+        })
+        .flat_map(|pos| pos.iter());
+
+    // Get the forks corresponding to uses of bare thread ids.
+    let nodes = &editor.func().nodes;
+    let fork_thread_id_pairs = positions.filter_map(|id| {
+        if let Node::ThreadID { control, dimension } = nodes[id.idx()] {
+            Some((control, dimension))
+        } else {
+            None
+        }
+    });
+    let mut rep_forks = HashMap::<NodeID, Vec<usize>>::new();
+    for (fork, dim) in fork_thread_id_pairs {
+        rep_forks.entry(fork).or_default().push(dim);
+    }
+
+    // If each fork the query is over is represented and each of its dimensions
+    // is represented, then the indices are parallel over the forks.
+    forks.all(|fork| {
+        let Some(mut rep_dims) = rep_forks.remove(&fork) else {
+            return false;
+        };
+
+        rep_dims.sort();
+        rep_dims.dedup();
+        nodes[fork.idx()].try_fork().unwrap().1.len() == rep_dims.len()
+    })
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 123313a0..8f4ec9ad 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -46,6 +46,11 @@ fixpoint panic after 20 {
 }
 
 fork-coalesce(auto.test2@loop2);
+dce(auto.test2);
+reduce-slf(auto.test2);
+slf(auto.test2);
+infer-schedules(auto.test2);
+fork-interchange[0, 1](auto.test2);
 
 fork-tile[32, 0, false](test6@loop);
 let out = fork-split(test6@loop);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 0ca36380..4ea8dfb5 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -8,7 +8,6 @@ use hercules_ir::ir::{Device, Schedule};
 
 use lrlex::DefaultLexerTypes;
 use lrpar::NonStreamingLexer;
-use lrpar::Span;
 
 use std::fmt;
 use std::str::FromStr;
@@ -112,11 +111,13 @@ impl FromStr for Appliable {
             }
             "fork-fission-bufferize" => Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)),
             "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
+            "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)),
             "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
             "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
             "outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
             "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
             "predication" => Ok(Appliable::Pass(ir::Pass::Predication)),
+            "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)),
             "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)),
             "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)),
             "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 881c63a3..0ecac39a 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -26,6 +26,7 @@ pub enum Pass {
     Outline,
     PhiElim,
     Predication,
+    ReduceSLF,
     SLF,
     SROA,
     Serialize,
@@ -42,6 +43,7 @@ impl Pass {
             Pass::Xdot => 1,
             Pass::ForkChunk => 3,
             Pass::ForkFissionBufferize => 2,
+            Pass::ForkInterchange => 2,
             _ => 0,
         }
     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 8581ee23..e9c681cd 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1966,6 +1966,29 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::ReduceSLF => {
+            assert!(args.is_empty());
+            pm.make_fork_join_maps();
+            pm.make_reduce_cycles();
+            pm.make_nodes_in_fork_joins();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let reduce_cycles = pm.reduce_cycles.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+
+            for (((func, fork_join_map), reduce_cycles), nodes_in_fork_joins) in
+                build_selection(pm, selection, false)
+                    .into_iter()
+                    .zip(fork_join_maps.iter())
+                    .zip(reduce_cycles.iter())
+                    .zip(nodes_in_fork_joins.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                reduce_slf(&mut func, fork_join_map, reduce_cycles, nodes_in_fork_joins);
+                changed |= func.modified();
+            }
+        }
         Pass::SLF => {
             assert!(args.is_empty());
             pm.make_reverse_postorders();
-- 
GitLab


From 33801197254c852ba8a0df65ca152f4977a97878 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 11 Feb 2025 16:28:59 -0600
Subject: [PATCH 5/7] Scalar reduce slf

---
 hercules_opt/src/slf.rs   | 55 ++++++++++++++++++++++++++++++++++++++-
 hercules_opt/src/utils.rs | 33 +++++++++++++----------
 2 files changed, 73 insertions(+), 15 deletions(-)

diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs
index b10f58ce..3133972f 100644
--- a/hercules_opt/src/slf.rs
+++ b/hercules_opt/src/slf.rs
@@ -277,10 +277,12 @@ pub fn reduce_slf(
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
 ) {
     // Helper to get write use and read user of reduce with identical indices.
+    // This checks conditions 1, 2, and 3 of both parallel and scalar reduction
+    // optimization.
     let read_write_helper = |reduce: NodeID, editor: &FunctionEditor| -> Option<(NodeID, NodeID)> {
         let nodes = &editor.func().nodes;
         let reduct = nodes[reduce.idx()].try_reduce().unwrap().2;
-        if !nodes[reduct.idx()].is_write() {
+        if !nodes[reduct.idx()].is_write() || nodes[reduct.idx()].try_write().unwrap().0 != reduce {
             return None;
         }
         let mut users = editor
@@ -306,6 +308,8 @@ pub fn reduce_slf(
             let Some((read, _)) = read_write_helper(reduce, editor) else {
                 continue;
             };
+
+            // Check condition 4 of parallel reduction optimization.
             let indices = editor.func().nodes[read.idx()].try_read().unwrap().1;
             if indices_parallel_over_forks(editor, indices, once(*fork)) {
                 let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
@@ -321,4 +325,53 @@ pub fn reduce_slf(
             }
         }
     }
+
+    // Second, optimize scalar reductions.
+    for (fork, join) in fork_join_map {
+        let reduces: Vec<_> = editor
+            .get_users(*join)
+            .filter(|id| editor.func().nodes[id.idx()].is_reduce())
+            .collect();
+        for reduce in reduces {
+            let Some((read, write)) = read_write_helper(reduce, editor) else {
+                continue;
+            };
+
+            // Check condition 4 of scalar reduction optimization.
+            let indices = editor.func().nodes[read.idx()].try_read().unwrap().1;
+            if node_indices(indices).all(|id| !nodes_in_fork_joins[fork].contains(&id)) {
+                let indices = indices.to_vec().into_boxed_slice();
+                let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
+                let data = editor.func().nodes[write.idx()].try_write().unwrap().1;
+                let init_read = Node::Read {
+                    collect: init,
+                    indices: indices.clone(),
+                };
+                editor.edit(|mut edit| {
+                    let init_read = edit.add_node(init_read);
+                    let new_reduce = Node::Reduce {
+                        control: *join,
+                        init: init_read,
+                        reduct: data,
+                    };
+                    let new_reduce = edit.add_node(new_reduce);
+                    let post_write = Node::Write {
+                        collect: init,
+                        data: new_reduce,
+                        indices,
+                    };
+                    let post_write = edit.add_node(post_write);
+
+                    edit = edit.replace_all_uses(read, new_reduce)?;
+                    edit = edit.replace_all_uses(reduce, post_write)?;
+
+                    edit = edit.delete_node(read)?;
+                    edit = edit.delete_node(reduce)?;
+                    edit = edit.delete_node(write)?;
+
+                    Ok(edit)
+                });
+            }
+        }
+    }
 }
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index 1482a7f0..c5e0d934 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -478,6 +478,23 @@ pub fn materialize_simple_einsum_expr(
     }
 }
 
+/*
+ * Get the node IDs referred to in position indices in a indices set.
+ */
+pub fn node_indices(indices: &[Index]) -> impl Iterator<Item = NodeID> + '_ {
+    indices
+        .iter()
+        .filter_map(|index| {
+            if let Index::Position(indices) = index {
+                Some(indices)
+            } else {
+                None
+            }
+        })
+        .flat_map(|pos| pos.iter())
+        .map(|id| *id)
+}
+
 /*
  * Checks if a set of indices is fully parallel over a set of forks - that is,
  * every thread ID from every fork appears at least once in positions in the
@@ -491,21 +508,9 @@ pub fn indices_parallel_over_forks<I>(
 where
     I: Iterator<Item = NodeID>,
 {
-    // Get the node ids in the indices set.
-    let positions = indices
-        .iter()
-        .filter_map(|index| {
-            if let Index::Position(indices) = index {
-                Some(indices)
-            } else {
-                None
-            }
-        })
-        .flat_map(|pos| pos.iter());
-
-    // Get the forks corresponding to uses of bare thread ids.
+    // Get the forks corresponding to position uses of bare thread ids.
     let nodes = &editor.func().nodes;
-    let fork_thread_id_pairs = positions.filter_map(|id| {
+    let fork_thread_id_pairs = node_indices(indices).filter_map(|id| {
         if let Node::ThreadID { control, dimension } = nodes[id.idx()] {
             Some((control, dimension))
         } else {
-- 
GitLab


From cc5f607f047ffb049e3c1eef36e0f75bd03e0e4d Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 11 Feb 2025 16:43:11 -0600
Subject: [PATCH 6/7] add extra case to parallel reduce slf

---
 hercules_opt/src/slf.rs | 44 +++++++++++++++++++++++++++--------------
 1 file changed, 29 insertions(+), 15 deletions(-)

diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs
index 3133972f..aeacec76 100644
--- a/hercules_opt/src/slf.rs
+++ b/hercules_opt/src/slf.rs
@@ -255,6 +255,10 @@ pub fn array_slf(
  * replaces the use of the reduce in the read with the `init` input of the
  * reduce, nominally so that ParallelReduce can get inferred.
  *
+ * If a reduction has the ParallelReduce schedule on it, then any read of it
+ * directly can be much more easily optimized to refer to the `init` input
+ * rather than the reduce itself.
+ *
  * The scalar variant looks for reductions on collections with the following
  * form:
  *
@@ -305,22 +309,32 @@ pub fn reduce_slf(
             .filter(|id| editor.func().nodes[id.idx()].is_reduce())
             .collect();
         for reduce in reduces {
-            let Some((read, _)) = read_write_helper(reduce, editor) else {
-                continue;
-            };
-
-            // Check condition 4 of parallel reduction optimization.
-            let indices = editor.func().nodes[read.idx()].try_read().unwrap().1;
-            if indices_parallel_over_forks(editor, indices, once(*fork)) {
+            if let Some((read, _)) = read_write_helper(reduce, editor) {
+                // Check condition 4 of parallel reduction optimization.
+                let indices = editor.func().nodes[read.idx()].try_read().unwrap().1;
+                if indices_parallel_over_forks(editor, indices, once(*fork)) {
+                    let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
+                    let new_read = Node::Read {
+                        collect: init,
+                        indices: indices.to_vec().into_boxed_slice(),
+                    };
+                    editor.edit(|mut edit| {
+                        let new_read = edit.add_node(new_read);
+                        edit = edit.replace_all_uses(read, new_read)?;
+                        edit.delete_node(read)
+                    });
+                }
+            } else if editor.func().schedules[reduce.idx()].contains(&Schedule::ParallelReduce) {
                 let init = editor.func().nodes[reduce.idx()].try_reduce().unwrap().1;
-                let new_read = Node::Read {
-                    collect: init,
-                    indices: indices.to_vec().into_boxed_slice(),
-                };
-                editor.edit(|mut edit| {
-                    let new_read = edit.add_node(new_read);
-                    edit = edit.replace_all_uses(read, new_read)?;
-                    edit.delete_node(read)
+                let read_users: Vec<_> = editor
+                    .get_users(reduce)
+                    .filter(|id| {
+                        editor.func().nodes[id.idx()].is_read()
+                            && nodes_in_fork_joins[fork].contains(&id)
+                    })
+                    .collect();
+                editor.edit(|edit| {
+                    edit.replace_all_uses_where(reduce, init, |id| read_users.contains(id))
                 });
             }
         }
-- 
GitLab


From ea6e12b43761fde64d31aff66c8ca22aa0a45dfc Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 11 Feb 2025 16:46:55 -0600
Subject: [PATCH 7/7] transform juno matmul sample to hercules matmul sample

---
 hercules_opt/src/slf.rs         |  1 +
 juno_samples/matmul/src/gpu.sch | 16 +++++++++++++++-
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs
index aeacec76..1d70676a 100644
--- a/hercules_opt/src/slf.rs
+++ b/hercules_opt/src/slf.rs
@@ -112,6 +112,7 @@ pub fn slf(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, typing:
 
                 value
             }
+            Node::Constant { id: _ } | Node::Undef { ty: _ } => SLFLattice::top(),
             _ => SLFLattice::bottom(),
         }
     });
diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index 3d3f919c..e6eb3641 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -12,7 +12,21 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-infer-schedules(*);
+fixpoint panic after 20 {
+  forkify(out.matmul);
+  fork-guard-elim(out.matmul);
+}
+gvn(out.matmul);
+phi-elim(out.matmul);
+dce(out.matmul);
+
+fixpoint panic after 20 {
+  reduce-slf(out.matmul);
+  slf(out.matmul);
+  infer-schedules(out.matmul);
+}
+fork-coalesce(out.matmul);
+dce(out.matmul);
 
 gcm(*);
 float-collections(*);
-- 
GitLab