From 1cdace83ebf1b518d58aec9d04d1c1dcb67586b3 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Thu, 30 Jan 2025 17:23:39 -0600
Subject: [PATCH] Einsum analysis

---
 hercules_ir/src/einsum.rs           | 413 ++++++++++++++++++++++++++++
 hercules_ir/src/lib.rs              |   3 +
 hercules_samples/matmul/src/cpu.sch |   2 +
 juno_samples/matmul/build.rs        |   2 -
 juno_samples/matmul/src/sched.sch   |  76 -----
 juno_scheduler/src/pm.rs            |  53 +++-
 6 files changed, 467 insertions(+), 82 deletions(-)
 create mode 100644 hercules_ir/src/einsum.rs
 delete mode 100644 juno_samples/matmul/src/sched.sch

diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
new file mode 100644
index 00000000..c2521c22
--- /dev/null
+++ b/hercules_ir/src/einsum.rs
@@ -0,0 +1,413 @@
+use std::collections::{HashMap, HashSet};
+use std::iter::zip;
+
+use crate::*;
+
+/*
+ * Math expressions are stored as a simple tree.
+ */
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct ForkDimension(pub usize, pub DynamicConstantID);
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum MathExpr {
+    // Zero constant of a specific type.
+    Zero(TypeID),
+    // One constant of a specific type.
+    One(TypeID),
+    // Opaque value corresponding to a particular node in the original IR.
+    OpaqueNode(NodeID),
+
+    // Thread ID from the fork corresponding to the reduce being expressed.
+    // Thread IDs from outside forks are considered OpaqueNodes.
+    ThreadID(ForkDimension),
+
+    // Sum reduction over a dimension of a fork.
+    SumReduction(MathID, Box<[ForkDimension]>),
+    // Comprehend a scalar expression into an array over fork dimensions.
+    Comprehension(MathID, Box<[ForkDimension]>),
+
+    // Read from an array.
+    Read(MathID, Box<[MathID]>),
+
+    // Math ops.
+    Add(MathID, MathID),
+    Sub(MathID, MathID),
+    Mul(MathID, MathID),
+    Div(MathID, MathID),
+    Rem(MathID, MathID),
+}
+
+pub type MathEnv = Vec<MathExpr>;
+
+define_id_type!(MathID);
+
+/*
+ * Top level function to run "einsum" analysis on fork-joins. This is a terrible
+ * name for this analysis, since it's actually more general than identifying
+ * einsums, but einsum syntax has a passing resemblance to the idea of this
+ * analysis and it's what we keep calling it, so we're doomed to this bad name.
+ * The idea of this analysis is to convert some fork-joins into pure math
+ * expressions that we can rewrite into intrinsic functions for higher level
+ * operators like matmul. Specifically, this function returns a map from each
+ * reduce node to a math expression.
+ */
+pub fn einsum(
+    function: &Function,
+    types: &Vec<Type>,
+    constants: &Vec<Constant>,
+    def_use: &ImmutableDefUseMap,
+    typing: &Vec<TypeID>,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
+    data_nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+) -> (MathEnv, HashMap<NodeID, MathID>) {
+    let mut env = vec![];
+    let mut rev_env = HashMap::new();
+    let mut result = HashMap::new();
+
+    // Iterate fork-joins bottom-up, since we need to compute the math
+    // expressions of inner reduces before getting to outer reduces. Since fork-
+    // joins are strictly nested, we can literally iterate entries of
+    // `fork_join_nest` in decreasing order of nesting size to accomplish this.
+    let mut nests: Vec<_> = fork_join_nest
+        .into_iter()
+        .filter(|(id, _)| function.nodes[id.idx()].is_fork())
+        .collect();
+    nests.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
+    for fork in nests.into_iter().map(|(id, _)| *id) {
+        let Node::Fork {
+            control: _,
+            ref factors,
+        } = function.nodes[fork.idx()]
+        else {
+            panic!()
+        };
+        let join = fork_join_map[&fork];
+        let thread_ids: Vec<_> = def_use
+            .get_users(fork)
+            .into_iter()
+            .filter_map(|id| {
+                function.nodes[id.idx()]
+                    .try_thread_id()
+                    .map(|(_, dim)| (*id, dim))
+            })
+            .collect();
+        let reduces = def_use
+            .get_users(join)
+            .into_iter()
+            .filter_map(|id| function.nodes[id.idx()].try_reduce().map(|v| (*id, v)));
+        let mut ctx = EinsumContext {
+            function,
+            typing,
+            constants,
+            data_nodes_in_fork_joins,
+            fork,
+            factors,
+            thread_ids: &thread_ids,
+            so_far: &mut result,
+            env: &mut env,
+            rev_env: &mut rev_env,
+        };
+
+        // Compute a math expression for each reduce node in the fork-join with
+        // appropriate schedules.
+        for (reduce, (_, init, reduct)) in reduces {
+            // The reduce defines an array where each fork dimension corresponds
+            // to one array dimension.
+            if function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
+                && let Node::Write {
+                    collect,
+                    data,
+                    ref indices,
+                } = function.nodes[reduct.idx()]
+                && collect == reduce
+                && indices.len() == 1
+                && let Some(indices) = indices[0].try_position()
+                && let Some(dimension_bounds) = indices
+                    .into_iter()
+                    .map(|id| {
+                        function.nodes[id.idx()]
+                            .try_thread_id()
+                            .filter(|(tid_fork, _)| *tid_fork == fork)
+                            .map(|(_, dim)| dim)
+                    })
+                    .collect::<Option<Vec<usize>>>()
+                && let Type::Array(_, ref array_bounds) = types[typing[reduce.idx()].idx()]
+                && zip(array_bounds.into_iter(), dimension_bounds.iter())
+                    .all(|(array, fork)| *array == factors[*fork])
+            {
+                let data_expr = ctx.compute_math_expr(data);
+                let reduce_expr = MathExpr::Comprehension(
+                    data_expr,
+                    dimension_bounds
+                        .into_iter()
+                        .map(|dim| ForkDimension(dim, factors[dim]))
+                        .collect(),
+                );
+                // We don't need to consider the initializer, since the writes
+                // cover the whole array.
+                let total_id = ctx.intern_math_expr(reduce_expr);
+                ctx.result_insert(reduce, total_id);
+            }
+            // The reduce defines a sum reduction over a set of fork dimensions.
+            else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
+                && let Node::Binary {
+                    op: BinaryOperator::Add,
+                    left,
+                    right,
+                } = function.nodes[reduct.idx()]
+                && (left == reduce || right == reduce)
+            {
+                let data_expr = ctx.compute_math_expr(if left == reduce { right } else { left });
+                let reduce_expr = MathExpr::SumReduction(
+                    data_expr,
+                    factors
+                        .into_iter()
+                        .enumerate()
+                        .map(|(dim, factor)| ForkDimension(dim, *factor))
+                        .collect(),
+                );
+                // Add the initializer.
+                let reduce_expr_id = ctx.intern_math_expr(reduce_expr);
+                let init_expr_id = ctx.compute_math_expr(init);
+                let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id);
+                let total_id = ctx.intern_math_expr(add_expr);
+                ctx.result_insert(reduce, total_id);
+            }
+        }
+    }
+
+    (env, result)
+}
+
+struct EinsumContext<'a> {
+    function: &'a Function,
+    typing: &'a Vec<TypeID>,
+    constants: &'a Vec<Constant>,
+    data_nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
+    fork: NodeID,
+    factors: &'a [DynamicConstantID],
+    thread_ids: &'a Vec<(NodeID, usize)>,
+    so_far: &'a mut HashMap<NodeID, MathID>,
+    env: &'a mut MathEnv,
+    rev_env: &'a mut HashMap<MathExpr, MathID>,
+}
+
+impl<'a> EinsumContext<'a> {
+    fn compute_math_expr(&mut self, id: NodeID) -> MathID {
+        let math_expr = match self.function.nodes[id.idx()] {
+            Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_zero() => {
+                MathExpr::Zero(self.typing[id.idx()])
+            }
+            Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_one() => {
+                MathExpr::One(self.typing[id.idx()])
+            }
+            Node::ThreadID { control, dimension } if control == self.fork => {
+                MathExpr::ThreadID(ForkDimension(dimension, self.factors[dimension]))
+            }
+            Node::Binary { op, left, right } if representable(op) => {
+                let left = self.compute_math_expr(left);
+                let right = self.compute_math_expr(right);
+                match op {
+                    BinaryOperator::Add => MathExpr::Add(left, right),
+                    BinaryOperator::Sub => MathExpr::Sub(left, right),
+                    BinaryOperator::Mul => MathExpr::Mul(left, right),
+                    BinaryOperator::Div => MathExpr::Div(left, right),
+                    BinaryOperator::Rem => MathExpr::Rem(left, right),
+                    _ => unreachable!(),
+                }
+            }
+            Node::Read {
+                collect,
+                ref indices,
+            } if indices.len() == 1
+                && let Some(indices) = indices[0].try_position() =>
+            {
+                let collect = self.compute_math_expr(collect);
+                let indices = indices
+                    .into_iter()
+                    .map(|id| self.compute_math_expr(*id))
+                    .collect();
+                MathExpr::Read(collect, indices)
+            }
+            Node::Reduce {
+                control: _,
+                init: _,
+                reduct: _,
+            } => {
+                let reduce = self.so_far[&id];
+                // Substitute opaque uses of thread ID nodes in inner expression
+                // with thread ID math expression, and increment inner-fork
+                // dimensions (alpha renaming).
+                return self.substitute_new_dims(reduce);
+            }
+            _ => MathExpr::OpaqueNode(id),
+        };
+        self.intern_math_expr(math_expr)
+    }
+
+    fn intern_math_expr(&mut self, expr: MathExpr) -> MathID {
+        if let Some(id) = self.rev_env.get(&expr) {
+            *id
+        } else {
+            let id = MathID::new(self.env.len());
+            self.env.push(expr.clone());
+            self.rev_env.insert(expr, id);
+            id
+        }
+    }
+
+    fn result_insert(&mut self, node: NodeID, math: MathID) {
+        self.so_far.insert(node, math);
+    }
+
+    fn substitute_new_dims(&mut self, id: MathID) -> MathID {
+        match self.env[id.idx()] {
+            MathExpr::OpaqueNode(opaque)
+                if let Some((_, dim)) = self
+                    .thread_ids
+                    .into_iter()
+                    .filter(|(node, _)| *node == opaque)
+                    .next() =>
+            {
+                self.intern_math_expr(MathExpr::ThreadID(ForkDimension(*dim, self.factors[*dim])))
+            }
+            MathExpr::ThreadID(dim) => self.intern_math_expr(MathExpr::ThreadID(ForkDimension(
+                dim.0 + self.factors.len(),
+                dim.1,
+            ))),
+            MathExpr::SumReduction(id, ref dims) => {
+                let dims = dims
+                    .into_iter()
+                    .map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1))
+                    .collect();
+                let id = self.substitute_new_dims(id);
+                self.intern_math_expr(MathExpr::SumReduction(id, dims))
+            }
+            MathExpr::Comprehension(id, ref dims) => {
+                let dims = dims
+                    .into_iter()
+                    .map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1))
+                    .collect();
+                let id = self.substitute_new_dims(id);
+                self.intern_math_expr(MathExpr::Comprehension(id, dims))
+            }
+            MathExpr::Read(array, ref indices) => {
+                let indices = indices
+                    .clone()
+                    .iter()
+                    .map(|id| self.substitute_new_dims(*id))
+                    .collect();
+                let array = self.substitute_new_dims(array);
+                self.intern_math_expr(MathExpr::Read(array, indices))
+            }
+            MathExpr::Add(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Add(left, right))
+            }
+            MathExpr::Sub(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Sub(left, right))
+            }
+            MathExpr::Mul(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Mul(left, right))
+            }
+            MathExpr::Div(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Div(left, right))
+            }
+            MathExpr::Rem(left, right) => {
+                let left = self.substitute_new_dims(left);
+                let right = self.substitute_new_dims(right);
+                self.intern_math_expr(MathExpr::Rem(left, right))
+            }
+            _ => id,
+        }
+    }
+
+    fn debug_print_expr(&self, id: MathID) {
+        match self.env[id.idx()] {
+            MathExpr::Zero(_) => print!("0"),
+            MathExpr::One(_) => print!("1"),
+            MathExpr::OpaqueNode(id) => print!("{:?}", id),
+            MathExpr::ThreadID(dim) => print!("#{}", dim.0),
+            MathExpr::SumReduction(id, ref dims) => {
+                print!("Sum (");
+                for dim in dims {
+                    print!("#{}/{:?},", dim.0, dim.1);
+                }
+                print!(") ");
+                self.debug_print_expr(id);
+            }
+            MathExpr::Comprehension(id, ref dims) => {
+                print!("[");
+                for dim in dims {
+                    print!("#{}/{:?},", dim.0, dim.1);
+                }
+                print!("] ");
+                self.debug_print_expr(id);
+            }
+            MathExpr::Read(id, ref pos) => {
+                print!("read(");
+                self.debug_print_expr(id);
+                for pos in pos {
+                    print!(", ");
+                    self.debug_print_expr(*pos);
+                }
+                print!(")");
+            }
+            MathExpr::Add(left, right) => {
+                print!("+(");
+                self.debug_print_expr(left);
+                print!(", ");
+                self.debug_print_expr(right);
+                print!(")");
+            }
+            MathExpr::Sub(left, right) => {
+                print!("-(");
+                self.debug_print_expr(left);
+                print!(", ");
+                self.debug_print_expr(right);
+                print!(")");
+            }
+            MathExpr::Mul(left, right) => {
+                print!("*(");
+                self.debug_print_expr(left);
+                print!(", ");
+                self.debug_print_expr(right);
+                print!(")");
+            }
+            MathExpr::Div(left, right) => {
+                print!("/(");
+                self.debug_print_expr(left);
+                print!(", ");
+                self.debug_print_expr(right);
+                print!(")");
+            }
+            MathExpr::Rem(left, right) => {
+                print!("%(");
+                self.debug_print_expr(left);
+                print!(", ");
+                self.debug_print_expr(right);
+                print!(")");
+            }
+        }
+    }
+}
+
+fn representable(op: BinaryOperator) -> bool {
+    match op {
+        BinaryOperator::Add
+        | BinaryOperator::Sub
+        | BinaryOperator::Mul
+        | BinaryOperator::Div
+        | BinaryOperator::Rem => true,
+        _ => false,
+    }
+}
diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs
index 85dc277f..185a2888 100644
--- a/hercules_ir/src/lib.rs
+++ b/hercules_ir/src/lib.rs
@@ -2,6 +2,7 @@
     coroutines,
     coroutine_trait,
     let_chains,
