diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 4d9a6cf63605362cd1ddb96a8486974431cc791c..5edddd86df2901c29109148e2b107c0b923dbd0e 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -597,6 +597,59 @@ impl<'a> RTContext<'a> { } write!(block, "){};", postfix)?; } + Node::LibraryCall { + library_function, + ref args, + ty, + device, + } => match library_function { + LibraryFunction::GEMM => { + assert_eq!(args.len(), 3); + assert_eq!(self.typing[args[0].idx()], ty); + let c_ty = &self.module.types[self.typing[args[0].idx()].idx()]; + let a_ty = &self.module.types[self.typing[args[1].idx()].idx()]; + let b_ty = &self.module.types[self.typing[args[2].idx()].idx()]; + let ( + Type::Array(c_elem, c_dims), + Type::Array(a_elem, a_dims), + Type::Array(b_elem, b_dims), + ) = (c_ty, a_ty, b_ty) + else { + panic!(); + }; + assert_eq!(a_elem, b_elem); + assert_eq!(a_elem, c_elem); + assert_eq!(c_dims.len(), 2); + assert_eq!(a_dims.len(), 2); + assert_eq!(b_dims.len(), 2); + assert_eq!(a_dims[1], b_dims[0]); + assert_eq!(a_dims[0], c_dims[0]); + assert_eq!(b_dims[1], c_dims[1]); + + let block = &mut blocks.get_mut(&bb).unwrap().data; + let prim_ty = self.library_prim_ty(*a_elem); + write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?; + self.codegen_dynamic_constant(a_dims[0], block)?; + write!(block, ", ")?; + self.codegen_dynamic_constant(a_dims[1], block)?; + write!(block, ", ")?; + self.codegen_dynamic_constant(b_dims[1], block)?; + write!( + block, + ", {}.0, {}.0, {}.0, {});", + self.get_value(args[0], bb, false), + self.get_value(args[1], bb, false), + self.get_value(args[2], bb, false), + prim_ty + )?; + write!( + block, + "{} = {};", + self.get_value(id, bb, true), + self.get_value(args[0], bb, false) + )?; + } + }, Node::Unary { op, input } => { let block = &mut blocks.get_mut(&bb).unwrap().data; match op { @@ -1316,6 +1369,25 @@ impl<'a> RTContext<'a> { fn get_type(&self, id: TypeID) -> &'static str { convert_type(&self.module.types[id.idx()]) } + + fn library_prim_ty(&self, id: TypeID) -> &'static str { + match self.module.types[id.idx()] { + Type::Boolean => "::hercules_rt::PrimTy::Bool", + Type::Integer8 => "::hercules_rt::PrimTy::I8", + Type::Integer16 => "::hercules_rt::PrimTy::I16", + Type::Integer32 => "::hercules_rt::PrimTy::I32", + Type::Integer64 => "::hercules_rt::PrimTy::I64", + Type::UnsignedInteger8 => "::hercules_rt::PrimTy::U8", + Type::UnsignedInteger16 => "::hercules_rt::PrimTy::U16", + Type::UnsignedInteger32 => "::hercules_rt::PrimTy::U32", + Type::UnsignedInteger64 => "::hercules_rt::PrimTy::U64", + Type::Float8 => "::hercules_rt::PrimTy::F8", + Type::BFloat16 => "::hercules_rt::PrimTy::BF16", + Type::Float32 => "::hercules_rt::PrimTy::F32", + Type::Float64 => "::hercules_rt::PrimTy::F64", + _ => panic!(), + } + } } fn convert_type(ty: &Type) -> &'static str { diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 841c6f44dc004fd333410f77d6c1b60e34dbda62..e360076eb2840e315503ca74cac2ef99ab2abe18 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -29,7 +29,10 @@ pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) { eprintln!("__cpu_dealloc: {:?}, {}", ptr, size); assert!(!ptr.is_null() || size == 0); } - dealloc(ptr, Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap()) + dealloc( + ptr, + Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap(), + ) } pub unsafe fn __cpu_zero_mem(ptr: *mut u8, size: usize) { @@ -103,14 +106,45 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) { ___copy_cuda_to_cuda(dst, src, size); } +#[repr(u8)] +#[derive(Debug, Copy, Clone)] +pub enum PrimTy { + Bool, + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, + F8, + BF16, + F32, + F64, +} + +#[cfg(feature = "cuda")] +pub unsafe fn __library_cuda_gemm( + i: u64, + j: u64, + k: u64, + c: *mut u8, + a: *const u8, + b: *const u8, + ty: PrimTy, +) { + panic!("{} {} {} {:?} {:?} {:?} {:?}", i, j, k, c, a, b, ty); +} + #[cfg(feature = "cuda")] extern "C" { - pub fn ___cuda_alloc(size: usize) -> *mut u8; - pub fn ___cuda_dealloc(ptr: *mut u8, size: usize); - pub fn ___cuda_zero_mem(ptr: *mut u8, size: usize); - pub fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); - pub fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); - pub fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + fn ___cuda_alloc(size: usize) -> *mut u8; + fn ___cuda_dealloc(ptr: *mut u8, size: usize); + fn ___cuda_zero_mem(ptr: *mut u8, size: usize); + fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); + fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); } #[derive(Clone, Debug)]