From 328e0d2c645c1de891b2b75c8ff39f72e5aa68c5 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Tue, 11 Feb 2025 11:19:52 -0600 Subject: [PATCH] Float8 / BFloat16 --- hercules_cg/src/cpu.rs | 6 ++++-- hercules_cg/src/gpu.rs | 7 +++++-- hercules_cg/src/lib.rs | 4 ++-- hercules_cg/src/rt.rs | 6 ++++-- hercules_ir/src/build.rs | 9 +++++++++ hercules_ir/src/collections.rs | 2 +- hercules_ir/src/ir.rs | 6 ++++++ hercules_ir/src/parse.rs | 1 + hercules_ir/src/typecheck.rs | 2 ++ hercules_opt/src/editor.rs | 2 ++ hercules_opt/src/gcm.rs | 4 ++-- hercules_opt/src/interprocedural_sroa.rs | 2 -- hercules_opt/src/sroa.rs | 1 + hercules_test/hercules_interpreter/src/value.rs | 2 ++ juno_frontend/src/lang.l | 2 ++ juno_frontend/src/lang.y | 4 +++- juno_frontend/src/semant.rs | 2 ++ juno_frontend/src/types.rs | 14 +++++++++++++- juno_samples/fork_join_tests/src/gpu.sch | 1 - juno_samples/matmul/src/main.rs | 8 ++++---- juno_samples/multi_device/src/main.rs | 2 +- 21 files changed, 66 insertions(+), 21 deletions(-) diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 20a3e6cb..ba78e8e2 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -811,8 +811,10 @@ impl<'a> CPUContext<'a> { fn codegen_type_size(&self, ty: TypeID, body: &mut String) -> Result<String, Error> { match self.types[ty.idx()] { Type::Control => panic!(), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => Ok("1".to_string()), - Type::Integer16 | Type::UnsignedInteger16 => Ok("2".to_string()), + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { + Ok("1".to_string()) + } + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => Ok("2".to_string()), Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => Ok("4".to_string()), Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => Ok("8".to_string()), Type::Product(ref fields) => { diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index f20bfeb7..e6b540ae 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -359,6 +359,7 @@ impl GPUContext<'_> { #include <mma.h> #include <cooperative_groups.h> #include <cooperative_groups/reduce.h> +#include <cuda_bf16.h> namespace cg = cooperative_groups; #define uabs(a) (a) @@ -2063,8 +2064,8 @@ extern \"C\" {} {}(", .map(|field| self.get_alignment(*field)) .max() .unwrap_or(0), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1, - Type::Integer16 | Type::UnsignedInteger16 => 2, + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => 1, + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => 2, Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4, Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8, _ => panic!("Unsupported type for alignment"), @@ -2286,6 +2287,8 @@ fn convert_type(ty: &Type, make_pointer: bool) -> String { Type::UnsignedInteger32 => "unsigned int".to_string(), Type::Integer64 => "long long".to_string(), Type::UnsignedInteger64 => "unsigned long long".to_string(), + Type::Float8 => "__nv_fp8_e4m3".to_string(), + Type::BFloat16 => "nv_bfloat16".to_string(), Type::Float32 => "float".to_string(), Type::Float64 => "double".to_string(), _ => panic!("Unsupported type"), diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index 217f5879..15946f72 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -24,8 +24,8 @@ pub const LARGEST_ALIGNMENT: usize = 8; pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize { match types[ty.idx()] { Type::Control => panic!(), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => 1, - Type::Integer16 | Type::UnsignedInteger16 => 2, + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => 1, + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => 2, Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4, Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => 8, Type::Product(ref members) | Type::Summation(ref members) => members diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index b79e4953..62f683ce 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -832,8 +832,10 @@ impl<'a> RTContext<'a> { fn codegen_type_size(&self, ty: TypeID) -> String { match self.module.types[ty.idx()] { Type::Control => panic!(), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => "1".to_string(), - Type::Integer16 | Type::UnsignedInteger16 => "2".to_string(), + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { + "1".to_string() + } + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => "2".to_string(), Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => "4".to_string(), Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => "8".to_string(), Type::Product(ref fields) => { diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index b8044045..40538cef 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -143,6 +143,14 @@ impl<'a> Builder<'a> { self.intern_type(Type::UnsignedInteger64) } + pub fn create_type_fp8(&mut self) -> TypeID { + self.intern_type(Type::Float8) + } + + pub fn create_type_bf16(&mut self) -> TypeID { + self.intern_type(Type::BFloat16) + } + pub fn create_type_f32(&mut self) -> TypeID { self.intern_type(Type::Float32) } @@ -393,6 +401,7 @@ impl<'a> Builder<'a> { Type::UnsignedInteger16 => self.create_constant_u16(0), Type::UnsignedInteger32 => self.create_constant_u32(0), Type::UnsignedInteger64 => self.create_constant_u64(0), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => self.create_constant_f32(0.0), Type::Float64 => self.create_constant_f64(0.0), Type::Product(fs) => { diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index d236d5b5..f3474ae0 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::iter::{once, repeat, zip}; use either::Either; diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index c2ebd86c..1dce5cfc 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -70,6 +70,8 @@ pub enum Type { UnsignedInteger16, UnsignedInteger32, UnsignedInteger64, + Float8, + BFloat16, Float32, Float64, Product(Box<[TypeID]>), @@ -384,6 +386,8 @@ impl Module { Type::UnsignedInteger16 => write!(w, "UnsignedInteger16"), Type::UnsignedInteger32 => write!(w, "UnsignedInteger32"), Type::UnsignedInteger64 => write!(w, "UnsignedInteger64"), + Type::Float8 => write!(w, "Float8"), + Type::BFloat16 => write!(w, "BFloat16"), Type::Float32 => write!(w, "Float32"), Type::Float64 => write!(w, "Float64"), Type::Product(fields) => { @@ -874,6 +878,8 @@ impl Type { pub fn is_float(&self) -> bool { match self { + Type::Float8 => true, + Type::BFloat16 => true, Type::Float32 => true, Type::Float64 => true, _ => false, diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 257dd4d9..f1f4153a 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -1004,6 +1004,7 @@ fn parse_constant<'a>( Type::UnsignedInteger16 => parse_unsigned_integer16(ir_text)?, Type::UnsignedInteger32 => parse_unsigned_integer32(ir_text)?, Type::UnsignedInteger64 => parse_unsigned_integer64(ir_text)?, + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => parse_float32(ir_text)?, Type::Float64 => parse_float64(ir_text)?, Type::Product(ref tys) => { diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index f7ea397e..d01a5c58 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -1144,6 +1144,8 @@ impl<'a> DCSubst<'a> { | Type::UnsignedInteger16 | Type::UnsignedInteger32 | Type::UnsignedInteger64 + | Type::Float8 + | Type::BFloat16 | Type::Float32 | Type::Float64 => typ, Type::Product(ts) => { diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 9f2b7ef4..8c339d72 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -753,6 +753,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { Type::UnsignedInteger16 => Constant::UnsignedInteger16(0), Type::UnsignedInteger32 => Constant::UnsignedInteger32(0), Type::UnsignedInteger64 => Constant::UnsignedInteger64(0), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(0.0)), Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(0.0)), Type::Control => panic!("PANIC: Can't create zero constant for the control type."), @@ -779,6 +780,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { Type::UnsignedInteger16 => Constant::UnsignedInteger16(1), Type::UnsignedInteger32 => Constant::UnsignedInteger32(1), Type::UnsignedInteger64 => Constant::UnsignedInteger64(1), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(1.0)), Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(1.0)), Type::Control => panic!("PANIC: Can't create one constant for the control type."), diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 79bd2851..99c44d52 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1379,10 +1379,10 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> let ty = edit.get_type(ty_id).clone(); let size = match ty { Type::Control => panic!(), - Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => { + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { edit.add_dynamic_constant(DynamicConstant::Constant(1)) } - Type::Integer16 | Type::UnsignedInteger16 => { + Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => { edit.add_dynamic_constant(DynamicConstant::Constant(2)) } Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => { diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index f22c1fe8..b345f9bc 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -322,8 +322,6 @@ fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_edi .collect::<HashMap<_, _>>(); let edit_successful = editor.edit(|mut edit| { - let mut substituted = old_return_type_ids[function_id.idx()]; - let substituted = substitute_dynamic_constants_in_type( &substs, old_return_type_ids[function_id.idx()], diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 66d11d69..3210094d 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -1007,6 +1007,7 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { Type::UnsignedInteger16 => add_const!(editor, Constant::UnsignedInteger16(0)), Type::UnsignedInteger32 => add_const!(editor, Constant::UnsignedInteger32(0)), Type::UnsignedInteger64 => add_const!(editor, Constant::UnsignedInteger64(0)), + Type::Float8 | Type::BFloat16 => panic!(), Type::Float32 => add_const!(editor, Constant::Float32(ordered_float::OrderedFloat(0.0))), Type::Float64 => add_const!(editor, Constant::Float64(ordered_float::OrderedFloat(0.0))), Type::Summation(ts) => { diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index 4a802f7a..dfc290b2 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -841,6 +841,8 @@ impl<'a> InterpreterVal { Type::UnsignedInteger16 => todo!(), Type::UnsignedInteger32 => todo!(), Type::UnsignedInteger64 => Self::UnsignedInteger64(val.try_into().unwrap()), + Type::Float8 => todo!(), + Type::BFloat16 => todo!(), Type::Float32 => todo!(), Type::Float64 => todo!(), Type::Product(_) => todo!(), diff --git a/juno_frontend/src/lang.l b/juno_frontend/src/lang.l index 94a12002..4039ef0a 100644 --- a/juno_frontend/src/lang.l +++ b/juno_frontend/src/lang.l @@ -53,6 +53,8 @@ u16 "u16" u32 "u32" u64 "u64" usize "usize" +fp8 "fp8" +bf16 "bf16" f32 "f32" f64 "f64" void "void" diff --git a/juno_frontend/src/lang.y b/juno_frontend/src/lang.y index b47186ff..13d7e292 100644 --- a/juno_frontend/src/lang.y +++ b/juno_frontend/src/lang.y @@ -143,6 +143,8 @@ PrimType -> Result<Primitive, ()> | 'i64' { Ok(Primitive::I64) } | 'u64' { Ok(Primitive::U64) } | 'usize' { Ok(Primitive::USize) } + | 'fp8' { Ok(Primitive::FP8) } + | 'bf16' { Ok(Primitive::BF16) } | 'f32' { Ok(Primitive::F32) } | 'f64' { Ok(Primitive::F64) } | 'void' { Ok(Primitive::Void) } @@ -621,7 +623,7 @@ pub type ImportName = (PackageName, Option<Span>); // option is the wildcard * #[derive(Debug, Copy, Clone)] pub enum Kind { Type, USize, Number, Integer, Float } #[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum Primitive { Bool, I8, U8, I16, U16, I32, U32, I64, U64, USize, F32, F64, Void } +pub enum Primitive { Bool, I8, U8, I16, U16, I32, U32, I64, U64, USize, FP8, BF16, F32, F64, Void } #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AssignOp { None, Add, Sub, Mul, Div, Mod, BitAnd, BitOr, Xor, LogAnd, LogOr, LShift, RShift } diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index 8668d1b4..86979128 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -5057,6 +5057,8 @@ fn convert_primitive(prim: parser::Primitive) -> types::Primitive { parser::Primitive::I64 => types::Primitive::I64, parser::Primitive::U64 => types::Primitive::U64, parser::Primitive::USize => types::Primitive::U64, + parser::Primitive::FP8 => types::Primitive::FP8, + parser::Primitive::BF16 => types::Primitive::BF16, parser::Primitive::F32 => types::Primitive::F32, parser::Primitive::F64 => types::Primitive::F64, parser::Primitive::Void => types::Primitive::Unit, diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs index d4d8b233..edb51db5 100644 --- a/juno_frontend/src/types.rs +++ b/juno_frontend/src/types.rs @@ -82,6 +82,8 @@ pub enum Primitive { I32, U64, I64, + FP8, + BF16, F32, F64, Unit, @@ -102,6 +104,8 @@ impl Primitive { | Primitive::I32 | Primitive::U64 | Primitive::I64 + | Primitive::FP8 + | Primitive::BF16 | Primitive::F32 | Primitive::F64 => true, _ => false, @@ -118,7 +122,7 @@ impl Primitive { _ => false, }, parser::Kind::Float => match self { - Primitive::F32 | Primitive::F64 => true, + Primitive::FP8 | Primitive::BF16 | Primitive::F32 | Primitive::F64 => true, _ => false, }, } @@ -135,6 +139,8 @@ impl Primitive { Primitive::U32 => "u32".to_string(), Primitive::I64 => "i64".to_string(), Primitive::U64 => "u64".to_string(), + Primitive::FP8 => "fp8".to_string(), + Primitive::BF16 => "bf16".to_string(), Primitive::F32 => "f32".to_string(), Primitive::F64 => "f64".to_string(), Primitive::Unit => "()".to_string(), @@ -1005,6 +1011,10 @@ impl TypeSolverInst<'_> { Primitive::U32 } else if type_id == builder.create_type_u64() { Primitive::U64 + } else if type_id == builder.create_type_fp8() { + Primitive::FP8 + } else if type_id == builder.create_type_bf16() { + Primitive::BF16 } else if type_id == builder.create_type_f32() { Primitive::F32 } else if type_id == builder.create_type_f64() { @@ -1025,6 +1035,8 @@ impl TypeSolverInst<'_> { Primitive::U16 => builder.create_type_u16(), Primitive::U32 => builder.create_type_u32(), Primitive::U64 => builder.create_type_u64(), + Primitive::FP8 => builder.create_type_fp8(), + Primitive::BF16 => builder.create_type_bf16(), Primitive::F32 => builder.create_type_f32(), Primitive::F64 => builder.create_type_f64(), Primitive::Unit => builder.create_type_prod(vec![].into()), diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 159fac94..117cf37e 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -54,7 +54,6 @@ simplify-cfg(auto.test7); dce(auto.test7); let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8); -xdot[true](*); dce(auto.test8); ccp(auto.test8); dce(auto.test8); diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 2892cd34..98b6c777 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -13,8 +13,8 @@ fn main() { const I: usize = 256; const J: usize = 64; const K: usize = 128; - let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); - let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); + let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); + let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); for i in 0..I { for k in 0..K { @@ -40,8 +40,8 @@ fn main() { } #[cfg(feature = "cuda")] { - let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); - let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); + let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&a)); + let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&b)); let mut r = runner!(matmul); let c = r .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) diff --git a/juno_samples/multi_device/src/main.rs b/juno_samples/multi_device/src/main.rs index f62de56c..c989bd7b 100644 --- a/juno_samples/multi_device/src/main.rs +++ b/juno_samples/multi_device/src/main.rs @@ -6,8 +6,8 @@ use hercules_rt::runner; #[cfg(feature = "cuda")] juno_build::juno!("multi_device"); -#[cfg(feature = "cuda")] fn main() { + #[cfg(feature = "cuda")] async_std::task::block_on(async { let mut r = runner!(multi_device_1); let out = r.run(42).await; -- GitLab