From d3763dc2de4d0d9d1f12059391374544a49a4e95 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 29 Jan 2025 17:09:03 -0600
Subject: [PATCH 1/8] Skeleton

---
 hercules_ir/src/einsum.rs | 86 +++++++++++++++++++++++++++++++++++++++
 hercules_ir/src/lib.rs    |  2 +
 2 files changed, 88 insertions(+)
 create mode 100644 hercules_ir/src/einsum.rs

diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
new file mode 100644
index 00000000..7c7d1b58
--- /dev/null
+++ b/hercules_ir/src/einsum.rs
@@ -0,0 +1,86 @@
+use std::collections::HashMap;
+
+use crate::*;
+
+/*
+ * Math expressions are stored as a simple tree.
+ */
+#[derive(Debug, Clone)]
+pub struct ForkDimension(pub usize, pub DynamicConstantID);
+
+#[derive(Debug, Clone)]
+pub enum MathExpr {
+    // Zero constant of a specific type.
+    Zero(TypeID),
+    // One constant of a specific type.
+    One(TypeID),
+    // Opaque value corresponding to a particular node in the original IR.
+    OpaqueNode(NodeID),
+
+    // Thread ID from the fork corresponding to the reduce being expressed.
+    // Thread IDs from outside forks are considered OpaqueNodes.
+    ThreadID(ForkDimension),
+
+    // Sum reduction over a dimension of a fork.
+    SumReduction(MathID, ForkDimension),
+    // Comprehend a scalar expression into an array over fork dimensions.
+    Comprehension(MathID, Box<[ForkDimension]>),
+
+    // Read from an array.
+    Read(MathID, Box<[MathID]>),
+
+    // Math ops.
+    Add(MathID, MathID),
+    Sub(MathID, MathID),
+    Mul(MathID, MathID),
+    Div(MathID, MathID),
+}
+
+pub type MathEnv = Vec<MathExpr>;
+
+define_id_type!(MathID);
+
+/*
+ * Top level function to run "einsum" analysis on fork-joins. This is a terrible
+ * name for this analysis, since it's actually more general than identifying
+ * einsums, but einsum syntax has a passing resemblance to the idea of this
+ * analysis and it's what we keep calling it, so we're doomed to this bad name.
+ * The idea of this analysis is to convert some fork-joins into pure math
+ * expressions that we can rewrite into intrinsic functions for higher level
+ * operators like matmul. Specifically, this function returns a map from each
+ * reduce node to a math expression.
+ */
+pub fn einsum(
+    function: &Function,
+    types: &Vec<Type>,
+    def_use: &ImmutableDefUseMap,
+    typing: &Vec<TypeID>,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
+) -> (MathEnv, HashMap<NodeID, MathID>) {
+    let mut env = vec![];
+    let mut result = HashMap::new();
+
+    // Iterate fork-joins bottom-up, since we need to compute the math
+    // expressions of inner reduces before getting to outer reduces. Since fork-
+    // joins are strictly nested, we can literally iterate entries of
+    // `fork_join_nest` in decreasing order of nesting size to accomplish this.
+    let mut nests: Vec<_> = fork_join_nest.into_iter().collect();
+    nests.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
+    for fork in nests.into_iter().map(|(id, _)| *id) {
+        // Check that the fork-join has no control flow.
+        let join = fork_join_map[&fork];
+        if function.nodes[join.idx()].try_join().unwrap() != fork {
+            continue;
+        }
+
+        // Compute a math expression for each reduce node in the fork-join.
+        let reduces = def_use
+            .get_users(join)
+            .into_iter()
+            .filter(|id| function.nodes[id.idx()].is_reduce());
+        for reduce in reduces {}
+    }
+
+    (env, result)
+}
diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs
index 85dc277f..adfe030c 100644
--- a/hercules_ir/src/lib.rs
+++ b/hercules_ir/src/lib.rs
@@ -14,6 +14,7 @@ pub mod def_use;
 pub mod device;
 pub mod dom;
 pub mod dot;
+pub mod einsum;
 pub mod fork_join_analysis;
 pub mod ir;
 pub mod loops;
@@ -30,6 +31,7 @@ pub use crate::def_use::*;
 pub use crate::device::*;
 pub use crate::dom::*;
 pub use crate::dot::*;
+pub use crate::einsum::*;
 pub use crate::fork_join_analysis::*;
 pub use crate::ir::*;
 pub use crate::loops::*;
-- 
GitLab


From 869dfa14ba1db4b7f506fcf76430eaa2758d2d41 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 29 Jan 2025 21:48:12 -0600
Subject: [PATCH 2/8] Setup of accepted pattersn

---
 hercules_ir/src/einsum.rs           | 118 ++++++++++++++++++++++++++--
 hercules_samples/matmul/src/cpu.sch |   3 +
 2 files changed, 114 insertions(+), 7 deletions(-)

diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index 7c7d1b58..58bb73fb 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -1,14 +1,14 @@
-use std::collections::HashMap;
+use std::collections::{HashMap, HashSet};
 
 use crate::*;
 
 /*
  * Math expressions are stored as a simple tree.
  */
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub struct ForkDimension(pub usize, pub DynamicConstantID);
 
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub enum MathExpr {
     // Zero constant of a specific type.
     Zero(TypeID),
@@ -22,7 +22,7 @@ pub enum MathExpr {
     ThreadID(ForkDimension),
 
     // Sum reduction over a dimension of a fork.
-    SumReduction(MathID, ForkDimension),
+    SumReduction(MathID, Box<[ForkDimension]>),
     // Comprehend a scalar expression into an array over fork dimensions.
     Comprehension(MathID, Box<[ForkDimension]>),
 
@@ -34,6 +34,7 @@ pub enum MathExpr {
     Sub(MathID, MathID),
     Mul(MathID, MathID),
     Div(MathID, MathID),
+    Rem(MathID, MathID),
 }
 
 pub type MathEnv = Vec<MathExpr>;
@@ -53,12 +54,15 @@ define_id_type!(MathID);
 pub fn einsum(
     function: &Function,
     types: &Vec<Type>,
+    constants: &Vec<Constant>,
     def_use: &ImmutableDefUseMap,
     typing: &Vec<TypeID>,
     fork_join_map: &HashMap<NodeID, NodeID>,
     fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
+    data_nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
 ) -> (MathEnv, HashMap<NodeID, MathID>) {
     let mut env = vec![];
+    let mut rev_env = HashMap::new();
     let mut result = HashMap::new();
 
     // Iterate fork-joins bottom-up, since we need to compute the math
@@ -73,14 +77,114 @@ pub fn einsum(
         if function.nodes[join.idx()].try_join().unwrap() != fork {
             continue;
         }
+        let Node::Fork {
+            control: _,
+            ref factors,
+        } = function.nodes[fork.idx()]
+        else {
+            panic!()
+        };
 
-        // Compute a math expression for each reduce node in the fork-join.
+        // Compute a math expression for each reduce node in the fork-join with
+        // appropriate schedules.
         let reduces = def_use
             .get_users(join)
             .into_iter()
-            .filter(|id| function.nodes[id.idx()].is_reduce());
-        for reduce in reduces {}
+            .filter_map(|id| function.nodes[id.idx()].try_reduce().map(|v| (*id, v)));
+        for (reduce, (_, init, reduct)) in reduces {
+            // The reduce defines an array where each fork dimension corresponds
+            // to one array dimension.
+            if function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
+                && let Node::Write {
+                    collect,
+                    data,
+                    ref indices,
+                } = function.nodes[reduct.idx()]
+                && collect == reduce
+                && indices.len() == 1
+                && let Some(indices) = indices[0].try_position()
+                && let Some(dimension_bounds) = indices
+                    .into_iter()
+                    .map(|id| {
+                        function.nodes[id.idx()]
+                            .try_thread_id()
+                            .filter(|(tid_fork, _)| *tid_fork == fork)
+                            .map(|(_, dim)| dim)
+                    })
+                    .collect::<Option<Vec<usize>>>()
+            {
+                let data_expr = compute_einsum_expression(data);
+                let reduce_expr = MathExpr::Comprehension(
+                    data_expr,
+                    dimension_bounds
+                        .into_iter()
+                        .map(|dim| ForkDimension(dim, factors[dim]))
+                        .collect(),
+                );
+                result.insert(
+                    reduce,
+                    intern_math_expr(reduce_expr, &mut env, &mut rev_env),
+                );
+            }
+            // The reduce defines a sum reduction over a set of fork dimensions.
+            else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
+                && let Node::Binary {
+                    op: BinaryOperator::Add,
+                    left,
+                    right,
+                } = function.nodes[reduct.idx()]
+                && (left == reduce || right == reduce)
+            {
+                let data_expr =
+                    compute_einsum_expression(if left == reduce { right } else { left });
+                let reduce_expr = MathExpr::SumReduction(
+                    data_expr,
+                    factors
+                        .into_iter()
+                        .enumerate()
+                        .map(|(dim, factor)| ForkDimension(dim, *factor))
+                        .collect(),
+                );
+                if function.nodes[init.idx()]
+                    .try_constant()
+                    .map(|id| constants[id.idx()].is_zero())
+                    .unwrap_or(false)
+                {
+                    // If the initializer is zero, the expression is just the
+                    // sum reduction.
+                    result.insert(
+                        reduce,
+                        intern_math_expr(reduce_expr, &mut env, &mut rev_env),
+                    );
+                } else {
+                    // If the initializer is not zero, we need to add it.
+                    let reduce_expr_id = intern_math_expr(reduce_expr, &mut env, &mut rev_env);
+                    let init_expr = MathExpr::OpaqueNode(init);
+                    let init_expr_id = intern_math_expr(init_expr, &mut env, &mut rev_env);
+                    let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id);
+                    result.insert(reduce, intern_math_expr(add_expr, &mut env, &mut rev_env));
+                }
+            }
+        }
     }
 
     (env, result)
 }
+
+fn compute_einsum_expression(id: NodeID) -> MathID {
+    todo!()
+}
+
+fn intern_math_expr(
+    expr: MathExpr,
+    env: &mut MathEnv,
+    rev_env: &mut HashMap<MathExpr, MathID>,
+) -> MathID {
+    if let Some(id) = rev_env.get(&expr) {
+        *id
+    } else {
+        let id = MathID::new(env.len());
+        rev_env.insert(expr, id);
+        id
+    }
+}
diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch
index f7891b9b..b58ac566 100644
--- a/hercules_samples/matmul/src/cpu.sch
+++ b/hercules_samples/matmul/src/cpu.sch
@@ -6,6 +6,9 @@ auto-outline(*);
 
 ip-sroa(*);
 sroa(*);
+dce(*);
+infer-schedules(*);
+xdot[true](*);
 fork-split(*);
 unforkify(*);
 dce(*);
-- 
GitLab


From 9748467ed31aa3c5af849f4bec6a0959bb8489a6 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 30 Jan 2025 12:07:56 -0600
Subject: [PATCH 3/8] basic einsum analysis

---
 hercules_ir/src/einsum.rs | 208 ++++++++++++++++++++++++++++++--------
 hercules_ir/src/lib.rs    |   1 +
 2 files changed, 169 insertions(+), 40 deletions(-)

diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index 58bb73fb..ea1f5775 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -1,11 +1,12 @@
 use std::collections::{HashMap, HashSet};
+use std::iter::zip;
 
 use crate::*;
 
 /*
  * Math expressions are stored as a simple tree.
  */
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
 pub struct ForkDimension(pub usize, pub DynamicConstantID);
 
 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -84,6 +85,16 @@ pub fn einsum(
         else {
             panic!()
         };
+        let mut ctx = EinsumContext {
+            function,
+            typing,
+            constants,
+            data_nodes_in_fork_joins,
+            fork,
+            factors,
+            env: &mut env,
+            rev_env: &mut rev_env,
+        };
 
         // Compute a math expression for each reduce node in the fork-join with
         // appropriate schedules.
@@ -112,8 +123,11 @@ pub fn einsum(
                             .map(|(_, dim)| dim)
                     })
                     .collect::<Option<Vec<usize>>>()
+                && let Type::Array(_, ref array_bounds) = types[typing[reduce.idx()].idx()]
+                && zip(array_bounds.into_iter(), dimension_bounds.iter())
+                    .all(|(array, fork)| *array == factors[*fork])
             {
-                let data_expr = compute_einsum_expression(data);
+                let data_expr = ctx.compute_math_expr(data);
                 let reduce_expr = MathExpr::Comprehension(
                     data_expr,
                     dimension_bounds
@@ -121,10 +135,11 @@ pub fn einsum(
                         .map(|dim| ForkDimension(dim, factors[dim]))
                         .collect(),
                 );
-                result.insert(
-                    reduce,
-                    intern_math_expr(reduce_expr, &mut env, &mut rev_env),
-                );
+                // We don't need to consider the initializer, since the writes
+                // cover the whole array.
+                let total_id = ctx.intern_math_expr(reduce_expr);
+                ctx.debug_print_expr(total_id);
+                result.insert(reduce, total_id);
             }
             // The reduce defines a sum reduction over a set of fork dimensions.
             else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
@@ -135,8 +150,7 @@ pub fn einsum(
                 } = function.nodes[reduct.idx()]
                 && (left == reduce || right == reduce)
             {
-                let data_expr =
-                    compute_einsum_expression(if left == reduce { right } else { left });
+                let data_expr = ctx.compute_math_expr(if left == reduce { right } else { left });
                 let reduce_expr = MathExpr::SumReduction(
                     data_expr,
                     factors
@@ -145,25 +159,13 @@ pub fn einsum(
                         .map(|(dim, factor)| ForkDimension(dim, *factor))
                         .collect(),
                 );
-                if function.nodes[init.idx()]
-                    .try_constant()
-                    .map(|id| constants[id.idx()].is_zero())
-                    .unwrap_or(false)
-                {
-                    // If the initializer is zero, the expression is just the
-                    // sum reduction.
-                    result.insert(
-                        reduce,
-                        intern_math_expr(reduce_expr, &mut env, &mut rev_env),
-                    );
-                } else {
-                    // If the initializer is not zero, we need to add it.
-                    let reduce_expr_id = intern_math_expr(reduce_expr, &mut env, &mut rev_env);
-                    let init_expr = MathExpr::OpaqueNode(init);
-                    let init_expr_id = intern_math_expr(init_expr, &mut env, &mut rev_env);
-                    let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id);
-                    result.insert(reduce, intern_math_expr(add_expr, &mut env, &mut rev_env));
-                }
+                // Add the initializer.
+                let reduce_expr_id = ctx.intern_math_expr(reduce_expr);
+                let init_expr_id = ctx.compute_math_expr(init);
+                let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id);
+                let total_id = ctx.intern_math_expr(add_expr);
+                ctx.debug_print_expr(total_id);
+                result.insert(reduce, total_id);
             }
         }
     }
