diff --git a/Cargo.lock b/Cargo.lock index 43eaa90984783995b495e7e059e5414506e5ef92..c438e846d87f12fd7eea2c27d73a8fc7e043836d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anstream" version = "0.6.18" @@ -621,6 +627,27 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "egg" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abb749745461743bb477fba3ef87c663d5965876155c676c9489cfe0963de5ab" +dependencies = [ + "env_logger", + "hashbrown", + "indexmap", + "log", + "num-bigint", + "num-traits", + "quanta", + "rustc-hash", + "saturating", + "smallvec", + "symbol_table", + "symbolic_expressions", + "thiserror", +] + [[package]] name = "either" version = "1.13.0" @@ -639,6 +666,15 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" +[[package]] +name = "env_logger" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +dependencies = [ + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -752,6 +788,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + [[package]] name = "funty" version = "1.1.0" @@ -882,6 +924,11 @@ name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "heapless" @@ -955,6 +1002,7 @@ version = "0.1.0" dependencies = [ "bimap", "bitvec 1.0.1", + "egg", "either", "hercules_cg", "hercules_ir", @@ -1214,7 +1262,7 @@ dependencies = [ "async-std", "hercules_rt", "juno_build", - "rand 0.8.5", + "rand 0.9.0", "with_builtin_macros", ] @@ -1970,6 +2018,21 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "quanta" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi 0.11.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quick-error" version = "2.0.1" @@ -2110,6 +2173,15 @@ dependencies = [ "rgb", ] +[[package]] +name = "raw-cpuid" +version = "11.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "529468c1335c1c03919960dfefdb1b3648858c20d7ec2d0663e728e4a717efbc" +dependencies = [ + "bitflags 2.8.0", +] + [[package]] name = "rayon" version = "1.10.0" @@ -2168,6 +2240,12 @@ version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustc_version" version = "0.4.1" @@ -2202,6 +2280,12 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +[[package]] +name = "saturating" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71" + [[package]] name = "scopeguard" version = "1.2.0" @@ -2333,6 +2417,23 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "symbol_table" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f19bffd69fb182e684d14e3c71d04c0ef33d1641ac0b9e81c712c734e83703bc" +dependencies = [ + "crossbeam-utils", + "foldhash", + "hashbrown", +] + +[[package]] +name = "symbolic_expressions" +version = "5.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c68d531d83ec6c531150584c42a4290911964d5f0d79132b193b67252a23b71" + [[package]] name = "syn" version = "1.0.109" @@ -2697,6 +2798,28 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows" version = "0.59.0" diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 4d9a6cf63605362cd1ddb96a8486974431cc791c..5edddd86df2901c29109148e2b107c0b923dbd0e 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -597,6 +597,59 @@ impl<'a> RTContext<'a> { } write!(block, "){};", postfix)?; } + Node::LibraryCall { + library_function, + ref args, + ty, + device, + } => match library_function { + LibraryFunction::GEMM => { + assert_eq!(args.len(), 3); + assert_eq!(self.typing[args[0].idx()], ty); + let c_ty = &self.module.types[self.typing[args[0].idx()].idx()]; + let a_ty = &self.module.types[self.typing[args[1].idx()].idx()]; + let b_ty = &self.module.types[self.typing[args[2].idx()].idx()]; + let ( + Type::Array(c_elem, c_dims), + Type::Array(a_elem, a_dims), + Type::Array(b_elem, b_dims), + ) = (c_ty, a_ty, b_ty) + else { + panic!(); + }; + assert_eq!(a_elem, b_elem); + assert_eq!(a_elem, c_elem); + assert_eq!(c_dims.len(), 2); + assert_eq!(a_dims.len(), 2); + assert_eq!(b_dims.len(), 2); + assert_eq!(a_dims[1], b_dims[0]); + assert_eq!(a_dims[0], c_dims[0]); + assert_eq!(b_dims[1], c_dims[1]); + + let block = &mut blocks.get_mut(&bb).unwrap().data; + let prim_ty = self.library_prim_ty(*a_elem); + write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?; + self.codegen_dynamic_constant(a_dims[0], block)?; + write!(block, ", ")?; + self.codegen_dynamic_constant(a_dims[1], block)?; + write!(block, ", ")?; + self.codegen_dynamic_constant(b_dims[1], block)?; + write!( + block, + ", {}.0, {}.0, {}.0, {});", + self.get_value(args[0], bb, false), + self.get_value(args[1], bb, false), + self.get_value(args[2], bb, false), + prim_ty + )?; + write!( + block, + "{} = {};", + self.get_value(id, bb, true), + self.get_value(args[0], bb, false) + )?; + } + }, Node::Unary { op, input } => { let block = &mut blocks.get_mut(&bb).unwrap().data; match op { @@ -1316,6 +1369,25 @@ impl<'a> RTContext<'a> { fn get_type(&self, id: TypeID) -> &'static str { convert_type(&self.module.types[id.idx()]) } + + fn library_prim_ty(&self, id: TypeID) -> &'static str { + match self.module.types[id.idx()] { + Type::Boolean => "::hercules_rt::PrimTy::Bool", + Type::Integer8 => "::hercules_rt::PrimTy::I8", + Type::Integer16 => "::hercules_rt::PrimTy::I16", + Type::Integer32 => "::hercules_rt::PrimTy::I32", + Type::Integer64 => "::hercules_rt::PrimTy::I64", + Type::UnsignedInteger8 => "::hercules_rt::PrimTy::U8", + Type::UnsignedInteger16 => "::hercules_rt::PrimTy::U16", + Type::UnsignedInteger32 => "::hercules_rt::PrimTy::U32", + Type::UnsignedInteger64 => "::hercules_rt::PrimTy::U64", + Type::Float8 => "::hercules_rt::PrimTy::F8", + Type::BFloat16 => "::hercules_rt::PrimTy::BF16", + Type::Float32 => "::hercules_rt::PrimTy::F32", + Type::Float64 => "::hercules_rt::PrimTy::F64", + _ => panic!(), + } + } } fn convert_type(ty: &Type) -> &'static str { diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index f3474ae069c98c7d7cd2cd46bdbcc7a2719c0c56..6b631519d69cf2a548164eaad58cb2574c6b70c3 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -218,6 +218,8 @@ pub fn collection_objects( // - Constant: may originate an object. // - Call: may originate an object and may return an object passed in as // a parameter. + // - LibraryCall: may return an object passed in as a parameter, but may + // not originate an object. // - Read: may extract a smaller object from the input - this is // considered to be the same object as the input, as no copy takes // place. @@ -288,6 +290,14 @@ pub fn collection_objects( } CollectionObjectLattice { objs } } + Node::LibraryCall { + library_function, + args: _, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => inputs[0].clone(), + }, Node::Undef { ty: _ } => { let obj = origins .iter() @@ -332,7 +342,13 @@ pub fn collection_objects( for object in objects_per_node[idx].iter() { mutated[object.idx()].push(NodeID::new(idx)); } - } else if let Some((_, callee, _, args)) = node.try_call() { + } else if let Node::Call { + control: _, + function: callee, + dynamic_constants: _, + args, + } = node + { let fco = &collection_objects[&callee]; for (param_idx, arg) in args.into_iter().enumerate() { // If this parameter corresponds to an object and it's @@ -347,6 +363,20 @@ pub fn collection_objects( } } } + } else if let Node::LibraryCall { + library_function, + args, + ty: _, + device: _, + } = node + { + match library_function { + LibraryFunction::GEMM => { + for object in objects_per_node[args[0].idx()].iter() { + mutated[object.idx()].push(NodeID::new(idx)); + } + } + } } } diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index d60f7fd5a84605cc31d22347cf7797b22e55a3cd..ff0e08edc8c15f76ea38243eebea7cba1940cd19 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -178,7 +178,13 @@ pub fn get_uses(node: &Node) -> NodeUses { uses.extend(args); NodeUses::Variable(uses.into_boxed_slice()) } - Node::IntrinsicCall { intrinsic: _, args } => NodeUses::Variable(args.clone()), + Node::IntrinsicCall { intrinsic: _, args } + | Node::LibraryCall { + library_function: _, + args, + ty: _, + device: _, + } => NodeUses::Variable(args.clone()), Node::Read { collect, indices } => { let mut uses = vec![]; for index in indices.iter() { @@ -276,9 +282,13 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { uses.extend(args); NodeUsesMut::Variable(uses.into_boxed_slice()) } - Node::IntrinsicCall { intrinsic: _, args } => { - NodeUsesMut::Variable(args.iter_mut().collect()) - } + Node::IntrinsicCall { intrinsic: _, args } + | Node::LibraryCall { + library_function: _, + args, + ty: _, + device: _, + } => NodeUsesMut::Variable(args.iter_mut().collect()), Node::Read { collect, indices } => { let mut uses = vec![]; for index in indices.iter_mut() { diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs index cbf8d6349857b6110cc4a9e03f574b358646b29a..c4a5454bc9ea3daaadc602c0c7517d8310c7be93 100644 --- a/hercules_ir/src/device.rs +++ b/hercules_ir/src/device.rs @@ -23,118 +23,3 @@ pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec devices } - -pub type FunctionObjectDeviceDemands = Vec<BTreeSet<Device>>; -pub type ObjectDeviceDemands = Vec<FunctionObjectDeviceDemands>; - -/* - * This analysis figures out which device each collection object may be on. At - * first, an object may need to be on different devices at different times. This - * is fine during optimization. - */ -pub fn object_device_demands( - functions: &Vec<Function>, - types: &Vec<Type>, - typing: &ModuleTyping, - callgraph: &CallGraph, - objects: &CollectionObjects, - devices: &Vec<Device>, -) -> ObjectDeviceDemands { - // An object is "demanded" on a device when: - // 1. The object is used by a primitive read node or write node in a device - // function. This includes objects on the `data` input to write nodes. - // Non-primitive reads don't demand an object on a device since they are - // lowered to pointer math and no actual memory transfers. - // 2. The object is a constant / undef defined in a device function. - // 3. The object is passed as input to a call node where the corresponding - // object in the callee is demanded on a device. - // 4. The object is returned from a call node where the corresponding object - // in the callee is demanded on a device. - // Note that reads and writes in a RT function don't induce a device demand. - // This is because RT functions can call device functions as necessary to - // arbitrarily move data onto / off of devices (though this may be slow). - // Traverse the functions in a module in reverse topological order, since - // the analysis of a function depends on all functions it calls. - let mut demands: ObjectDeviceDemands = vec![vec![]; functions.len()]; - let topo = callgraph.topo(); - - for func_id in topo { - let function = &functions[func_id.idx()]; - let typing = &typing[func_id.idx()]; - let device = devices[func_id.idx()]; - - demands[func_id.idx()].resize(objects[&func_id].num_objects(), BTreeSet::new()); - match device { - Device::LLVM | Device::CUDA => { - for (idx, node) in function.nodes.iter().enumerate() { - match node { - // Condition #1. - Node::Read { - collect, - indices: _, - } if types[typing[idx].idx()].is_primitive() => { - for object in objects[&func_id].objects(*collect) { - demands[func_id.idx()][object.idx()].insert(device); - } - } - Node::Write { - collect, - data, - indices: _, - } => { - for object in objects[&func_id] - .objects(*collect) - .into_iter() - .chain(objects[&func_id].objects(*data).into_iter()) - { - demands[func_id.idx()][object.idx()].insert(device); - } - } - // Condition #2. - Node::Constant { id: _ } | Node::Undef { ty: _ } => { - for object in objects[&func_id].objects(NodeID::new(idx)) { - demands[func_id.idx()][object.idx()].insert(device); - } - } - _ => {} - } - } - } - Device::AsyncRust => { - for (idx, node) in function.nodes.iter().enumerate() { - if let Node::Call { - control: _, - function: callee, - dynamic_constants: _, - args, - } = node - { - // Condition #3. - for (param_idx, arg) in args.into_iter().enumerate() { - if let Some(callee_obj) = objects[callee].param_to_object(param_idx) { - let callee_demands = - take(&mut demands[callee.idx()][callee_obj.idx()]); - for object in objects[&func_id].objects(*arg) { - demands[func_id.idx()][object.idx()] - .extend(callee_demands.iter()); - } - demands[callee.idx()][callee_obj.idx()] = callee_demands; - } - } - - // Condition #4. - for callee_obj in objects[callee].returned_objects() { - let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]); - for object in objects[&func_id].objects(NodeID::new(idx)) { - demands[func_id.idx()][object.idx()].extend(callee_demands.iter()); - } - demands[callee.idx()][callee_obj.idx()] = callee_demands; - } - } - } - } - } - } - - demands -} diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 0e0840853bdcb83caa954991f40397b33e1fd5eb..921a813d7c9d680defe2e9dd8fa750ddf960d0cb 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -318,6 +318,12 @@ fn write_node<W: Write>( Node::IntrinsicCall { intrinsic, args: _ } => { write!(&mut suffix, "{}", intrinsic.lower_case_name())? } + Node::LibraryCall { + library_function, + args: _, + ty: _, + device, + } => write!(&mut suffix, "{:?} on {:?}", library_function, device)?, Node::Read { collect: _, indices, diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index f96cc10c003315aaab022d69aaad32838ec53cf3..bf9698b32125832b047ee9a94668bd0a09b93ac9 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -222,6 +222,12 @@ pub enum Node { intrinsic: Intrinsic, args: Box<[NodeID]>, }, + LibraryCall { + library_function: LibraryFunction, + args: Box<[NodeID]>, + ty: TypeID, + device: Device, + }, Read { collect: NodeID, indices: Box<[Index]>, @@ -336,7 +342,7 @@ pub enum Schedule { * The authoritative enumeration of supported backends. Multiple backends may * correspond to the same kind of hardware. */ -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub enum Device { LLVM, CUDA, @@ -345,6 +351,14 @@ pub enum Device { AsyncRust, } +/* + * The authoritative enumeration of supported library calls. + */ +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub enum LibraryFunction { + GEMM, +} + /* * A single node may have multiple schedules. */ @@ -1531,6 +1545,12 @@ impl Node { intrinsic: _, args: _, } => "Intrinsic", + Node::LibraryCall { + library_function: _, + args: _, + ty: _, + device: _, + } => "Library", Node::Read { collect: _, indices: _, @@ -1604,6 +1624,12 @@ impl Node { intrinsic: _, args: _, } => "intrinsic", + Node::LibraryCall { + library_function: _, + args: _, + ty: _, + device: _, + } => "library", Node::Read { collect: _, indices: _, diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index d01a5c58d8de15ba291a75a6f1b0528433e36d2c..1ff890db4ff4ed3678be0338106f4add2b829aba 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -935,6 +935,12 @@ fn typeflow( } } } + Node::LibraryCall { + library_function: _, + args: _, + ty, + device: _, + } => Concrete(*ty), Node::Read { collect: _, indices, diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml index 0b250d28d9460bd0b19f574c27916e82802664c8..92e0533938eab0cc76fe4b69a205f4dceae2f244 100644 --- a/hercules_opt/Cargo.toml +++ b/hercules_opt/Cargo.toml @@ -21,4 +21,5 @@ serde = { version = "*", features = ["derive"] } hercules_cg = { path = "../hercules_cg" } hercules_ir = { path = "../hercules_ir" } nestify = "*" -bimap = "*" \ No newline at end of file +bimap = "*" +egg = "*" diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 1969430a95ea020e9d1fb87d4165ea219d3f96a9..b626148c936e384b6bc0a9aaf951c35a9c4b4736 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -933,6 +933,17 @@ fn ccp_flow_function( constant: new_constant, } } + Node::LibraryCall { + library_function: _, + args, + ty: _, + device: _, + } => CCPLattice { + reachability: args.iter().fold(ReachabilityLattice::bottom(), |val, id| { + ReachabilityLattice::join(&val, &inputs[id.idx()].reachability) + }), + constant: ConstantLattice::bottom(), + }, Node::Read { collect, indices } => { let mut reachability = inputs[collect.idx()].reachability.clone(); for index in indices.iter() { diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 821d02ea886fdcbd5ffa35ad9a298128a92dbb07..446b31849b20fd3fa3d5361a74c15f0514d1dbad 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -655,6 +655,14 @@ fn terminating_reads<'a>( None } })), + Node::LibraryCall { + library_function, + ref args, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))), + }, _ => Box::new(empty()), } } @@ -728,6 +736,16 @@ fn mutating_objects<'a>( }) .flatten(), ), + Node::LibraryCall { + library_function, + ref args, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => { + Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id)) + } + }, _ => Box::new(empty()), } } @@ -757,6 +775,14 @@ fn mutating_writes<'a>( None } })), + Node::LibraryCall { + library_function, + ref args, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => Box::new(once(args[0])), + }, _ => Box::new(empty()), } } @@ -1311,6 +1337,17 @@ fn color_nodes( } } } + Node::LibraryCall { + library_function: _, + ref args, + ty: _, + device, + } => { + for arg in args { + equations.push((UTerm::Node(*arg), UTerm::Device(device))); + } + equations.push((UTerm::Node(id), UTerm::Device(device))); + } _ => {} } } diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index b56f94086204444b49672ee6baefccdaa0b0cb8b..b25449e7bd143e82c6d8b2fb39f2d8e238bd5cfa 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -21,6 +21,7 @@ pub mod outline; pub mod phi_elim; pub mod pred; pub mod reuse_products; +pub mod rewrite_math_expressions; pub mod schedule; pub mod simplify_cfg; pub mod slf; @@ -49,6 +50,7 @@ pub use crate::outline::*; pub use crate::phi_elim::*; pub use crate::pred::*; pub use crate::reuse_products::*; +pub use crate::rewrite_math_expressions::*; pub use crate::schedule::*; pub use crate::simplify_cfg::*; pub use crate::slf::*; diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs new file mode 100644 index 0000000000000000000000000000000000000000..6f52dc58ee32751b96190b42a0add01d42ce3b1e --- /dev/null +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -0,0 +1,166 @@ +use std::collections::{HashMap, HashSet}; +use std::fmt::{Error, Write}; + +use hercules_ir::*; + +use egg::*; + +use crate::*; + +define_language! { + enum MathLanguage { + "zero" = Zero, + "one" = One, + + ForkDim(i64), + "tid" = ThreadID(Id), + "sum" = SumReduction(Box<[Id]>), + "array" = Comprehension(Box<[Id]>), + + "read" = Read(Box<[Id]>), + + "+" = Add([Id; 2]), + "*" = Mul([Id; 2]), + + "library_gemm" = LibraryGemm([Id; 2]), + + Opaque(Symbol), + } +} + +fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> { + vec![ + rewrite!("add-zero"; "(+ zero ?a)" => "?a"), + rewrite!("mul-zero"; "(* zero ?a)" => "zero"), + rewrite!("mul-one"; "(* one ?a)" => "?a"), + rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"), + ] +} + +pub fn rewrite_math_expressions( + editor: &mut FunctionEditor, + device: Device, + typing: &Vec<TypeID>, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, + reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>), +) { + let join_fork_map: HashMap<_, _> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); + + // Step 1: figure out how many fork-joins each reduce is in. We want to + // rewrite outer reductions before inner ones. + let mut depth: HashMap<NodeID, u32> = HashMap::new(); + for (_, inside) in nodes_in_fork_joins { + for id in inside { + if editor.func().nodes[id.idx()].is_reduce() { + *depth.entry(*id).or_default() += 1; + } + } + } + let mut reduces: Vec<(u32, NodeID)> = + depth.into_iter().map(|(id, depth)| (depth, id)).collect(); + reduces.sort(); + + for (_, reduce) in reduces { + let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap(); + let fork = join_fork_map[&join]; + + // Step 2: convert the reduce to an expression in egg and rewrite it. + let mut s = String::new(); + egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap(); + + let expr: RecExpr<MathLanguage> = s.parse().unwrap(); + let egraph = Runner::default().with_expr(&expr).run(&make_rules()).egraph; + + // Step 3: match the smallest expression against patterns for known + // library functions. + let gemm_pattern: Pattern<MathLanguage> = "(library_gemm ?A ?B)".parse().unwrap(); + let mut matches = gemm_pattern.search(&egraph); + if matches.len() > 0 + && let m = matches.remove(0) + && m.substs.len() > 0 + { + let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph); + let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph); + let call = Node::LibraryCall { + library_function: LibraryFunction::GEMM, + args: vec![init, left_id, right_id].into_boxed_slice(), + ty: typing[reduce.idx()], + device, + }; + let success = editor.edit(|mut edit| { + let call = edit.add_node(call); + edit.replace_all_uses_where(reduce, call, |id| { + !nodes_in_fork_joins[&fork].contains(id) + }) + }); + if success { + return; + } + } + } +} + +fn get_node_ids_from_subst(var: &str, subst: &Subst, egraph: &EGraph<MathLanguage, ()>) -> NodeID { + let id = *subst.get(var.parse::<Var>().unwrap()).unwrap(); + let expr = egraph.id_to_expr(id); + let MathLanguage::Opaque(sym) = expr.last().unwrap() else { + todo!(); + }; + let sym = sym.as_str(); + assert!(sym.chars().nth(0).unwrap() == 'n'); + NodeID::new(sym[1..].parse().unwrap()) +} + +fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> { + match env[id.idx()] { + MathExpr::Zero(_) => write!(w, "zero"), + MathExpr::One(_) => write!(w, "one"), + MathExpr::OpaqueNode(id) => write!(w, "n{}", id.idx()), + MathExpr::ThreadID(dim) => write!(w, "{}", dim.0), + MathExpr::SumReduction(id, ref dims) => { + write!(w, "(sum")?; + for dim in dims { + write!(w, " {}", dim.0)?; + } + write!(w, " ")?; + egg_print_math_expr(id, env, w)?; + write!(w, ")") + } + MathExpr::Comprehension(id, ref dims) => { + write!(w, "(array")?; + for dim in dims { + write!(w, " {}", dim.0)?; + } + write!(w, " ")?; + egg_print_math_expr(id, env, w)?; + write!(w, ")") + } + MathExpr::Read(id, ref pos) => { + write!(w, "(read")?; + for pos in pos { + write!(w, " ")?; + egg_print_math_expr(*pos, env, w)?; + } + write!(w, " ")?; + egg_print_math_expr(id, env, w)?; + write!(w, ")") + } + MathExpr::Binary(op, left, right) => { + write!(w, "(")?; + match op { + BinaryOperator::Add => write!(w, "+ "), + BinaryOperator::Mul => write!(w, "* "), + _ => Err(Error::default()), + }?; + egg_print_math_expr(left, env, w)?; + write!(w, " ")?; + egg_print_math_expr(right, env, w)?; + write!(w, ")") + } + _ => Err(Error::default()), + } +} diff --git a/hercules_opt/src/simplify_cfg.rs b/hercules_opt/src/simplify_cfg.rs index 14a152dcb3dae335e5f9724bdb1f467cca5e1c19..cf39db2b35283da7eb641101478df68f858cb8c3 100644 --- a/hercules_opt/src/simplify_cfg.rs +++ b/hercules_opt/src/simplify_cfg.rs @@ -126,9 +126,7 @@ fn remove_useless_fork_joins( // Third, get rid of fork-joins. for (fork, join) in fork_join_map { - if editor.get_users(*join).len() == 1 { - assert_eq!(editor.get_users(*fork).len(), 1); - + if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 { let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0]; let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0]; diff --git a/hercules_rt/build.rs b/hercules_rt/build.rs index 2a1538d600e1632e9311ef9464c38a83fdf44415..ab9dda2e0948e40d3c0bf226a855d56e273289c6 100644 --- a/hercules_rt/build.rs +++ b/hercules_rt/build.rs @@ -28,6 +28,7 @@ fn main() { println!("cargo::rustc-link-search=native=/opt/cuda/lib/"); println!("cargo::rustc-link-lib=static=rtdefs"); println!("cargo::rustc-link-lib=cudart"); + println!("cargo::rustc-link-lib=cublas"); println!("cargo::rerun-if-changed=src/rtdefs.cu"); } } diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 79b5cbac72422e82e53ac2e74dc15798968ea155..419a760fa49647c28295625bd7df9db0e87707c1 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -29,7 +29,10 @@ pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) { eprintln!("__cpu_dealloc: {:?}, {}", ptr, size); assert!(!ptr.is_null() || size == 0); } - dealloc(ptr, Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap()) + dealloc( + ptr, + Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap(), + ) } pub unsafe fn __cpu_zero_mem(ptr: *mut u8, size: usize) { @@ -103,14 +106,48 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) { ___copy_cuda_to_cuda(dst, src, size); } +#[derive(Debug, Copy, Clone)] +pub enum PrimTy { + Bool, + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, + F8, + BF16, + F32, + F64, +} + +#[cfg(feature = "cuda")] +pub unsafe fn __library_cuda_gemm( + i: u64, + j: u64, + k: u64, + c: *mut u8, + a: *const u8, + b: *const u8, + ty: PrimTy, +) { + match ty { + PrimTy::F32 => ___cublas_sgemm(i, j, k, c, a, b), + _ => todo!(), + } +} + #[cfg(feature = "cuda")] extern "C" { - pub fn ___cuda_alloc(size: usize) -> *mut u8; - pub fn ___cuda_dealloc(ptr: *mut u8, size: usize); - pub fn ___cuda_zero_mem(ptr: *mut u8, size: usize); - pub fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); - pub fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); - pub fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + fn ___cuda_alloc(size: usize) -> *mut u8; + fn ___cuda_dealloc(ptr: *mut u8, size: usize); + fn ___cuda_zero_mem(ptr: *mut u8, size: usize); + fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); + fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + fn ___cublas_sgemm(i: u64, j: u64, k: u64, c: *mut u8, a: *const u8, b: *const u8); } #[derive(Clone, Debug)] diff --git a/hercules_rt/src/rtdefs.cu b/hercules_rt/src/rtdefs.cu index 50e11fa6748d8569b511b226ca460f5ba1855e8d..26e698218b92b88cbae013daeb3bea3db8f4f5ce 100644 --- a/hercules_rt/src/rtdefs.cu +++ b/hercules_rt/src/rtdefs.cu @@ -1,31 +1,49 @@ -extern "C" { - void *___cuda_alloc(size_t size) { - void *ptr = NULL; - cudaError_t res = cudaMalloc(&ptr, size); - if (res != cudaSuccess) { - ptr = NULL; - } - return ptr; - } +#include <stdint.h> +#include <cublas_v2.h> - void ___cuda_dealloc(void *ptr, size_t size) { - (void) size; - cudaFree(ptr); - } - - void ___cuda_zero_mem(void *ptr, size_t size) { - cudaMemset(ptr, 0, size); - } +static cublasHandle_t cublas_handle = 0; - void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) { - cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); - } - - void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) { - cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); +extern "C" { + void *___cuda_alloc(size_t size) { + void *ptr = NULL; + cudaError_t res = cudaMalloc(&ptr, size); + if (res != cudaSuccess) { + ptr = NULL; } - - void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) { - cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); + return ptr; + } + + void ___cuda_dealloc(void *ptr, size_t size) { + (void) size; + cudaFree(ptr); + } + + void ___cuda_zero_mem(void *ptr, size_t size) { + cudaMemset(ptr, 0, size); + } + + void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); + } + + void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); + } + + void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); + } + + void ___cublas_sgemm(uint64_t i, uint64_t j, uint64_t k, float *c, float *a, float *b) { + if (!cublas_handle) { + cublasCreate(&cublas_handle); } + float alf = 1.0; + float beta = 0.0; + cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + k, i, j, + &alf, b, k, a, j, + &beta, c, k); + cudaDeviceSynchronize(); + } } diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index edb83d74a54bd72c0207923b82485eff8905cbb6..76808149e7b99dce97e43f0e936536d3d13b7417 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -12,9 +12,12 @@ fixpoint { fork-coalesce(*); infer-schedules(*); dce(*); +rewrite(*); +fixpoint { + simplify-cfg(*); + dce(*); +} -let out = auto-outline(*); -gpu(out.matmul); ip-sroa(*); sroa(*); dce(*); diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 3cb7d7f0f35e847b7c88bf4ee8dc2b6536d30647..c0e228daa04704b90156a592d27b761aeba6591c 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -1,4 +1,5 @@ #![feature(concat_idents)] +use std::iter::zip; use rand::random; @@ -13,9 +14,9 @@ fn main() { const I: usize = 256; const J: usize = 64; const K: usize = 128; - let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); - let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); - let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); + let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect(); + let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect(); + let mut correct_c: Box<[f32]> = (0..I * K).map(|_| 0.0).collect(); for i in 0..I { for k in 0..K { for j in 0..J { @@ -27,7 +28,8 @@ fn main() { { let mut r = runner!(matmul); let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await; - assert_eq!(c.as_slice::<i32>(), &*correct_c); + let c = c.as_slice::<f32>(); + assert_eq!(c, &*correct_c); } #[cfg(feature = "cuda")] { @@ -37,9 +39,9 @@ fn main() { let c = r .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) .await; - let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); + let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice(); c.to_cpu_ref(&mut c_cpu); - assert_eq!(&*c_cpu, &*correct_c); + assert!(zip(c_cpu, correct_c).all(|(calc, correct)| (calc - correct).abs() < 0.00001)); } }); } diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn index e36d94e209a2385fc28e54b1cfe4c432564d3706..460ce41c3742486100062b52942870fb35379081 100644 --- a/juno_samples/matmul/src/matmul.jn +++ b/juno_samples/matmul/src/matmul.jn @@ -1,6 +1,6 @@ #[entry] -fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { - let res : i32[n, l]; +fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] { + let res : f32[n, l]; @outer for i = 0 to n { @middle for j = 0 to l { diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 43871c908fbb809e9795d727125454fa4a3f80ff..e9132fd20650d1c06a9b13f8a8f0815f16f83337 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -141,6 +141,9 @@ impl FromStr for Appliable { "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)), "rename" => Ok(Appliable::Pass(ir::Pass::Rename)), "reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)), + "rewrite" | "rewrite-math" | "rewrite-math-expressions" => { + Ok(Appliable::Pass(ir::Pass::RewriteMathExpressions)) + } "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)), "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 378e1c49b67e31d1c83a876489b4880ed62b1121..4ac5a732d12f87b6eb76c2771c2c5addfa42cb50 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -34,6 +34,7 @@ pub enum Pass { ReduceSLF, Rename, ReuseProducts, + RewriteMathExpressions, SLF, SROA, Serialize, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index aa5a8bf13f04b9d8dfb98ce7f4bbc39e1b86a920..d5c0af27e280d523683cec2b39da387a3f2c3f45 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2279,6 +2279,40 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::RewriteMathExpressions => { + assert!(args.is_empty()); + pm.make_typing(); + pm.make_fork_join_maps(); + pm.make_nodes_in_fork_joins(); + pm.make_reduce_einsums(); + let typing = pm.typing.take().unwrap(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + let reduce_einsums = pm.reduce_einsums.take().unwrap(); + for ((((func, typing), fork_join_map), nodes_in_fork_joins), reduce_einsums) in + build_selection(pm, selection, false) + .into_iter() + .zip(typing.iter()) + .zip(fork_join_maps.iter()) + .zip(nodes_in_fork_joins.iter()) + .zip(reduce_einsums.iter()) + { + let Some(mut func) = func else { + continue; + }; + rewrite_math_expressions( + &mut func, + Device::CUDA, + typing, + fork_join_map, + nodes_in_fork_joins, + reduce_einsums, + ); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::SLF => { assert!(args.is_empty()); pm.make_reverse_postorders();