diff --git a/Cargo.lock b/Cargo.lock
index 43eaa90984783995b495e7e059e5414506e5ef92..c438e846d87f12fd7eea2c27d73a8fc7e043836d 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -23,6 +23,12 @@ version = "0.5.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
 
+[[package]]
+name = "allocator-api2"
+version = "0.2.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
+
 [[package]]
 name = "anstream"
 version = "0.6.18"
@@ -621,6 +627,27 @@ version = "1.0.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813"
 
+[[package]]
+name = "egg"
+version = "0.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "abb749745461743bb477fba3ef87c663d5965876155c676c9489cfe0963de5ab"
+dependencies = [
+ "env_logger",
+ "hashbrown",
+ "indexmap",
+ "log",
+ "num-bigint",
+ "num-traits",
+ "quanta",
+ "rustc-hash",
+ "saturating",
+ "smallvec",
+ "symbol_table",
+ "symbolic_expressions",
+ "thiserror",
+]
+
 [[package]]
 name = "either"
 version = "1.13.0"
@@ -639,6 +666,15 @@ version = "0.6.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
 
+[[package]]
+name = "env_logger"
+version = "0.9.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7"
+dependencies = [
+ "log",
+]
+
 [[package]]
 name = "equivalent"
 version = "1.0.1"
@@ -752,6 +788,12 @@ version = "1.0.7"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
 
+[[package]]
+name = "foldhash"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f"
+
 [[package]]
 name = "funty"
 version = "1.1.0"
@@ -882,6 +924,11 @@ name = "hashbrown"
 version = "0.15.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
+dependencies = [
+ "allocator-api2",
+ "equivalent",
+ "foldhash",
+]
 
 [[package]]
 name = "heapless"
@@ -955,6 +1002,7 @@ version = "0.1.0"
 dependencies = [
  "bimap",
  "bitvec 1.0.1",
+ "egg",
  "either",
  "hercules_cg",
  "hercules_ir",
@@ -1214,7 +1262,7 @@ dependencies = [
  "async-std",
  "hercules_rt",
  "juno_build",
- "rand 0.8.5",
+ "rand 0.9.0",
  "with_builtin_macros",
 ]
 
@@ -1970,6 +2018,21 @@ dependencies = [
  "bytemuck",
 ]
 
+[[package]]
+name = "quanta"
+version = "0.12.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e"
+dependencies = [
+ "crossbeam-utils",
+ "libc",
+ "once_cell",
+ "raw-cpuid",
+ "wasi 0.11.0+wasi-snapshot-preview1",
+ "web-sys",
+ "winapi",
+]
+
 [[package]]
 name = "quick-error"
 version = "2.0.1"
@@ -2110,6 +2173,15 @@ dependencies = [
  "rgb",
 ]
 
+[[package]]
+name = "raw-cpuid"
+version = "11.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "529468c1335c1c03919960dfefdb1b3648858c20d7ec2d0663e728e4a717efbc"
+dependencies = [
+ "bitflags 2.8.0",
+]
+
 [[package]]
 name = "rayon"
 version = "1.10.0"
@@ -2168,6 +2240,12 @@ version = "0.8.50"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a"
 
+[[package]]
+name = "rustc-hash"
+version = "2.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
+
 [[package]]
 name = "rustc_version"
 version = "0.4.1"
@@ -2202,6 +2280,12 @@ version = "1.0.19"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
 
+[[package]]
+name = "saturating"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71"
+
 [[package]]
 name = "scopeguard"
 version = "1.2.0"
@@ -2333,6 +2417,23 @@ version = "0.11.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
 
+[[package]]
+name = "symbol_table"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f19bffd69fb182e684d14e3c71d04c0ef33d1641ac0b9e81c712c734e83703bc"
+dependencies = [
+ "crossbeam-utils",
+ "foldhash",
+ "hashbrown",
+]
+
+[[package]]
+name = "symbolic_expressions"
+version = "5.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7c68d531d83ec6c531150584c42a4290911964d5f0d79132b193b67252a23b71"
+
 [[package]]
 name = "syn"
 version = "1.0.109"
@@ -2697,6 +2798,28 @@ version = "0.1.8"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
 
+[[package]]
+name = "winapi"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
+dependencies = [
+ "winapi-i686-pc-windows-gnu",
+ "winapi-x86_64-pc-windows-gnu",
+]
+
+[[package]]
+name = "winapi-i686-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
+
+[[package]]
+name = "winapi-x86_64-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
+
 [[package]]
 name = "windows"
 version = "0.59.0"
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 4d9a6cf63605362cd1ddb96a8486974431cc791c..5edddd86df2901c29109148e2b107c0b923dbd0e 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -597,6 +597,59 @@ impl<'a> RTContext<'a> {
                 }
                 write!(block, "){};", postfix)?;
             }