+    if_let_guard,
     stmt_expr_attributes,
     iter_intersperse
 )]
@@ -14,6 +15,7 @@ pub mod def_use;
 pub mod device;
 pub mod dom;
 pub mod dot;
+pub mod einsum;
 pub mod fork_join_analysis;
 pub mod ir;
 pub mod loops;
@@ -30,6 +32,7 @@ pub use crate::def_use::*;
 pub use crate::device::*;
 pub use crate::dom::*;
 pub use crate::dot::*;
+pub use crate::einsum::*;
 pub use crate::fork_join_analysis::*;
 pub use crate::ir::*;
 pub use crate::loops::*;
diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch
index f7891b9b..c00a3314 100644
--- a/hercules_samples/matmul/src/cpu.sch
+++ b/hercules_samples/matmul/src/cpu.sch
@@ -6,6 +6,8 @@ auto-outline(*);
 
 ip-sroa(*);
 sroa(*);
+dce(*);
+infer-schedules(*);
 fork-split(*);
 unforkify(*);
 dce(*);
diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs
index 511bf483..926fbc33 100644
--- a/juno_samples/matmul/build.rs
+++ b/juno_samples/matmul/build.rs
@@ -4,8 +4,6 @@ fn main() {
     JunoCompiler::new()
         .file_in_src("matmul.jn")
         .unwrap()
-        //.schedule_in_src("sched.sch")
-        //.unwrap()
         .build()
         .unwrap();
 }
