From 204084e2830561b675c77eca1c4c714590c5121e Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 15:00:57 -0600
Subject: [PATCH 01/13] Rewrite skeleton

---
 hercules_opt/src/lib.rs                      |  2 ++
 hercules_opt/src/rewrite_math_expressions.rs | 17 ++++++++++++++++
 juno_samples/matmul/src/gpu.sch              |  2 ++
 juno_scheduler/src/compile.rs                |  3 +++
 juno_scheduler/src/ir.rs                     |  1 +
 juno_scheduler/src/pm.rs                     | 21 ++++++++++++++++++++
 6 files changed, 46 insertions(+)
 create mode 100644 hercules_opt/src/rewrite_math_expressions.rs

diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index b56f9408..b25449e7 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -21,6 +21,7 @@ pub mod outline;
 pub mod phi_elim;
 pub mod pred;
 pub mod reuse_products;
+pub mod rewrite_math_expressions;
 pub mod schedule;
 pub mod simplify_cfg;
 pub mod slf;
@@ -49,6 +50,7 @@ pub use crate::outline::*;
 pub use crate::phi_elim::*;
 pub use crate::pred::*;
 pub use crate::reuse_products::*;
+pub use crate::rewrite_math_expressions::*;
 pub use crate::schedule::*;
 pub use crate::simplify_cfg::*;
 pub use crate::slf::*;
diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
new file mode 100644
index 00000000..32161a79
--- /dev/null
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -0,0 +1,17 @@
+use std::collections::{HashMap, HashSet};
+
+use hercules_ir::*;
+
+use crate::*;
+
+pub fn rewrite_math_expressions(
+    editor: &mut FunctionEditor,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+    reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
+) {
+    for (reduce, einsum) in reduce_einsums.1.iter() {
+        print!("{:?}: ", reduce);
+        debug_print_math_expr(*einsum, &reduce_einsums.0);
+        println!("");
+    }
+}
diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index edb83d74..c785fd5e 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -12,6 +12,8 @@ fixpoint {
 fork-coalesce(*);
 infer-schedules(*);
 dce(*);
+rewrite(*);
+xdot[true](*);
 
 let out = auto-outline(*);
 gpu(out.matmul);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 43871c90..e9132fd2 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -141,6 +141,9 @@ impl FromStr for Appliable {
             "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)),
             "rename" => Ok(Appliable::Pass(ir::Pass::Rename)),
             "reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)),
+            "rewrite" | "rewrite-math" | "rewrite-math-expressions" => {
+                Ok(Appliable::Pass(ir::Pass::RewriteMathExpressions))
+            }
             "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)),
             "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)),
             "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 5b6bd297..25cc5ef8 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -34,6 +34,7 @@ pub enum Pass {
     ReduceSLF,
     Rename,
     ReuseProducts,
+    RewriteMathExpressions,
     SLF,
     SROA,
     Serialize,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index d83ff0bb..461bc645 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2279,6 +2279,27 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::RewriteMathExpressions => {
+            assert!(args.is_empty());
+            pm.make_nodes_in_fork_joins();
+            pm.make_reduce_einsums();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            let reduce_einsums = pm.reduce_einsums.take().unwrap();
+            for ((func, nodes_in_fork_joins), reduce_einsums) in
+                build_selection(pm, selection, false)
+                    .into_iter()
+                    .zip(nodes_in_fork_joins.iter())
+                    .zip(reduce_einsums.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                rewrite_math_expressions(&mut func, nodes_in_fork_joins, reduce_einsums);
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::SLF => {
             assert!(args.is_empty());
             pm.make_reverse_postorders();
-- 
GitLab


From 6e2acddfc47ef27fb048e84a42d34bce046e831d Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 15:43:22 -0600
Subject: [PATCH 02/13] Construct egg expr

---
 Cargo.lock                                   | 125 ++++++++++++++++++-
 hercules_opt/Cargo.toml                      |   3 +-
 hercules_opt/src/rewrite_math_expressions.rs |  92 +++++++++++++-
 3 files changed, 216 insertions(+), 4 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 81c37d79..b6ca23d3 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -23,6 +23,12 @@ version = "0.5.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
 
+[[package]]
+name = "allocator-api2"
+version = "0.2.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
+
 [[package]]
 name = "anstream"
 version = "0.6.18"
@@ -621,6 +627,27 @@ version = "1.0.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
 
+[[package]]
+name = "egg"
+version = "0.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "abb749745461743bb477fba3ef87c663d5965876155c676c9489cfe0963de5ab"
+dependencies = [
+ "env_logger",
+ "hashbrown",
+ "indexmap",
+ "log",
+ "num-bigint",
+ "num-traits",
+ "quanta",
+ "rustc-hash",
+ "saturating",
+ "smallvec",
+ "symbol_table",
+ "symbolic_expressions",
+ "thiserror",
+]
+
 [[package]]
 name = "either"
 version = "1.13.0"
@@ -639,6 +666,15 @@ version = "0.6.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
 
+[[package]]
+name = "env_logger"
+version = "0.9.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7"
+dependencies = [
+ "log",
+]
+
 [[package]]
 name = "equivalent"
 version = "1.0.1"
@@ -752,6 +788,12 @@ version = "1.0.7"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
 
+[[package]]
+name = "foldhash"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f"
+
 [[package]]
 name = "funty"
 version = "1.1.0"
@@ -882,6 +924,11 @@ name = "hashbrown"
 version = "0.15.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
+dependencies = [
+ "allocator-api2",
+ "equivalent",
+ "foldhash",
+]
 
 [[package]]
 name = "heapless"
@@ -955,6 +1002,7 @@ version = "0.1.0"
 dependencies = [
  "bimap",
  "bitvec 1.0.1",
+ "egg",
  "either",
  "hercules_cg",
  "hercules_ir",
@@ -1177,7 +1225,7 @@ dependencies = [
  "async-std",
  "hercules_rt",
  "juno_build",
- "rand 0.8.5",
+ "rand 0.9.0",
  "with_builtin_macros",
 ]
 
@@ -1921,6 +1969,21 @@ dependencies = [
  "bytemuck",
 ]
 
+[[package]]
+name = "quanta"
+version = "0.12.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e"
+dependencies = [
+ "crossbeam-utils",
+ "libc",
+ "once_cell",
+ "raw-cpuid",
+ "wasi 0.11.0+wasi-snapshot-preview1",
+ "web-sys",
+ "winapi",
+]
+
 [[package]]
 name = "quick-error"
 version = "2.0.1"
@@ -2061,6 +2124,15 @@ dependencies = [
  "rgb",
 ]
 
+[[package]]
+name = "raw-cpuid"
+version = "11.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "529468c1335c1c03919960dfefdb1b3648858c20d7ec2d0663e728e4a717efbc"
+dependencies = [
+ "bitflags 2.8.0",
+]
+
 [[package]]
 name = "rayon"
 version = "1.10.0"
@@ -2119,6 +2191,12 @@ version = "0.8.50"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a"
 
+[[package]]
+name = "rustc-hash"
+version = "2.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
+
 [[package]]
 name = "rustc_version"
 version = "0.4.1"
@@ -2153,6 +2231,12 @@ version = "1.0.19"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
 
+[[package]]
+name = "saturating"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71"
+
 [[package]]
 name = "scopeguard"
 version = "1.2.0"
@@ -2284,6 +2368,23 @@ version = "0.11.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
 
+[[package]]
+name = "symbol_table"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f19bffd69fb182e684d14e3c71d04c0ef33d1641ac0b9e81c712c734e83703bc"
+dependencies = [
+ "crossbeam-utils",
+ "foldhash",
+ "hashbrown",
+]
+
+[[package]]
+name = "symbolic_expressions"
+version = "5.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7c68d531d83ec6c531150584c42a4290911964d5f0d79132b193b67252a23b71"
+
 [[package]]
 name = "syn"
 version = "1.0.109"
@@ -2648,6 +2749,28 @@ version = "0.1.8"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
 
+[[package]]
+name = "winapi"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
+dependencies = [
+ "winapi-i686-pc-windows-gnu",
+ "winapi-x86_64-pc-windows-gnu",
+]
+
+[[package]]
+name = "winapi-i686-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
+
+[[package]]
+name = "winapi-x86_64-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
+
 [[package]]
 name = "windows"
 version = "0.59.0"
diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml
index 0b250d28..92e05339 100644
--- a/hercules_opt/Cargo.toml
+++ b/hercules_opt/Cargo.toml
@@ -21,4 +21,5 @@ serde = { version = "*", features = ["derive"] }
 hercules_cg = { path = "../hercules_cg" }
 hercules_ir = { path = "../hercules_ir" }
 nestify = "*"
-bimap = "*"
\ No newline at end of file
+bimap = "*"
+egg = "*"
diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
index 32161a79..53526a95 100644
--- a/hercules_opt/src/rewrite_math_expressions.rs
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -1,9 +1,39 @@
 use std::collections::{HashMap, HashSet};
+use std::fmt::{Error, Write};
 
 use hercules_ir::*;
 
+use egg::*;
+
 use crate::*;
 
+define_language! {
+    enum MathLanguage {
+        "zero" = Zero,
+        "one" = One,
+
+        ForkDim(i64),
+        "tid" = ThreadID(Id),
+        "sum" = SumReduction(Box<[Id]>),
+        "array" = Comprehension(Box<[Id]>),
+
+        "read" = Read(Box<[Id]>),
+
+        "+" = Add([Id; 2]),
+        "*" = Mul([Id; 2]),
+
+        Opaque(Symbol),
+    }
+}
+
+fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
+    vec![
+        rewrite!("add-zero"; "(+ zero ?a)" => "?a"),
+        rewrite!("mul-zero"; "(* zero ?a)" => "zero"),
+        rewrite!("mul-one"; "(* one ?a)" => "?a"),
+    ]
+}
+
 pub fn rewrite_math_expressions(
     editor: &mut FunctionEditor,
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
@@ -11,7 +41,65 @@ pub fn rewrite_math_expressions(
 ) {
     for (reduce, einsum) in reduce_einsums.1.iter() {
         print!("{:?}: ", reduce);
-        debug_print_math_expr(*einsum, &reduce_einsums.0);
-        println!("");
+        let mut s = String::new();
+        egg_print_math_expr(*einsum, &reduce_einsums.0, &mut s).unwrap();
+        println!("{}", s);
+
+        let expr: RecExpr<MathLanguage> = s.parse().unwrap();
+        let runner = Runner::default().with_expr(&expr).run(&make_rules());
+        let root = runner.roots[0];
+        let extractor = Extractor::new(&runner.egraph, AstSize);
+        let (best_cost, best) = extractor.find_best(root);
+        println!("Simplified {} to {} with cost {}", expr, best, best_cost);
+    }
+}
+
+pub fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> {
+    match env[id.idx()] {
+        MathExpr::Zero(_) => write!(w, "zero"),
+        MathExpr::One(_) => write!(w, "one"),
+        MathExpr::OpaqueNode(id) => write!(w, "n{}", id.idx()),
+        MathExpr::ThreadID(dim) => write!(w, "{}", dim.0),
+        MathExpr::SumReduction(id, ref dims) => {
+            write!(w, "(sum")?;
+            for dim in dims {
+                write!(w, " {}", dim.0)?;
+            }
+            write!(w, " ")?;
+            egg_print_math_expr(id, env, w)?;
+            write!(w, ")")
+        }
+        MathExpr::Comprehension(id, ref dims) => {
+            write!(w, "(array")?;
+            for dim in dims {
+                write!(w, " {}", dim.0)?;
+            }
+            write!(w, " ")?;
+            egg_print_math_expr(id, env, w)?;
+            write!(w, ")")
+        }
+        MathExpr::Read(id, ref pos) => {
+            write!(w, "(read")?;
+            for pos in pos {
+                write!(w, " ")?;
+                egg_print_math_expr(*pos, env, w)?;
+            }
+            write!(w, " ")?;
+            egg_print_math_expr(id, env, w)?;
+            write!(w, ")")
+        }
+        MathExpr::Binary(op, left, right) => {
+            write!(w, "(")?;
+            match op {
+                BinaryOperator::Add => write!(w, "+ "),
+                BinaryOperator::Mul => write!(w, "* "),
+                _ => Err(Error::default()),
+            }?;
+            egg_print_math_expr(left, env, w)?;
+            write!(w, " ")?;
+            egg_print_math_expr(right, env, w)?;
+            write!(w, ")")
+        }
+        _ => Err(Error::default()),
     }
 }
-- 
GitLab


From 0c2289d49779500adeb9c33d0ec135cdd0bfc504 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 15:55:54 -0600
Subject: [PATCH 03/13] Identify matmul

---
 hercules_opt/src/rewrite_math_expressions.rs | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
index 53526a95..a0493e54 100644
--- a/hercules_opt/src/rewrite_math_expressions.rs
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -22,6 +22,8 @@ define_language! {
         "+" = Add([Id; 2]),
         "*" = Mul([Id; 2]),
 
+        "library_matmul" = LibraryMatmul([Id; 2]),
+
         Opaque(Symbol),
     }
 }
@@ -31,6 +33,7 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
         rewrite!("add-zero"; "(+ zero ?a)" => "?a"),
         rewrite!("mul-zero"; "(* zero ?a)" => "zero"),
         rewrite!("mul-one"; "(* one ?a)" => "?a"),
+        rewrite!("library-matmul"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_matmul ?A ?B)"),
     ]
 }
 
@@ -40,10 +43,9 @@ pub fn rewrite_math_expressions(
     reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
 ) {
     for (reduce, einsum) in reduce_einsums.1.iter() {
-        print!("{:?}: ", reduce);
+        println!("{:?}", reduce);
         let mut s = String::new();
         egg_print_math_expr(*einsum, &reduce_einsums.0, &mut s).unwrap();
-        println!("{}", s);
 
         let expr: RecExpr<MathLanguage> = s.parse().unwrap();
         let runner = Runner::default().with_expr(&expr).run(&make_rules());
-- 
GitLab


From f74b4e644602b3e0dbd48f3a7c2562df355c6aeb Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 16:13:35 -0600
Subject: [PATCH 04/13] Get outer reduces first

---
 hercules_opt/src/rewrite_math_expressions.rs | 20 ++++++++++++++++++--
 1 file changed, 18 insertions(+), 2 deletions(-)

diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
index a0493e54..2fd4bbf7 100644
--- a/hercules_opt/src/rewrite_math_expressions.rs
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -42,10 +42,26 @@ pub fn rewrite_math_expressions(
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
     reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
 ) {
-    for (reduce, einsum) in reduce_einsums.1.iter() {
+    // Step 1: figure out how many fork-joins each reduce is in. We want to
+    // rewrite outer reductions before inner ones.
+    let nodes = &editor.func().nodes;
+    let mut depth: HashMap<NodeID, u32> = HashMap::new();
+    for (_, inside) in nodes_in_fork_joins {
+        for id in inside {
+            if nodes[id.idx()].is_reduce() {
+                *depth.entry(*id).or_default() += 1;
+            }
+        }
+    }
+    let mut reduces: Vec<(u32, NodeID)> =
+        depth.into_iter().map(|(id, depth)| (depth, id)).collect();
+    reduces.sort();
+
+    for (_, reduce) in reduces {
+        // Step 2: convert the reduce to an expression in egg and rewrite it.
         println!("{:?}", reduce);
         let mut s = String::new();
-        egg_print_math_expr(*einsum, &reduce_einsums.0, &mut s).unwrap();
+        egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap();
 
         let expr: RecExpr<MathLanguage> = s.parse().unwrap();
         let runner = Runner::default().with_expr(&expr).run(&make_rules());
-- 
GitLab


From 88ed34d67fd97589c926be9db48fbb6563cf8114 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 17:31:37 -0600
Subject: [PATCH 05/13] Get the node ID arguments for matmul

---
 hercules_opt/src/rewrite_math_expressions.rs | 35 ++++++++++++++++----
 juno_scheduler/src/pm.rs                     |  7 +++-
 2 files changed, 34 insertions(+), 8 deletions(-)

diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
index 2fd4bbf7..e1680e43 100644
--- a/hercules_opt/src/rewrite_math_expressions.rs
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -39,6 +39,7 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
 
 pub fn rewrite_math_expressions(
     editor: &mut FunctionEditor,
+    device: Device,
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
     reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
 ) {
@@ -59,20 +60,40 @@ pub fn rewrite_math_expressions(
 
     for (_, reduce) in reduces {
         // Step 2: convert the reduce to an expression in egg and rewrite it.
-        println!("{:?}", reduce);
         let mut s = String::new();
         egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap();
 
         let expr: RecExpr<MathLanguage> = s.parse().unwrap();
-        let runner = Runner::default().with_expr(&expr).run(&make_rules());
-        let root = runner.roots[0];
-        let extractor = Extractor::new(&runner.egraph, AstSize);
-        let (best_cost, best) = extractor.find_best(root);
-        println!("Simplified {} to {} with cost {}", expr, best, best_cost);
+        let egraph = Runner::default().with_expr(&expr).run(&make_rules()).egraph;
+
+        // Step 3: match the smallest expression against patterns for known
+        // library function.
+        let matmul_pattern: Pattern<MathLanguage> = "(library_matmul ?A ?B)".parse().unwrap();
+        let mut matches = matmul_pattern.search(&egraph);
+        if matches.len() > 0
+            && let m = matches.remove(0)
+            && m.substs.len() > 0
+        {
+            let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph);
+            let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph);
+            println!("{:?} {:?}", left_id, right_id);
+            return;
+        }
     }
 }
 
-pub fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> {
+fn get_node_ids_from_subst(var: &str, subst: &Subst, egraph: &EGraph<MathLanguage, ()>) -> NodeID {
+    let id = *subst.get(var.parse::<Var>().unwrap()).unwrap();
+    let expr = egraph.id_to_expr(id);
+    let MathLanguage::Opaque(sym) = expr.last().unwrap() else {
+        todo!();
+    };
+    let sym = sym.as_str();
+    assert!(sym.chars().nth(0).unwrap() == 'n');
+    NodeID::new(sym[1..].parse().unwrap())
+}
+
+fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> {
     match env[id.idx()] {
         MathExpr::Zero(_) => write!(w, "zero"),
         MathExpr::One(_) => write!(w, "one"),
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 461bc645..fbecfc1f 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2294,7 +2294,12 @@ fn run_pass(
                 let Some(mut func) = func else {
                     continue;
                 };
-                rewrite_math_expressions(&mut func, nodes_in_fork_joins, reduce_einsums);
+                rewrite_math_expressions(
+                    &mut func,
+                    Device::LLVM,
+                    nodes_in_fork_joins,
+                    reduce_einsums,
+                );
                 changed |= func.modified();
             }
             pm.delete_gravestones();
-- 
GitLab


From b4fc0252499d01ce9ce571a8e1021afb374347f4 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 17:45:52 -0600
Subject: [PATCH 06/13] LibraryCall node

---
 hercules_ir/src/def_use.rs   | 16 ++++++++++++----
 hercules_ir/src/ir.rs        | 15 +++++++++++++++
 hercules_ir/src/typecheck.rs |  5 +++++
 hercules_opt/src/ccp.rs      | 10 ++++++++++
 4 files changed, 42 insertions(+), 4 deletions(-)

diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs
index d60f7fd5..77ef83fb 100644
--- a/hercules_ir/src/def_use.rs
+++ b/hercules_ir/src/def_use.rs
@@ -178,7 +178,12 @@ pub fn get_uses(node: &Node) -> NodeUses {
             uses.extend(args);
             NodeUses::Variable(uses.into_boxed_slice())
         }
-        Node::IntrinsicCall { intrinsic: _, args } => NodeUses::Variable(args.clone()),
+        Node::IntrinsicCall { intrinsic: _, args }
+        | Node::LibraryCall {
+            library_function: _,
+            args,
+            ty: _,
+        } => NodeUses::Variable(args.clone()),
         Node::Read { collect, indices } => {
             let mut uses = vec![];
             for index in indices.iter() {
@@ -276,9 +281,12 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
             uses.extend(args);
             NodeUsesMut::Variable(uses.into_boxed_slice())
         }
-        Node::IntrinsicCall { intrinsic: _, args } => {
-            NodeUsesMut::Variable(args.iter_mut().collect())
-        }
+        Node::IntrinsicCall { intrinsic: _, args }
+        | Node::LibraryCall {
+            library_function: _,
+            args,
+            ty: _,
+        } => NodeUsesMut::Variable(args.iter_mut().collect()),
         Node::Read { collect, indices } => {
             let mut uses = vec![];
             for index in indices.iter_mut() {
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index f96cc10c..5ba48629 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -222,6 +222,11 @@ pub enum Node {
         intrinsic: Intrinsic,
         args: Box<[NodeID]>,
     },
+    LibraryCall {
+        library_function: String,
+        args: Box<[NodeID]>,
+        ty: TypeID,
+    },
     Read {
         collect: NodeID,
         indices: Box<[Index]>,
@@ -1531,6 +1536,11 @@ impl Node {
                 intrinsic: _,
                 args: _,
             } => "Intrinsic",
+            Node::LibraryCall {
+                library_function: _,
+                args: _,
+                ty: _,
+            } => "Library",
             Node::Read {
                 collect: _,
                 indices: _,
@@ -1604,6 +1614,11 @@ impl Node {
                 intrinsic: _,
                 args: _,
             } => "intrinsic",
+            Node::LibraryCall {
+                library_function: _,
+                args: _,
+                ty: _,
+            } => "library",
             Node::Read {
                 collect: _,
                 indices: _,
diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index d01a5c58..fca258a7 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -935,6 +935,11 @@ fn typeflow(
                 }
             }
         }
+        Node::LibraryCall {
+            library_function: _,
+            args: _,
+            ty,
+        } => Concrete(*ty),
         Node::Read {
             collect: _,
             indices,
diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs
index 1969430a..a7df71f5 100644
--- a/hercules_opt/src/ccp.rs
+++ b/hercules_opt/src/ccp.rs
@@ -933,6 +933,16 @@ fn ccp_flow_function(
                 constant: new_constant,
             }
         }
+        Node::LibraryCall {
+            library_function: _,
+            args,
+            ty: _,
+        } => CCPLattice {
+            reachability: args.iter().fold(ReachabilityLattice::bottom(), |val, id| {
+                ReachabilityLattice::join(&val, &inputs[id.idx()].reachability)
+            }),
+            constant: ConstantLattice::bottom(),
+        },
         Node::Read { collect, indices } => {
             let mut reachability = inputs[collect.idx()].reachability.clone();
             for index in indices.iter() {
-- 
GitLab


From 2e3bd4d76078dff45c606ce80ad08ed912fac319 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 18:02:02 -0600
Subject: [PATCH 07/13] Replace reduce with library call node

---
 hercules_opt/src/rewrite_math_expressions.rs | 39 +++++++++++++++++---
 hercules_opt/src/simplify_cfg.rs             |  4 +-
 juno_samples/matmul/src/gpu.sch              |  5 ++-
 juno_scheduler/src/pm.rs                     | 10 ++++-
 4 files changed, 48 insertions(+), 10 deletions(-)

diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
index e1680e43..1b968243 100644
--- a/hercules_opt/src/rewrite_math_expressions.rs
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -40,16 +40,22 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
 pub fn rewrite_math_expressions(
     editor: &mut FunctionEditor,
     device: Device,
+    typing: &Vec<TypeID>,
+    fork_join_map: &HashMap<NodeID, NodeID>,
     nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
     reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
 ) {
+    let join_fork_map: HashMap<_, _> = fork_join_map
+        .into_iter()
+        .map(|(fork, join)| (*join, *fork))
+        .collect();
+
     // Step 1: figure out how many fork-joins each reduce is in. We want to
     // rewrite outer reductions before inner ones.
-    let nodes = &editor.func().nodes;
     let mut depth: HashMap<NodeID, u32> = HashMap::new();
     for (_, inside) in nodes_in_fork_joins {
         for id in inside {
-            if nodes[id.idx()].is_reduce() {
+            if editor.func().nodes[id.idx()].is_reduce() {
                 *depth.entry(*id).or_default() += 1;
             }
         }
@@ -59,6 +65,9 @@ pub fn rewrite_math_expressions(
     reduces.sort();
 
     for (_, reduce) in reduces {
+        let join = editor.func().nodes[reduce.idx()].try_reduce().unwrap().0;
+        let fork = join_fork_map[&join];
+
         // Step 2: convert the reduce to an expression in egg and rewrite it.
         let mut s = String::new();
         egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap();
@@ -67,7 +76,7 @@ pub fn rewrite_math_expressions(
         let egraph = Runner::default().with_expr(&expr).run(&make_rules()).egraph;
 
         // Step 3: match the smallest expression against patterns for known
-        // library function.
+        // library functions.
         let matmul_pattern: Pattern<MathLanguage> = "(library_matmul ?A ?B)".parse().unwrap();
         let mut matches = matmul_pattern.search(&egraph);
         if matches.len() > 0
@@ -76,8 +85,28 @@ pub fn rewrite_math_expressions(
         {
             let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph);
             let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph);
-            println!("{:?} {:?}", left_id, right_id);
-            return;
+            let library_function = match device {
+                Device::LLVM => "blas_gemm",
+                Device::CUDA => todo!(),
+                _ => panic!(),
+            }
+            .to_string();
+            let args = vec![left_id, right_id].into_boxed_slice();
+            let ty = typing[reduce.idx()];
+            let call = Node::LibraryCall {
+                library_function,
+                args,
+                ty,
+            };
+            let success = editor.edit(|mut edit| {
+                let call = edit.add_node(call);
+                edit.replace_all_uses_where(reduce, call, |id| {
+                    !nodes_in_fork_joins[&fork].contains(id)
+                })
+            });
+            if success {
+                return;
+            }
         }
     }
 }
diff --git a/hercules_opt/src/simplify_cfg.rs b/hercules_opt/src/simplify_cfg.rs
index 14a152dc..cf39db2b 100644
--- a/hercules_opt/src/simplify_cfg.rs
+++ b/hercules_opt/src/simplify_cfg.rs
@@ -126,9 +126,7 @@ fn remove_useless_fork_joins(
 
     // Third, get rid of fork-joins.
     for (fork, join) in fork_join_map {
-        if editor.get_users(*join).len() == 1 {
-            assert_eq!(editor.get_users(*fork).len(), 1);
-
+        if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 {
             let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0];
             let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0];
 
diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index c785fd5e..39139cdb 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -13,7 +13,10 @@ fork-coalesce(*);
 infer-schedules(*);
 dce(*);
 rewrite(*);
-xdot[true](*);
+fixpoint {
+  simplify-cfg(*);
+  dce(*);
+}
 
 let out = auto-outline(*);
 gpu(out.matmul);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index fbecfc1f..2a270070 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2281,13 +2281,19 @@ fn run_pass(
         }
         Pass::RewriteMathExpressions => {
             assert!(args.is_empty());
+            pm.make_typing();
+            pm.make_fork_join_maps();
             pm.make_nodes_in_fork_joins();
             pm.make_reduce_einsums();
+            let typing = pm.typing.take().unwrap();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
             let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
             let reduce_einsums = pm.reduce_einsums.take().unwrap();
-            for ((func, nodes_in_fork_joins), reduce_einsums) in
+            for ((((func, typing), fork_join_map), nodes_in_fork_joins), reduce_einsums) in
                 build_selection(pm, selection, false)
                     .into_iter()
+                    .zip(typing.iter())
+                    .zip(fork_join_maps.iter())
                     .zip(nodes_in_fork_joins.iter())
                     .zip(reduce_einsums.iter())
             {
@@ -2297,6 +2303,8 @@ fn run_pass(
                 rewrite_math_expressions(
                     &mut func,
                     Device::LLVM,
+                    typing,
+                    fork_join_map,
                     nodes_in_fork_joins,
                     reduce_einsums,
                 );
-- 
GitLab


From 1b67a3db5b1a6d7b175156edc448d1ddcc38ce46 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 20:43:43 -0600
Subject: [PATCH 08/13] Add device member to LibraryCall

---
 hercules_ir/src/def_use.rs                   | 2 ++
 hercules_ir/src/dot.rs                       | 6 ++++++
 hercules_ir/src/ir.rs                        | 5 ++++-
 hercules_ir/src/typecheck.rs                 | 1 +
 hercules_opt/src/ccp.rs                      | 1 +
 hercules_opt/src/rewrite_math_expressions.rs | 8 ++------
 juno_samples/matmul/src/gpu.sch              | 3 +--
 juno_scheduler/src/pm.rs                     | 2 +-
 8 files changed, 18 insertions(+), 10 deletions(-)

diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs
index 77ef83fb..ff0e08ed 100644
--- a/hercules_ir/src/def_use.rs
+++ b/hercules_ir/src/def_use.rs
@@ -183,6 +183,7 @@ pub fn get_uses(node: &Node) -> NodeUses {
             library_function: _,
             args,
             ty: _,
+            device: _,
         } => NodeUses::Variable(args.clone()),
         Node::Read { collect, indices } => {
             let mut uses = vec![];
@@ -286,6 +287,7 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
             library_function: _,
             args,
             ty: _,
+            device: _,
         } => NodeUsesMut::Variable(args.iter_mut().collect()),
         Node::Read { collect, indices } => {
             let mut uses = vec![];
diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 0e084085..1773f7dd 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -318,6 +318,12 @@ fn write_node<W: Write>(
         Node::IntrinsicCall { intrinsic, args: _ } => {
             write!(&mut suffix, "{}", intrinsic.lower_case_name())?
         }
+        Node::LibraryCall {
+            library_function,
+            args: _,
+            ty: _,
+            device,
+        } => write!(&mut suffix, "{} on {:?}", library_function, device)?,
         Node::Read {
             collect: _,
             indices,
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 5ba48629..5abc9adf 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -226,6 +226,7 @@ pub enum Node {
         library_function: String,
         args: Box<[NodeID]>,
         ty: TypeID,
+        device: Device,
     },
     Read {
         collect: NodeID,
@@ -341,7 +342,7 @@ pub enum Schedule {
  * The authoritative enumeration of supported backends. Multiple backends may
  * correspond to the same kind of hardware.
  */
-#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 pub enum Device {
     LLVM,
     CUDA,
@@ -1540,6 +1541,7 @@ impl Node {
                 library_function: _,
                 args: _,
                 ty: _,
+                device: _,
             } => "Library",
             Node::Read {
                 collect: _,
@@ -1618,6 +1620,7 @@ impl Node {
                 library_function: _,
                 args: _,
                 ty: _,
+                device: _,
             } => "library",
             Node::Read {
                 collect: _,
diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index fca258a7..1ff890db 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -939,6 +939,7 @@ fn typeflow(
             library_function: _,
             args: _,
             ty,
+            device: _,
         } => Concrete(*ty),
         Node::Read {
             collect: _,
diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs
index a7df71f5..b626148c 100644
--- a/hercules_opt/src/ccp.rs
+++ b/hercules_opt/src/ccp.rs
@@ -937,6 +937,7 @@ fn ccp_flow_function(
             library_function: _,
             args,
             ty: _,
+            device: _,
         } => CCPLattice {
             reachability: args.iter().fold(ReachabilityLattice::bottom(), |val, id| {
                 ReachabilityLattice::join(&val, &inputs[id.idx()].reachability)
diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
index 1b968243..4d76c608 100644
--- a/hercules_opt/src/rewrite_math_expressions.rs
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -85,18 +85,14 @@ pub fn rewrite_math_expressions(
         {
             let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph);
             let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph);
-            let library_function = match device {
-                Device::LLVM => "blas_gemm",
-                Device::CUDA => todo!(),
-                _ => panic!(),
-            }
-            .to_string();
+            let library_function = "gemm".to_string();
             let args = vec![left_id, right_id].into_boxed_slice();
             let ty = typing[reduce.idx()];
             let call = Node::LibraryCall {
                 library_function,
                 args,
                 ty,
+                device,
             };
             let success = editor.edit(|mut edit| {
                 let call = edit.add_node(call);
diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index 39139cdb..35ed1e84 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -18,11 +18,10 @@ fixpoint {
   dce(*);
 }
 
-let out = auto-outline(*);
-gpu(out.matmul);
 ip-sroa(*);
 sroa(*);
 dce(*);
+xdot[true](*);
 
 float-collections(*);
 gcm(*);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 2a270070..e8899fae 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2302,7 +2302,7 @@ fn run_pass(
                 };
                 rewrite_math_expressions(
                     &mut func,
-                    Device::LLVM,
+                    Device::CUDA,
                     typing,
                     fork_join_map,
                     nodes_in_fork_joins,
-- 
GitLab


From a8452c2f86002512eb9f956881eaf3483bfaf4ad Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 20:56:25 -0600
Subject: [PATCH 09/13] Use enum for library call kinds

---
 hercules_ir/src/dot.rs                       |  2 +-
 hercules_ir/src/ir.rs                        | 10 +++++++++-
 hercules_opt/src/rewrite_math_expressions.rs | 19 ++++++++-----------
 3 files changed, 18 insertions(+), 13 deletions(-)

diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 1773f7dd..921a813d 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -323,7 +323,7 @@ fn write_node<W: Write>(
             args: _,
             ty: _,
             device,
-        } => write!(&mut suffix, "{} on {:?}", library_function, device)?,
+        } => write!(&mut suffix, "{:?} on {:?}", library_function, device)?,
         Node::Read {
             collect: _,
             indices,
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 5abc9adf..bf9698b3 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -223,7 +223,7 @@ pub enum Node {
         args: Box<[NodeID]>,
     },
     LibraryCall {
-        library_function: String,
+        library_function: LibraryFunction,
         args: Box<[NodeID]>,
         ty: TypeID,
         device: Device,
@@ -351,6 +351,14 @@ pub enum Device {
     AsyncRust,
 }
 
+/*
+ * The authoritative enumeration of supported library calls.
+ */
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
+pub enum LibraryFunction {
+    GEMM,
+}
+
 /*
  * A single node may have multiple schedules.
  */
diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
index 4d76c608..6f52dc58 100644
--- a/hercules_opt/src/rewrite_math_expressions.rs
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -22,7 +22,7 @@ define_language! {
         "+" = Add([Id; 2]),
         "*" = Mul([Id; 2]),
 
-        "library_matmul" = LibraryMatmul([Id; 2]),
+        "library_gemm" = LibraryGemm([Id; 2]),
 
         Opaque(Symbol),
     }
@@ -33,7 +33,7 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
         rewrite!("add-zero"; "(+ zero ?a)" => "?a"),
         rewrite!("mul-zero"; "(* zero ?a)" => "zero"),
         rewrite!("mul-one"; "(* one ?a)" => "?a"),
-        rewrite!("library-matmul"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_matmul ?A ?B)"),
+        rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"),
     ]
 }
 
@@ -65,7 +65,7 @@ pub fn rewrite_math_expressions(
     reduces.sort();
 
     for (_, reduce) in reduces {
-        let join = editor.func().nodes[reduce.idx()].try_reduce().unwrap().0;
+        let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap();
         let fork = join_fork_map[&join];
 
         // Step 2: convert the reduce to an expression in egg and rewrite it.
@@ -77,21 +77,18 @@ pub fn rewrite_math_expressions(
 
         // Step 3: match the smallest expression against patterns for known
         // library functions.
-        let matmul_pattern: Pattern<MathLanguage> = "(library_matmul ?A ?B)".parse().unwrap();
-        let mut matches = matmul_pattern.search(&egraph);
+        let gemm_pattern: Pattern<MathLanguage> = "(library_gemm ?A ?B)".parse().unwrap();
+        let mut matches = gemm_pattern.search(&egraph);
         if matches.len() > 0
             && let m = matches.remove(0)
             && m.substs.len() > 0
         {
             let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph);
             let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph);
-            let library_function = "gemm".to_string();
-            let args = vec![left_id, right_id].into_boxed_slice();
-            let ty = typing[reduce.idx()];
             let call = Node::LibraryCall {
-                library_function,
-                args,
-                ty,
+                library_function: LibraryFunction::GEMM,
+                args: vec![init, left_id, right_id].into_boxed_slice(),
+                ty: typing[reduce.idx()],
                 device,
             };
             let success = editor.edit(|mut edit| {
-- 
GitLab


From 8f806736e67dd0661578464c681ed76aa60a4d99 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 21:14:47 -0600
Subject: [PATCH 10/13] Add all the stuff related to collections for
 LibraryCall nodes

---
 hercules_ir/src/collections.rs  |  32 ++++++++-
 hercules_ir/src/device.rs       | 115 --------------------------------
 hercules_opt/src/gcm.rs         |  37 ++++++++++
 juno_samples/matmul/src/gpu.sch |   2 +-
 4 files changed, 69 insertions(+), 117 deletions(-)

diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index f3474ae0..6b631519 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -218,6 +218,8 @@ pub fn collection_objects(
         // - Constant: may originate an object.
         // - Call: may originate an object and may return an object passed in as
         //   a parameter.
+        // - LibraryCall: may return an object passed in as a parameter, but may
+        //   not originate an object.
         // - Read: may extract a smaller object from the input - this is
         //   considered to be the same object as the input, as no copy takes
         //   place.
@@ -288,6 +290,14 @@ pub fn collection_objects(
                     }
                     CollectionObjectLattice { objs }
                 }
+                Node::LibraryCall {
+                    library_function,
+                    args: _,
+                    ty: _,
+                    device: _,
+                } => match library_function {
+                    LibraryFunction::GEMM => inputs[0].clone(),
+                },
                 Node::Undef { ty: _ } => {
                     let obj = origins
                         .iter()
@@ -332,7 +342,13 @@ pub fn collection_objects(
                 for object in objects_per_node[idx].iter() {
                     mutated[object.idx()].push(NodeID::new(idx));
                 }
-            } else if let Some((_, callee, _, args)) = node.try_call() {
+            } else if let Node::Call {
+                control: _,
+                function: callee,
+                dynamic_constants: _,
+                args,
+            } = node
+            {
                 let fco = &collection_objects[&callee];
                 for (param_idx, arg) in args.into_iter().enumerate() {
                     // If this parameter corresponds to an object and it's
@@ -347,6 +363,20 @@ pub fn collection_objects(
                         }
                     }
                 }
+            } else if let Node::LibraryCall {
+                library_function,
+                args,
+                ty: _,
+                device: _,
+            } = node
+            {
+                match library_function {
+                    LibraryFunction::GEMM => {
+                        for object in objects_per_node[args[0].idx()].iter() {
+                            mutated[object.idx()].push(NodeID::new(idx));
+                        }
+                    }
+                }
             }
         }
 
diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs
index cbf8d634..c4a5454b 100644
--- a/hercules_ir/src/device.rs
+++ b/hercules_ir/src/device.rs
@@ -23,118 +23,3 @@ pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec
 
     devices
 }
-
-pub type FunctionObjectDeviceDemands = Vec<BTreeSet<Device>>;
-pub type ObjectDeviceDemands = Vec<FunctionObjectDeviceDemands>;
-
-/*
- * This analysis figures out which device each collection object may be on. At
- * first, an object may need to be on different devices at different times. This
- * is fine during optimization.
- */
-pub fn object_device_demands(
-    functions: &Vec<Function>,
-    types: &Vec<Type>,
-    typing: &ModuleTyping,
-    callgraph: &CallGraph,
-    objects: &CollectionObjects,
-    devices: &Vec<Device>,
-) -> ObjectDeviceDemands {
-    // An object is "demanded" on a device when:
-    // 1. The object is used by a primitive read node or write node in a device
-    //    function. This includes objects on the `data` input to write nodes.
-    //    Non-primitive reads don't demand an object on a device since they are
-    //    lowered to pointer math and no actual memory transfers.
-    // 2. The object is a constant / undef defined in a device function.
-    // 3. The object is passed as input to a call node where the corresponding
-    //    object in the callee is demanded on a device.
-    // 4. The object is returned from a call node where the corresponding object
-    //    in the callee is demanded on a device.
-    // Note that reads and writes in a RT function don't induce a device demand.
-    // This is because RT functions can  call device functions as necessary to
-    // arbitrarily move data onto / off of devices (though this may be slow).
-    // Traverse the functions in a module in reverse topological order, since
-    // the analysis of a function depends on all functions it calls.
-    let mut demands: ObjectDeviceDemands = vec![vec![]; functions.len()];
-    let topo = callgraph.topo();
-
-    for func_id in topo {
-        let function = &functions[func_id.idx()];
-        let typing = &typing[func_id.idx()];
-        let device = devices[func_id.idx()];
-
-        demands[func_id.idx()].resize(objects[&func_id].num_objects(), BTreeSet::new());
-        match device {
-            Device::LLVM | Device::CUDA => {
-                for (idx, node) in function.nodes.iter().enumerate() {
-                    match node {
-                        // Condition #1.
-                        Node::Read {
-                            collect,
-                            indices: _,
-                        } if types[typing[idx].idx()].is_primitive() => {
-                            for object in objects[&func_id].objects(*collect) {
-                                demands[func_id.idx()][object.idx()].insert(device);
-                            }
-                        }
-                        Node::Write {
-                            collect,
-                            data,
-                            indices: _,
-                        } => {
-                            for object in objects[&func_id]
-                                .objects(*collect)
-                                .into_iter()
-                                .chain(objects[&func_id].objects(*data).into_iter())
-                            {
-                                demands[func_id.idx()][object.idx()].insert(device);
-                            }
-                        }
-                        // Condition #2.
-                        Node::Constant { id: _ } | Node::Undef { ty: _ } => {
-                            for object in objects[&func_id].objects(NodeID::new(idx)) {
-                                demands[func_id.idx()][object.idx()].insert(device);
-                            }
-                        }
-                        _ => {}
-                    }
-                }
-            }
-            Device::AsyncRust => {
-                for (idx, node) in function.nodes.iter().enumerate() {
-                    if let Node::Call {
-                        control: _,
-                        function: callee,
-                        dynamic_constants: _,
-                        args,
-                    } = node
-                    {
-                        // Condition #3.
-                        for (param_idx, arg) in args.into_iter().enumerate() {
-                            if let Some(callee_obj) = objects[callee].param_to_object(param_idx) {
-                                let callee_demands =
-                                    take(&mut demands[callee.idx()][callee_obj.idx()]);
-                                for object in objects[&func_id].objects(*arg) {
-                                    demands[func_id.idx()][object.idx()]
-                                        .extend(callee_demands.iter());
-                                }
-                                demands[callee.idx()][callee_obj.idx()] = callee_demands;
-                            }
-                        }
-
-                        // Condition #4.
-                        for callee_obj in objects[callee].returned_objects() {
-                            let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]);
-                            for object in objects[&func_id].objects(NodeID::new(idx)) {
-                                demands[func_id.idx()][object.idx()].extend(callee_demands.iter());
-                            }
-                            demands[callee.idx()][callee_obj.idx()] = callee_demands;
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    demands
-}
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 821d02ea..446b3184 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -655,6 +655,14 @@ fn terminating_reads<'a>(
                 None
             }
         })),
+        Node::LibraryCall {
+            library_function,
+            ref args,
+            ty: _,
+            device: _,
+        } => match library_function {
+            LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))),
+        },
         _ => Box::new(empty()),
     }
 }
@@ -728,6 +736,16 @@ fn mutating_objects<'a>(
                 })
                 .flatten(),
         ),
+        Node::LibraryCall {
+            library_function,
+            ref args,
+            ty: _,
+            device: _,
+        } => match library_function {
+            LibraryFunction::GEMM => {
+                Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id))
+            }
+        },
         _ => Box::new(empty()),
     }
 }
@@ -757,6 +775,14 @@ fn mutating_writes<'a>(
                 None
             }
         })),
+        Node::LibraryCall {
+            library_function,
+            ref args,
+            ty: _,
+            device: _,
+        } => match library_function {
+            LibraryFunction::GEMM => Box::new(once(args[0])),
+        },
         _ => Box::new(empty()),
     }
 }
@@ -1311,6 +1337,17 @@ fn color_nodes(
                     }
                 }
             }
+            Node::LibraryCall {
+                library_function: _,
+                ref args,
+                ty: _,
+                device,
+            } => {
+                for arg in args {
+                    equations.push((UTerm::Node(*arg), UTerm::Device(device)));
+                }
+                equations.push((UTerm::Node(id), UTerm::Device(device)));
+            }
             _ => {}
         }
     }
diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index 35ed1e84..76159ef7 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -21,7 +21,7 @@ fixpoint {
 ip-sroa(*);
 sroa(*);
 dce(*);
-xdot[true](*);
 
 float-collections(*);
 gcm(*);
+xdot[true](*);
-- 
GitLab


From 3409decf7a8a04121678d2f78bcb043c255dfeb7 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 18 Feb 2025 09:50:43 -0600
Subject: [PATCH 11/13] Plumb things

---
 hercules_cg/src/rt.rs  | 72 ++++++++++++++++++++++++++++++++++++++++++
 hercules_rt/src/lib.rs | 48 ++++++++++++++++++++++++----
 2 files changed, 113 insertions(+), 7 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 4d9a6cf6..5edddd86 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -597,6 +597,59 @@ impl<'a> RTContext<'a> {
                 }
                 write!(block, "){};", postfix)?;
             }
+            Node::LibraryCall {
+                library_function,
+                ref args,
+                ty,
+                device,
+            } => match library_function {
+                LibraryFunction::GEMM => {
+                    assert_eq!(args.len(), 3);
+                    assert_eq!(self.typing[args[0].idx()], ty);
+                    let c_ty = &self.module.types[self.typing[args[0].idx()].idx()];
+                    let a_ty = &self.module.types[self.typing[args[1].idx()].idx()];
+                    let b_ty = &self.module.types[self.typing[args[2].idx()].idx()];
+                    let (
+                        Type::Array(c_elem, c_dims),
+                        Type::Array(a_elem, a_dims),
+                        Type::Array(b_elem, b_dims),
+                    ) = (c_ty, a_ty, b_ty)
+                    else {
+                        panic!();
+                    };
+                    assert_eq!(a_elem, b_elem);
+                    assert_eq!(a_elem, c_elem);
+                    assert_eq!(c_dims.len(), 2);
+                    assert_eq!(a_dims.len(), 2);
+                    assert_eq!(b_dims.len(), 2);
+                    assert_eq!(a_dims[1], b_dims[0]);
+                    assert_eq!(a_dims[0], c_dims[0]);
+                    assert_eq!(b_dims[1], c_dims[1]);
+
+                    let block = &mut blocks.get_mut(&bb).unwrap().data;
+                    let prim_ty = self.library_prim_ty(*a_elem);
+                    write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?;
+                    self.codegen_dynamic_constant(a_dims[0], block)?;
+                    write!(block, ", ")?;
+                    self.codegen_dynamic_constant(a_dims[1], block)?;
+                    write!(block, ", ")?;
+                    self.codegen_dynamic_constant(b_dims[1], block)?;
+                    write!(
+                        block,
+                        ", {}.0, {}.0, {}.0, {});",
+                        self.get_value(args[0], bb, false),
+                        self.get_value(args[1], bb, false),
+                        self.get_value(args[2], bb, false),
+                        prim_ty
+                    )?;
+                    write!(
+                        block,
+                        "{} = {};",
+                        self.get_value(id, bb, true),
+                        self.get_value(args[0], bb, false)
+                    )?;
+                }
+            },
             Node::Unary { op, input } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 match op {
@@ -1316,6 +1369,25 @@ impl<'a> RTContext<'a> {
     fn get_type(&self, id: TypeID) -> &'static str {
         convert_type(&self.module.types[id.idx()])
     }
+
+    fn library_prim_ty(&self, id: TypeID) -> &'static str {
+        match self.module.types[id.idx()] {
+            Type::Boolean => "::hercules_rt::PrimTy::Bool",
+            Type::Integer8 => "::hercules_rt::PrimTy::I8",
+            Type::Integer16 => "::hercules_rt::PrimTy::I16",
+            Type::Integer32 => "::hercules_rt::PrimTy::I32",
+            Type::Integer64 => "::hercules_rt::PrimTy::I64",
+            Type::UnsignedInteger8 => "::hercules_rt::PrimTy::U8",
+            Type::UnsignedInteger16 => "::hercules_rt::PrimTy::U16",
+            Type::UnsignedInteger32 => "::hercules_rt::PrimTy::U32",
+            Type::UnsignedInteger64 => "::hercules_rt::PrimTy::U64",
+            Type::Float8 => "::hercules_rt::PrimTy::F8",
+            Type::BFloat16 => "::hercules_rt::PrimTy::BF16",
+            Type::Float32 => "::hercules_rt::PrimTy::F32",
+            Type::Float64 => "::hercules_rt::PrimTy::F64",
+            _ => panic!(),
+        }
+    }
 }
 
 fn convert_type(ty: &Type) -> &'static str {
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index 841c6f44..e360076e 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -29,7 +29,10 @@ pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) {
         eprintln!("__cpu_dealloc: {:?}, {}", ptr, size);
         assert!(!ptr.is_null() || size == 0);
     }
-    dealloc(ptr, Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap())
+    dealloc(
+        ptr,
+        Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap(),
+    )
 }
 
 pub unsafe fn __cpu_zero_mem(ptr: *mut u8, size: usize) {
@@ -103,14 +106,45 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) {
     ___copy_cuda_to_cuda(dst, src, size);
 }
 
+#[repr(u8)]
+#[derive(Debug, Copy, Clone)]
+pub enum PrimTy {
+    Bool,
+    U8,
+    U16,
+    U32,
+    U64,
+    I8,
+    I16,
+    I32,
+    I64,
+    F8,
+    BF16,
+    F32,
+    F64,
+}
+
+#[cfg(feature = "cuda")]
+pub unsafe fn __library_cuda_gemm(
+    i: u64,
+    j: u64,
+    k: u64,
+    c: *mut u8,
+    a: *const u8,
+    b: *const u8,
+    ty: PrimTy,
+) {
+    panic!("{} {} {} {:?} {:?} {:?} {:?}", i, j, k, c, a, b, ty);
+}
+
 #[cfg(feature = "cuda")]
 extern "C" {
-    pub fn ___cuda_alloc(size: usize) -> *mut u8;
-    pub fn ___cuda_dealloc(ptr: *mut u8, size: usize);
-    pub fn ___cuda_zero_mem(ptr: *mut u8, size: usize);
-    pub fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
-    pub fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
-    pub fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___cuda_alloc(size: usize) -> *mut u8;
+    fn ___cuda_dealloc(ptr: *mut u8, size: usize);
+    fn ___cuda_zero_mem(ptr: *mut u8, size: usize);
+    fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
 }
 
 #[derive(Clone, Debug)]
-- 
GitLab


From 84dc38986eb6af2815adf526fa14fbab968b0000 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 18 Feb 2025 10:23:49 -0600
Subject: [PATCH 12/13] e2e matmul cublas works

---
 hercules_rt/build.rs              |  1 +
 hercules_rt/src/lib.rs            |  7 +++-
 hercules_rt/src/rtdefs.cu         | 70 +++++++++++++++++++------------
 juno_samples/matmul/src/main.rs   | 14 ++++---
 juno_samples/matmul/src/matmul.jn |  4 +-
 5 files changed, 60 insertions(+), 36 deletions(-)

diff --git a/hercules_rt/build.rs b/hercules_rt/build.rs
index 2a1538d6..ab9dda2e 100644
--- a/hercules_rt/build.rs
+++ b/hercules_rt/build.rs
@@ -28,6 +28,7 @@ fn main() {
         println!("cargo::rustc-link-search=native=/opt/cuda/lib/");
         println!("cargo::rustc-link-lib=static=rtdefs");
         println!("cargo::rustc-link-lib=cudart");
+        println!("cargo::rustc-link-lib=cublas");
         println!("cargo::rerun-if-changed=src/rtdefs.cu");
     }
 }
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index e360076e..714ac7a1 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -106,7 +106,6 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) {
     ___copy_cuda_to_cuda(dst, src, size);
 }
 
-#[repr(u8)]
 #[derive(Debug, Copy, Clone)]
 pub enum PrimTy {
     Bool,
@@ -134,7 +133,10 @@ pub unsafe fn __library_cuda_gemm(
     b: *const u8,
     ty: PrimTy,
 ) {
-    panic!("{} {} {} {:?} {:?} {:?} {:?}", i, j, k, c, a, b, ty);
+    match ty {
+        PrimTy::F32 => ___cublas_sgemm(i, j, k, c, a, b),
+        _ => todo!(),
+    }
 }
 
 #[cfg(feature = "cuda")]
@@ -145,6 +147,7 @@ extern "C" {
     fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
     fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
     fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___cublas_sgemm(i: u64, j: u64, k: u64, c: *mut u8, a: *const u8, b: *const u8);
 }
 
 #[derive(Clone, Debug)]
diff --git a/hercules_rt/src/rtdefs.cu b/hercules_rt/src/rtdefs.cu
index 50e11fa6..26e69821 100644
--- a/hercules_rt/src/rtdefs.cu
+++ b/hercules_rt/src/rtdefs.cu
@@ -1,31 +1,49 @@
-extern "C" {
-	void *___cuda_alloc(size_t size) {
-		void *ptr = NULL;
-		cudaError_t res = cudaMalloc(&ptr, size);
-		if (res != cudaSuccess) {
-			ptr = NULL;
-		}
-		return ptr;
-	}
+#include <stdint.h>
+#include <cublas_v2.h>
 
-	void ___cuda_dealloc(void *ptr, size_t size) {
-		(void) size;
-		cudaFree(ptr);
-	}
-
-	void ___cuda_zero_mem(void *ptr, size_t size) {
-		cudaMemset(ptr, 0, size);
-	}
+static cublasHandle_t cublas_handle = 0;
 
-	void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) {
-		cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
-	}
-
-	void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) {
-		cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost);
+extern "C" {
+    void *___cuda_alloc(size_t size) {
+	void *ptr = NULL;
+	cudaError_t res = cudaMalloc(&ptr, size);
+	if (res != cudaSuccess) {
+	    ptr = NULL;
 	}
-
-	void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) {
-		cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice);
+	return ptr;
+    }
+    
+    void ___cuda_dealloc(void *ptr, size_t size) {
+	(void) size;
+	cudaFree(ptr);
+    }
+    
+    void ___cuda_zero_mem(void *ptr, size_t size) {
+	cudaMemset(ptr, 0, size);
+    }
+    
+    void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) {
+	cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
+    }
+    
+    void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) {
+	cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost);
+    }
+    
+    void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) {
+	cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice);
+    }
+    
+    void ___cublas_sgemm(uint64_t i, uint64_t j, uint64_t k, float *c, float *a, float *b) {
+	if (!cublas_handle) {
+	    cublasCreate(&cublas_handle);
 	}
+	float alf = 1.0;
+	float beta = 0.0;
+	cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
+		    k, i, j,
+		    &alf, b, k, a, j,
+		    &beta, c, k);
+	cudaDeviceSynchronize();
+    }
 }
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index 3cb7d7f0..c0e228da 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -1,4 +1,5 @@
 #![feature(concat_idents)]
+use std::iter::zip;
 
 use rand::random;
 
@@ -13,9 +14,9 @@ fn main() {
         const I: usize = 256;
         const J: usize = 64;
         const K: usize = 128;
-        let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
-        let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
-        let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
+        let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect();
+        let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect();
+        let mut correct_c: Box<[f32]> = (0..I * K).map(|_| 0.0).collect();
         for i in 0..I {
             for k in 0..K {
                 for j in 0..J {
@@ -27,7 +28,8 @@ fn main() {
         {
             let mut r = runner!(matmul);
             let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await;
-            assert_eq!(c.as_slice::<i32>(), &*correct_c);
+            let c = c.as_slice::<f32>();
+            assert_eq!(c, &*correct_c);
         }
         #[cfg(feature = "cuda")]
         {
@@ -37,9 +39,9 @@ fn main() {
             let c = r
                 .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
                 .await;
-            let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
+            let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice();
             c.to_cpu_ref(&mut c_cpu);
-            assert_eq!(&*c_cpu, &*correct_c);
+            assert!(zip(c_cpu, correct_c).all(|(calc, correct)| (calc - correct).abs() < 0.00001));
         }
     });
 }
diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn
index e36d94e2..460ce41c 100644
--- a/juno_samples/matmul/src/matmul.jn
+++ b/juno_samples/matmul/src/matmul.jn
@@ -1,6 +1,6 @@
 #[entry]
-fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
-  let res : i32[n, l];
+fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] {
+  let res : f32[n, l];
 
   @outer for i = 0 to n {
     @middle for j = 0 to l {
-- 
GitLab


From 82cd5671a37cd180159e2ca19ce69869016328d1 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 18 Feb 2025 10:33:05 -0600
Subject: [PATCH 13/13] remove xdot

---
 juno_samples/matmul/src/gpu.sch | 1 -
 1 file changed, 1 deletion(-)

diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index 76159ef7..76808149 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -24,4 +24,3 @@ dce(*);
 
 float-collections(*);
 gcm(*);
-xdot[true](*);
-- 
GitLab