Skip to content
Snippets Groups Projects
Commit 3409decf authored by Russel Arbore's avatar Russel Arbore
Browse files

Plumb things

parent 8f806736
No related branches found
No related tags found
1 merge request!187Identify and lower library functions
Pipeline #201738 failed
......@@ -597,6 +597,59 @@ impl<'a> RTContext<'a> {
}
write!(block, "){};", postfix)?;
}
Node::LibraryCall {
library_function,
ref args,
ty,
device,
} => match library_function {
LibraryFunction::GEMM => {
assert_eq!(args.len(), 3);
assert_eq!(self.typing[args[0].idx()], ty);
let c_ty = &self.module.types[self.typing[args[0].idx()].idx()];
let a_ty = &self.module.types[self.typing[args[1].idx()].idx()];
let b_ty = &self.module.types[self.typing[args[2].idx()].idx()];
let (
Type::Array(c_elem, c_dims),
Type::Array(a_elem, a_dims),
Type::Array(b_elem, b_dims),
) = (c_ty, a_ty, b_ty)
else {
panic!();
};
assert_eq!(a_elem, b_elem);
assert_eq!(a_elem, c_elem);
assert_eq!(c_dims.len(), 2);
assert_eq!(a_dims.len(), 2);
assert_eq!(b_dims.len(), 2);
assert_eq!(a_dims[1], b_dims[0]);
assert_eq!(a_dims[0], c_dims[0]);
assert_eq!(b_dims[1], c_dims[1]);
let block = &mut blocks.get_mut(&bb).unwrap().data;
let prim_ty = self.library_prim_ty(*a_elem);
write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?;
self.codegen_dynamic_constant(a_dims[0], block)?;
write!(block, ", ")?;
self.codegen_dynamic_constant(a_dims[1], block)?;
write!(block, ", ")?;
self.codegen_dynamic_constant(b_dims[1], block)?;
write!(
block,
", {}.0, {}.0, {}.0, {});",
self.get_value(args[0], bb, false),
self.get_value(args[1], bb, false),
self.get_value(args[2], bb, false),
prim_ty
)?;
write!(
block,
"{} = {};",
self.get_value(id, bb, true),
self.get_value(args[0], bb, false)
)?;
}
},
Node::Unary { op, input } => {
let block = &mut blocks.get_mut(&bb).unwrap().data;
match op {
......@@ -1316,6 +1369,25 @@ impl<'a> RTContext<'a> {
fn get_type(&self, id: TypeID) -> &'static str {
convert_type(&self.module.types[id.idx()])
}
fn library_prim_ty(&self, id: TypeID) -> &'static str {
match self.module.types[id.idx()] {
Type::Boolean => "::hercules_rt::PrimTy::Bool",
Type::Integer8 => "::hercules_rt::PrimTy::I8",
Type::Integer16 => "::hercules_rt::PrimTy::I16",
Type::Integer32 => "::hercules_rt::PrimTy::I32",
Type::Integer64 => "::hercules_rt::PrimTy::I64",
Type::UnsignedInteger8 => "::hercules_rt::PrimTy::U8",
Type::UnsignedInteger16 => "::hercules_rt::PrimTy::U16",
Type::UnsignedInteger32 => "::hercules_rt::PrimTy::U32",
Type::UnsignedInteger64 => "::hercules_rt::PrimTy::U64",
Type::Float8 => "::hercules_rt::PrimTy::F8",
Type::BFloat16 => "::hercules_rt::PrimTy::BF16",
Type::Float32 => "::hercules_rt::PrimTy::F32",
Type::Float64 => "::hercules_rt::PrimTy::F64",
_ => panic!(),
}
}
}
fn convert_type(ty: &Type) -> &'static str {
......
......@@ -29,7 +29,10 @@ pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) {
eprintln!("__cpu_dealloc: {:?}, {}", ptr, size);
assert!(!ptr.is_null() || size == 0);
}
dealloc(ptr, Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap())
dealloc(
ptr,
Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap(),
)
}
pub unsafe fn __cpu_zero_mem(ptr: *mut u8, size: usize) {
......@@ -103,14 +106,45 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) {
___copy_cuda_to_cuda(dst, src, size);
}
#[repr(u8)]
#[derive(Debug, Copy, Clone)]
pub enum PrimTy {
Bool,
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
F8,
BF16,
F32,
F64,
}
#[cfg(feature = "cuda")]
pub unsafe fn __library_cuda_gemm(
i: u64,
j: u64,
k: u64,
c: *mut u8,
a: *const u8,
b: *const u8,
ty: PrimTy,
) {
panic!("{} {} {} {:?} {:?} {:?} {:?}", i, j, k, c, a, b, ty);
}
#[cfg(feature = "cuda")]
extern "C" {
pub fn ___cuda_alloc(size: usize) -> *mut u8;
pub fn ___cuda_dealloc(ptr: *mut u8, size: usize);
pub fn ___cuda_zero_mem(ptr: *mut u8, size: usize);
pub fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
pub fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
pub fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
fn ___cuda_alloc(size: usize) -> *mut u8;
fn ___cuda_dealloc(ptr: *mut u8, size: usize);
fn ___cuda_zero_mem(ptr: *mut u8, size: usize);
fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize);
fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize);
}
#[derive(Clone, Debug)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment