From 204084e2830561b675c77eca1c4c714590c5121e Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 15:00:57 -0600 Subject: [PATCH 01/13] Rewrite skeleton --- hercules_opt/src/lib.rs | 2 ++ hercules_opt/src/rewrite_math_expressions.rs | 17 ++++++++++++++++ juno_samples/matmul/src/gpu.sch | 2 ++ juno_scheduler/src/compile.rs | 3 +++ juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 21 ++++++++++++++++++++ 6 files changed, 46 insertions(+) create mode 100644 hercules_opt/src/rewrite_math_expressions.rs diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index b56f9408..b25449e7 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -21,6 +21,7 @@ pub mod outline; pub mod phi_elim; pub mod pred; pub mod reuse_products; +pub mod rewrite_math_expressions; pub mod schedule; pub mod simplify_cfg; pub mod slf; @@ -49,6 +50,7 @@ pub use crate::outline::*; pub use crate::phi_elim::*; pub use crate::pred::*; pub use crate::reuse_products::*; +pub use crate::rewrite_math_expressions::*; pub use crate::schedule::*; pub use crate::simplify_cfg::*; pub use crate::slf::*; diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs new file mode 100644 index 00000000..32161a79 --- /dev/null +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -0,0 +1,17 @@ +use std::collections::{HashMap, HashSet}; + +use hercules_ir::*; + +use crate::*; + +pub fn rewrite_math_expressions( + editor: &mut FunctionEditor, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, + reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>), +) { + for (reduce, einsum) in reduce_einsums.1.iter() { + print!("{:?}: ", reduce); + debug_print_math_expr(*einsum, &reduce_einsums.0); + println!(""); + } +} diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index edb83d74..c785fd5e 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -12,6 +12,8 @@ fixpoint { fork-coalesce(*); infer-schedules(*); dce(*); +rewrite(*); +xdot[true](*); let out = auto-outline(*); gpu(out.matmul); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 43871c90..e9132fd2 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -141,6 +141,9 @@ impl FromStr for Appliable { "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)), "rename" => Ok(Appliable::Pass(ir::Pass::Rename)), "reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)), + "rewrite" | "rewrite-math" | "rewrite-math-expressions" => { + Ok(Appliable::Pass(ir::Pass::RewriteMathExpressions)) + } "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)), "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 5b6bd297..25cc5ef8 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -34,6 +34,7 @@ pub enum Pass { ReduceSLF, Rename, ReuseProducts, + RewriteMathExpressions, SLF, SROA, Serialize, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d83ff0bb..461bc645 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2279,6 +2279,27 @@ fn run_pass( pm.delete_gravestones(); pm.clear_analyses(); } + Pass::RewriteMathExpressions => { + assert!(args.is_empty()); + pm.make_nodes_in_fork_joins(); + pm.make_reduce_einsums(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + let reduce_einsums = pm.reduce_einsums.take().unwrap(); + for ((func, nodes_in_fork_joins), reduce_einsums) in + build_selection(pm, selection, false) + .into_iter() + .zip(nodes_in_fork_joins.iter()) + .zip(reduce_einsums.iter()) + { + let Some(mut func) = func else { + continue; + }; + rewrite_math_expressions(&mut func, nodes_in_fork_joins, reduce_einsums); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::SLF => { assert!(args.is_empty()); pm.make_reverse_postorders(); -- GitLab From 6e2acddfc47ef27fb048e84a42d34bce046e831d Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 15:43:22 -0600 Subject: [PATCH 02/13] Construct egg expr --- Cargo.lock | 125 ++++++++++++++++++- hercules_opt/Cargo.toml | 3 +- hercules_opt/src/rewrite_math_expressions.rs | 92 +++++++++++++- 3 files changed, 216 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 81c37d79..b6ca23d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anstream" version = "0.6.18" @@ -621,6 +627,27 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "egg" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abb749745461743bb477fba3ef87c663d5965876155c676c9489cfe0963de5ab" +dependencies = [ + "env_logger", + "hashbrown", + "indexmap", + "log", + "num-bigint", + "num-traits", + "quanta", + "rustc-hash", + "saturating", + "smallvec", + "symbol_table", + "symbolic_expressions", + "thiserror", +] + [[package]] name = "either" version = "1.13.0" @@ -639,6 +666,15 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" +[[package]] +name = "env_logger" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +dependencies = [ + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -752,6 +788,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + [[package]] name = "funty" version = "1.1.0" @@ -882,6 +924,11 @@ name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "heapless" @@ -955,6 +1002,7 @@ version = "0.1.0" dependencies = [ "bimap", "bitvec 1.0.1", + "egg", "either", "hercules_cg", "hercules_ir", @@ -1177,7 +1225,7 @@ dependencies = [ "async-std", "hercules_rt", "juno_build", - "rand 0.8.5", + "rand 0.9.0", "with_builtin_macros", ] @@ -1921,6 +1969,21 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "quanta" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi 0.11.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quick-error" version = "2.0.1" @@ -2061,6 +2124,15 @@ dependencies = [ "rgb", ] +[[package]] +name = "raw-cpuid" +version = "11.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "529468c1335c1c03919960dfefdb1b3648858c20d7ec2d0663e728e4a717efbc" +dependencies = [ + "bitflags 2.8.0", +] + [[package]] name = "rayon" version = "1.10.0" @@ -2119,6 +2191,12 @@ version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustc_version" version = "0.4.1" @@ -2153,6 +2231,12 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +[[package]] +name = "saturating" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71" + [[package]] name = "scopeguard" version = "1.2.0" @@ -2284,6 +2368,23 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "symbol_table" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f19bffd69fb182e684d14e3c71d04c0ef33d1641ac0b9e81c712c734e83703bc" +dependencies = [ + "crossbeam-utils", + "foldhash", + "hashbrown", +] + +[[package]] +name = "symbolic_expressions" +version = "5.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c68d531d83ec6c531150584c42a4290911964d5f0d79132b193b67252a23b71" + [[package]] name = "syn" version = "1.0.109" @@ -2648,6 +2749,28 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows" version = "0.59.0" diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml index 0b250d28..92e05339 100644 --- a/hercules_opt/Cargo.toml +++ b/hercules_opt/Cargo.toml @@ -21,4 +21,5 @@ serde = { version = "*", features = ["derive"] } hercules_cg = { path = "../hercules_cg" } hercules_ir = { path = "../hercules_ir" } nestify = "*" -bimap = "*" \ No newline at end of file +bimap = "*" +egg = "*" diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs index 32161a79..53526a95 100644 --- a/hercules_opt/src/rewrite_math_expressions.rs +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -1,9 +1,39 @@ use std::collections::{HashMap, HashSet}; +use std::fmt::{Error, Write}; use hercules_ir::*; +use egg::*; + use crate::*; +define_language! { + enum MathLanguage { + "zero" = Zero, + "one" = One, + + ForkDim(i64), + "tid" = ThreadID(Id), + "sum" = SumReduction(Box<[Id]>), + "array" = Comprehension(Box<[Id]>), + + "read" = Read(Box<[Id]>), + + "+" = Add([Id; 2]), + "*" = Mul([Id; 2]), + + Opaque(Symbol), + } +} + +fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> { + vec![ + rewrite!("add-zero"; "(+ zero ?a)" => "?a"), + rewrite!("mul-zero"; "(* zero ?a)" => "zero"), + rewrite!("mul-one"; "(* one ?a)" => "?a"), + ] +} + pub fn rewrite_math_expressions( editor: &mut FunctionEditor, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, @@ -11,7 +41,65 @@ pub fn rewrite_math_expressions( ) { for (reduce, einsum) in reduce_einsums.1.iter() { print!("{:?}: ", reduce); - debug_print_math_expr(*einsum, &reduce_einsums.0); - println!(""); + let mut s = String::new(); + egg_print_math_expr(*einsum, &reduce_einsums.0, &mut s).unwrap(); + println!("{}", s); + + let expr: RecExpr<MathLanguage> = s.parse().unwrap(); + let runner = Runner::default().with_expr(&expr).run(&make_rules()); + let root = runner.roots[0]; + let extractor = Extractor::new(&runner.egraph, AstSize); + let (best_cost, best) = extractor.find_best(root); + println!("Simplified {} to {} with cost {}", expr, best, best_cost); + } +} + +pub fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> { + match env[id.idx()] { + MathExpr::Zero(_) => write!(w, "zero"), + MathExpr::One(_) => write!(w, "one"), + MathExpr::OpaqueNode(id) => write!(w, "n{}", id.idx()), + MathExpr::ThreadID(dim) => write!(w, "{}", dim.0), + MathExpr::SumReduction(id, ref dims) => { + write!(w, "(sum")?; + for dim in dims { + write!(w, " {}", dim.0)?; + } + write!(w, " ")?; + egg_print_math_expr(id, env, w)?; + write!(w, ")") + } + MathExpr::Comprehension(id, ref dims) => { + write!(w, "(array")?; + for dim in dims { + write!(w, " {}", dim.0)?; + } + write!(w, " ")?; + egg_print_math_expr(id, env, w)?; + write!(w, ")") + } + MathExpr::Read(id, ref pos) => { + write!(w, "(read")?; + for pos in pos { + write!(w, " ")?; + egg_print_math_expr(*pos, env, w)?; + } + write!(w, " ")?; + egg_print_math_expr(id, env, w)?; + write!(w, ")") + } + MathExpr::Binary(op, left, right) => { + write!(w, "(")?; + match op { + BinaryOperator::Add => write!(w, "+ "), + BinaryOperator::Mul => write!(w, "* "), + _ => Err(Error::default()), + }?; + egg_print_math_expr(left, env, w)?; + write!(w, " ")?; + egg_print_math_expr(right, env, w)?; + write!(w, ")") + } + _ => Err(Error::default()), } } -- GitLab From 0c2289d49779500adeb9c33d0ec135cdd0bfc504 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 15:55:54 -0600 Subject: [PATCH 03/13] Identify matmul --- hercules_opt/src/rewrite_math_expressions.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs index 53526a95..a0493e54 100644 --- a/hercules_opt/src/rewrite_math_expressions.rs +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -22,6 +22,8 @@ define_language! { "+" = Add([Id; 2]), "*" = Mul([Id; 2]), + "library_matmul" = LibraryMatmul([Id; 2]), + Opaque(Symbol), } } @@ -31,6 +33,7 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> { rewrite!("add-zero"; "(+ zero ?a)" => "?a"), rewrite!("mul-zero"; "(* zero ?a)" => "zero"), rewrite!("mul-one"; "(* one ?a)" => "?a"), + rewrite!("library-matmul"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_matmul ?A ?B)"), ] } @@ -40,10 +43,9 @@ pub fn rewrite_math_expressions( reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>), ) { for (reduce, einsum) in reduce_einsums.1.iter() { - print!("{:?}: ", reduce); + println!("{:?}", reduce); let mut s = String::new(); egg_print_math_expr(*einsum, &reduce_einsums.0, &mut s).unwrap(); - println!("{}", s); let expr: RecExpr<MathLanguage> = s.parse().unwrap(); let runner = Runner::default().with_expr(&expr).run(&make_rules()); -- GitLab From f74b4e644602b3e0dbd48f3a7c2562df355c6aeb Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 16:13:35 -0600 Subject: [PATCH 04/13] Get outer reduces first --- hercules_opt/src/rewrite_math_expressions.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs index a0493e54..2fd4bbf7 100644 --- a/hercules_opt/src/rewrite_math_expressions.rs +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -42,10 +42,26 @@ pub fn rewrite_math_expressions( nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>), ) { - for (reduce, einsum) in reduce_einsums.1.iter() { + // Step 1: figure out how many fork-joins each reduce is in. We want to + // rewrite outer reductions before inner ones. + let nodes = &editor.func().nodes; + let mut depth: HashMap<NodeID, u32> = HashMap::new(); + for (_, inside) in nodes_in_fork_joins { + for id in inside { + if nodes[id.idx()].is_reduce() { + *depth.entry(*id).or_default() += 1; + } + } + } + let mut reduces: Vec<(u32, NodeID)> = + depth.into_iter().map(|(id, depth)| (depth, id)).collect(); + reduces.sort(); + + for (_, reduce) in reduces { + // Step 2: convert the reduce to an expression in egg and rewrite it. println!("{:?}", reduce); let mut s = String::new(); - egg_print_math_expr(*einsum, &reduce_einsums.0, &mut s).unwrap(); + egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap(); let expr: RecExpr<MathLanguage> = s.parse().unwrap(); let runner = Runner::default().with_expr(&expr).run(&make_rules()); -- GitLab From 88ed34d67fd97589c926be9db48fbb6563cf8114 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 17:31:37 -0600 Subject: [PATCH 05/13] Get the node ID arguments for matmul --- hercules_opt/src/rewrite_math_expressions.rs | 35 ++++++++++++++++---- juno_scheduler/src/pm.rs | 7 +++- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs index 2fd4bbf7..e1680e43 100644 --- a/hercules_opt/src/rewrite_math_expressions.rs +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -39,6 +39,7 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> { pub fn rewrite_math_expressions( editor: &mut FunctionEditor, + device: Device, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>), ) { @@ -59,20 +60,40 @@ pub fn rewrite_math_expressions( for (_, reduce) in reduces { // Step 2: convert the reduce to an expression in egg and rewrite it. - println!("{:?}", reduce); let mut s = String::new(); egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap(); let expr: RecExpr<MathLanguage> = s.parse().unwrap(); - let runner = Runner::default().with_expr(&expr).run(&make_rules()); - let root = runner.roots[0]; - let extractor = Extractor::new(&runner.egraph, AstSize); - let (best_cost, best) = extractor.find_best(root); - println!("Simplified {} to {} with cost {}", expr, best, best_cost); + let egraph = Runner::default().with_expr(&expr).run(&make_rules()).egraph; + + // Step 3: match the smallest expression against patterns for known + // library function. + let matmul_pattern: Pattern<MathLanguage> = "(library_matmul ?A ?B)".parse().unwrap(); + let mut matches = matmul_pattern.search(&egraph); + if matches.len() > 0 + && let m = matches.remove(0) + && m.substs.len() > 0 + { + let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph); + let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph); + println!("{:?} {:?}", left_id, right_id); + return; + } } } -pub fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> { +fn get_node_ids_from_subst(var: &str, subst: &Subst, egraph: &EGraph<MathLanguage, ()>) -> NodeID { + let id = *subst.get(var.parse::<Var>().unwrap()).unwrap(); + let expr = egraph.id_to_expr(id); + let MathLanguage::Opaque(sym) = expr.last().unwrap() else { + todo!(); + }; + let sym = sym.as_str(); + assert!(sym.chars().nth(0).unwrap() == 'n'); + NodeID::new(sym[1..].parse().unwrap()) +} + +fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> { match env[id.idx()] { MathExpr::Zero(_) => write!(w, "zero"), MathExpr::One(_) => write!(w, "one"), diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 461bc645..fbecfc1f 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2294,7 +2294,12 @@ fn run_pass( let Some(mut func) = func else { continue; }; - rewrite_math_expressions(&mut func, nodes_in_fork_joins, reduce_einsums); + rewrite_math_expressions( + &mut func, + Device::LLVM, + nodes_in_fork_joins, + reduce_einsums, + ); changed |= func.modified(); } pm.delete_gravestones(); -- GitLab From b4fc0252499d01ce9ce571a8e1021afb374347f4 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 17:45:52 -0600 Subject: [PATCH 06/13] LibraryCall node --- hercules_ir/src/def_use.rs | 16 ++++++++++++---- hercules_ir/src/ir.rs | 15 +++++++++++++++ hercules_ir/src/typecheck.rs | 5 +++++ hercules_opt/src/ccp.rs | 10 ++++++++++ 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index d60f7fd5..77ef83fb 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -178,7 +178,12 @@ pub fn get_uses(node: &Node) -> NodeUses { uses.extend(args); NodeUses::Variable(uses.into_boxed_slice()) } - Node::IntrinsicCall { intrinsic: _, args } => NodeUses::Variable(args.clone()), + Node::IntrinsicCall { intrinsic: _, args } + | Node::LibraryCall { + library_function: _, + args, + ty: _, + } => NodeUses::Variable(args.clone()), Node::Read { collect, indices } => { let mut uses = vec![]; for index in indices.iter() { @@ -276,9 +281,12 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { uses.extend(args); NodeUsesMut::Variable(uses.into_boxed_slice()) } - Node::IntrinsicCall { intrinsic: _, args } => { - NodeUsesMut::Variable(args.iter_mut().collect()) - } + Node::IntrinsicCall { intrinsic: _, args } + | Node::LibraryCall { + library_function: _, + args, + ty: _, + } => NodeUsesMut::Variable(args.iter_mut().collect()), Node::Read { collect, indices } => { let mut uses = vec![]; for index in indices.iter_mut() { diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index f96cc10c..5ba48629 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -222,6 +222,11 @@ pub enum Node { intrinsic: Intrinsic, args: Box<[NodeID]>, }, + LibraryCall { + library_function: String, + args: Box<[NodeID]>, + ty: TypeID, + }, Read { collect: NodeID, indices: Box<[Index]>, @@ -1531,6 +1536,11 @@ impl Node { intrinsic: _, args: _, } => "Intrinsic", + Node::LibraryCall { + library_function: _, + args: _, + ty: _, + } => "Library", Node::Read { collect: _, indices: _, @@ -1604,6 +1614,11 @@ impl Node { intrinsic: _, args: _, } => "intrinsic", + Node::LibraryCall { + library_function: _, + args: _, + ty: _, + } => "library", Node::Read { collect: _, indices: _, diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index d01a5c58..fca258a7 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -935,6 +935,11 @@ fn typeflow( } } } + Node::LibraryCall { + library_function: _, + args: _, + ty, + } => Concrete(*ty), Node::Read { collect: _, indices, diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 1969430a..a7df71f5 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -933,6 +933,16 @@ fn ccp_flow_function( constant: new_constant, } } + Node::LibraryCall { + library_function: _, + args, + ty: _, + } => CCPLattice { + reachability: args.iter().fold(ReachabilityLattice::bottom(), |val, id| { + ReachabilityLattice::join(&val, &inputs[id.idx()].reachability) + }), + constant: ConstantLattice::bottom(), + }, Node::Read { collect, indices } => { let mut reachability = inputs[collect.idx()].reachability.clone(); for index in indices.iter() { -- GitLab From 2e3bd4d76078dff45c606ce80ad08ed912fac319 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 18:02:02 -0600 Subject: [PATCH 07/13] Replace reduce with library call node --- hercules_opt/src/rewrite_math_expressions.rs | 39 +++++++++++++++++--- hercules_opt/src/simplify_cfg.rs | 4 +- juno_samples/matmul/src/gpu.sch | 5 ++- juno_scheduler/src/pm.rs | 10 ++++- 4 files changed, 48 insertions(+), 10 deletions(-) diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs index e1680e43..1b968243 100644 --- a/hercules_opt/src/rewrite_math_expressions.rs +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -40,16 +40,22 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> { pub fn rewrite_math_expressions( editor: &mut FunctionEditor, device: Device, + typing: &Vec<TypeID>, + fork_join_map: &HashMap<NodeID, NodeID>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>), ) { + let join_fork_map: HashMap<_, _> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); + // Step 1: figure out how many fork-joins each reduce is in. We want to // rewrite outer reductions before inner ones. - let nodes = &editor.func().nodes; let mut depth: HashMap<NodeID, u32> = HashMap::new(); for (_, inside) in nodes_in_fork_joins { for id in inside { - if nodes[id.idx()].is_reduce() { + if editor.func().nodes[id.idx()].is_reduce() { *depth.entry(*id).or_default() += 1; } } @@ -59,6 +65,9 @@ pub fn rewrite_math_expressions( reduces.sort(); for (_, reduce) in reduces { + let join = editor.func().nodes[reduce.idx()].try_reduce().unwrap().0; + let fork = join_fork_map[&join]; + // Step 2: convert the reduce to an expression in egg and rewrite it. let mut s = String::new(); egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap(); @@ -67,7 +76,7 @@ pub fn rewrite_math_expressions( let egraph = Runner::default().with_expr(&expr).run(&make_rules()).egraph; // Step 3: match the smallest expression against patterns for known - // library function. + // library functions. let matmul_pattern: Pattern<MathLanguage> = "(library_matmul ?A ?B)".parse().unwrap(); let mut matches = matmul_pattern.search(&egraph); if matches.len() > 0 @@ -76,8 +85,28 @@ pub fn rewrite_math_expressions( { let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph); let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph); - println!("{:?} {:?}", left_id, right_id); - return; + let library_function = match device { + Device::LLVM => "blas_gemm", + Device::CUDA => todo!(), + _ => panic!(), + } + .to_string(); + let args = vec![left_id, right_id].into_boxed_slice(); + let ty = typing[reduce.idx()]; + let call = Node::LibraryCall { + library_function, + args, + ty, + }; + let success = editor.edit(|mut edit| { + let call = edit.add_node(call); + edit.replace_all_uses_where(reduce, call, |id| { + !nodes_in_fork_joins[&fork].contains(id) + }) + }); + if success { + return; + } } } } diff --git a/hercules_opt/src/simplify_cfg.rs b/hercules_opt/src/simplify_cfg.rs index 14a152dc..cf39db2b 100644 --- a/hercules_opt/src/simplify_cfg.rs +++ b/hercules_opt/src/simplify_cfg.rs @@ -126,9 +126,7 @@ fn remove_useless_fork_joins( // Third, get rid of fork-joins. for (fork, join) in fork_join_map { - if editor.get_users(*join).len() == 1 { - assert_eq!(editor.get_users(*fork).len(), 1); - + if editor.get_users(*fork).len() == 1 && editor.get_users(*join).len() == 1 { let fork_use = get_uses(&editor.func().nodes[fork.idx()]).as_ref()[0]; let join_use = get_uses(&editor.func().nodes[join.idx()]).as_ref()[0]; diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index c785fd5e..39139cdb 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -13,7 +13,10 @@ fork-coalesce(*); infer-schedules(*); dce(*); rewrite(*); -xdot[true](*); +fixpoint { + simplify-cfg(*); + dce(*); +} let out = auto-outline(*); gpu(out.matmul); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index fbecfc1f..2a270070 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2281,13 +2281,19 @@ fn run_pass( } Pass::RewriteMathExpressions => { assert!(args.is_empty()); + pm.make_typing(); + pm.make_fork_join_maps(); pm.make_nodes_in_fork_joins(); pm.make_reduce_einsums(); + let typing = pm.typing.take().unwrap(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); let reduce_einsums = pm.reduce_einsums.take().unwrap(); - for ((func, nodes_in_fork_joins), reduce_einsums) in + for ((((func, typing), fork_join_map), nodes_in_fork_joins), reduce_einsums) in build_selection(pm, selection, false) .into_iter() + .zip(typing.iter()) + .zip(fork_join_maps.iter()) .zip(nodes_in_fork_joins.iter()) .zip(reduce_einsums.iter()) { @@ -2297,6 +2303,8 @@ fn run_pass( rewrite_math_expressions( &mut func, Device::LLVM, + typing, + fork_join_map, nodes_in_fork_joins, reduce_einsums, ); -- GitLab From 1b67a3db5b1a6d7b175156edc448d1ddcc38ce46 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 20:43:43 -0600 Subject: [PATCH 08/13] Add device member to LibraryCall --- hercules_ir/src/def_use.rs | 2 ++ hercules_ir/src/dot.rs | 6 ++++++ hercules_ir/src/ir.rs | 5 ++++- hercules_ir/src/typecheck.rs | 1 + hercules_opt/src/ccp.rs | 1 + hercules_opt/src/rewrite_math_expressions.rs | 8 ++------ juno_samples/matmul/src/gpu.sch | 3 +-- juno_scheduler/src/pm.rs | 2 +- 8 files changed, 18 insertions(+), 10 deletions(-) diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index 77ef83fb..ff0e08ed 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -183,6 +183,7 @@ pub fn get_uses(node: &Node) -> NodeUses { library_function: _, args, ty: _, + device: _, } => NodeUses::Variable(args.clone()), Node::Read { collect, indices } => { let mut uses = vec![]; @@ -286,6 +287,7 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { library_function: _, args, ty: _, + device: _, } => NodeUsesMut::Variable(args.iter_mut().collect()), Node::Read { collect, indices } => { let mut uses = vec![]; diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 0e084085..1773f7dd 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -318,6 +318,12 @@ fn write_node<W: Write>( Node::IntrinsicCall { intrinsic, args: _ } => { write!(&mut suffix, "{}", intrinsic.lower_case_name())? } + Node::LibraryCall { + library_function, + args: _, + ty: _, + device, + } => write!(&mut suffix, "{} on {:?}", library_function, device)?, Node::Read { collect: _, indices, diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 5ba48629..5abc9adf 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -226,6 +226,7 @@ pub enum Node { library_function: String, args: Box<[NodeID]>, ty: TypeID, + device: Device, }, Read { collect: NodeID, @@ -341,7 +342,7 @@ pub enum Schedule { * The authoritative enumeration of supported backends. Multiple backends may * correspond to the same kind of hardware. */ -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub enum Device { LLVM, CUDA, @@ -1540,6 +1541,7 @@ impl Node { library_function: _, args: _, ty: _, + device: _, } => "Library", Node::Read { collect: _, @@ -1618,6 +1620,7 @@ impl Node { library_function: _, args: _, ty: _, + device: _, } => "library", Node::Read { collect: _, diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index fca258a7..1ff890db 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -939,6 +939,7 @@ fn typeflow( library_function: _, args: _, ty, + device: _, } => Concrete(*ty), Node::Read { collect: _, diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index a7df71f5..b626148c 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -937,6 +937,7 @@ fn ccp_flow_function( library_function: _, args, ty: _, + device: _, } => CCPLattice { reachability: args.iter().fold(ReachabilityLattice::bottom(), |val, id| { ReachabilityLattice::join(&val, &inputs[id.idx()].reachability) diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs index 1b968243..4d76c608 100644 --- a/hercules_opt/src/rewrite_math_expressions.rs +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -85,18 +85,14 @@ pub fn rewrite_math_expressions( { let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph); let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph); - let library_function = match device { - Device::LLVM => "blas_gemm", - Device::CUDA => todo!(), - _ => panic!(), - } - .to_string(); + let library_function = "gemm".to_string(); let args = vec![left_id, right_id].into_boxed_slice(); let ty = typing[reduce.idx()]; let call = Node::LibraryCall { library_function, args, ty, + device, }; let success = editor.edit(|mut edit| { let call = edit.add_node(call); diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index 39139cdb..35ed1e84 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -18,11 +18,10 @@ fixpoint { dce(*); } -let out = auto-outline(*); -gpu(out.matmul); ip-sroa(*); sroa(*); dce(*); +xdot[true](*); float-collections(*); gcm(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 2a270070..e8899fae 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2302,7 +2302,7 @@ fn run_pass( }; rewrite_math_expressions( &mut func, - Device::LLVM, + Device::CUDA, typing, fork_join_map, nodes_in_fork_joins, -- GitLab From a8452c2f86002512eb9f956881eaf3483bfaf4ad Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 20:56:25 -0600 Subject: [PATCH 09/13] Use enum for library call kinds --- hercules_ir/src/dot.rs | 2 +- hercules_ir/src/ir.rs | 10 +++++++++- hercules_opt/src/rewrite_math_expressions.rs | 19 ++++++++----------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 1773f7dd..921a813d 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -323,7 +323,7 @@ fn write_node<W: Write>( args: _, ty: _, device, - } => write!(&mut suffix, "{} on {:?}", library_function, device)?, + } => write!(&mut suffix, "{:?} on {:?}", library_function, device)?, Node::Read { collect: _, indices, diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 5abc9adf..bf9698b3 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -223,7 +223,7 @@ pub enum Node { args: Box<[NodeID]>, }, LibraryCall { - library_function: String, + library_function: LibraryFunction, args: Box<[NodeID]>, ty: TypeID, device: Device, @@ -351,6 +351,14 @@ pub enum Device { AsyncRust, } +/* + * The authoritative enumeration of supported library calls. + */ +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub enum LibraryFunction { + GEMM, +} + /* * A single node may have multiple schedules. */ diff --git a/hercules_opt/src/rewrite_math_expressions.rs b/hercules_opt/src/rewrite_math_expressions.rs index 4d76c608..6f52dc58 100644 --- a/hercules_opt/src/rewrite_math_expressions.rs +++ b/hercules_opt/src/rewrite_math_expressions.rs @@ -22,7 +22,7 @@ define_language! { "+" = Add([Id; 2]), "*" = Mul([Id; 2]), - "library_matmul" = LibraryMatmul([Id; 2]), + "library_gemm" = LibraryGemm([Id; 2]), Opaque(Symbol), } @@ -33,7 +33,7 @@ fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> { rewrite!("add-zero"; "(+ zero ?a)" => "?a"), rewrite!("mul-zero"; "(* zero ?a)" => "zero"), rewrite!("mul-one"; "(* one ?a)" => "?a"), - rewrite!("library-matmul"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_matmul ?A ?B)"), + rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"), ] } @@ -65,7 +65,7 @@ pub fn rewrite_math_expressions( reduces.sort(); for (_, reduce) in reduces { - let join = editor.func().nodes[reduce.idx()].try_reduce().unwrap().0; + let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap(); let fork = join_fork_map[&join]; // Step 2: convert the reduce to an expression in egg and rewrite it. @@ -77,21 +77,18 @@ pub fn rewrite_math_expressions( // Step 3: match the smallest expression against patterns for known // library functions. - let matmul_pattern: Pattern<MathLanguage> = "(library_matmul ?A ?B)".parse().unwrap(); - let mut matches = matmul_pattern.search(&egraph); + let gemm_pattern: Pattern<MathLanguage> = "(library_gemm ?A ?B)".parse().unwrap(); + let mut matches = gemm_pattern.search(&egraph); if matches.len() > 0 && let m = matches.remove(0) && m.substs.len() > 0 { let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph); let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph); - let library_function = "gemm".to_string(); - let args = vec![left_id, right_id].into_boxed_slice(); - let ty = typing[reduce.idx()]; let call = Node::LibraryCall { - library_function, - args, - ty, + library_function: LibraryFunction::GEMM, + args: vec![init, left_id, right_id].into_boxed_slice(), + ty: typing[reduce.idx()], device, }; let success = editor.edit(|mut edit| { -- GitLab From 8f806736e67dd0661578464c681ed76aa60a4d99 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 21:14:47 -0600 Subject: [PATCH 10/13] Add all the stuff related to collections for LibraryCall nodes --- hercules_ir/src/collections.rs | 32 ++++++++- hercules_ir/src/device.rs | 115 -------------------------------- hercules_opt/src/gcm.rs | 37 ++++++++++ juno_samples/matmul/src/gpu.sch | 2 +- 4 files changed, 69 insertions(+), 117 deletions(-) diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index f3474ae0..6b631519 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -218,6 +218,8 @@ pub fn collection_objects( // - Constant: may originate an object. // - Call: may originate an object and may return an object passed in as // a parameter. + // - LibraryCall: may return an object passed in as a parameter, but may + // not originate an object. // - Read: may extract a smaller object from the input - this is // considered to be the same object as the input, as no copy takes // place. @@ -288,6 +290,14 @@ pub fn collection_objects( } CollectionObjectLattice { objs } } + Node::LibraryCall { + library_function, + args: _, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => inputs[0].clone(), + }, Node::Undef { ty: _ } => { let obj = origins .iter() @@ -332,7 +342,13 @@ pub fn collection_objects( for object in objects_per_node[idx].iter() { mutated[object.idx()].push(NodeID::new(idx)); } - } else if let Some((_, callee, _, args)) = node.try_call() { + } else if let Node::Call { + control: _, + function: callee, + dynamic_constants: _, + args, + } = node + { let fco = &collection_objects[&callee]; for (param_idx, arg) in args.into_iter().enumerate() { // If this parameter corresponds to an object and it's @@ -347,6 +363,20 @@ pub fn collection_objects( } } } + } else if let Node::LibraryCall { + library_function, + args, + ty: _, + device: _, + } = node + { + match library_function { + LibraryFunction::GEMM => { + for object in objects_per_node[args[0].idx()].iter() { + mutated[object.idx()].push(NodeID::new(idx)); + } + } + } } } diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs index cbf8d634..c4a5454b 100644 --- a/hercules_ir/src/device.rs +++ b/hercules_ir/src/device.rs @@ -23,118 +23,3 @@ pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec devices } - -pub type FunctionObjectDeviceDemands = Vec<BTreeSet<Device>>; -pub type ObjectDeviceDemands = Vec<FunctionObjectDeviceDemands>; - -/* - * This analysis figures out which device each collection object may be on. At - * first, an object may need to be on different devices at different times. This - * is fine during optimization. - */ -pub fn object_device_demands( - functions: &Vec<Function>, - types: &Vec<Type>, - typing: &ModuleTyping, - callgraph: &CallGraph, - objects: &CollectionObjects, - devices: &Vec<Device>, -) -> ObjectDeviceDemands { - // An object is "demanded" on a device when: - // 1. The object is used by a primitive read node or write node in a device - // function. This includes objects on the `data` input to write nodes. - // Non-primitive reads don't demand an object on a device since they are - // lowered to pointer math and no actual memory transfers. - // 2. The object is a constant / undef defined in a device function. - // 3. The object is passed as input to a call node where the corresponding - // object in the callee is demanded on a device. - // 4. The object is returned from a call node where the corresponding object - // in the callee is demanded on a device. - // Note that reads and writes in a RT function don't induce a device demand. - // This is because RT functions can call device functions as necessary to - // arbitrarily move data onto / off of devices (though this may be slow). - // Traverse the functions in a module in reverse topological order, since - // the analysis of a function depends on all functions it calls. - let mut demands: ObjectDeviceDemands = vec![vec![]; functions.len()]; - let topo = callgraph.topo(); - - for func_id in topo { - let function = &functions[func_id.idx()]; - let typing = &typing[func_id.idx()]; - let device = devices[func_id.idx()]; - - demands[func_id.idx()].resize(objects[&func_id].num_objects(), BTreeSet::new()); - match device { - Device::LLVM | Device::CUDA => { - for (idx, node) in function.nodes.iter().enumerate() { - match node { - // Condition #1. - Node::Read { - collect, - indices: _, - } if types[typing[idx].idx()].is_primitive() => { - for object in objects[&func_id].objects(*collect) { - demands[func_id.idx()][object.idx()].insert(device); - } - } - Node::Write { - collect, - data, - indices: _, - } => { - for object in objects[&func_id] - .objects(*collect) - .into_iter() - .chain(objects[&func_id].objects(*data).into_iter()) - { - demands[func_id.idx()][object.idx()].insert(device); - } - } - // Condition #2. - Node::Constant { id: _ } | Node::Undef { ty: _ } => { - for object in objects[&func_id].objects(NodeID::new(idx)) { - demands[func_id.idx()][object.idx()].insert(device); - } - } - _ => {} - } - } - } - Device::AsyncRust => { - for (idx, node) in function.nodes.iter().enumerate() { - if let Node::Call { - control: _, - function: callee, - dynamic_constants: _, - args, - } = node - { - // Condition #3. - for (param_idx, arg) in args.into_iter().enumerate() { - if let Some(callee_obj) = objects[callee].param_to_object(param_idx) { - let callee_demands = - take(&mut demands[callee.idx()][callee_obj.idx()]); - for object in objects[&func_id].objects(*arg) { - demands[func_id.idx()][object.idx()] - .extend(callee_demands.iter()); - } - demands[callee.idx()][callee_obj.idx()] = callee_demands; - } - } - - // Condition #4. - for callee_obj in objects[callee].returned_objects() { - let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]); - for object in objects[&func_id].objects(NodeID::new(idx)) { - demands[func_id.idx()][object.idx()].extend(callee_demands.iter()); - } - demands[callee.idx()][callee_obj.idx()] = callee_demands; - } - } - } - } - } - } - - demands -} diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 821d02ea..446b3184 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -655,6 +655,14 @@ fn terminating_reads<'a>( None } })), + Node::LibraryCall { + library_function, + ref args, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))), + }, _ => Box::new(empty()), } } @@ -728,6 +736,16 @@ fn mutating_objects<'a>( }) .flatten(), ), + Node::LibraryCall { + library_function, + ref args, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => { + Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id)) + } + }, _ => Box::new(empty()), } } @@ -757,6 +775,14 @@ fn mutating_writes<'a>( None } })), + Node::LibraryCall { + library_function, + ref args, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => Box::new(once(args[0])), + }, _ => Box::new(empty()), } } @@ -1311,6 +1337,17 @@ fn color_nodes( } } } + Node::LibraryCall { + library_function: _, + ref args, + ty: _, + device, + } => { + for arg in args { + equations.push((UTerm::Node(*arg), UTerm::Device(device))); + } + equations.push((UTerm::Node(id), UTerm::Device(device))); + } _ => {} } } diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index 35ed1e84..76159ef7 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -21,7 +21,7 @@ fixpoint { ip-sroa(*); sroa(*); dce(*); -xdot[true](*); float-collections(*); gcm(*); +xdot[true](*); -- GitLab From 3409decf7a8a04121678d2f78bcb043c255dfeb7 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 18 Feb 2025 09:50:43 -0600 Subject: [PATCH 11/13] Plumb things --- hercules_cg/src/rt.rs | 72 ++++++++++++++++++++++++++++++++++++++++++ hercules_rt/src/lib.rs | 48 ++++++++++++++++++++++++---- 2 files changed, 113 insertions(+), 7 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 4d9a6cf6..5edddd86 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -597,6 +597,59 @@ impl<'a> RTContext<'a> { } write!(block, "){};", postfix)?; } + Node::LibraryCall { + library_function, + ref args, + ty, + device, + } => match library_function { + LibraryFunction::GEMM => { + assert_eq!(args.len(), 3); + assert_eq!(self.typing[args[0].idx()], ty); + let c_ty = &self.module.types[self.typing[args[0].idx()].idx()]; + let a_ty = &self.module.types[self.typing[args[1].idx()].idx()]; + let b_ty = &self.module.types[self.typing[args[2].idx()].idx()]; + let ( + Type::Array(c_elem, c_dims), + Type::Array(a_elem, a_dims), + Type::Array(b_elem, b_dims), + ) = (c_ty, a_ty, b_ty) + else { + panic!(); + }; + assert_eq!(a_elem, b_elem); + assert_eq!(a_elem, c_elem); + assert_eq!(c_dims.len(), 2); + assert_eq!(a_dims.len(), 2); + assert_eq!(b_dims.len(), 2); + assert_eq!(a_dims[1], b_dims[0]); + assert_eq!(a_dims[0], c_dims[0]); + assert_eq!(b_dims[1], c_dims[1]); + + let block = &mut blocks.get_mut(&bb).unwrap().data; + let prim_ty = self.library_prim_ty(*a_elem); + write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?; + self.codegen_dynamic_constant(a_dims[0], block)?; + write!(block, ", ")?; + self.codegen_dynamic_constant(a_dims[1], block)?; + write!(block, ", ")?; + self.codegen_dynamic_constant(b_dims[1], block)?; + write!( + block, + ", {}.0, {}.0, {}.0, {});", + self.get_value(args[0], bb, false), + self.get_value(args[1], bb, false), + self.get_value(args[2], bb, false), + prim_ty + )?; + write!( + block, + "{} = {};", + self.get_value(id, bb, true), + self.get_value(args[0], bb, false) + )?; + } + }, Node::Unary { op, input } => { let block = &mut blocks.get_mut(&bb).unwrap().data; match op { @@ -1316,6 +1369,25 @@ impl<'a> RTContext<'a> { fn get_type(&self, id: TypeID) -> &'static str { convert_type(&self.module.types[id.idx()]) } + + fn library_prim_ty(&self, id: TypeID) -> &'static str { + match self.module.types[id.idx()] { + Type::Boolean => "::hercules_rt::PrimTy::Bool", + Type::Integer8 => "::hercules_rt::PrimTy::I8", + Type::Integer16 => "::hercules_rt::PrimTy::I16", + Type::Integer32 => "::hercules_rt::PrimTy::I32", + Type::Integer64 => "::hercules_rt::PrimTy::I64", + Type::UnsignedInteger8 => "::hercules_rt::PrimTy::U8", + Type::UnsignedInteger16 => "::hercules_rt::PrimTy::U16", + Type::UnsignedInteger32 => "::hercules_rt::PrimTy::U32", + Type::UnsignedInteger64 => "::hercules_rt::PrimTy::U64", + Type::Float8 => "::hercules_rt::PrimTy::F8", + Type::BFloat16 => "::hercules_rt::PrimTy::BF16", + Type::Float32 => "::hercules_rt::PrimTy::F32", + Type::Float64 => "::hercules_rt::PrimTy::F64", + _ => panic!(), + } + } } fn convert_type(ty: &Type) -> &'static str { diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 841c6f44..e360076e 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -29,7 +29,10 @@ pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) { eprintln!("__cpu_dealloc: {:?}, {}", ptr, size); assert!(!ptr.is_null() || size == 0); } - dealloc(ptr, Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap()) + dealloc( + ptr, + Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap(), + ) } pub unsafe fn __cpu_zero_mem(ptr: *mut u8, size: usize) { @@ -103,14 +106,45 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) { ___copy_cuda_to_cuda(dst, src, size); } +#[repr(u8)] +#[derive(Debug, Copy, Clone)] +pub enum PrimTy { + Bool, + U8, + U16, + U32, + U64, + I8, + I16, + I32, + I64, + F8, + BF16, + F32, + F64, +} + +#[cfg(feature = "cuda")] +pub unsafe fn __library_cuda_gemm( + i: u64, + j: u64, + k: u64, + c: *mut u8, + a: *const u8, + b: *const u8, + ty: PrimTy, +) { + panic!("{} {} {} {:?} {:?} {:?} {:?}", i, j, k, c, a, b, ty); +} + #[cfg(feature = "cuda")] extern "C" { - pub fn ___cuda_alloc(size: usize) -> *mut u8; - pub fn ___cuda_dealloc(ptr: *mut u8, size: usize); - pub fn ___cuda_zero_mem(ptr: *mut u8, size: usize); - pub fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); - pub fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); - pub fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + fn ___cuda_alloc(size: usize) -> *mut u8; + fn ___cuda_dealloc(ptr: *mut u8, size: usize); + fn ___cuda_zero_mem(ptr: *mut u8, size: usize); + fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); + fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); } #[derive(Clone, Debug)] -- GitLab From 84dc38986eb6af2815adf526fa14fbab968b0000 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 18 Feb 2025 10:23:49 -0600 Subject: [PATCH 12/13] e2e matmul cublas works --- hercules_rt/build.rs | 1 + hercules_rt/src/lib.rs | 7 +++- hercules_rt/src/rtdefs.cu | 70 +++++++++++++++++++------------ juno_samples/matmul/src/main.rs | 14 ++++--- juno_samples/matmul/src/matmul.jn | 4 +- 5 files changed, 60 insertions(+), 36 deletions(-) diff --git a/hercules_rt/build.rs b/hercules_rt/build.rs index 2a1538d6..ab9dda2e 100644 --- a/hercules_rt/build.rs +++ b/hercules_rt/build.rs @@ -28,6 +28,7 @@ fn main() { println!("cargo::rustc-link-search=native=/opt/cuda/lib/"); println!("cargo::rustc-link-lib=static=rtdefs"); println!("cargo::rustc-link-lib=cudart"); + println!("cargo::rustc-link-lib=cublas"); println!("cargo::rerun-if-changed=src/rtdefs.cu"); } } diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index e360076e..714ac7a1 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -106,7 +106,6 @@ pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) { ___copy_cuda_to_cuda(dst, src, size); } -#[repr(u8)] #[derive(Debug, Copy, Clone)] pub enum PrimTy { Bool, @@ -134,7 +133,10 @@ pub unsafe fn __library_cuda_gemm( b: *const u8, ty: PrimTy, ) { - panic!("{} {} {} {:?} {:?} {:?} {:?}", i, j, k, c, a, b, ty); + match ty { + PrimTy::F32 => ___cublas_sgemm(i, j, k, c, a, b), + _ => todo!(), + } } #[cfg(feature = "cuda")] @@ -145,6 +147,7 @@ extern "C" { fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + fn ___cublas_sgemm(i: u64, j: u64, k: u64, c: *mut u8, a: *const u8, b: *const u8); } #[derive(Clone, Debug)] diff --git a/hercules_rt/src/rtdefs.cu b/hercules_rt/src/rtdefs.cu index 50e11fa6..26e69821 100644 --- a/hercules_rt/src/rtdefs.cu +++ b/hercules_rt/src/rtdefs.cu @@ -1,31 +1,49 @@ -extern "C" { - void *___cuda_alloc(size_t size) { - void *ptr = NULL; - cudaError_t res = cudaMalloc(&ptr, size); - if (res != cudaSuccess) { - ptr = NULL; - } - return ptr; - } +#include <stdint.h> +#include <cublas_v2.h> - void ___cuda_dealloc(void *ptr, size_t size) { - (void) size; - cudaFree(ptr); - } - - void ___cuda_zero_mem(void *ptr, size_t size) { - cudaMemset(ptr, 0, size); - } +static cublasHandle_t cublas_handle = 0; - void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) { - cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); - } - - void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) { - cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); +extern "C" { + void *___cuda_alloc(size_t size) { + void *ptr = NULL; + cudaError_t res = cudaMalloc(&ptr, size); + if (res != cudaSuccess) { + ptr = NULL; } - - void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) { - cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); + return ptr; + } + + void ___cuda_dealloc(void *ptr, size_t size) { + (void) size; + cudaFree(ptr); + } + + void ___cuda_zero_mem(void *ptr, size_t size) { + cudaMemset(ptr, 0, size); + } + + void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); + } + + void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); + } + + void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) { + cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); + } + + void ___cublas_sgemm(uint64_t i, uint64_t j, uint64_t k, float *c, float *a, float *b) { + if (!cublas_handle) { + cublasCreate(&cublas_handle); } + float alf = 1.0; + float beta = 0.0; + cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, + k, i, j, + &alf, b, k, a, j, + &beta, c, k); + cudaDeviceSynchronize(); + } } diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index 3cb7d7f0..c0e228da 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -1,4 +1,5 @@ #![feature(concat_idents)] +use std::iter::zip; use rand::random; @@ -13,9 +14,9 @@ fn main() { const I: usize = 256; const J: usize = 64; const K: usize = 128; - let a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); - let b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); - let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); + let a: Box<[f32]> = (0..I * J).map(|_| random::<f32>()).collect(); + let b: Box<[f32]> = (0..J * K).map(|_| random::<f32>()).collect(); + let mut correct_c: Box<[f32]> = (0..I * K).map(|_| 0.0).collect(); for i in 0..I { for k in 0..K { for j in 0..J { @@ -27,7 +28,8 @@ fn main() { { let mut r = runner!(matmul); let c = r.run(I as u64, J as u64, K as u64, a.to(), b.to()).await; - assert_eq!(c.as_slice::<i32>(), &*correct_c); + let c = c.as_slice::<f32>(); + assert_eq!(c, &*correct_c); } #[cfg(feature = "cuda")] { @@ -37,9 +39,9 @@ fn main() { let c = r .run(I as u64, J as u64, K as u64, a.get_ref(), b.get_ref()) .await; - let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); + let mut c_cpu: Box<[f32]> = vec![0.0; correct_c.len()].into_boxed_slice(); c.to_cpu_ref(&mut c_cpu); - assert_eq!(&*c_cpu, &*correct_c); + assert!(zip(c_cpu, correct_c).all(|(calc, correct)| (calc - correct).abs() < 0.00001)); } }); } diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn index e36d94e2..460ce41c 100644 --- a/juno_samples/matmul/src/matmul.jn +++ b/juno_samples/matmul/src/matmul.jn @@ -1,6 +1,6 @@ #[entry] -fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { - let res : i32[n, l]; +fn matmul<n : usize, m : usize, l : usize>(a : f32[n, m], b : f32[m, l]) -> f32[n, l] { + let res : f32[n, l]; @outer for i = 0 to n { @middle for j = 0 to l { -- GitLab From 82cd5671a37cd180159e2ca19ce69869016328d1 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 18 Feb 2025 10:33:05 -0600 Subject: [PATCH 13/13] remove xdot --- juno_samples/matmul/src/gpu.sch | 1 - 1 file changed, 1 deletion(-) diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index 76159ef7..76808149 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -24,4 +24,3 @@ dce(*); float-collections(*); gcm(*); -xdot[true](*); -- GitLab