Skip to content
Snippets Groups Projects
Commit 84dc3898 authored by Russel Arbore's avatar Russel Arbore
Browse files

e2e matmul cublas works

parent 3409decf
No related branches found
No related tags found
1 merge request!187Identify and lower library functions
Pipeline #201739 passed
......@@ -28,6 +28,7 @@ fn main() {
println!("cargo::rustc-link-search=native=/opt/cuda/lib/");
println!("cargo::rustc-link-lib=static=rtdefs");
println!("cargo::rustc-link-lib=cudart");
println!("cargo::rustc-link-lib=cublas");
println!("cargo::rerun-if-changed=src/rtdefs.cu");
}
}
......@@ -106,7 +106,6 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) {
___copy_cuda_to_cuda(dst, src, size);
}
#[repr(u8)]
#[derive(Debug, Copy, Clone)]
pub enum PrimTy {
Bool,
......@@ -134,7 +133,10 @@ pub unsafe fn __library_cuda_gemm(
b: *const u8,
ty: PrimTy,
) {
panic!("{} {} {} {:?} {:?} {:?} {:?}", i, j, k, c, a, b, ty);
match ty {
PrimTy::F32 => ___cublas_sgemm(i, j, k, c, a, b),
_ => todo!(),
}
}
#[cfg(feature = "cuda")]
......@@ -145,6 +147,7 @@ extern "C" {
fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
fn ___cublas_sgemm(i: u64, j: u64, k: u64, c: *mut u8, a: *const u8, b: *const u8);
}
#[derive(Clone, Debug)]
......
extern "C" {
void *___cuda_alloc(size_t size) {
void *ptr = NULL;
cudaError_t res = cudaMalloc(&ptr, size);
if (res != cudaSuccess) {
ptr = NULL;
}
return ptr;
}
#include <stdint.h>
#include <cublas_v2.h>
void ___cuda_dealloc(void *ptr, size_t size) {
(void) size;
cudaFree(ptr);
}
void ___cuda_zero_mem(void *ptr, size_t size) {
cudaMemset(ptr, 0, size);
}
static cublasHandle_t cublas_handle = 0;
void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) {
cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
}
void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) {
cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost);
extern "C" {
void *___cuda_alloc(size_t size) {
void *ptr = NULL;
cudaError_t res = cudaMalloc(&ptr, size);
if (res != cudaSuccess) {
ptr = NULL;
}
void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) {
cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice);
return ptr;
}
void ___cuda_dealloc(void *ptr, size_t size) {
(void) size;
cudaFree(ptr);
}
void ___cuda_zero_mem(void *ptr, size_t size) {
cudaMemset(ptr, 0, size);
}
void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) {
cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
}
void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) {
cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost);
}
void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) {
cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice);
}
void ___cublas_sgemm(uint64_t i, uint64_t j, uint64_t k, float *c, float *a, float *b) {
if (!cublas_handle) {
cublasCreate(&cublas_handle);
}
float alf = 1.0;
float beta = 0.0;
cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N,
k, i, j,
&alf, b, k, a, j,
&beta, c, k);
cudaDeviceSynchronize();
}
}
#![feature(concat_idents)]
use std::iter::zip;
use rand::random;
......@@ -13,9 +14,9 @@ fn main() {
const I: usize = 256;
const J: usize = 64;
const K: usize = 128;
let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect();
let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect();
let mut correct_c: Box<[f32]> = (0..I * K).map(|_| 0.0).collect();
for i in 0..I {
for k in 0..K {
for j in 0..J {
......@@ -27,7 +28,8 @@ fn main() {
{
let mut r = runner!(matmul);
let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await;
assert_eq!(c.as_slice::<i32>(), &*correct_c);
let c = c.as_slice::<f32>();
assert_eq!(c, &*correct_c);
}
#[cfg(feature = "cuda")]
{
......@@ -37,9 +39,9 @@ fn main() {
let c = r
.run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref())
.await;
let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice();
c.to_cpu_ref(&mut c_cpu);
assert_eq!(&*c_cpu, &*correct_c);
assert!(zip(c_cpu, correct_c).all(|(calc, correct)| (calc - correct).abs() < 0.00001));
}
});
}
......
#[entry]
fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
let res : i32[n, l];
fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] {
let res : f32[n, l];
@outer for i = 0 to n {
@middle for j = 0 to l {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment