diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 4d9a6cf63605362cd1ddb96a8486974431cc791c..5edddd86df2901c29109148e2b107c0b923dbd0e 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -597,6 +597,59 @@ impl<'a> RTContext<'a> {
                 }
                 write!(block, "){};", postfix)?;
             }
+            Node::LibraryCall {
+                library_function,
+                ref args,
+                ty,
+                device,
+            } => match library_function {
+                LibraryFunction::GEMM => {
+                    assert_eq!(args.len(), 3);
+                    assert_eq!(self.typing[args[0].idx()], ty);
+                    let c_ty = &self.module.types[self.typing[args[0].idx()].idx()];
+                    let a_ty = &self.module.types[self.typing[args[1].idx()].idx()];
+                    let b_ty = &self.module.types[self.typing[args[2].idx()].idx()];
+                    let (
+                        Type::Array(c_elem, c_dims),
+                        Type::Array(a_elem, a_dims),
+                        Type::Array(b_elem, b_dims),
+                    ) = (c_ty, a_ty, b_ty)
+                    else {
+                        panic!();
+                    };
+                    assert_eq!(a_elem, b_elem);
+                    assert_eq!(a_elem, c_elem);
+                    assert_eq!(c_dims.len(), 2);
+                    assert_eq!(a_dims.len(), 2);
+                    assert_eq!(b_dims.len(), 2);
+                    assert_eq!(a_dims[1], b_dims[0]);
+                    assert_eq!(a_dims[0], c_dims[0]);
+                    assert_eq!(b_dims[1], c_dims[1]);
+
+                    let block = &mut blocks.get_mut(&bb).unwrap().data;
+                    let prim_ty = self.library_prim_ty(*a_elem);
+                    write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?;
+                    self.codegen_dynamic_constant(a_dims[0], block)?;
+                    write!(block, ", ")?;
+                    self.codegen_dynamic_constant(a_dims[1], block)?;
+                    write!(block, ", ")?;
+                    self.codegen_dynamic_constant(b_dims[1], block)?;
+                    write!(
+                        block,
+                        ", {}.0, {}.0, {}.0, {});",
+                        self.get_value(args[0], bb, false),
+                        self.get_value(args[1], bb, false),
+                        self.get_value(args[2], bb, false),
+                        prim_ty
+                    )?;
+                    write!(
+                        block,
+                        "{} = {};",
+                        self.get_value(id, bb, true),
+                        self.get_value(args[0], bb, false)
+                    )?;
+                }
+            },
             Node::Unary { op, input } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 match op {
@@ -1316,6 +1369,25 @@ impl<'a> RTContext<'a> {
     fn get_type(&self, id: TypeID) -> &'static str {
         convert_type(&self.module.types[id.idx()])
     }
+
+    fn library_prim_ty(&self, id: TypeID) -> &'static str {
+        match self.module.types[id.idx()] {
+            Type::Boolean => "::hercules_rt::PrimTy::Bool",
+            Type::Integer8 => "::hercules_rt::PrimTy::I8",
+            Type::Integer16 => "::hercules_rt::PrimTy::I16",
+            Type::Integer32 => "::hercules_rt::PrimTy::I32",
+            Type::Integer64 => "::hercules_rt::PrimTy::I64",
+            Type::UnsignedInteger8 => "::hercules_rt::PrimTy::U8",
+            Type::UnsignedInteger16 => "::hercules_rt::PrimTy::U16",
+            Type::UnsignedInteger32 => "::hercules_rt::PrimTy::U32",
+            Type::UnsignedInteger64 => "::hercules_rt::PrimTy::U64",
+            Type::Float8 => "::hercules_rt::PrimTy::F8",
+            Type::BFloat16 => "::hercules_rt::PrimTy::BF16",
+            Type::Float32 => "::hercules_rt::PrimTy::F32",
+            Type::Float64 => "::hercules_rt::PrimTy::F64",
+            _ => panic!(),
+        }
+    }
 }
 
 fn convert_type(ty: &Type) -> &'static str {
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index 841c6f44dc004fd333410f77d6c1b60e34dbda62..e360076eb2840e315503ca74cac2ef99ab2abe18 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -29,7 +29,10 @@ pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) {
         eprintln!("__cpu_dealloc: {:?}, {}", ptr, size);
         assert!(!ptr.is_null() || size == 0);
     }
-    dealloc(ptr, Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap())
+    dealloc(
+        ptr,
+        Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap(),
+    )
 }
 
 pub unsafe fn __cpu_zero_mem(ptr: *mut u8, size: usize) {
@@ -103,14 +106,45 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) {
     ___copy_cuda_to_cuda(dst, src, size);
 }
 
+#[repr(u8)]
+#[derive(Debug, Copy, Clone)]
+pub enum PrimTy {
+    Bool,
+    U8,
+    U16,
+    U32,
+    U64,
+    I8,
+    I16,
+    I32,
+    I64,
+    F8,
+    BF16,
+    F32,
+    F64,
+}
+
+#[cfg(feature = "cuda")]
+pub unsafe fn __library_cuda_gemm(
+    i: u64,
+    j: u64,
+    k: u64,
+    c: *mut u8,
+    a: *const u8,
+    b: *const u8,
+    ty: PrimTy,
+) {
+    panic!("{} {} {} {:?} {:?} {:?} {:?}", i, j, k, c, a, b, ty);
+}
+
 #[cfg(feature = "cuda")]
 extern "C" {
-    pub fn ___cuda_alloc(size: usize) -> *mut u8;
-    pub fn ___cuda_dealloc(ptr: *mut u8, size: usize);
-    pub fn ___cuda_zero_mem(ptr: *mut u8, size: usize);
-    pub fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
-    pub fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
-    pub fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___cuda_alloc(size: usize) -> *mut u8;
+    fn ___cuda_dealloc(ptr: *mut u8, size: usize);
+    fn ___cuda_zero_mem(ptr: *mut u8, size: usize);
+    fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
+    fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
 }
 
 #[derive(Clone, Debug)]