From 84dc38986eb6af2815adf526fa14fbab968b0000 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 18 Feb 2025 10:23:49 -0600 Subject: [PATCH] e2e matmul cublas works --- hercules_rt/build.rs | 1 + hercules_rt/src/lib.rs | 7 +++- hercules_rt/src/rtdefs.cu | 70 +++++++++++++++++++------------ juno_samples/matmul/src/main.rs | 14 ++++--- juno_samples/matmul/src/matmul.jn | 4 +- 5 files changed, 60 insertions(+), 36 deletions(-) diff --git a/hercules_rt/build.rs b/hercules_rt/build.rs index 2a1538d6..ab9dda2e 100644 --- a/hercules_rt/build.rs +++ b/hercules_rt/build.rs @@ -28,6 +28,7 @@ fn main() { println!("cargo::rustc-link-search=native=/opt/cuda/lib/"); println!("cargo::rustc-link-lib=static=rtdefs"); println!("cargo::rustc-link-lib=cudart"); + println!("cargo::rustc-link-lib=cublas"); println!("cargo::rerun-if-changed=src/rtdefs.cu"); } } diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index e360076e..714ac7a1 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -106,7 +106,6 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) { ___copy_cuda_to_cuda(dst, src, size); } -#[repr(u8)] #[derive(Debug, Copy, Clone)] pub enum PrimTy { Bool, @@ -134,7 +133,10 @@ pub unsafe fn __library_cuda_gemm( b: *const u8, ty: PrimTy, ) { - panic!("{} {} {} {:?} {:?} {:?} {:?}", i, j, k, c, a, b, ty); + match ty { + PrimTy::F32 => ___cublas_sgemm(i, j, k, c, a, b), + _ => todo!(), + } } #[cfg(feature = "cuda")] @@ -145,6 +147,7 @@ extern "C" { fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + fn ___cublas_sgemm(i: u64, j: u64, k: u64, c: *mut u8, a: *const u8, b: *const u8); } #[derive(Clone, Debug)] diff --git a/hercules_rt/src/rtdefs.cu b/hercules_rt/src/rtdefs.cu index 50e11fa6..26e69821 100644 --- a/hercules_rt/src/rtdefs.cu +++ b/hercules_rt/src/rtdefs.cu @@ -1,31 +1,49 @@ -extern "C" { - void *___cuda_alloc(size_t size) { - void *ptr = NULL; - cudaError_t res = cudaMalloc(&ptr, size); - if (res != cudaSuccess) { - ptr = NULL; - } - return ptr; - } +#include <stdint.h> +#include <cublas_v2.h> - void ___cuda_dealloc(void *ptr, size_t size) { - (void) size; - cudaFree(ptr); - } - - void ___cuda_zero_mem(void *ptr, size_t size) { - cudaMemset(ptr, 0, size); - } +static cublasHandle_t cublas_handle = 0; - void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) { - cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); - } - - void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) { - cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); +extern "C" { + void *___cuda_alloc(size_t size) { + void *ptr = NULL; + cudaError_t res = cudaMalloc(&ptr, size); + if (res != cudaSuccess) { + ptr = NULL; } - - void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) { - cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); + return ptr; + } + + void ___cuda_dealloc(void *ptr, size_t size) { + (void) size; + cudaFree(ptr); + } + + void ___cuda_zero_mem(void *ptr, size_t size) { + cudaMemset(ptr, 0, size); + } + + void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); + } + + void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); + } + + void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); + } + + void ___cublas_sgemm(uint64_t i, uint64_t j, uint64_t k, float *c, float *a, float *b) { + if (!cublas_handle) { + cublasCreate(&cublas_handle); } + float alf = 1.0; + float beta = 0.0; + cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + k, i, j, + &alf, b, k, a, j, + &beta, c, k); + cudaDeviceSynchronize(); + } } diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 3cb7d7f0..c0e228da 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -1,4 +1,5 @@ #![feature(concat_idents)] +use std::iter::zip; use rand::random; @@ -13,9 +14,9 @@ fn main() { const I: usize = 256; const J: usize = 64; const K: usize = 128; - let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); - let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); - let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); + let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect(); + let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect(); + let mut correct_c: Box<[f32]> = (0..I * K).map(|_| 0.0).collect(); for i in 0..I { for k in 0..K { for j in 0..J { @@ -27,7 +28,8 @@ fn main() { { let mut r = runner!(matmul); let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await; - assert_eq!(c.as_slice::<i32>(), &*correct_c); + let c = c.as_slice::<f32>(); + assert_eq!(c, &*correct_c); } #[cfg(feature = "cuda")] { @@ -37,9 +39,9 @@ fn main() { let c = r .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) .await; - let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); + let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice(); c.to_cpu_ref(&mut c_cpu); - assert_eq!(&*c_cpu, &*correct_c); + assert!(zip(c_cpu, correct_c).all(|(calc, correct)| (calc - correct).abs() < 0.00001)); } }); } diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn index e36d94e2..460ce41c 100644 --- a/juno_samples/matmul/src/matmul.jn +++ b/juno_samples/matmul/src/matmul.jn @@ -1,6 +1,6 @@ #[entry] -fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { - let res : i32[n, l]; +fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] { + let res : f32[n, l]; @outer for i = 0 to n { @middle for j = 0 to l { -- GitLab