From 0b1a9eb3c1a023f695fad05fd070c62610cd2a22 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 14 Feb 2025 15:02:58 -0600
Subject: [PATCH] Add intrinsics to math exprs, array slf gets rid of first
 ctrl pts look in gamut

---
 hercules_ir/src/einsum.rs     | 31 +++++++++++++++++++++++++++++++
 hercules_opt/src/utils.rs     |  7 +++++++
 juno_samples/cava/src/cpu.sch |  3 +++
 3 files changed, 41 insertions(+)

diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs
index 5dd0fe5d..b222e1bc 100644
--- a/hercules_ir/src/einsum.rs
+++ b/hercules_ir/src/einsum.rs
@@ -34,6 +34,7 @@ pub enum MathExpr {
     Unary(UnaryOperator, MathID),
     Binary(BinaryOperator, MathID, MathID),
     Ternary(TernaryOperator, MathID, MathID, MathID),
+    IntrinsicFunc(Intrinsic, Box<[MathID]>),
 }
 
 pub type MathEnv = Vec<MathExpr>;
@@ -224,6 +225,16 @@ impl<'a> EinsumContext<'a> {
                 let third = self.compute_math_expr(third);
                 MathExpr::Ternary(op, first, second, third)
             }
+            Node::IntrinsicCall {
+                intrinsic,
+                ref args,
+            } => {
+                let args = args
+                    .into_iter()
+                    .map(|id| self.compute_math_expr(*id))
+                    .collect();
+                MathExpr::IntrinsicFunc(intrinsic, args)
+            }
             Node::Read {
                 collect,
                 ref indices,
@@ -322,6 +333,14 @@ impl<'a> EinsumContext<'a> {
                 let third = self.substitute_new_dims(third);
                 self.intern_math_expr(MathExpr::Ternary(op, first, second, third))
             }
+            MathExpr::IntrinsicFunc(intrinsic, ref args) => {
+                let args = args
+                    .clone()
+                    .iter()
+                    .map(|id| self.substitute_new_dims(*id))
+                    .collect();
+                self.intern_math_expr(MathExpr::IntrinsicFunc(intrinsic, args))
+            }
             _ => id,
         }
     }
@@ -355,6 +374,9 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
                 stack.push(second);
                 stack.push(third);
             }
+            MathExpr::IntrinsicFunc(_, ref args) => {
+                stack.extend(args);
+            }
         }
     }
     set
@@ -412,5 +434,14 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) {
             debug_print_math_expr(third, env);
             print!(")");
         }
+        MathExpr::IntrinsicFunc(intrinsic, ref args) => {
+            print!("{}(", intrinsic.lower_case_name());
+            debug_print_math_expr(id, env);
+            for arg in args {
+                print!(", ");
+                debug_print_math_expr(*arg, env);
+            }
+            print!(")");
+        }
     }
 }
diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs
index c5e0d934..3f12ad7c 100644
--- a/hercules_opt/src/utils.rs
+++ b/hercules_opt/src/utils.rs
@@ -474,6 +474,13 @@ pub fn materialize_simple_einsum_expr(
                 third,
             })
         }
+        MathExpr::IntrinsicFunc(intrinsic, ref args) => {
+            let args = args
+                .into_iter()
+                .map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs))
+                .collect();
+            edit.add_node(Node::IntrinsicCall { intrinsic, args })
+        }
         _ => panic!(),
     }
 }
diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch
index 3021f6a0..1ae1dc13 100644
--- a/juno_samples/cava/src/cpu.sch
+++ b/juno_samples/cava/src/cpu.sch
@@ -102,6 +102,9 @@ simpl!(fuse4);
 fork-unroll(fuse4@channel_loop);
 simpl!(fuse4);
 //fork-fusion(fuse4@channel_loop);
+simpl!(fuse4);
+array-slf(fuse4);
+simpl!(fuse4);
 fork-split(fuse4);
 unforkify(fuse4);
 
-- 
GitLab