+            Node::LibraryCall {
+                library_function,
+                ref args,
+                ty,
+                device,
+            } => match library_function {
+                LibraryFunction::GEMM => {
+                    assert_eq!(args.len(), 3);
+                    assert_eq!(self.typing[args[0].idx()], ty);
+                    let c_ty = &self.module.types[self.typing[args[0].idx()].idx()];
+                    let a_ty = &self.module.types[self.typing[args[1].idx()].idx()];
+                    let b_ty = &self.module.types[self.typing[args[2].idx()].idx()];
+                    let (
+                        Type::Array(c_elem, c_dims),
+                        Type::Array(a_elem, a_dims),
+                        Type::Array(b_elem, b_dims),
+                    ) = (c_ty, a_ty, b_ty)
+                    else {
+                        panic!();
+                    };
+                    assert_eq!(a_elem, b_elem);
+                    assert_eq!(a_elem, c_elem);
+                    assert_eq!(c_dims.len(), 2);
+                    assert_eq!(a_dims.len(), 2);
+                    assert_eq!(b_dims.len(), 2);
+                    assert_eq!(a_dims[1], b_dims[0]);
+                    assert_eq!(a_dims[0], c_dims[0]);
+                    assert_eq!(b_dims[1], c_dims[1]);
+
+                    let block = &mut blocks.get_mut(&bb).unwrap().data;
+                    let prim_ty = self.library_prim_ty(*a_elem);
+                    write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?;
+                    self.codegen_dynamic_constant(a_dims[0], block)?;
+                    write!(block, ", ")?;
+                    self.codegen_dynamic_constant(a_dims[1], block)?;
+                    write!(block, ", ")?;
+                    self.codegen_dynamic_constant(b_dims[1], block)?;
+                    write!(
+                        block,
+                        ", {}.0, {}.0, {}.0, {});",
+                        self.get_value(args[0], bb, false),
+                        self.get_value(args[1], bb, false),
+                        self.get_value(args[2], bb, false),
+                        prim_ty
+                    )?;
+                    write!(
+                        block,
+                        "{} = {};",
+                        self.get_value(id, bb, true),
+                        self.get_value(args[0], bb, false)
+                    )?;
+                }
+            },
             Node::Unary { op, input } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 match op {
@@ -1316,6 +1369,25 @@ impl<'a> RTContext<'a> {
     fn get_type(&self, id: TypeID) -> &'static str {
         convert_type(&self.module.types[id.idx()])
     }
+
+    fn library_prim_ty(&self, id: TypeID) -> &'static str {
+        match self.module.types[id.idx()] {
+            Type::Boolean => "::hercules_rt::PrimTy::Bool",
+            Type::Integer8 => "::hercules_rt::PrimTy::I8",
+            Type::Integer16 => "::hercules_rt::PrimTy::I16",
+            Type::Integer32 => "::hercules_rt::PrimTy::I32",
+            Type::Integer64 => "::hercules_rt::PrimTy::I64",
+            Type::UnsignedInteger8 => "::hercules_rt::PrimTy::U8",
+            Type::UnsignedInteger16 => "::hercules_rt::PrimTy::U16",
+            Type::UnsignedInteger32 => "::hercules_rt::PrimTy::U32",
+            Type::UnsignedInteger64 => "::hercules_rt::PrimTy::U64",
+            Type::Float8 => "::hercules_rt::PrimTy::F8",
+            Type::BFloat16 => "::hercules_rt::PrimTy::BF16",
+            Type::Float32 => "::hercules_rt::PrimTy::F32",
+            Type::Float64 => "::hercules_rt::PrimTy::F64",
+            _ => panic!(),
+        }
+    }
 }
 
 fn convert_type(ty: &Type) -> &'static str {
diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index f3474ae069c98c7d7cd2cd46bdbcc7a2719c0c56..6b631519d69cf2a548164eaad58cb2574c6b70c3 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -218,6 +218,8 @@ pub fn collection_objects(
         // - Constant: may originate an object.
         // - Call: may originate an object and may return an object passed in as
         //   a parameter.
+        // - LibraryCall: may return an object passed in as a parameter, but may
+        //   not originate an object.
         // - Read: may extract a smaller object from the input - this is
         //   considered to be the same object as the input, as no copy takes
         //   place.
@@ -288,6 +290,14 @@ pub fn collection_objects(
                     }
                     CollectionObjectLattice { objs }
                 }
+                Node::LibraryCall {
+                    library_function,
+                    args: _,
+                    ty: _,
+                    device: _,
+                } => match library_function {
+                    LibraryFunction::GEMM => inputs[0].clone(),
+                },
                 Node::Undef { ty: _ } => {
                     let obj = origins
                         .iter()
@@ -332,7 +342,13 @@ pub fn collection_objects(
                 for object in objects_per_node[idx].iter() {
                     mutated[object.idx()].push(NodeID::new(idx));
                 }
-            } else if let Some((_, callee, _, args)) = node.try_call() {
+            } else if let Node::Call {
+                control: _,
+                function: callee,
+                dynamic_constants: _,
+                args,
+            } = node
+            {
                 let fco = &collection_objects[&callee];
                 for (param_idx, arg) in args.into_iter().enumerate() {
                     // If this parameter corresponds to an object and it's
@@ -347,6 +363,20 @@ pub fn collection_objects(
                         }
                     }
                 }
+            } else if let Node::LibraryCall {
+                library_function,
+                args,
+                ty: _,
+                device: _,
+            } = node
+            {
+                match library_function {
+                    LibraryFunction::GEMM => {
+                        for object in objects_per_node[args[0].idx()].iter() {
+                            mutated[object.idx()].push(NodeID::new(idx));
+                        }
+                    }
+                }
             }
         }
 
diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs
index d60f7fd5a84605cc31d22347cf7797b22e55a3cd..ff0e08edc8c15f76ea38243eebea7cba1940cd19 100644
--- a/hercules_ir/src/def_use.rs
+++ b/hercules_ir/src/def_use.rs
@@ -178,7 +178,13 @@ pub fn get_uses(node: &Node) -> NodeUses {
             uses.extend(args);
             NodeUses::Variable(uses.into_boxed_slice())
         }
-        Node::IntrinsicCall { intrinsic: _, args } => NodeUses::Variable(args.clone()),
+        Node::IntrinsicCall { intrinsic: _, args }
+        | Node::LibraryCall {
+            library_function: _,
+            args,
+            ty: _,
+            device: _,
+        } => NodeUses::Variable(args.clone()),
         Node::Read { collect, indices } => {
             let mut uses = vec![];
             for index in indices.iter() {
@@ -276,9 +282,13 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> {
             uses.extend(args);
             NodeUsesMut::Variable(uses.into_boxed_slice())
         }
-        Node::IntrinsicCall { intrinsic: _, args } => {
-            NodeUsesMut::Variable(args.iter_mut().collect())
-        }
+        Node::IntrinsicCall { intrinsic: _, args }
+        | Node::LibraryCall {
+            library_function: _,
+            args,
+            ty: _,
+            device: _,
+        } => NodeUsesMut::Variable(args.iter_mut().collect()),
         Node::Read { collect, indices } => {
             let mut uses = vec![];
             for index in indices.iter_mut() {
diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs
index cbf8d6349857b6110cc4a9e03f574b358646b29a..c4a5454bc9ea3daaadc602c0c7517d8310c7be93 100644
--- a/hercules_ir/src/device.rs
+++ b/hercules_ir/src/device.rs
@@ -23,118 +23,3 @@ pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec
 
     devices
 }
-
-pub type FunctionObjectDeviceDemands = Vec<BTreeSet<Device>>;
-pub type ObjectDeviceDemands = Vec<FunctionObjectDeviceDemands>;
-
-/*
- * This analysis figures out which device each collection object may be on. At
- * first, an object may need to be on different devices at different times. This
- * is fine during optimization.
- */
-pub fn object_device_demands(
-    functions: &Vec<Function>,
-    types: &Vec<Type>,
-    typing: &ModuleTyping,
-    callgraph: &CallGraph,
-    objects: &CollectionObjects,
-    devices: &Vec<Device>,
-) -> ObjectDeviceDemands {
-    // An object is "demanded" on a device when:
-    // 1. The object is used by a primitive read node or write node in a device
-    //    function. This includes objects on the `data` input to write nodes.
-    //    Non-primitive reads don't demand an object on a device since they are
-    //    lowered to pointer math and no actual memory transfers.
-    // 2. The object is a constant / undef defined in a device function.
-    // 3. The object is passed as input to a call node where the corresponding
-    //    object in the callee is demanded on a device.
-    // 4. The object is returned from a call node where the corresponding object
-    //    in the callee is demanded on a device.
-    // Note that reads and writes in a RT function don't induce a device demand.
-    // This is because RT functions can  call device functions as necessary to
-    // arbitrarily move data onto / off of devices (though this may be slow).
-    // Traverse the functions in a module in reverse topological order, since
-    // the analysis of a function depends on all functions it calls.
-    let mut demands: ObjectDeviceDemands = vec![vec![]; functions.len()];
-    let topo = callgraph.topo();
-
-    for func_id in topo {
-        let function = &functions[func_id.idx()];
-        let typing = &typing[func_id.idx()];
-        let device = devices[func_id.idx()];
-
-        demands[func_id.idx()].resize(objects[&func_id].num_objects(), BTreeSet::new());
-        match device {
-            Device::LLVM | Device::CUDA => {
-                for (idx, node) in function.nodes.iter().enumerate() {
-                    match node {
-                        // Condition #1.
-                        Node::Read {
-                            collect,
-                            indices: _,
-                        } if types[typing[idx].idx()].is_primitive() => {
-                            for object in objects[&func_id].objects(*collect) {
-                                demands[func_id.idx()][object.idx()].insert(device);
-                            }
-                        }
-                        Node::Write {
-                            collect,
-                            data,
-                            indices: _,
-                        } => {
-                            for object in objects[&func_id]
-                                .objects(*collect)
-                                .into_iter()
-                                .chain(objects[&func_id].objects(*data).into_iter())
-                            {
-                                demands[func_id.idx()][object.idx()].insert(device);
-                            }
-                        }
-                        // Condition #2.
-                        Node::Constant { id: _ } | Node::Undef { ty: _ } => {
-                            for object in objects[&func_id].objects(NodeID::new(idx)) {
-                                demands[func_id.idx()][object.idx()].insert(device);
-                            }
-                        }
-                        _ => {}
-                    }
-                }
-            }
-            Device::AsyncRust => {
-                for (idx, node) in function.nodes.iter().enumerate() {
-                    if let Node::Call {
-                        control: _,
-                        function: callee,
-                        dynamic_constants: _,
-                        args,
-                    } = node
-                    {
-                        // Condition #3.
-                        for (param_idx, arg) in args.into_iter().enumerate() {
-                            if let Some(callee_obj) = objects[callee].param_to_object(param_idx) {
-                                let callee_demands =
-                                    take(&mut demands[callee.idx()][callee_obj.idx()]);
-                                for object in objects[&func_id].objects(*arg) {
-                                    demands[func_id.idx()][object.idx()]
-                                        .extend(callee_demands.iter());
-                                }
-                                demands[callee.idx()][callee_obj.idx()] = callee_demands;
-                            }
-                        }
-
-                        // Condition #4.
-                        for callee_obj in objects[callee].returned_objects() {
-                            let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]);
-                            for object in objects[&func_id].objects(NodeID::new(idx)) {
-                                demands[func_id.idx()][object.idx()].extend(callee_demands.iter());
-                            }
-                            demands[callee.idx()][callee_obj.idx()] = callee_demands;
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    demands
-}
diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs
index 0e0840853bdcb83caa954991f40397b33e1fd5eb..921a813d7c9d680defe2e9dd8fa750ddf960d0cb 100644
--- a/hercules_ir/src/dot.rs
+++ b/hercules_ir/src/dot.rs
@@ -318,6 +318,12 @@ fn write_node<W: Write>(
         Node::IntrinsicCall { intrinsic, args: _ } => {
             write!(&mut suffix, "{}", intrinsic.lower_case_name())?
         }
+        Node::LibraryCall {
+            library_function,
+            args: _,
+            ty: _,
+            device,
+        } => write!(&mut suffix, "{:?} on {:?}", library_function, device)?,
         Node::Read {
             collect: _,
             indices,
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index f96cc10c003315aaab022d69aaad32838ec53cf3..bf9698b32125832b047ee9a94668bd0a09b93ac9 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -222,6 +222,12 @@ pub enum Node {
         intrinsic: Intrinsic,
         args: Box<[NodeID]>,
     },
+    LibraryCall {
+        library_function: LibraryFunction,
+        args: Box<[NodeID]>,
+        ty: TypeID,
+        device: Device,
+    },
     Read {
         collect: NodeID,
         indices: Box<[Index]>,
@@ -336,7 +342,7 @@ pub enum Schedule {
  * The authoritative enumeration of supported backends. Multiple backends may
  * correspond to the same kind of hardware.
  */
-#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
 pub enum Device {
     LLVM,
     CUDA,
@@ -345,6 +351,14 @@ pub enum Device {
     AsyncRust,
 }
 
+/*
+ * The authoritative enumeration of supported library calls.
+ */
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
+pub enum LibraryFunction {
+    GEMM,
+}
+
 /*
  * A single node may have multiple schedules.
  */
@@ -1531,6 +1545,12 @@ impl Node {
                 intrinsic: _,
                 args: _,
             } => "Intrinsic",
+            Node::LibraryCall {
+                library_function: _,
+                args: _,
+                ty: _,
+                device: _,
+            } => "Library",
             Node::Read {
                 collect: _,
                 indices: _,
@@ -1604,6 +1624,12 @@ impl Node {
                 intrinsic: _,
                 args: _,
             } => "intrinsic",
+            Node::LibraryCall {
+                library_function: _,
+                args: _,
+                ty: _,
+                device: _,
+            } => "library",
             Node::Read {
                 collect: _,
                 indices: _,
diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index d01a5c58d8de15ba291a75a6f1b0528433e36d2c..1ff890db4ff4ed3678be0338106f4add2b829aba 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -935,6 +935,12 @@ fn typeflow(
                 }
             }
         }
+        Node::LibraryCall {
+            library_function: _,
+            args: _,
+            ty,
+            device: _,
+        } => Concrete(*ty),
         Node::Read {
             collect: _,
             indices,
diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml
index 0b250d28d9460bd0b19f574c27916e82802664c8..92e0533938eab0cc76fe4b69a205f4dceae2f244 100644
--- a/hercules_opt/Cargo.toml
+++ b/hercules_opt/Cargo.toml
@@ -21,4 +21,5 @@ serde = { version = "*", features = ["derive"] }
 hercules_cg = { path = "../hercules_cg" }
 hercules_ir = { path = "../hercules_ir" }
 nestify = "*"
-bimap = "*"
\ No newline at end of file
+bimap = "*"
+egg = "*"
diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs
index 1969430a95ea020e9d1fb87d4165ea219d3f96a9..b626148c936e384b6bc0a9aaf951c35a9c4b4736 100644
--- a/hercules_opt/src/ccp.rs
+++ b/hercules_opt/src/ccp.rs
@@ -933,6 +933,17 @@ fn ccp_flow_function(
                 constant: new_constant,
             }
         }
+        Node::LibraryCall {
+            library_function: _,
+            args,
+            ty: _,
+            device: _,
+        } => CCPLattice {
+            reachability: args.iter().fold(ReachabilityLattice::bottom(), |val, id| {
+                ReachabilityLattice::join(&val, &inputs[id.idx()].reachability)
+            }),
+            constant: ConstantLattice::bottom(),
+        },
         Node::Read { collect, indices } => {
             let mut reachability = inputs[collect.idx()].reachability.clone();
             for index in indices.iter() {
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 821d02ea886fdcbd5ffa35ad9a298128a92dbb07..446b31849b20fd3fa3d5361a74c15f0514d1dbad 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -655,6 +655,14 @@ fn terminating_reads<'a>(
                 None
             }
         })),
+        Node::LibraryCall {
+            library_function,
+            ref args,
+            ty: _,
+            device: _,
+        } => match library_function {
+            LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))),
+        },
         _ => Box::new(empty()),
     }
 }
@@ -728,6 +736,16 @@ fn mutating_objects<'a>(
                 })
                 .flatten(),
         ),
+        Node::LibraryCall {
+            library_function,
+            ref args,
+            ty: _,
+            device: _,
+        } => match library_function {
+            LibraryFunction::GEMM => {
+                Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id))
+            }
+        },
         _ => Box::new(empty()),
     }
 }