diff --git a/juno_samples/matmul/src/sched.sch b/juno_samples/matmul/src/sched.sch
deleted file mode 100644
index 3999f923..00000000
--- a/juno_samples/matmul/src/sched.sch
+++ /dev/null
@@ -1,76 +0,0 @@
-macro juno-setup!(X) {
-  gvn(X);
-  dce(X);
-  phi-elim(X);
-}
-
-macro default!(X) {
-  dce(X);
-  crc(X);
-  dce(X);
-  slf(X);
-  dce(X);
-  inline(X);
-  ip-sroa(X);
-  sroa(X);
-  phi-elim(X);
-  dce(X);
-  ccp(X);
-  dce(X);
-  gvn(X);
-  dce(X);
-  write-predication(X);
-  phi-elim(X);
-  dce(X);
-  crc(X);
-  dce(X);
-  slf(X);
-  dce(X);
-  predication(X);
-  dce(X);
-  ccp(X);
-  dce(X);
-  gvn(X);
-  dce(X);
-  lift-dc-math(X);
-  dce(X);
-  gvn(X);
-  dce(X);
-}
-
-macro codegen-prep!(X) {
-  verify(*);
-  ip-sroa(*);
-  sroa(*);
-  infer-schedules(X);
-  dce(X);
-  gcm(X);
-  dce(X);
-  phi-elim(X);
-  float-collections(X);
-  gcm(X);
-}
-
-juno-setup!(*);
-default!(*);
-// your stuff here.
-
-fixpoint stop after 13 {
-  forkify(*);
-  fork-guard-elim(*);
-  fork-coalesce(*);
-  phi-elim(*);
-  dce(*);
-}
-
-xdot[true](*);
-// serialize(*);
-
-fork-split(*);
-unforkify(*);
-
-gvn(*);
-dce(*);
-
-auto-outline(*);
-codegen-prep!(*);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 9888f3d2..81c46656 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -181,6 +181,7 @@ pub struct PassManager {
     pub loops: Option<Vec<LoopTree>>,
     pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
     pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
+    pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>,
     pub collection_objects: Option<CollectionObjects>,
     pub callgraph: Option<CallGraph>,
     pub devices: Option<Vec<Device>>,
@@ -216,6 +217,7 @@ impl PassManager {
             loops: None,
             reduce_cycles: None,
             data_nodes_in_fork_joins: None,
+            reduce_einsums: None,
             collection_objects: None,
             callgraph: None,
             devices: None,
@@ -392,6 +394,48 @@ impl PassManager {
         }
     }
 
+    pub fn make_reduce_einsums(&mut self) {
+        if self.reduce_einsums.is_none() {
+            self.make_def_uses();
+            self.make_typing();
+            self.make_fork_join_maps();
+            self.make_fork_join_nests();
+            self.make_data_nodes_in_fork_joins();
+            let def_uses = self.def_uses.as_ref().unwrap().iter();
+            let typing = self.typing.as_ref().unwrap().iter();
+            let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
+            let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter();
+            let data_nodes_in_fork_joins = self.data_nodes_in_fork_joins.as_ref().unwrap().iter();
+            self.reduce_einsums = Some(
+                self.functions
+                    .iter()
+                    .zip(def_uses)
+                    .zip(typing)
+                    .zip(fork_join_maps)
+                    .zip(fork_join_nests)
+                    .zip(data_nodes_in_fork_joins)
+                    .map(
+                        |(
+                            ((((function, def_use), typing), fork_join_map), fork_join_nest),
+                            data_nodes_in_fork_joins,
+                        )| {
+                            einsum(
+                                function,
+                                &self.types.borrow(),
+                                &self.constants.borrow(),
+                                def_use,
+                                typing,
+                                fork_join_map,
+                                fork_join_nest,
+                                data_nodes_in_fork_joins,
+                            )
+                        },
+                    )
+                    .collect(),
+            );
+        }
+    }
+
     pub fn make_collection_objects(&mut self) {
         if self.collection_objects.is_none() {
             self.make_reverse_postorders();
@@ -463,6 +507,7 @@ impl PassManager {
         self.loops = None;
         self.reduce_cycles = None;
         self.data_nodes_in_fork_joins = None;
+        self.reduce_einsums = None;
         self.collection_objects = None;
         self.callgraph = None;
         self.devices = None;
@@ -643,7 +688,7 @@ impl PassManager {
 }
 
 pub fn schedule_codegen(
-    mut module: Module,
+    module: Module,
     schedule: ScheduleStmt,
     mut stringtab: StringTable,
     mut env: Env<usize, Value>,
@@ -913,8 +958,8 @@ fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo>
             .labels
             .iter()
             .enumerate()
-            .filter(|(i, ls)| ls.contains(&label))
-            .map(|(i, ls)| i)
+            .filter(|(_, ls)| ls.contains(&label))
+            .map(|(i, _)| i)
             .collect::<Vec<_>>();
         for node in nodes {
             pm.functions[func.idx()].schedules[node].push(sched.clone());
@@ -1350,7 +1395,7 @@ fn run_pass(
                 let Some(mut func) = func else {
                     continue;
                 };
-                // TODO: uses direct return from forkify for now instead of 
+                // TODO: uses direct return from forkify for now instead of
                 // func.modified, see comment on top of `forkify` for why. Fix
                 // this eventually.
                 changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
-- 
GitLab