@@ -171,20 +173,146 @@ pub fn einsum(
     (env, result)
 }
 
-fn compute_einsum_expression(id: NodeID) -> MathID {
-    todo!()
+struct EinsumContext<'a> {
+    function: &'a Function,
+    typing: &'a Vec<TypeID>,
+    constants: &'a Vec<Constant>,
+    data_nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
+    fork: NodeID,
+    factors: &'a [DynamicConstantID],
+    env: &'a mut MathEnv,
+    rev_env: &'a mut HashMap<MathExpr, MathID>,
+}
+
+impl<'a> EinsumContext<'a> {
+    fn compute_math_expr(&mut self, id: NodeID) -> MathID {
+        let math_expr = match self.function.nodes[id.idx()] {
+            Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_zero() => {
+                MathExpr::Zero(self.typing[id.idx()])
+            }
+            Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_one() => {
+                MathExpr::One(self.typing[id.idx()])
+            }
+            Node::ThreadID { control, dimension } if control == self.fork => {
+                MathExpr::ThreadID(ForkDimension(dimension, self.factors[dimension]))
+            }
+            Node::Binary { op, left, right } if representable(op) => {
+                let left = self.compute_math_expr(left);
+                let right = self.compute_math_expr(right);
+                match op {
+                    BinaryOperator::Add => MathExpr::Add(left, right),
+                    BinaryOperator::Sub => MathExpr::Sub(left, right),
+                    BinaryOperator::Mul => MathExpr::Mul(left, right),
+                    BinaryOperator::Div => MathExpr::Div(left, right),
+                    BinaryOperator::Rem => MathExpr::Rem(left, right),
+                    _ => unreachable!(),
+                }
+            }
+            Node::Read {
+                collect,
+                ref indices,
+            } if indices.len() == 1
+                && let Some(indices) = indices[0].try_position() =>
+            {
+                let collect = self.compute_math_expr(collect);
+                let indices = indices
+                    .into_iter()
+                    .map(|id| self.compute_math_expr(*id))
+                    .collect();
+                MathExpr::Read(collect, indices)
+            }
+            _ => MathExpr::OpaqueNode(id),
+        };
+        self.intern_math_expr(math_expr)
+    }
+
+    fn intern_math_expr(&mut self, expr: MathExpr) -> MathID {
+        if let Some(id) = self.rev_env.get(&expr) {
+            *id
+        } else {
+            let id = MathID::new(self.env.len());
+            self.rev_env.insert(expr, id);
+            id
+        }
+    }
+
+    fn debug_print_expr(&self, id: MathID) {
+        match self.env[id.idx()] {
+            MathExpr::Zero(_) => println!("0"),
+            MathExpr::One(_) => println!("1"),
+            MathExpr::OpaqueNode(id) => println!("{:?}", id),
+            MathExpr::ThreadID(dim) => println!("#{}", dim.0),
+            MathExpr::SumReduction(id, ref dims) => {
+                println!("Sum (");
+                for dim in dims {
+                    println!("#{}/{:?},", dim.0, dim.1);
+                }
+                println!(") ");
+                self.debug_print_expr(id);
+            }
+            MathExpr::Comprehension(id, ref dims) => {
+                println!("[");
+                for dim in dims {
+                    println!("#{}/{:?},", dim.0, dim.1);
+                }
+                println!("] ");
+                self.debug_print_expr(id);
+            }
+            MathExpr::Read(id, ref pos) => {
+                println!("read(");
+                self.debug_print_expr(id);
+                for pos in pos {
+                    println!(", ");
+                    self.debug_print_expr(*pos);
+                }
+                println!(")");
+            }
+            MathExpr::Add(left, right) => {
+                println!("+(");
+                self.debug_print_expr(left);
+                println!(", ");
+                self.debug_print_expr(right);
+                println!(")");
+            }
+            MathExpr::Sub(left, right) => {
+                println!("-(");
+                self.debug_print_expr(left);
+                println!(", ");
+                self.debug_print_expr(right);
+                println!(")");
+            }
+            MathExpr::Mul(left, right) => {
+                println!("*(");
+                self.debug_print_expr(left);
+                println!(", ");
+                self.debug_print_expr(right);
+                println!(")");
+            }
+            MathExpr::Div(left, right) => {
+                println!("/(");
+                self.debug_print_expr(left);
+                println!(", ");
+                self.debug_print_expr(right);
+                println!(")");
+            }
+            MathExpr::Rem(left, right) => {
+                println!("%(");
+                self.debug_print_expr(left);
+                println!(", ");
+                self.debug_print_expr(right);
+                println!(")");
+            }
+        }
+    }
 }
 