@@ -757,6 +775,14 @@ fn mutating_writes<'a>(
                 None
             }
         })),
+        Node::LibraryCall {
+            library_function,
+            ref args,
+            ty: _,
+            device: _,
+        } => match library_function {
+            LibraryFunction::GEMM => Box::new(once(args[0])),
+        },
         _ => Box::new(empty()),
     }
 }
@@ -1311,6 +1337,17 @@ fn color_nodes(
                     }
                 }
             }
+            Node::LibraryCall {
+                library_function: _,
+                ref args,
+                ty: _,
+                device,
+            } => {
+                for arg in args {
+                    equations.push((UTerm::Node(*arg), UTerm::Device(device)));
+                }
+                equations.push((UTerm::Node(id), UTerm::Device(device)));
+            }
             _ => {}
         }
     }
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index b56f94086204444b49672ee6baefccdaa0b0cb8b..b25449e7bd143e82c6d8b2fb39f2d8e238bd5cfa 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -21,6 +21,7 @@ pub mod outline;
 pub mod phi_elim;
 pub mod pred;
 pub mod reuse_products;
+pub mod rewrite_math_expressions;
 pub mod schedule;
 pub mod simplify_cfg;
 pub mod slf;
@@ -49,6 +50,7 @@ pub use crate::outline::*;
 pub use crate::phi_elim::*;
 pub use crate::pred::*;
 pub use crate::reuse_products::*;
