From a8452c2f86002512eb9f956881eaf3483bfaf4ad Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 20:56:25 -0600 Subject: [PATCH] Use enum for library call kinds --- hercules_ir/src/dot.rs | 2 +- hercules_ir/src/ir.rs | 10 +++++++++- hercules_opt/src/rewrite_math_expressions.rs | 19 ++++++++----------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 1773f7dd..921a813d 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -323,7 +323,7 @@ fn write_node<W: Write>( args: _, ty: _, device, - } => write!(&mut suffix, "{} on {:?}", library_function, device)?, + } => write!(&mut suffix, "{:?} on {:?}", library_function, device)?, Node::Read { collect: _, indices, diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 5abc9adf..bf9698b3 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -223,7 +223,7 @@ pub enum Node { args: Box<[NodeID]>, }, LibraryCall { - library_function: String, + library_function: LibraryFunction, args: Box<[NodeID]>, ty: TypeID, device: Device, @@ -351,6 +351,14 @@ pub enum Device { AsyncRust, } +/* + * The authoritative enumeration of supported library calls. + */ +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub enum LibraryFunction { + GEMM, +} + /* * A single node may have multiple schedules. */ diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs index 4d76c608..6f52dc58 100644 --- a/hercules_opt/src/rewrite_math_expressions.rs +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -22,7 +22,7 @@ define_language! { "+" = Add([Id; 2]), "*" = Mul([Id; 2]), - "library_matmul" = LibraryMatmul([Id; 2]), + "library_gemm" = LibraryGemm([Id; 2]), Opaque(Symbol), } @@ -33,7 +33,7 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> { rewrite!("add-zero"; "(+ zero ?a)" => "?a"), rewrite!("mul-zero"; "(* zero ?a)" => "zero"), rewrite!("mul-one"; "(* one ?a)" => "?a"), - rewrite!("library-matmul"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_matmul ?A ?B)"), + rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"), ] } @@ -65,7 +65,7 @@ pub fn rewrite_math_expressions( reduces.sort(); for (_, reduce) in reduces { - let join = editor.func().nodes[reduce.idx()].try_reduce().unwrap().0; + let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap(); let fork = join_fork_map[&join]; // Step 2: convert the reduce to an expression in egg and rewrite it. @@ -77,21 +77,18 @@ pub fn rewrite_math_expressions( // Step 3: match the smallest expression against patterns for known // library functions. - let matmul_pattern: Pattern<MathLanguage> = "(library_matmul ?A ?B)".parse().unwrap(); - let mut matches = matmul_pattern.search(&egraph); + let gemm_pattern: Pattern<MathLanguage> = "(library_gemm ?A ?B)".parse().unwrap(); + let mut matches = gemm_pattern.search(&egraph); if matches.len() > 0 && let m = matches.remove(0) && m.substs.len() > 0 { let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph); let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph); - let library_function = "gemm".to_string(); - let args = vec![left_id, right_id].into_boxed_slice(); - let ty = typing[reduce.idx()]; let call = Node::LibraryCall { - library_function, - args, - ty, + library_function: LibraryFunction::GEMM, + args: vec![init, left_id, right_id].into_boxed_slice(), + ty: typing[reduce.idx()], device, }; let success = editor.edit(|mut edit| { -- GitLab