diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 20a3e6cb1a2b663e2bd3a7fc41df82fe0ef443db..ba78e8e2408e75c129cc4bd787928ea481c119a4 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -811,8 +811,10 @@ impl<'a> CPUContext<'a> { fn codegen_type_size(&self, ty: TypeID, body: &mut String) -> Result<String, Error> { match self.types[ty.idx()] { Type::Control => panic!(), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => Ok("1".to_string()), - Type::Integer16 | Type::UnsignedInteger16 => Ok("2".to_string()), + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { + Ok("1".to_string()) + } + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => Ok("2".to_string()), Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => Ok("4".to_string()), Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => Ok("8".to_string()), Type::Product(ref fields) => { diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index f20bfeb727dc6233dbb212a72c3a25b0c720cf36..e6b540aed6218f56e0a1fa33b0e4b1e1db09e9b0 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -359,6 +359,7 @@ impl GPUContext<'_> { #include <mma.h> #include <cooperative_groups.h> #include <cooperative_groups/reduce.h> +#include <cuda_bf16.h> namespace cg = cooperative_groups; #define uabs(a) (a) @@ -2063,8 +2064,8 @@ extern \"C\" {} {}(", .map(|field| self.get_alignment(*field)) .max() .unwrap_or(0), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1, - Type::Integer16 | Type::UnsignedInteger16 => 2, + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => 1, + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => 2, Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4, Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8, _ => panic!("Unsupported type for alignment"), @@ -2286,6 +2287,8 @@ fn convert_type(ty: &Type, make_pointer: bool) -> String { Type::UnsignedInteger32 => "unsigned int".to_string(), Type::Integer64 => "long long".to_string(), Type::UnsignedInteger64 => "unsigned long long".to_string(), + Type::Float8 => "__nv_fp8_e4m3".to_string(), + Type::BFloat16 => "nv_bfloat16".to_string(), Type::Float32 => "float".to_string(), Type::Float64 => "double".to_string(), _ => panic!("Unsupported type"), diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index 217f5879ab7a76ec74aafa63839a49fcad68d855..15946f72f86c44eaa2bf537450c1d22a57ba6992 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -24,8 +24,8 @@ pub const LARGEST_ALIGNMENT: usize = 8; pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize { match types[ty.idx()] { Type::Control => panic!(), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1, - Type::Integer16 | Type::UnsignedInteger16 => 2, + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => 1, + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => 2, Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4, Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8, Type::Product(ref members) | Type::Summation(ref members) => members diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index b79e4953346f2f47a81bf928eadb647ed7217645..62f683cec8a5b9a090515010a1d0832047a6e379 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -832,8 +832,10 @@ impl<'a> RTContext<'a> { fn codegen_type_size(&self, ty: TypeID) -> String { match self.module.types[ty.idx()] { Type::Control => panic!(), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => "1".to_string(), - Type::Integer16 | Type::UnsignedInteger16 => "2".to_string(), + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { + "1".to_string() + } + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => "2".to_string(), Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => "4".to_string(), Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => "8".to_string(), Type::Product(ref fields) => { diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index b804404524bb26e1e52c8f751bc416b7df84040d..40538cef34e34b886348ce38bbacbb4e596d00a7 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -143,6 +143,14 @@ impl<'a> Builder<'a> { self.intern_type(Type::UnsignedInteger64) } + pub fn create_type_fp8(&mut self) -> TypeID { + self.intern_type(Type::Float8) + } + + pub fn create_type_bf16(&mut self) -> TypeID { + self.intern_type(Type::BFloat16) + } + pub fn create_type_f32(&mut self) -> TypeID { self.intern_type(Type::Float32) } @@ -393,6 +401,7 @@ impl<'a> Builder<'a> { Type::UnsignedInteger16 => self.create_constant_u16(0), Type::UnsignedInteger32 => self.create_constant_u32(0), Type::UnsignedInteger64 => self.create_constant_u64(0), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => self.create_constant_f32(0.0), Type::Float64 => self.create_constant_f64(0.0), Type::Product(fs) => { diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index d236d5b55cd61f5f0cbbd721660779f2d3ba5e61..f3474ae069c98c7d7cd2cd46bdbcc7a2719c0c56 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::iter::{once, repeat, zip}; use either::Either; diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index c2ebd86ca07f7ada3d2cb582a1154644bc3f978d..1dce5cfc07a7d67878e891dca551746433606180 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -70,6 +70,8 @@ pub enum Type { UnsignedInteger16, UnsignedInteger32, UnsignedInteger64, + Float8, + BFloat16, Float32, Float64, Product(Box<[TypeID]>), @@ -384,6 +386,8 @@ impl Module { Type::UnsignedInteger16 => write!(w, "UnsignedInteger16"), Type::UnsignedInteger32 => write!(w, "UnsignedInteger32"), Type::UnsignedInteger64 => write!(w, "UnsignedInteger64"), + Type::Float8 => write!(w, "Float8"), + Type::BFloat16 => write!(w, "BFloat16"), Type::Float32 => write!(w, "Float32"), Type::Float64 => write!(w, "Float64"), Type::Product(fields) => { @@ -874,6 +878,8 @@ impl Type { pub fn is_float(&self) -> bool { match self { + Type::Float8 => true, + Type::BFloat16 => true, Type::Float32 => true, Type::Float64 => true, _ => false, diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 257dd4d998341feb2ad6e326a1e4b9e58b31c1e3..f1f4153a1e5d202d6ff6d3e75e5e8353bab658e5 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -1004,6 +1004,7 @@ fn parse_constant<'a>( Type::UnsignedInteger16 => parse_unsigned_integer16(ir_text)?, Type::UnsignedInteger32 => parse_unsigned_integer32(ir_text)?, Type::UnsignedInteger64 => parse_unsigned_integer64(ir_text)?, + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => parse_float32(ir_text)?, Type::Float64 => parse_float64(ir_text)?, Type::Product(ref tys) => { diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index f7ea397e49355029c7b5cbc0fa534494bc747c6d..d01a5c58d8de15ba291a75a6f1b0528433e36d2c 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -1144,6 +1144,8 @@ impl<'a> DCSubst<'a> { | Type::UnsignedInteger16 | Type::UnsignedInteger32 | Type::UnsignedInteger64 + | Type::Float8 + | Type::BFloat16 | Type::Float32 | Type::Float64 => typ, Type::Product(ts) => { diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 9f2b7ef40adc2933de4d2715c654fecceb087903..8c339d728bf3e139a8960673fcd687806c7def70 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -753,6 +753,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { Type::UnsignedInteger16 => Constant::UnsignedInteger16(0), Type::UnsignedInteger32 => Constant::UnsignedInteger32(0), Type::UnsignedInteger64 => Constant::UnsignedInteger64(0), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(0.0)), Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(0.0)), Type::Control => panic!("PANIC: Can't create zero constant for the control type."), @@ -779,6 +780,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { Type::UnsignedInteger16 => Constant::UnsignedInteger16(1), Type::UnsignedInteger32 => Constant::UnsignedInteger32(1), Type::UnsignedInteger64 => Constant::UnsignedInteger64(1), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(1.0)), Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(1.0)), Type::Control => panic!("PANIC: Can't create one constant for the control type."), diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 79bd2851303a125775385d08ea3a7c30784094a1..99c44d52ced301fd67be5277aa71f5b412965ae5 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1379,10 +1379,10 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> let ty = edit.get_type(ty_id).clone(); let size = match ty { Type::Control => panic!(), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => { + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { edit.add_dynamic_constant(DynamicConstant::Constant(1)) } - Type::Integer16 | Type::UnsignedInteger16 => { + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => { edit.add_dynamic_constant(DynamicConstant::Constant(2)) } Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => { diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index f22c1fe8410bdb008dbffe4acb90fa1679f9e44e..b345f9bcaaab82da8c7b5cb178bf39e3f8761dbc 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -322,8 +322,6 @@ fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_edi .collect::<HashMap<_, _>>(); let edit_successful = editor.edit(|mut edit| { - let mut substituted = old_return_type_ids[function_id.idx()]; - let substituted = substitute_dynamic_constants_in_type( &substs, old_return_type_ids[function_id.idx()], diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 66d11d69c33d1a77ce5a54bfd13ad88618916bfd..3210094ded7006001e6a82ee356161221cc0b468 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -1007,6 +1007,7 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)), Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)), Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))), Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))), Type::Summation(ts) => { diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index 4a802f7a7b7c2c4090380929fcc93824b2c79244..dfc290b253666c8251370450f9f2893fe78d8830 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -841,6 +841,8 @@ impl<'a> InterpreterVal { Type::UnsignedInteger16 => todo!(), Type::UnsignedInteger32 => todo!(), Type::UnsignedInteger64 => Self::UnsignedInteger64(val.try_into().unwrap()), + Type::Float8 => todo!(), + Type::BFloat16 => todo!(), Type::Float32 => todo!(), Type::Float64 => todo!(), Type::Product(_) => todo!(), diff --git a/juno_frontend/src/lang.l b/juno_frontend/src/lang.l index 94a12002131d8fb439b8e8cabc4cae4245d00ead..4039ef0a32b89742254be2d69b5c0f15f395abfb 100644 --- a/juno_frontend/src/lang.l +++ b/juno_frontend/src/lang.l @@ -53,6 +53,8 @@ u16 "u16" u32 "u32" u64 "u64" usize "usize" +fp8 "fp8" +bf16 "bf16" f32 "f32" f64 "f64" void "void" diff --git a/juno_frontend/src/lang.y b/juno_frontend/src/lang.y index b47186ff24ab91f4650adf6ffaed10b2ba0dd251..13d7e292a366dca1ae8bc9f859ee592293d5da7d 100644 --- a/juno_frontend/src/lang.y +++ b/juno_frontend/src/lang.y @@ -143,6 +143,8 @@ PrimType -> Result<Primitive, ()> | 'i64' { Ok(Primitive::I64) } | 'u64' { Ok(Primitive::U64) } | 'usize' { Ok(Primitive::USize) } + | 'fp8' { Ok(Primitive::FP8) } + | 'bf16' { Ok(Primitive::BF16) } | 'f32' { Ok(Primitive::F32) } | 'f64' { Ok(Primitive::F64) } | 'void' { Ok(Primitive::Void) } @@ -621,7 +623,7 @@ pub type ImportName = (PackageName, Option<Span>); // option is the wildcard * #[derive(Debug, Copy, Clone)] pub enum Kind { Type, USize, Number, Integer, Float } #[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum Primitive { Bool, I8, U8, I16, U16, I32, U32, I64, U64, USize, F32, F64, Void } +pub enum Primitive { Bool, I8, U8, I16, U16, I32, U32, I64, U64, USize, FP8, BF16, F32, F64, Void } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AssignOp { None, Add, Sub, Mul, Div, Mod, BitAnd, BitOr, Xor, LogAnd, LogOr, LShift, RShift } diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index 8668d1b45f7d68cff90ec90f7b2b55953b8d5a9c..8697912869cd0eb33e719958729b4eb16ade6d9c 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -5057,6 +5057,8 @@ fn convert_primitive(prim: parser::Primitive) -> types::Primitive { parser::Primitive::I64 => types::Primitive::I64, parser::Primitive::U64 => types::Primitive::U64, parser::Primitive::USize => types::Primitive::U64, + parser::Primitive::FP8 => types::Primitive::FP8, + parser::Primitive::BF16 => types::Primitive::BF16, parser::Primitive::F32 => types::Primitive::F32, parser::Primitive::F64 => types::Primitive::F64, parser::Primitive::Void => types::Primitive::Unit, diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs index d4d8b23326fcf624af05eb3c8dde89e8159eeba8..edb51db5c03eaef1ff1b6aa162a7767d38adc161 100644 --- a/juno_frontend/src/types.rs +++ b/juno_frontend/src/types.rs @@ -82,6 +82,8 @@ pub enum Primitive { I32, U64, I64, + FP8, + BF16, F32, F64, Unit, @@ -102,6 +104,8 @@ impl Primitive { | Primitive::I32 | Primitive::U64 | Primitive::I64 + | Primitive::FP8 + | Primitive::BF16 | Primitive::F32 | Primitive::F64 => true, _ => false, @@ -118,7 +122,7 @@ impl Primitive { _ => false, }, parser::Kind::Float => match self { - Primitive::F32 | Primitive::F64 => true, + Primitive::FP8 | Primitive::BF16 | Primitive::F32 | Primitive::F64 => true, _ => false, }, } @@ -135,6 +139,8 @@ impl Primitive { Primitive::U32 => "u32".to_string(), Primitive::I64 => "i64".to_string(), Primitive::U64 => "u64".to_string(), + Primitive::FP8 => "fp8".to_string(), + Primitive::BF16 => "bf16".to_string(), Primitive::F32 => "f32".to_string(), Primitive::F64 => "f64".to_string(), Primitive::Unit => "()".to_string(), @@ -1005,6 +1011,10 @@ impl TypeSolverInst<'_> { Primitive::U32 } else if type_id == builder.create_type_u64() { Primitive::U64 + } else if type_id == builder.create_type_fp8() { + Primitive::FP8 + } else if type_id == builder.create_type_bf16() { + Primitive::BF16 } else if type_id == builder.create_type_f32() { Primitive::F32 } else if type_id == builder.create_type_f64() { @@ -1025,6 +1035,8 @@ impl TypeSolverInst<'_> { Primitive::U16 => builder.create_type_u16(), Primitive::U32 => builder.create_type_u32(), Primitive::U64 => builder.create_type_u64(), + Primitive::FP8 => builder.create_type_fp8(), + Primitive::BF16 => builder.create_type_bf16(), Primitive::F32 => builder.create_type_f32(), Primitive::F64 => builder.create_type_f64(), Primitive::Unit => builder.create_type_prod(vec![].into()), diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 159fac94ee83c1dc7aa886545959e6cec2ed5c7a..117cf37e2e2ac0c5ebf6a908253dc85b0e31029f 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -54,7 +54,6 @@ simplify-cfg(auto.test7); dce(auto.test7); let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8); -xdot[true](*); dce(auto.test8); ccp(auto.test8); dce(auto.test8); diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 2892cd3473251b60b0a2d3d381ae0848c8b7fadb..98b6c7776089b2dee6e67ea11f14d7914cf82025 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -13,8 +13,8 @@ fn main() { const I: usize = 256; const J: usize = 64; const K: usize = 128; - let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); - let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); + let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { for k in 0..K { @@ -40,8 +40,8 @@ fn main() { } #[cfg(feature = "cuda")] { - let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); - let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); + let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&a)); + let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b)); let mut r = runner!(matmul); let c = r .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) diff --git a/juno_samples/multi_device/src/main.rs b/juno_samples/multi_device/src/main.rs index f62de56cb216448c27253305473d224cf1a26a0b..c989bd7b1fa0b26bda9d65244585f84a261b0f8c 100644 --- a/juno_samples/multi_device/src/main.rs +++ b/juno_samples/multi_device/src/main.rs @@ -6,8 +6,8 @@ use hercules_rt::runner; #[cfg(feature = "cuda")] juno_build::juno!("multi_device"); -#[cfg(feature = "cuda")] fn main() { + #[cfg(feature = "cuda")] async_std::task::block_on(async { let mut r = runner!(multi_device_1); let out = r.run(42).await;