+pub use crate::rewrite_math_expressions::*;
 pub use crate::schedule::*;
 pub use crate::simplify_cfg::*;
 pub use crate::slf::*;
diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs
new file mode 100644
index 0000000000000000000000000000000000000000..6f52dc58ee32751b96190b42a0add01d42ce3b1e
--- /dev/null
+++ b/hercules_opt/src/rewrite_math_expressions.rs
@@ -0,0 +1,166 @@
+use std::collections::{HashMap, HashSet};
+use std::fmt::{Error, Write};
+
+use hercules_ir::*;
+
+use egg::*;
+
+use crate::*;
+
+define_language! {
+    enum MathLanguage {
+        "zero" = Zero,
+        "one" = One,
+
+        ForkDim(i64),
+        "tid" = ThreadID(Id),
+        "sum" = SumReduction(Box<[Id]>),
+        "array" = Comprehension(Box<[Id]>),
+
+        "read" = Read(Box<[Id]>),
+
+        "+" = Add([Id; 2]),
+        "*" = Mul([Id; 2]),
+
+        "library_gemm" = LibraryGemm([Id; 2]),
+
+        Opaque(Symbol),
+    }
+}
+
+fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
+    vec![
+        rewrite!("add-zero"; "(+ zero ?a)" => "?a"),
+        rewrite!("mul-zero"; "(* zero ?a)" => "zero"),
+        rewrite!("mul-one"; "(* one ?a)" => "?a"),
+        rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"),
+    ]
+}
+
+pub fn rewrite_math_expressions(
+    editor: &mut FunctionEditor,
+    device: Device,
+    typing: &Vec<TypeID>,
+    fork_join_map: &HashMap<NodeID, NodeID>,
+    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
+    reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
+) {
+    let join_fork_map: HashMap<_, _> = fork_join_map
+        .into_iter()
+        .map(|(fork, join)| (*join, *fork))
+        .collect();
+
+    // Step 1: figure out how many fork-joins each reduce is in. We want to
+    // rewrite outer reductions before inner ones.
+    let mut depth: HashMap<NodeID, u32> = HashMap::new();
+    for (_, inside) in nodes_in_fork_joins {
+        for id in inside {
+            if editor.func().nodes[id.idx()].is_reduce() {
+                *depth.entry(*id).or_default() += 1;
+            }
+        }
+    }
+    let mut reduces: Vec<(u32, NodeID)> =
+        depth.into_iter().map(|(id, depth)| (depth, id)).collect();
+    reduces.sort();
+
+    for (_, reduce) in reduces {
+        let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap();
+        let fork = join_fork_map[&join];
+
+        // Step 2: convert the reduce to an expression in egg and rewrite it.
+        let mut s = String::new();
+        egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap();
+
+        let expr: RecExpr<MathLanguage> = s.parse().unwrap();
+        let egraph = Runner::default().with_expr(&expr).run(&make_rules()).egraph;
+
+        // Step 3: match the smallest expression against patterns for known
+        // library functions.
+        let gemm_pattern: Pattern<MathLanguage> = "(library_gemm ?A ?B)".parse().unwrap();
+        let mut matches = gemm_pattern.search(&egraph);
+        if matches.len() > 0
+            && let m = matches.remove(0)
+            && m.substs.len() > 0
+        {
+            let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph);
+            let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph);
+            let call = Node::LibraryCall {
+                library_function: LibraryFunction::GEMM,
+                args: vec![init, left_id, right_id].into_boxed_slice(),
+                ty: typing[reduce.idx()],
+                device,
+            };
+            let success = editor.edit(|mut edit| {
+                let call = edit.add_node(call);
+                edit.replace_all_uses_where(reduce, call, |id| {
+                    !nodes_in_fork_joins[&fork].contains(id)
+                })
+            });
+            if success {
+                return;
+            }
+        }
+    }
+}
+
+fn get_node_ids_from_subst(var: &str, subst: &Subst, egraph: &EGraph<MathLanguage, ()>) -> NodeID {
+    let id = *subst.get(var.parse::<Var>().unwrap()).unwrap();
+    let expr = egraph.id_to_expr(id);
+    let MathLanguage::Opaque(sym) = expr.last().unwrap() else {
+        todo!();
+    };
+    let sym = sym.as_str();
+    assert!(sym.chars().nth(0).unwrap() == 'n');
+    NodeID::new(sym[1..].parse().unwrap())
+}
+
+fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> {
+    match env[id.idx()] {
+        MathExpr::Zero(_) => write!(w, "zero"),
+        MathExpr::One(_) => write!(w, "one"),
+        MathExpr::OpaqueNode(id) => write!(w, "n{}", id.idx()),
+        MathExpr::ThreadID(dim) => write!(w, "{}", dim.0),
+        MathExpr::SumReduction(id, ref dims) => {
+            write!(w, "(sum")?;
+            for dim in dims {
+                write!(w, " {}", dim.0)?;
+            }
+            write!(w, " ")?;
+            egg_print_math_expr(id, env, w)?;
+            write!(w, ")")
+        }
+        MathExpr::Comprehension(id, ref dims) => {
+            write!(w, "(array")?;
+            for dim in dims {
+                write!(w, " {}", dim.0)?;
+            }
+            write!(w, " ")?;
+            egg_print_math_expr(id, env, w)?;
+            write!(w, ")")
+        }
+        MathExpr::Read(id, ref pos) => {
+            write!(w, "(read")?;
+            for pos in pos {
+                write!(w, " ")?;
+                egg_print_math_expr(*pos, env, w)?;
+            }
+            write!(w, " ")?;
+            egg_print_math_expr(id, env, w)?;
+            write!(w, ")")
+        }
+        MathExpr::Binary(op, left, right) => {
+            write!(w, "(")?;
+            match op {
+                BinaryOperator::Add => write!(w, "+ "),
+                BinaryOperator::Mul => write!(w, "* "),
+                _ => Err(Error::default()),
+            }?;
+            egg_print_math_expr(left, env, w)?;
+            write!(w, " ")?;
+            egg_print_math_expr(right, env, w)?;
+            write!(w, ")")
+        }
+        _ => Err(Error::default()),
+    }
+}
diff --git a/hercules_opt/src/simplify_cfg.rs b/hercules_opt/src/simplify_cfg.rs
index 14a152dcb3dae335e5f9724bdb1f467cca5e1c19..cf39db2b35283da7eb641101478df68f858cb8c3 100644
--- a/hercules_opt/src/simplify_cfg.rs
+++ b/hercules_opt/src/simplify_cfg.rs
@@ -126,9 +126,7 @@ fn remove_useless_fork_joins(
 
     // Third, get rid of fork-joins.
     for (fork, join) in fork_join_map {
-        if editor.get_users(*join).len() == 1 {
-            assert_eq!(editor.get_users(*fork).len(), 1);
-
+        if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 {
             let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0];
             let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0];
 
diff --git a/hercules_rt/build.rs b/hercules_rt/build.rs
index 2a1538d600e1632e9311ef9464c38a83fdf44415..ab9dda2e0948e40d3c0bf226a855d56e273289c6 100644
--- a/hercules_rt/build.rs
+++ b/hercules_rt/build.rs
@@ -28,6 +28,7 @@ fn main() {
         println!("cargo::rustc-link-search=native=/opt/cuda/lib/");
         println!("cargo::rustc-link-lib=static=rtdefs");
         println!("cargo::rustc-link-lib=cudart");
+        println!("cargo::rustc-link-lib=cublas");
         println!("cargo::rerun-if-changed=src/rtdefs.cu");
     }
 }
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index 79b5cbac72422e82e53ac2e74dc15798968ea155..419a760fa49647c28295625bd7df9db0e87707c1 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -29,7 +29,10 @@ pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) {
         eprintln!("__cpu_dealloc: {:?}, {}", ptr, size);
         assert!(!ptr.is_null() || size == 0);
     }
