From a8452c2f86002512eb9f956881eaf3483bfaf4ad Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 20:56:25 -0600
Subject: [PATCH] Use enum for library call kinds

---
 hercules_ir/src/dot.rs                       |  2 +-
 hercules_ir/src/ir.rs                        | 10 +++++++++-
 hercules_opt/src/rewrite_math_expressions.rs | 19 ++++++++-----------
 3 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 1773f7dd..921a813d 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -323,7 +323,7 @@ fn write_node<W: Write>(
             args: _,
             ty: _,
             device,
-        } => write!(&mut suffix, "{} on {:?}", library_function, device)?,
+        } => write!(&mut suffix, "{:?} on {:?}", library_function, device)?,
         Node::Read {
             collect: _,
             indices,
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 5abc9adf..bf9698b3 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -223,7 +223,7 @@ pub enum Node {
         args: Box<[NodeID]>,
     },
     LibraryCall {
-        library_function: String,
+        library_function: LibraryFunction,
         args: Box<[NodeID]>,
         ty: TypeID,
         device: Device,
@@ -351,6 +351,14 @@ pub enum Device {
     AsyncRust,
 }
 
+/*
+ * The authoritative enumeration of supported library calls.
+ */
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
+pub enum LibraryFunction {
+    GEMM,
+}
+
 /*
  * A single node may have multiple schedules.
  */
diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
index 4d76c608..6f52dc58 100644
--- a/hercules_opt/src/rewrite_math_expressions.rs
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -22,7 +22,7 @@ define_language! {
         "+" = Add([Id; 2]),
         "*" = Mul([Id; 2]),
 
-        "library_matmul" = LibraryMatmul([Id; 2]),
+        "library_gemm" = LibraryGemm([Id; 2]),
 
         Opaque(Symbol),
     }
@@ -33,7 +33,7 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
         rewrite!("add-zero"; "(+ zero ?a)" => "?a"),
         rewrite!("mul-zero"; "(* zero ?a)" => "zero"),
         rewrite!("mul-one"; "(* one ?a)" => "?a"),
-        rewrite!("library-matmul"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_matmul ?A ?B)"),
+        rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"),
     ]
 }
 
@@ -65,7 +65,7 @@ pub fn rewrite_math_expressions(
     reduces.sort();
 
     for (_, reduce) in reduces {
-        let join = editor.func().nodes[reduce.idx()].try_reduce().unwrap().0;
+        let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap();
         let fork = join_fork_map[&join];
 
         // Step 2: convert the reduce to an expression in egg and rewrite it.
@@ -77,21 +77,18 @@ pub fn rewrite_math_expressions(
 
         // Step 3: match the smallest expression against patterns for known
         // library functions.
-        let matmul_pattern: Pattern<MathLanguage> = "(library_matmul ?A ?B)".parse().unwrap();
-        let mut matches = matmul_pattern.search(&egraph);
+        let gemm_pattern: Pattern<MathLanguage> = "(library_gemm ?A ?B)".parse().unwrap();
+        let mut matches = gemm_pattern.search(&egraph);
         if matches.len() > 0
             && let m = matches.remove(0)
             && m.substs.len() > 0
         {
             let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph);
             let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph);
-            let library_function = "gemm".to_string();
-            let args = vec![left_id, right_id].into_boxed_slice();
-            let ty = typing[reduce.idx()];
             let call = Node::LibraryCall {
-                library_function,
-                args,
-                ty,
+                library_function: LibraryFunction::GEMM,
+                args: vec![init, left_id, right_id].into_boxed_slice(),
+                ty: typing[reduce.idx()],
                 device,
             };
             let success = editor.edit(|mut edit| {
-- 
GitLab