-fn intern_math_expr(
-    expr: MathExpr,
-    env: &mut MathEnv,
-    rev_env: &mut HashMap<MathExpr, MathID>,
-) -> MathID {
-    if let Some(id) = rev_env.get(&expr) {
-        *id
-    } else {
-        let id = MathID::new(env.len());
-        rev_env.insert(expr, id);
-        id
+fn representable(op: BinaryOperator) -> bool {
+    match op {
+        BinaryOperator::Add
+        | BinaryOperator::Sub
+        | BinaryOperator::Mul
+        | BinaryOperator::Div
+        | BinaryOperator::Rem => true,
+        _ => false,
     }
 }
diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs
index adfe030c..185a2888 100644
--- a/hercules_ir/src/lib.rs
+++ b/hercules_ir/src/lib.rs
@@ -2,6 +2,7 @@
     coroutines,
     coroutine_trait,
     let_chains,
+    if_let_guard,
     stmt_expr_attributes,
     iter_intersperse
 )]
-- 
GitLab


From 072cbe7a6a583edced53b088333873e6545b44eb Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 30 Jan 2025 15:18:39 -0600
Subject: [PATCH 4/8] Einsum works w/ one reduce

---
 hercules_ir/src/einsum.rs | 70 ++++++++++++++++++++-------------------
 juno_scheduler/src/pm.rs  | 51 ++++++++++++++++++++++++++--
 2 files changed, 84 insertions(+), 37 deletions(-)

diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index ea1f5775..90b2f07d 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -70,14 +70,12 @@ pub fn einsum(
     // expressions of inner reduces before getting to outer reduces. Since fork-
     // joins are strictly nested, we can literally iterate entries of
     // `fork_join_nest` in decreasing order of nesting size to accomplish this.
-    let mut nests: Vec<_> = fork_join_nest.into_iter().collect();
+    let mut nests: Vec<_> = fork_join_nest
+        .into_iter()
+        .filter(|(id, _)| function.nodes[id.idx()].is_fork())
+        .collect();
     nests.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
     for fork in nests.into_iter().map(|(id, _)| *id) {
-        // Check that the fork-join has no control flow.
-        let join = fork_join_map[&fork];
-        if function.nodes[join.idx()].try_join().unwrap() != fork {
-            continue;
-        }
         let Node::Fork {
             control: _,
             ref factors,
@@ -85,6 +83,7 @@ pub fn einsum(
         else {
             panic!()
         };
+        let join = fork_join_map[&fork];
         let mut ctx = EinsumContext {
             function,
             typing,
@@ -139,6 +138,7 @@ pub fn einsum(
                 // cover the whole array.
                 let total_id = ctx.intern_math_expr(reduce_expr);
                 ctx.debug_print_expr(total_id);
+                println!("");
                 result.insert(reduce, total_id);
             }
             // The reduce defines a sum reduction over a set of fork dimensions.
@@ -165,6 +165,7 @@ pub fn einsum(
                 let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id);
                 let total_id = ctx.intern_math_expr(add_expr);
                 ctx.debug_print_expr(total_id);
+                println!("");
                 result.insert(reduce, total_id);
             }
         }
@@ -231,6 +232,7 @@ impl<'a> EinsumContext<'a> {
             *id
         } else {
             let id = MathID::new(self.env.len());
+            self.env.push(expr.clone());
             self.rev_env.insert(expr, id);
             id
         }
@@ -238,69 +240,69 @@ impl<'a> EinsumContext<'a> {
 
     fn debug_print_expr(&self, id: MathID) {
         match self.env[id.idx()] {
-            MathExpr::Zero(_) => println!("0"),
-            MathExpr::One(_) => println!("1"),
-            MathExpr::OpaqueNode(id) => println!("{:?}", id),
-            MathExpr::ThreadID(dim) => println!("#{}", dim.0),
+            MathExpr::Zero(_) => print!("0"),
+            MathExpr::One(_) => print!("1"),
+            MathExpr::OpaqueNode(id) => print!("{:?}", id),
+            MathExpr::ThreadID(dim) => print!("#{}", dim.0),
             MathExpr::SumReduction(id, ref dims) => {
-                println!("Sum (");
+                print!("Sum (");
                 for dim in dims {
-                    println!("#{}/{:?},", dim.0, dim.1);
+                    print!("#{}/{:?},", dim.0, dim.1);
                 }
-                println!(") ");
+                print!(") ");
                 self.debug_print_expr(id);
             }
             MathExpr::Comprehension(id, ref dims) => {
-                println!("[");
+                print!("[");
                 for dim in dims {
-                    println!("#{}/{:?},", dim.0, dim.1);
+                    print!("#{}/{:?},", dim.0, dim.1);
                 }
-                println!("] ");
+                print!("] ");
                 self.debug_print_expr(id);
             }
             MathExpr::Read(id, ref pos) => {
-                println!("read(");
+                print!("read(");
                 self.debug_print_expr(id);
                 for pos in pos {
-                    println!(", ");
+                    print!(", ");
                     self.debug_print_expr(*pos);
                 }
-                println!(")");
+                print!(")");
             }
             MathExpr::Add(left, right) => {
-                println!("+(");
+                print!("+(");
                 self.debug_print_expr(left);
-                println!(", ");
+                print!(", ");
                 self.debug_print_expr(right);
-                println!(")");
+                print!(")");
             }
             MathExpr::Sub(left, right) => {
-                println!("-(");
+                print!("-(");
                 self.debug_print_expr(left);
-                println!(", ");
+                print!(", ");
                 self.debug_print_expr(right);
-                println!(")");
+                print!(")");
             }
             MathExpr::Mul(left, right) => {
-                println!("*(");
+                print!("*(");
                 self.debug_print_expr(left);
-                println!(", ");
+                print!(", ");
                 self.debug_print_expr(right);
-                println!(")");
+                print!(")");
             }
             MathExpr::Div(left, right) => {
-                println!("/(");
+                print!("/(");
                 self.debug_print_expr(left);
-                println!(", ");
+                print!(", ");
                 self.debug_print_expr(right);
-                println!(")");
+                print!(")");
             }
             MathExpr::Rem(left, right) => {
-                println!("%(");
+                print!("%(");
                 self.debug_print_expr(left);
-                println!(", ");
+                print!(", ");
                 self.debug_print_expr(right);
-                println!(")");
+                print!(")");
             }
         }
     }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 9888f3d2..910d3226 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -181,6 +181,7 @@ pub struct PassManager {
     pub loops: Option<Vec<LoopTree>>,
     pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
+    pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>,
     pub collection_objects: Option<CollectionObjects>,
     pub callgraph: Option<CallGraph>,
     pub devices: Option<Vec<Device>>,
@@ -216,6 +217,7 @@ impl PassManager {
             loops: None,
             reduce_cycles: None,
             data_nodes_in_fork_joins: None,
+            reduce_einsums: None,
             collection_objects: None,
             callgraph: None,
             devices: None,
@@ -392,6 +394,48 @@ impl PassManager {
         }
     }
 
+    pub fn make_reduce_einsums(&mut self) {
+        if self.reduce_einsums.is_none() {
+            self.make_def_uses();
+            self.make_typing();
+            self.make_fork_join_maps();
+            self.make_fork_join_nests();
+            self.make_data_nodes_in_fork_joins();
+            let def_uses = self.def_uses.as_ref().unwrap().iter();
+            let typing = self.typing.as_ref().unwrap().iter();
+            let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
+            let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter();
+            let data_nodes_in_fork_joins = self.data_nodes_in_fork_joins.as_ref().unwrap().iter();
+            self.reduce_einsums = Some(
+                self.functions
+                    .iter()
+                    .zip(def_uses)
+                    .zip(typing)
+                    .zip(fork_join_maps)
+                    .zip(fork_join_nests)
+                    .zip(data_nodes_in_fork_joins)
+                    .map(
+                        |(
+                            ((((function, def_use), typing), fork_join_map), fork_join_nest),
+                            data_nodes_in_fork_joins,
+                        )| {
+                            einsum(
+                                function,
+                                &self.types.borrow(),
+                                &self.constants.borrow(),
+                                def_use,
+                                typing,
+                                fork_join_map,
+                                fork_join_nest,
+                                data_nodes_in_fork_joins,
+                            )
+                        },
+                    )
+                    .collect(),
+            );
+        }
+    }
+
     pub fn make_collection_objects(&mut self) {
         if self.collection_objects.is_none() {
             self.make_reverse_postorders();
@@ -463,6 +507,7 @@ impl PassManager {
         self.loops = None;
         self.reduce_cycles = None;
         self.data_nodes_in_fork_joins = None;
+        self.reduce_einsums = None;
         self.collection_objects = None;
         self.callgraph = None;
         self.devices = None;
@@ -643,7 +688,7 @@ impl PassManager {
 }
 
 pub fn schedule_codegen(
-    mut module: Module,
+    module: Module,
     schedule: ScheduleStmt,
     mut stringtab: StringTable,
     mut env: Env<usize, Value>,
@@ -913,8 +958,8 @@ fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo>
             .labels
             .iter()
             .enumerate()
-            .filter(|(i, ls)| ls.contains(&label))
-            .map(|(i, ls)| i)
+            .filter(|(_, ls)| ls.contains(&label))
+            .map(|(i, _)| i)
             .collect::<Vec<_>>();
         for node in nodes {
             pm.functions[func.idx()].schedules[node].push(sched.clone());
-- 
GitLab


From 84c2ff659747e21301e67394084ab94be39cd775 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 30 Jan 2025 17:14:32 -0600
Subject: [PATCH 5/8] einsum across reduces of inner fork-joins

---
 hercules_ir/src/einsum.rs | 109 +++++++++++++++++++++++++++++++++++---
 juno_scheduler/src/pm.rs  |   3 +-
 2 files changed, 105 insertions(+), 7 deletions(-)

diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index 90b2f07d..bf91dc6d 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -84,6 +84,19 @@ pub fn einsum(
             panic!()
         };
         let join = fork_join_map[&fork];
+        let thread_ids: Vec<_> = def_use
+            .get_users(fork)
+            .into_iter()
+            .filter_map(|id| {
+                function.nodes[id.idx()]
+                    .try_thread_id()
+                    .map(|(_, dim)| (*id, dim))
+            })
+            .collect();
+        let reduces = def_use
+            .get_users(join)
+            .into_iter()
+            .filter_map(|id| function.nodes[id.idx()].try_reduce().map(|v| (*id, v)));
         let mut ctx = EinsumContext {
             function,
             typing,
@@ -91,16 +104,14 @@ pub fn einsum(
             data_nodes_in_fork_joins,
             fork,
             factors,
+            thread_ids: &thread_ids,
+            so_far: &mut result,
             env: &mut env,
             rev_env: &mut rev_env,
         };
 
         // Compute a math expression for each reduce node in the fork-join with
         // appropriate schedules.
-        let reduces = def_use
-            .get_users(join)
-            .into_iter()
-            .filter_map(|id| function.nodes[id.idx()].try_reduce().map(|v| (*id, v)));
         for (reduce, (_, init, reduct)) in reduces {
             // The reduce defines an array where each fork dimension corresponds
             // to one array dimension.
@@ -139,7 +150,7 @@ pub fn einsum(
                 let total_id = ctx.intern_math_expr(reduce_expr);
                 ctx.debug_print_expr(total_id);
                 println!("");
-                result.insert(reduce, total_id);
+                ctx.result_insert(reduce, total_id);
             }
             // The reduce defines a sum reduction over a set of fork dimensions.
             else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
@@ -166,7 +177,7 @@ pub fn einsum(
                 let total_id = ctx.intern_math_expr(add_expr);
                 ctx.debug_print_expr(total_id);
                 println!("");
-                result.insert(reduce, total_id);
+                ctx.result_insert(reduce, total_id);
             }
         }
     }
@@ -181,6 +192,8 @@ struct EinsumContext<'a> {
     data_nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
     fork: NodeID,
     factors: &'a [DynamicConstantID],
+    thread_ids: &'a Vec<(NodeID, usize)>,
+    so_far: &'a mut HashMap<NodeID, MathID>,
     env: &'a mut MathEnv,
     rev_env: &'a mut HashMap<MathExpr, MathID>,
 }
@@ -222,6 +235,17 @@ impl<'a> EinsumContext<'a> {
                     .collect();
                 MathExpr::Read(collect, indices)
             }
+            Node::Reduce {
+                control: _,
+                init: _,
+                reduct: _,
+            } => {
+                let reduce = self.so_far[&id];
+                // Substitute opaque uses of thread ID nodes in inner expression
+                // with thread ID math expression, and increment inner-fork
+                // dimensions (alpha renaming).
+                return self.substitute_new_dims(reduce);
+            }
             _ => MathExpr::OpaqueNode(id),
         };
         self.intern_math_expr(math_expr)
@@ -238,6 +262,79 @@ impl<'a> EinsumContext<'a> {
         }
     }
 
+    fn result_insert(&mut self, node: NodeID, math: MathID) {
+        self.so_far.insert(node, math);
+    }
+
+    fn substitute_new_dims(&mut self, id: MathID) -> MathID {
+        match self.env[id.idx()] {
+            MathExpr::OpaqueNode(opaque)
+                if let Some((_, dim)) = self
+                    .thread_ids
+                    .into_iter()
+                    .filter(|(node, _)| *node == opaque)
+                    .next() =>
+            {
+                self.intern_math_expr(MathExpr::ThreadID(ForkDimension(*dim, self.factors[*dim])))
+            }
+            MathExpr::ThreadID(dim) => self.intern_math_expr(MathExpr::ThreadID(ForkDimension(
+                dim.0 + self.factors.len(),
+                dim.1,
+            ))),
+            MathExpr::SumReduction(id, ref dims) => {
+                let dims = dims
+                    .into_iter()
+                    .map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1))
+                    .collect();
+                let id = self.substitute_new_dims(id);
+                self.intern_math_expr(MathExpr::SumReduction(id, dims))
+            }
+            MathExpr::Comprehension(id, ref dims) => {
+                let dims = dims
+                    .into_iter()
+                    .map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1))
+                    .collect();
+                let id = self.substitute_new_dims(id);
+                self.intern_math_expr(MathExpr::Comprehension(id, dims))
+            }
+            MathExpr::Read(array, ref indices) => {
+                let indices = indices
+                    .clone()
+                    .iter()
+                    .map(|id| self.substitute_new_dims(*id))
+                    .collect();
+                let array = self.substitute_new_dims(array);
+                self.intern_math_expr(MathExpr::Read(array, indices))
+            }
+            MathExpr::Add(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Add(left, right))
+            }
+            MathExpr::Sub(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Sub(left, right))
+            }
+            MathExpr::Mul(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Mul(left, right))
+            }
+            MathExpr::Div(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Div(left, right))
+            }
+            MathExpr::Rem(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Rem(left, right))
+            }
+            _ => id,
+        }
+    }
+
     fn debug_print_expr(&self, id: MathID) {
         match self.env[id.idx()] {
             MathExpr::Zero(_) => print!("0"),
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 910d3226..3f43aca8 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1351,6 +1351,7 @@ fn run_pass(
         }
         Pass::ForkSplit => {
             assert!(args.is_empty());
+            pm.make_reduce_einsums();
             loop {
                 let mut inner_changed = false;
                 pm.make_fork_join_maps();
@@ -1395,7 +1396,7 @@ fn run_pass(
                 let Some(mut func) = func else {
                     continue;
                 };
-                // TODO: uses direct return from forkify for now instead of 
+                // TODO: uses direct return from forkify for now instead of
                 // func.modified, see comment on top of `forkify` for why. Fix
                 // this eventually.
                 changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
-- 
GitLab


From fbc200a853cb35fe3265dfda37f682c28dd9c6cb Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 30 Jan 2025 17:17:58 -0600
Subject: [PATCH 6/8] remove debug stuff

---
 hercules_ir/src/einsum.rs | 4 ----
 juno_scheduler/src/pm.rs  | 1 -
 2 files changed, 5 deletions(-)

diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index bf91dc6d..c2521c22 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -148,8 +148,6 @@ pub fn einsum(
                 // We don't need to consider the initializer, since the writes
                 // cover the whole array.
                 let total_id = ctx.intern_math_expr(reduce_expr);
-                ctx.debug_print_expr(total_id);
-                println!("");
                 ctx.result_insert(reduce, total_id);
             }
             // The reduce defines a sum reduction over a set of fork dimensions.
@@ -175,8 +173,6 @@ pub fn einsum(
                 let init_expr_id = ctx.compute_math_expr(init);
                 let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id);
                 let total_id = ctx.intern_math_expr(add_expr);
-                ctx.debug_print_expr(total_id);
-                println!("");
                 ctx.result_insert(reduce, total_id);
             }
         }
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 3f43aca8..81c46656 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -1351,7 +1351,6 @@ fn run_pass(
         }
         Pass::ForkSplit => {
             assert!(args.is_empty());
-            pm.make_reduce_einsums();
             loop {
                 let mut inner_changed = false;
                 pm.make_fork_join_maps();
-- 
GitLab


From a92c8613365dff8ce7763767088cc9465191777d Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 30 Jan 2025 17:20:44 -0600
Subject: [PATCH 7/8] whoops

---
 hercules_samples/matmul/src/cpu.sch | 1 -
 1 file changed, 1 deletion(-)

diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch
index b58ac566..c00a3314 100644
--- a/hercules_samples/matmul/src/cpu.sch
+++ b/hercules_samples/matmul/src/cpu.sch
@@ -8,7 +8,6 @@ ip-sroa(*);
 sroa(*);
 dce(*);
 infer-schedules(*);
-xdot[true](*);
 fork-split(*);
 unforkify(*);
 dce(*);
-- 
GitLab


From 1acde31d7535c76d70c4cc8025d386d03bdccfa0 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 30 Jan 2025 17:21:25 -0600
Subject: [PATCH 8/8] clean

---
 juno_samples/matmul/build.rs      |  2 -
 juno_samples/matmul/src/sched.sch | 76 -------------------------------
 2 files changed, 78 deletions(-)
 delete mode 100644 juno_samples/matmul/src/sched.sch

diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs
index 511bf483..926fbc33 100644
--- a/juno_samples/matmul/build.rs
+++ b/juno_samples/matmul/build.rs
@@ -4,8 +4,6 @@ fn main() {
     JunoCompiler::new()
         .file_in_src("matmul.jn")
         .unwrap()
-        //.schedule_in_src("sched.sch")
-        //.unwrap()
         .build()
         .unwrap();
 }
diff --git a/juno_samples/matmul/src/sched.sch b/juno_samples/matmul/src/sched.sch
deleted file mode 100644
index 3999f923..00000000
--- a/juno_samples/matmul/src/sched.sch
+++ /dev/null
@@ -1,76 +0,0 @@
-macro juno-setup!(X) {
-  gvn(X);
-  dce(X);
-  phi-elim(X);
-}
-
-macro default!(X) {
-  dce(X);
-  crc(X);
-  dce(X);
-  slf(X);
-  dce(X);
-  inline(X);
-  ip-sroa(X);
-  sroa(X);
-  phi-elim(X);
-  dce(X);
-  ccp(X);
-  dce(X);
-  gvn(X);
-  dce(X);
-  write-predication(X);
-  phi-elim(X);
-  dce(X);
-  crc(X);
-  dce(X);
-  slf(X);
-  dce(X);
-  predication(X);
-  dce(X);
-  ccp(X);
-  dce(X);
-  gvn(X);
-  dce(X);
-  lift-dc-math(X);
-  dce(X);
-  gvn(X);
-  dce(X);
-}
-
-macro codegen-prep!(X) {
-  verify(*);
-  ip-sroa(*);
-  sroa(*);
-  infer-schedules(X);
-  dce(X);
-  gcm(X);
-  dce(X);
-  phi-elim(X);
-  float-collections(X);
-  gcm(X);
-}
-
-juno-setup!(*);
-default!(*);
-// your stuff here.
-
-fixpoint stop after 13 {
-  forkify(*);
-  fork-guard-elim(*);
-  fork-coalesce(*);
-  phi-elim(*);
-  dce(*);
-}
-
-xdot[true](*);
-// serialize(*);
-
-fork-split(*);
-unforkify(*);
-
-gvn(*);
-dce(*);
-
-auto-outline(*);
-codegen-prep!(*);
-- 
GitLab