-    dealloc(ptr, Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap())
+    dealloc(
+        ptr,
+        Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap(),
+    )
 }
 
 pub unsafe fn __cpu_zero_mem(ptr: *mut u8, size: usize) {
@@ -103,14 +106,48 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) {
     ___copy_cuda_to_cuda(dst, src, size);
 }
 
+#[derive(Debug, Copy, Clone)]
+pub enum PrimTy {
+    Bool,
+    U8,
+    U16,
+    U32,
+    U64,
+    I8,
+    I16,
+    I32,
+    I64,
+    F8,
+    BF16,
+    F32,
+    F64,
+}
+
+#[cfg(feature = "cuda")]
+pub unsafe fn __library_cuda_gemm(
+    i: u64,
+    j: u64,
+    k: u64,
+    c: *mut u8,
+    a: *const u8,
+    b: *const u8,
+    ty: PrimTy,
+) {
+    match ty {
+        PrimTy::F32 => ___cublas_sgemm(i, j, k, c, a, b),
+        _ => todo!(),
+    }
+}
+
 #[cfg(feature = "cuda")]
 extern "C" {
-    pub fn ___cuda_alloc(size: usize) -> *mut u8;
-    pub fn ___cuda_dealloc(ptr: *mut u8, size: usize);
-    pub fn ___cuda_zero_mem(ptr: *mut u8, size: usize);
-    pub fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
-    pub fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
-    pub fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___cuda_alloc(size: usize) -> *mut u8;
+    fn ___cuda_dealloc(ptr: *mut u8, size: usize);
+    fn ___cuda_zero_mem(ptr: *mut u8, size: usize);
+    fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___cublas_sgemm(i: u64, j: u64, k: u64, c: *mut u8, a: *const u8, b: *const u8);
 }
 
 #[derive(Clone, Debug)]
diff --git a/hercules_rt/src/rtdefs.cu b/hercules_rt/src/rtdefs.cu
index 50e11fa6748d8569b511b226ca460f5ba1855e8d..26e698218b92b88cbae013daeb3bea3db8f4f5ce 100644
--- a/hercules_rt/src/rtdefs.cu
+++ b/hercules_rt/src/rtdefs.cu
@@ -1,31 +1,49 @@
-extern "C" {
-	void *___cuda_alloc(size_t size) {
-		void *ptr = NULL;
-		cudaError_t res = cudaMalloc(&ptr, size);
-		if (res != cudaSuccess) {
-			ptr = NULL;
-		}
-		return ptr;
-	}
+#include <stdint.h>
+#include <cublas_v2.h>
 
-	void ___cuda_dealloc(void *ptr, size_t size) {
-		(void) size;
-		cudaFree(ptr);
-	}
-
-	void ___cuda_zero_mem(void *ptr, size_t size) {
-		cudaMemset(ptr, 0, size);
-	}
+static cublasHandle_t cublas_handle = 0;
 
-	void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) {
-		cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
-	}
-
-	void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) {
-		cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost);
+extern "C" {
+    void *___cuda_alloc(size_t size) {
+	void *ptr = NULL;
+	cudaError_t res = cudaMalloc(&ptr, size);
+	if (res != cudaSuccess) {
+	    ptr = NULL;
 	}
-
-	void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) {
-		cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice);
+	return ptr;
+    }
+    
+    void ___cuda_dealloc(void *ptr, size_t size) {
+	(void) size;
+	cudaFree(ptr);
+    }
+    
+    void ___cuda_zero_mem(void *ptr, size_t size) {
+	cudaMemset(ptr, 0, size);
+    }
+    
+    void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) {
+	cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
+    }
+    
+    void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) {
+	cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost);
+    }
+    
+    void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) {
+	cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice);
+    }
+    
+    void ___cublas_sgemm(uint64_t i, uint64_t j, uint64_t k, float *c, float *a, float *b) {
+	if (!cublas_handle) {
+	    cublasCreate(&cublas_handle);
 	}
+	float alf = 1.0;
+	float beta = 0.0;
+	cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
+		    k, i, j,
+		    &alf, b, k, a, j,
+		    &beta, c, k);
+	cudaDeviceSynchronize();
+    }
 }
diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index edb83d74a54bd72c0207923b82485eff8905cbb6..76808149e7b99dce97e43f0e936536d3d13b7417 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -12,9 +12,12 @@ fixpoint {
 fork-coalesce(*);
 infer-schedules(*);
 dce(*);
+rewrite(*);
+fixpoint {
+  simplify-cfg(*);
+  dce(*);
+}
 
-let out = auto-outline(*);
-gpu(out.matmul);
 ip-sroa(*);
 sroa(*);
 dce(*);
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index 3cb7d7f0f35e847b7c88bf4ee8dc2b6536d30647..c0e228daa04704b90156a592d27b761aeba6591c 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -1,4 +1,5 @@
 #![feature(concat_idents)]
+use std::iter::zip;
 
 use rand::random;
 
@@ -13,9 +14,9 @@ fn main() {
         const I: usize = 256;
         const J: usize = 64;
         const K: usize = 128;
-        let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
-        let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
-        let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
+        let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect();
+        let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect();
+        let mut correct_c: Box<[f32]> = (0..I * K).map(|_| 0.0).collect();
         for i in 0..I {
             for k in 0..K {
                 for j in 0..J {
@@ -27,7 +28,8 @@ fn main() {
         {
             let mut r = runner!(matmul);
             let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await;
-            assert_eq!(c.as_slice::<i32>(), &*correct_c);
+            let c = c.as_slice::<f32>();
+            assert_eq!(c, &*correct_c);
         }
         #[cfg(feature = "cuda")]
         {
@@ -37,9 +39,9 @@ fn main() {
             let c = r
                 .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
                 .await;
-            let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
+            let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice();
             c.to_cpu_ref(&mut c_cpu);
-            assert_eq!(&*c_cpu, &*correct_c);
+            assert!(zip(c_cpu, correct_c).all(|(calc, correct)| (calc - correct).abs() < 0.00001));
         }
     });
 }
diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn
index e36d94e209a2385fc28e54b1cfe4c432564d3706..460ce41c3742486100062b52942870fb35379081 100644
--- a/juno_samples/matmul/src/matmul.jn
+++ b/juno_samples/matmul/src/matmul.jn
@@ -1,6 +1,6 @@
 #[entry]
-fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
-  let res : i32[n, l];
+fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] {
+  let res : f32[n, l];
 
   @outer for i = 0 to n {
     @middle for j = 0 to l {
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 43871c908fbb809e9795d727125454fa4a3f80ff..e9132fd20650d1c06a9b13f8a8f0815f16f83337 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -141,6 +141,9 @@ impl FromStr for Appliable {
             "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)),
             "rename" => Ok(Appliable::Pass(ir::Pass::Rename)),
             "reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)),
+            "rewrite" | "rewrite-math" | "rewrite-math-expressions" => {
+                Ok(Appliable::Pass(ir::Pass::RewriteMathExpressions))
+            }
             "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)),
             "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)),
             "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)),
diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs
index 378e1c49b67e31d1c83a876489b4880ed62b1121..4ac5a732d12f87b6eb76c2771c2c5addfa42cb50 100644
--- a/juno_scheduler/src/ir.rs
+++ b/juno_scheduler/src/ir.rs
@@ -34,6 +34,7 @@ pub enum Pass {
     ReduceSLF,
     Rename,
     ReuseProducts,
+    RewriteMathExpressions,
     SLF,
     SROA,
     Serialize,
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index aa5a8bf13f04b9d8dfb98ce7f4bbc39e1b86a920..d5c0af27e280d523683cec2b39da387a3f2c3f45 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2279,6 +2279,40 @@ fn run_pass(
             pm.delete_gravestones();
             pm.clear_analyses();
         }
+        Pass::RewriteMathExpressions => {
+            assert!(args.is_empty());
+            pm.make_typing();
+            pm.make_fork_join_maps();
+            pm.make_nodes_in_fork_joins();
+            pm.make_reduce_einsums();
+            let typing = pm.typing.take().unwrap();
+            let fork_join_maps = pm.fork_join_maps.take().unwrap();
+            let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap();
+            let reduce_einsums = pm.reduce_einsums.take().unwrap();
+            for ((((func, typing), fork_join_map), nodes_in_fork_joins), reduce_einsums) in
+                build_selection(pm, selection, false)
+                    .into_iter()
+                    .zip(typing.iter())
+                    .zip(fork_join_maps.iter())
+                    .zip(nodes_in_fork_joins.iter())
+                    .zip(reduce_einsums.iter())
+            {
+                let Some(mut func) = func else {
+                    continue;
+                };
+                rewrite_math_expressions(
+                    &mut func,
+                    Device::CUDA,
+                    typing,
+                    fork_join_map,
+                    nodes_in_fork_joins,
+                    reduce_einsums,
+                );
+                changed |= func.modified();
+            }
+            pm.delete_gravestones();
+            pm.clear_analyses();
+        }
         Pass::SLF => {
             assert!(args.is_empty());
             pm.make_reverse_postorders();