From f1f319bf28aed5a1b988813aa166fb43b267227f Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 13 Feb 2025 14:26:08 -0600 Subject: [PATCH 01/15] multi-core scale --- juno_samples/cava/build.rs | 2 ++ juno_samples/cava/src/cava.jn | 4 ++-- juno_samples/cava/src/cpu.sch | 44 +++++++++++++++++++++++++++++++++++ juno_samples/cava/src/main.rs | 6 ++--- juno_scheduler/src/compile.rs | 2 ++ juno_scheduler/src/ir.rs | 6 ++--- juno_scheduler/src/lang.l | 4 ++-- juno_scheduler/src/pm.rs | 3 +++ 8 files changed, 60 insertions(+), 11 deletions(-) create mode 100644 juno_samples/cava/src/cpu.sch diff --git a/juno_samples/cava/build.rs b/juno_samples/cava/build.rs index 1b6dddf4..ff7e2b6b 100644 --- a/juno_samples/cava/build.rs +++ b/juno_samples/cava/build.rs @@ -13,6 +13,8 @@ fn main() { JunoCompiler::new() .file_in_src("cava.jn") .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index 95a73f5b..bb8afded 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -22,12 +22,12 @@ fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a { const CHAN : u64 = 3; fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, col] { - let res : f32[CHAN, row, col]; + @const let res : f32[CHAN, row, col]; for chan = 0 to CHAN { for r = 0 to row { for c = 0 to col { - res[chan, r, c] = input[chan, r, c] as f32 * 1.0 / 255; + res[chan, r, c] = input[chan, r, c] as f32 / 255; } } } diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch new file mode 100644 index 00000000..0913d4b3 --- /dev/null +++ b/juno_samples/cava/src/cpu.sch @@ -0,0 +1,44 @@ +macro simpl!(X) { + ccp(X); + gvn(X); + phi-elim(X); + dce(X); + infer-schedules(X); +} + +simpl!(*); + +inline(denoise); +cpu(scale, demosaic, denoise, transform, gamut, tone_map, descale); + +ip-sroa(*); +sroa(*); +simpl!(*); + +no-memset(scale@const); +fixpoint { + forkify(scale); + fork-guard-elim(scale); + fork-coalesce(scale); +} +simpl!(*); +fork-dim-merge(scale); +simpl!(*); +fork-tile[2048, 0, false](scale); +simpl!(*); +let out = fork-split(scale); +simpl!(*); +let out = outline(out._0_scale.fj1); +ip-sroa(*); +sroa(*); +simpl!(*); +host(scale); +unforkify(out); +xdot[true](scale, out); + +gcm(*); +fixpoint { + float-collections(*); + dce(*); + gcm(*); +} diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index b4a0f6fd..a940d6eb 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -159,6 +159,7 @@ fn cava_harness(args: CavaInputs) { tonemap, } = load_cam_model(cam_model, CHAN).expect("Error loading camera model"); + println!("Running cava with {} rows, {} columns, and {} control points.", rows, cols, num_ctrl_pts); let result = run_cava( rows, cols, @@ -227,10 +228,8 @@ fn cava_test_small() { }); } -// Disabling the larger test because of how long it takes -/* #[test] -fn cava_test() { +fn cava_test_full() { cava_harness(CavaInputs { input: "examples/raw_tulips.bin".to_string(), output: None, @@ -239,4 +238,3 @@ fn cava_test() { cam_model: "cam_models/NikonD7000".to_string(), }); } -*/ diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 7887b9b3..237ff3b9 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -130,6 +130,8 @@ impl FromStr for Appliable { "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)), "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)), + "print" => Ok(Appliable::Pass(ir::Pass::Print)), + "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)), "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)), "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index a9ee7956..4b88d6a2 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -27,6 +27,7 @@ pub enum Pass { Outline, PhiElim, Predication, + Print, ReduceSLF, ReuseProducts, SLF, @@ -42,10 +43,9 @@ pub enum Pass { impl Pass { pub fn num_args(&self) -> usize { match self { - Pass::Xdot => 1, + Pass::Xdot | Pass::Print => 1, + Pass::ForkFissionBufferize | Pass::ForkInterchange => 2, Pass::ForkChunk => 4, - Pass::ForkFissionBufferize => 2, - Pass::ForkInterchange => 2, _ => 0, } } diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l index 9d4c34bf..2f34f01f 100644 --- a/juno_scheduler/src/lang.l +++ b/juno_scheduler/src/lang.l @@ -43,8 +43,8 @@ panic[\t \n\r]+after "panic_after" print[\t \n\r]+iter "print_iter" stop[\t \n\r]+after "stop_after" -[a-zA-Z][a-zA-Z0-9_\-]*! "MACRO" -[a-zA-Z][a-zA-Z0-9_\-]* "ID" +[a-zA-Z_][a-zA-Z0-9_\-]*! "MACRO" +[a-zA-Z_][a-zA-Z0-9_\-]* "ID" [0-9]+ "INT" . "UNMATCHED" diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index de725608..45b36fcc 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2423,6 +2423,9 @@ fn run_pass( // Put BasicBlocks back, since it's needed for Codegen. pm.bbs = bbs; } + Pass::Print => { + println!("{:?}", args.get(0)); + } } println!("Ran Pass: {:?}", pass); -- GitLab From aedcd82e7c8e65195cf92ed73c59941e886cd9b2 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 13 Feb 2025 22:11:27 -0600 Subject: [PATCH 02/15] Fuse scale and demosaic --- juno_samples/cava/src/cava.jn | 4 ++-- juno_samples/cava/src/cpu.sch | 33 ++++++++++++++------------------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index bb8afded..a06c8d7b 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -207,8 +207,8 @@ fn cava<r, c, num_ctrl_pts : usize>( coefs : f32[4, CHAN], tonemap : f32[256, CHAN], ) -> u8[CHAN, r, c] { - let scaled = scale::<r, c>(input); - let demosc = demosaic::<r, c>(scaled); + @fuse1 let scaled = scale::<r, c>(input); + @fuse1 let demosc = demosaic::<r, c>(scaled); let denosd = denoise::<r, c>(demosc); let transf = transform::<r, c>(denosd, TsTw); let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs); diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 0913d4b3..f090eab3 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -1,5 +1,6 @@ macro simpl!(X) { ccp(X); + simplify-cfg(X); gvn(X); phi-elim(X); dce(X); @@ -8,33 +9,27 @@ macro simpl!(X) { simpl!(*); +let fuse1 = outline(cava@fuse1); +inline(fuse1); + inline(denoise); -cpu(scale, demosaic, denoise, transform, gamut, tone_map, descale); +cpu(denoise, transform, gamut, tone_map, descale); ip-sroa(*); sroa(*); simpl!(*); -no-memset(scale@const); +no-memset(fuse1@const); fixpoint { - forkify(scale); - fork-guard-elim(scale); - fork-coalesce(scale); + forkify(fuse1); + fork-guard-elim(fuse1); + fork-coalesce(fuse1); } -simpl!(*); -fork-dim-merge(scale); -simpl!(*); -fork-tile[2048, 0, false](scale); -simpl!(*); -let out = fork-split(scale); -simpl!(*); -let out = outline(out._0_scale.fj1); -ip-sroa(*); -sroa(*); -simpl!(*); -host(scale); -unforkify(out); -xdot[true](scale, out); +simpl!(fuse1); +array-slf(fuse1); +simpl!(fuse1); +xdot[true](fuse1); +unforkify(fuse1); gcm(*); fixpoint { -- GitLab From 619fad446643b683f3cb6c1622821fce0500a9a2 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 13 Feb 2025 22:31:29 -0600 Subject: [PATCH 03/15] Start optimizing denoise, need forkify fixes + array2prod still --- juno_samples/cava/src/cava.jn | 8 ++++---- juno_samples/cava/src/cpu.sch | 22 +++++++++++++++++++--- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index a06c8d7b..720629e7 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -1,7 +1,7 @@ fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a { const n : usize = rows * cols; - let tmp : a[rows * cols]; + @tmp let tmp : a[rows * cols]; for i = 0 to rows * cols { tmp[i] = m[i / cols, i % cols]; } @@ -102,13 +102,13 @@ fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, } fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] { - let res : f32[CHAN, row, col]; + @res let res : f32[CHAN, row, col]; for chan = 0 to CHAN { for r = 0 to row { for c = 0 to col { if r >= 1 && r < row - 1 && c >= 1 && c < col - 1 { - let filter : f32[3][3]; // same as [3, 3] + @filter let filter : f32[3][3]; // same as [3, 3] for i = 0 to 3 by 1 { for j = 0 to 3 by 1 { filter[i, j] = input[chan, r + i - 1, c + j - 1]; @@ -209,7 +209,7 @@ fn cava<r, c, num_ctrl_pts : usize>( ) -> u8[CHAN, r, c] { @fuse1 let scaled = scale::<r, c>(input); @fuse1 let demosc = demosaic::<r, c>(scaled); - let denosd = denoise::<r, c>(demosc); + @fuse2 let denosd = denoise::<r, c>(demosc); let transf = transform::<r, c>(denosd, TsTw); let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs); let tonemd = tone_map::<r, c>(gamutd, tonemap); diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index f090eab3..b0479f5c 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -1,6 +1,7 @@ macro simpl!(X) { ccp(X); simplify-cfg(X); + lift-dc-math(X); gvn(X); phi-elim(X); dce(X); @@ -12,8 +13,8 @@ simpl!(*); let fuse1 = outline(cava@fuse1); inline(fuse1); -inline(denoise); -cpu(denoise, transform, gamut, tone_map, descale); +let fuse2 = outline(cava@fuse2); +inline(fuse2); ip-sroa(*); sroa(*); @@ -28,9 +29,24 @@ fixpoint { simpl!(fuse1); array-slf(fuse1); simpl!(fuse1); -xdot[true](fuse1); unforkify(fuse1); +inline(fuse2); +no-memset(fuse2@res); +no-memset(fuse2@filter); +no-memset(fuse2@tmp); +fixpoint { + forkify(fuse2); + fork-guard-elim(fuse2); + fork-coalesce(fuse2); +} +simpl!(fuse2); +array-slf(fuse2); +simpl!(fuse2); +array-slf(fuse2); +simpl!(fuse2); +xdot[true](fuse2); + gcm(*); fixpoint { float-collections(*); -- GitLab From 8e8ed5ef8321f9312251262ce5636daf234275e5 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 09:33:49 -0600 Subject: [PATCH 04/15] fixes --- hercules_cg/src/rt.rs | 2 +- juno_samples/cava/src/cava.jn | 3 +- juno_samples/cava/src/cpu.sch | 13 +++---- juno_samples/cava/src/main.rs | 66 +++++++++++++++++++++-------------- 4 files changed, 46 insertions(+), 38 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 090253d4..f758fed3 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -533,7 +533,7 @@ impl<'a> RTContext<'a> { { write!(block, "backing_{}.byte_add(", device.name())?; self.codegen_dynamic_constant(offset, block)?; - write!(block, "), ")? + write!(block, " as usize), ")? } for dc in dynamic_constants { self.codegen_dynamic_constant(*dc, block)?; diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index 720629e7..de21ba78 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -213,5 +213,6 @@ fn cava<r, c, num_ctrl_pts : usize>( let transf = transform::<r, c>(denosd, TsTw); let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs); let tonemd = tone_map::<r, c>(gamutd, tonemap); - return descale::<r, c>(tonemd); + let dscald = descale::<r, c>(tonemd); + return dscald; } diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index b0479f5c..c75260ca 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -43,13 +43,8 @@ fixpoint { simpl!(fuse2); array-slf(fuse2); simpl!(fuse2); -array-slf(fuse2); -simpl!(fuse2); -xdot[true](fuse2); +unforkify(fuse2); -gcm(*); -fixpoint { - float-collections(*); - dce(*); - gcm(*); -} +delete-uncalled(*); + +gcm(*); \ No newline at end of file diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index a940d6eb..700b74e6 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -19,24 +19,31 @@ juno_build::juno!("cava"); // Individual lifetimes are not needed in this example but should probably be generated for // flexibility async fn safe_run<'a, 'b: 'a, 'c: 'a, 'd: 'a, 'e: 'a, 'f: 'a, 'g: 'a>( - runner: &'a mut HerculesRunner_cava, r: u64, c: u64, num_ctrl_pts: u64, - input: &'b HerculesImmBox<'b, u8>, tstw: &'c HerculesImmBox<'c, f32>, - ctrl_pts: &'d HerculesImmBox<'d, f32>, weights: &'e HerculesImmBox<'e, f32>, - coefs: &'f HerculesImmBox<'f, f32>, tonemap: &'g HerculesImmBox<'g, f32>, + runner: &'a mut HerculesRunner_cava, + r: u64, + c: u64, + num_ctrl_pts: u64, + input: &'b HerculesImmBox<'b, u8>, + tstw: &'c HerculesImmBox<'c, f32>, + ctrl_pts: &'d HerculesImmBox<'d, f32>, + weights: &'e HerculesImmBox<'e, f32>, + coefs: &'f HerculesImmBox<'f, f32>, + tonemap: &'g HerculesImmBox<'g, f32>, ) -> HerculesMutBox<'a, u8> { HerculesMutBox::from( - runner.run( - r, - c, - num_ctrl_pts, - input.to(), - tstw.to(), - ctrl_pts.to(), - weights.to(), - coefs.to(), - tonemap.to() - ) - .await + runner + .run( + r, + c, + num_ctrl_pts, + input.to(), + tstw.to(), + ctrl_pts.to(), + weights.to(), + coefs.to(), + tonemap.to(), + ) + .await, ) } @@ -68,16 +75,17 @@ fn run_cava( let mut r = runner!(cava); async_std::task::block_on(async { - safe_run(&mut r, - rows as u64, - cols as u64, - num_ctrl_pts as u64, - &image, - &tstw, - &ctrl_pts, - &weights, - &coefs, - &tonemap, + safe_run( + &mut r, + rows as u64, + cols as u64, + num_ctrl_pts as u64, + &image, + &tstw, + &ctrl_pts, + &weights, + &coefs, + &tonemap, ) .await }) @@ -159,7 +167,10 @@ fn cava_harness(args: CavaInputs) { tonemap, } = load_cam_model(cam_model, CHAN).expect("Error loading camera model"); - println!("Running cava with {} rows, {} columns, and {} control points.", rows, cols, num_ctrl_pts); + println!( + "Running cava with {} rows, {} columns, and {} control points.", + rows, cols, num_ctrl_pts + ); let result = run_cava( rows, cols, @@ -229,6 +240,7 @@ fn cava_test_small() { } #[test] +#[ignore] fn cava_test_full() { cava_harness(CavaInputs { input: "examples/raw_tulips.bin".to_string(), -- GitLab From 5ba0636be767b7421c6db89ce6773b71bc792b4c Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 09:43:11 -0600 Subject: [PATCH 05/15] I am surprised these forkify --- juno_samples/cava/src/cava.jn | 14 +++++++------- juno_samples/cava/src/cpu.sch | 29 +++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index de21ba78..29fc4df5 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -22,7 +22,7 @@ fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a { const CHAN : u64 = 3; fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, col] { - @const let res : f32[CHAN, row, col]; + @res1 let res : f32[CHAN, row, col]; for chan = 0 to CHAN { for r = 0 to row { @@ -50,7 +50,7 @@ fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, ro } fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] { - let res : f32[CHAN, row, col]; + @res2 let res : f32[CHAN, row, col]; for r = 1 to row-1 { for c = 1 to col-1 { @@ -129,7 +129,7 @@ fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, r fn transform<row : usize, col : usize> (input : f32[CHAN, row, col], tstw_trans : f32[CHAN, CHAN]) -> f32[CHAN, row, col] { - let result : f32[CHAN, row, col]; + @res let result : f32[CHAN, row, col]; for chan = 0 to CHAN { for r = 0 to row { @@ -152,11 +152,11 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( weights : f32[num_ctrl_pts, CHAN], coefs : f32[4, CHAN] ) -> f32[CHAN, row, col] { - let result : f32[CHAN, row, col]; - let l2_dist : f32[num_ctrl_pts]; + @res let result : f32[CHAN, row, col]; for r = 0 to row { for c = 0 to col { + @l2 let l2_dist : f32[num_ctrl_pts]; for cp = 0 to num_ctrl_pts { let v1 = input[0, r, c] - ctrl_pts[cp, 0]; let v2 = input[1, r, c] - ctrl_pts[cp, 1]; @@ -210,8 +210,8 @@ fn cava<r, c, num_ctrl_pts : usize>( @fuse1 let scaled = scale::<r, c>(input); @fuse1 let demosc = demosaic::<r, c>(scaled); @fuse2 let denosd = denoise::<r, c>(demosc); - let transf = transform::<r, c>(denosd, TsTw); - let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs); + @fuse3 let transf = transform::<r, c>(denosd, TsTw); + @fuse4 let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs); let tonemd = tone_map::<r, c>(gamutd, tonemap); let dscald = descale::<r, c>(tonemd); return dscald; diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index c75260ca..8099c0ba 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -16,11 +16,18 @@ inline(fuse1); let fuse2 = outline(cava@fuse2); inline(fuse2); +let fuse3 = outline(cava@fuse3); +inline(fuse3); + +let fuse4 = outline(cava@fuse4); +inline(fuse4); + ip-sroa(*); sroa(*); simpl!(*); -no-memset(fuse1@const); +no-memset(fuse1@res1); +no-memset(fuse1@res2); fixpoint { forkify(fuse1); fork-guard-elim(fuse1); @@ -45,6 +52,24 @@ array-slf(fuse2); simpl!(fuse2); unforkify(fuse2); -delete-uncalled(*); +no-memset(fuse3@res); +fixpoint { + forkify(fuse3); + fork-guard-elim(fuse3); + fork-coalesce(fuse3); +} +fork-split(fuse3); +unforkify(fuse3); +no-memset(fuse4@res); +no-memset(fuse4@l2); +fixpoint { + forkify(fuse4); + fork-guard-elim(fuse4); + fork-coalesce(fuse4); +} +fork-split(fuse4); +unforkify(fuse4); + +delete-uncalled(*); gcm(*); \ No newline at end of file -- GitLab From 1848d531851e3a815d2545a8e7471ef74711095c Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 09:48:40 -0600 Subject: [PATCH 06/15] Chefs kiss --- juno_samples/cava/src/cava.jn | 34 +++++++++++++++++----------------- juno_samples/cava/src/cpu.sch | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index 29fc4df5..366792c3 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -35,20 +35,6 @@ fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, return res; } -fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] { - let res : u8[CHAN, row, col]; - - for chan = 0 to CHAN { - for r = 0 to row { - for c = 0 to col { - res[chan, r, c] = min!::<f32>(max!::<f32>(input[chan, r, c] * 255, 0), 255) as u8; - } - } - } - - return res; -} - fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] { @res2 let res : f32[CHAN, row, col]; @@ -184,7 +170,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( fn tone_map<row : usize, col:usize> (input : f32[CHAN, row, col], tone_map : f32[256, CHAN]) -> f32[CHAN, row, col] { - let result : f32[CHAN, row, col]; + @res1 let result : f32[CHAN, row, col]; for chan = 0 to CHAN { for r = 0 to row { @@ -198,6 +184,20 @@ fn tone_map<row : usize, col:usize> return result; } +fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] { + @res2 let res : u8[CHAN, row, col]; + + for chan = 0 to CHAN { + for r = 0 to row { + for c = 0 to col { + res[chan, r, c] = min!::<f32>(max!::<f32>(input[chan, r, c] * 255, 0), 255) as u8; + } + } + } + + return res; +} + #[entry] fn cava<r, c, num_ctrl_pts : usize>( input : u8[CHAN, r, c], @@ -212,7 +212,7 @@ fn cava<r, c, num_ctrl_pts : usize>( @fuse2 let denosd = denoise::<r, c>(demosc); @fuse3 let transf = transform::<r, c>(denosd, TsTw); @fuse4 let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs); - let tonemd = tone_map::<r, c>(gamutd, tonemap); - let dscald = descale::<r, c>(tonemd); + @fuse5 let tonemd = tone_map::<r, c>(gamutd, tonemap); + @fuse5 let dscald = descale::<r, c>(tonemd); return dscald; } diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 8099c0ba..d51b8b94 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -22,6 +22,9 @@ inline(fuse3); let fuse4 = outline(cava@fuse4); inline(fuse4); +let fuse5 = outline(cava@fuse5); +inline(fuse5); + ip-sroa(*); sroa(*); simpl!(*); @@ -58,6 +61,7 @@ fixpoint { fork-guard-elim(fuse3); fork-coalesce(fuse3); } +simpl!(fuse3); fork-split(fuse3); unforkify(fuse3); @@ -68,8 +72,23 @@ fixpoint { fork-guard-elim(fuse4); fork-coalesce(fuse4); } +simpl!(fuse4); fork-split(fuse4); unforkify(fuse4); +no-memset(fuse5@res1); +no-memset(fuse5@res2); +fixpoint { + forkify(fuse5); + fork-guard-elim(fuse5); + fork-coalesce(fuse5); +} +simpl!(fuse5); +array-slf(fuse5); +simpl!(fuse5); +fork-split(fuse5); +unforkify(fuse5); + +simpl!(*); delete-uncalled(*); gcm(*); \ No newline at end of file -- GitLab From cb35f6a1b0b4748ceac295075f5ab138cbde5424 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 14:17:06 -0600 Subject: [PATCH 07/15] Works! --- juno_samples/cava/src/cpu.sch | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 2ef7152f..246d10d4 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -75,7 +75,8 @@ inline(fuse2); ip-sroa(*); sroa(*); array-slf(fuse2); -xdot[true](fuse2); +write-predication(fuse2); +simpl!(fuse2); fork-split(fuse2); unforkify(fuse2); -- GitLab From c5d8d401362ca444ecb73afb47c52f021d360b61 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 14:21:43 -0600 Subject: [PATCH 08/15] whoops --- juno_samples/cava/src/cava.jn | 2 +- juno_samples/cava/src/cpu.sch | 140 +++++++++++++++++----------------- 2 files changed, 71 insertions(+), 71 deletions(-) diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index f95f1ebc..839e80e2 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -29,7 +29,7 @@ fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, for chan = 0 to CHAN { for r = 0 to row { for c = 0 to col { - res[chan, r, c] = input[chan, r, c] as f32 * (1.0 / 255.0); + res[chan, r, c] = input[chan, r, c] as f32 / 255.0; } } } diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 246d10d4..5c5428ae 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -27,92 +27,92 @@ inline(fuse5); ip-sroa(*); sroa(*); -simpl!(*); +//simpl!(*); no-memset(fuse1@res1); no-memset(fuse1@res2); -fixpoint { - forkify(fuse1); - fork-guard-elim(fuse1); - fork-coalesce(fuse1); -} -simpl!(fuse1); -array-slf(fuse1); -simpl!(fuse1); -unforkify(fuse1); +//fixpoint { +// forkify(fuse1); +// fork-guard-elim(fuse1); +// fork-coalesce(fuse1); +//} +//simpl!(fuse1); +//array-slf(fuse1); +//simpl!(fuse1); +//unforkify(fuse1); inline(fuse2); no-memset(fuse2@res); no-memset(fuse2@filter); no-memset(fuse2@tmp); -fixpoint { - forkify(fuse2); - fork-guard-elim(fuse2); - fork-coalesce(fuse2); -} -simpl!(fuse2); -predication(fuse2); -simpl!(fuse2); - -let median = outline(fuse2@median); -fork-unroll(median@medianOuter); -simpl!(median); -fixpoint { - forkify(median); - fork-guard-elim(median); -} -simpl!(median); -fixpoint { - fork-unroll(median); -} -ccp(median); -array-to-product(median); -sroa(median); -predication(median); -simpl!(median); - -inline(fuse2); -ip-sroa(*); -sroa(*); -array-slf(fuse2); -write-predication(fuse2); -simpl!(fuse2); -fork-split(fuse2); -unforkify(fuse2); +//fixpoint { +// forkify(fuse2); +// fork-guard-elim(fuse2); +// fork-coalesce(fuse2); +//} +//simpl!(fuse2); +//predication(fuse2); +//simpl!(fuse2); +// +//let median = outline(fuse2@median); +//fork-unroll(median@medianOuter); +//simpl!(median); +//fixpoint { +// forkify(median); +// fork-guard-elim(median); +//} +//simpl!(median); +//fixpoint { +// fork-unroll(median); +//} +//ccp(median); +//array-to-product(median); +//sroa(median); +//predication(median); +//simpl!(median); +// +//inline(fuse2); +//ip-sroa(*); +//sroa(*); +//array-slf(fuse2); +//write-predication(fuse2); +//simpl!(fuse2); +//fork-split(fuse2); +//unforkify(fuse2); no-memset(fuse3@res); -fixpoint { - forkify(fuse3); - fork-guard-elim(fuse3); - fork-coalesce(fuse3); -} -simpl!(fuse3); -fork-split(fuse3); -unforkify(fuse3); +//fixpoint { +// forkify(fuse3); +// fork-guard-elim(fuse3); +// fork-coalesce(fuse3); +//} +//simpl!(fuse3); +//fork-split(fuse3); +//unforkify(fuse3); no-memset(fuse4@res); no-memset(fuse4@l2); -fixpoint { - forkify(fuse4); - fork-guard-elim(fuse4); - fork-coalesce(fuse4); -} -simpl!(fuse4); -fork-split(fuse4); -unforkify(fuse4); +//fixpoint { +// forkify(fuse4); +// fork-guard-elim(fuse4); +// fork-coalesce(fuse4); +//} +//simpl!(fuse4); +//fork-split(fuse4); +//unforkify(fuse4); no-memset(fuse5@res1); no-memset(fuse5@res2); -fixpoint { - forkify(fuse5); - fork-guard-elim(fuse5); - fork-coalesce(fuse5); -} -simpl!(fuse5); -array-slf(fuse5); -simpl!(fuse5); -fork-split(fuse5); -unforkify(fuse5); +//fixpoint { +// forkify(fuse5); +// fork-guard-elim(fuse5); +// fork-coalesce(fuse5); +//} +//simpl!(fuse5); +//array-slf(fuse5); +//simpl!(fuse5); +//fork-split(fuse5); +//unforkify(fuse5); simpl!(*); delete-uncalled(*); -- GitLab From cd0f1635fc0c2beb088ab11b470f092e2b6949b0 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 14:22:45 -0600 Subject: [PATCH 09/15] actually works --- juno_samples/cava/src/cpu.sch | 140 +++++++++++++++++----------------- 1 file changed, 70 insertions(+), 70 deletions(-) diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 5c5428ae..246d10d4 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -27,92 +27,92 @@ inline(fuse5); ip-sroa(*); sroa(*); -//simpl!(*); +simpl!(*); no-memset(fuse1@res1); no-memset(fuse1@res2); -//fixpoint { -// forkify(fuse1); -// fork-guard-elim(fuse1); -// fork-coalesce(fuse1); -//} -//simpl!(fuse1); -//array-slf(fuse1); -//simpl!(fuse1); -//unforkify(fuse1); +fixpoint { + forkify(fuse1); + fork-guard-elim(fuse1); + fork-coalesce(fuse1); +} +simpl!(fuse1); +array-slf(fuse1); +simpl!(fuse1); +unforkify(fuse1); inline(fuse2); no-memset(fuse2@res); no-memset(fuse2@filter); no-memset(fuse2@tmp); -//fixpoint { -// forkify(fuse2); -// fork-guard-elim(fuse2); -// fork-coalesce(fuse2); -//} -//simpl!(fuse2); -//predication(fuse2); -//simpl!(fuse2); -// -//let median = outline(fuse2@median); -//fork-unroll(median@medianOuter); -//simpl!(median); -//fixpoint { -// forkify(median); -// fork-guard-elim(median); -//} -//simpl!(median); -//fixpoint { -// fork-unroll(median); -//} -//ccp(median); -//array-to-product(median); -//sroa(median); -//predication(median); -//simpl!(median); -// -//inline(fuse2); -//ip-sroa(*); -//sroa(*); -//array-slf(fuse2); -//write-predication(fuse2); -//simpl!(fuse2); -//fork-split(fuse2); -//unforkify(fuse2); +fixpoint { + forkify(fuse2); + fork-guard-elim(fuse2); + fork-coalesce(fuse2); +} +simpl!(fuse2); +predication(fuse2); +simpl!(fuse2); + +let median = outline(fuse2@median); +fork-unroll(median@medianOuter); +simpl!(median); +fixpoint { + forkify(median); + fork-guard-elim(median); +} +simpl!(median); +fixpoint { + fork-unroll(median); +} +ccp(median); +array-to-product(median); +sroa(median); +predication(median); +simpl!(median); + +inline(fuse2); +ip-sroa(*); +sroa(*); +array-slf(fuse2); +write-predication(fuse2); +simpl!(fuse2); +fork-split(fuse2); +unforkify(fuse2); no-memset(fuse3@res); -//fixpoint { -// forkify(fuse3); -// fork-guard-elim(fuse3); -// fork-coalesce(fuse3); -//} -//simpl!(fuse3); -//fork-split(fuse3); -//unforkify(fuse3); +fixpoint { + forkify(fuse3); + fork-guard-elim(fuse3); + fork-coalesce(fuse3); +} +simpl!(fuse3); +fork-split(fuse3); +unforkify(fuse3); no-memset(fuse4@res); no-memset(fuse4@l2); -//fixpoint { -// forkify(fuse4); -// fork-guard-elim(fuse4); -// fork-coalesce(fuse4); -//} -//simpl!(fuse4); -//fork-split(fuse4); -//unforkify(fuse4); +fixpoint { + forkify(fuse4); + fork-guard-elim(fuse4); + fork-coalesce(fuse4); +} +simpl!(fuse4); +fork-split(fuse4); +unforkify(fuse4); no-memset(fuse5@res1); no-memset(fuse5@res2); -//fixpoint { -// forkify(fuse5); -// fork-guard-elim(fuse5); -// fork-coalesce(fuse5); -//} -//simpl!(fuse5); -//array-slf(fuse5); -//simpl!(fuse5); -//fork-split(fuse5); -//unforkify(fuse5); +fixpoint { + forkify(fuse5); + fork-guard-elim(fuse5); + fork-coalesce(fuse5); +} +simpl!(fuse5); +array-slf(fuse5); +simpl!(fuse5); +fork-split(fuse5); +unforkify(fuse5); simpl!(*); delete-uncalled(*); -- GitLab From 81f524ca76868a04d1642895fe974f6f58ca8484 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 14:27:55 -0600 Subject: [PATCH 10/15] This eliminates most selects --- juno_samples/cava/src/cpu.sch | 1 + 1 file changed, 1 insertion(+) diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 246d10d4..e9e3e683 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -68,6 +68,7 @@ fixpoint { ccp(median); array-to-product(median); sroa(median); +phi-elim(median); predication(median); simpl!(median); -- GitLab From 0958995eb70b45d0b7d71994b5f2de47fb8828e3 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 14:40:24 -0600 Subject: [PATCH 11/15] unroll channel loop in gamut with intention of fusion the ctrl pts loop across channels --- juno_samples/cava/src/cava.jn | 2 +- juno_samples/cava/src/cpu.sch | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index 839e80e2..e6961faf 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -153,7 +153,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( l2_dist[cp] = sqrt!::<f32>(v); } - for chan = 0 to CHAN { + @channel_loop for chan = 0 to CHAN { let chan_val : f32 = 0.0; for cp = 0 to num_ctrl_pts { chan_val += l2_dist[cp] * weights[cp, chan]; diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index e9e3e683..3021f6a0 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -99,6 +99,9 @@ fixpoint { fork-coalesce(fuse4); } simpl!(fuse4); +fork-unroll(fuse4@channel_loop); +simpl!(fuse4); +//fork-fusion(fuse4@channel_loop); fork-split(fuse4); unforkify(fuse4); -- GitLab From 0b1a9eb3c1a023f695fad05fd070c62610cd2a22 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 15:02:58 -0600 Subject: [PATCH 12/15] Add intrinsics to math exprs, array slf gets rid of first ctrl pts look in gamut --- hercules_ir/src/einsum.rs | 31 +++++++++++++++++++++++++++++++ hercules_opt/src/utils.rs | 7 +++++++ juno_samples/cava/src/cpu.sch | 3 +++ 3 files changed, 41 insertions(+) diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index 5dd0fe5d..b222e1bc 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -34,6 +34,7 @@ pub enum MathExpr { Unary(UnaryOperator, MathID), Binary(BinaryOperator, MathID, MathID), Ternary(TernaryOperator, MathID, MathID, MathID), + IntrinsicFunc(Intrinsic, Box<[MathID]>), } pub type MathEnv = Vec<MathExpr>; @@ -224,6 +225,16 @@ impl<'a> EinsumContext<'a> { let third = self.compute_math_expr(third); MathExpr::Ternary(op, first, second, third) } + Node::IntrinsicCall { + intrinsic, + ref args, + } => { + let args = args + .into_iter() + .map(|id| self.compute_math_expr(*id)) + .collect(); + MathExpr::IntrinsicFunc(intrinsic, args) + } Node::Read { collect, ref indices, @@ -322,6 +333,14 @@ impl<'a> EinsumContext<'a> { let third = self.substitute_new_dims(third); self.intern_math_expr(MathExpr::Ternary(op, first, second, third)) } + MathExpr::IntrinsicFunc(intrinsic, ref args) => { + let args = args + .clone() + .iter() + .map(|id| self.substitute_new_dims(*id)) + .collect(); + self.intern_math_expr(MathExpr::IntrinsicFunc(intrinsic, args)) + } _ => id, } } @@ -355,6 +374,9 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> { stack.push(second); stack.push(third); } + MathExpr::IntrinsicFunc(_, ref args) => { + stack.extend(args); + } } } set @@ -412,5 +434,14 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) { debug_print_math_expr(third, env); print!(")"); } + MathExpr::IntrinsicFunc(intrinsic, ref args) => { + print!("{}(", intrinsic.lower_case_name()); + debug_print_math_expr(id, env); + for arg in args { + print!(", "); + debug_print_math_expr(*arg, env); + } + print!(")"); + } } } diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index c5e0d934..3f12ad7c 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -474,6 +474,13 @@ pub fn materialize_simple_einsum_expr( third, }) } + MathExpr::IntrinsicFunc(intrinsic, ref args) => { + let args = args + .into_iter() + .map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs)) + .collect(); + edit.add_node(Node::IntrinsicCall { intrinsic, args }) + } _ => panic!(), } } diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 3021f6a0..1ae1dc13 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -102,6 +102,9 @@ simpl!(fuse4); fork-unroll(fuse4@channel_loop); simpl!(fuse4); //fork-fusion(fuse4@channel_loop); +simpl!(fuse4); +array-slf(fuse4); +simpl!(fuse4); fork-split(fuse4); unforkify(fuse4); -- GitLab From 9397f55db557497e0fbf6934af646625668a8bc3 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 17:17:02 -0600 Subject: [PATCH 13/15] Fork fusion --- hercules_opt/src/fork_transforms.rs | 120 +++++++++++++++++++++++++--- juno_samples/cava/src/cava.jn | 2 +- juno_samples/cava/src/cpu.sch | 4 +- juno_scheduler/src/compile.rs | 1 + juno_scheduler/src/ir.rs | 1 + juno_scheduler/src/pm.rs | 21 +++++ 6 files changed, 135 insertions(+), 14 deletions(-) diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 72b5716f..c32a517e 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -912,16 +912,16 @@ pub fn chunk_fork_unguarded( let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx + 1, tile_size); new_factors[dim_idx] = edit.add_dynamic_constant(outer); - + let new_fork = Node::Fork { control: old_control, factors: new_factors.into(), }; let new_fork = edit.add_node(new_fork); - + edit = edit.replace_all_uses(fork, new_fork)?; edit.sub_edit(fork, new_fork); - + for (tid, node) in fork_users { let Node::ThreadID { control: _, @@ -945,7 +945,7 @@ pub fn chunk_fork_unguarded( dimension: tid_dim + 1, }; let tile_tid = edit.add_node(tile_tid); - + let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); let mul = edit.add_node(Node::Binary { left: tid, @@ -965,23 +965,23 @@ pub fn chunk_fork_unguarded( edit = edit.delete_node(fork)?; Ok(edit) }); - }, + } TileOrder::TileOuter => { editor.edit(|mut edit| { let inner = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx, tile_size); - let inner_dc_id = edit.add_dynamic_constant(inner); + let inner_dc_id = edit.add_dynamic_constant(inner); new_factors[dim_idx + 1] = inner_dc_id; - + let new_fork = Node::Fork { control: old_control, factors: new_factors.into(), }; let new_fork = edit.add_node(new_fork); - + edit = edit.replace_all_uses(fork, new_fork)?; edit.sub_edit(fork, new_fork); - + for (tid, node) in fork_users { let Node::ThreadID { control: _, @@ -1000,13 +1000,12 @@ pub fn chunk_fork_unguarded( edit.sub_edit(tid, new_tid); edit = edit.delete_node(tid)?; } else if tid_dim == dim_idx { - let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim, }; let tile_tid = edit.add_node(tile_tid); - let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id } ); + let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id }); let mul = edit.add_node(Node::Binary { left: tid, right: inner_dc, @@ -1027,7 +1026,6 @@ pub fn chunk_fork_unguarded( }); } } - } pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { @@ -1350,3 +1348,101 @@ pub fn fork_unroll( Ok(edit) }) } + +/* + * Looks for fork-joins that are next to each other, not inter-dependent, and + * have the same bounds. These fork-joins can be fused, pooling together all + * their reductions. + */ +pub fn fork_fusion_all_forks( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) { + for (fork, join) in fork_join_map { + if editor.is_mutable(*fork) + && fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins) + { + break; + } + } +} + +/* + * Tries to fuse a given fork join with the immediately following fork-join, if + * it exists. + */ +fn fork_fusion( + editor: &mut FunctionEditor, + top_fork: NodeID, + top_join: NodeID, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) -> bool { + let nodes = &editor.func().nodes; + // Rust operator precedence is not such that these can be put in one big + // let-else statement. Sad! + let Some(bottom_fork) = editor + .get_users(top_join) + .filter(|id| nodes[id.idx()].is_control()) + .next() + else { + return false; + }; + let Some(bottom_join) = fork_join_map.get(&bottom_fork) else { + return false; + }; + let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap(); + let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap(); + assert_eq!(bottom_fork_pred, top_join); + let top_join_pred = nodes[top_join.idx()].try_join().unwrap(); + let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap(); + + // The fork factors must be identical. + if top_factors != bottom_factors { + return false; + } + + // Check that no iterated users of the top's reduces are in the bottom fork- + // join (iteration stops at a phi or reduce outside the bottom fork-join). + for reduce in editor + .get_users(top_join) + .filter(|id| nodes[id.idx()].is_reduce()) + { + let mut visited = HashSet::new(); + visited.insert(reduce); + let mut workset = vec![reduce]; + while let Some(pop) = workset.pop() { + for u in editor.get_users(pop) { + if nodes_in_fork_joins[&bottom_fork].contains(&u) { + return false; + } else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce()) + && !nodes_in_fork_joins[&top_fork].contains(&u) + { + } else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) { + visited.insert(u); + workset.push(u); + } + } + } + } + + // Perform the fusion. + editor.edit(|mut edit| { + if bottom_join_pred != bottom_fork { + // If there is control flow in the bottom fork-join, stitch it into + // the top fork-join. + edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| { + nodes_in_fork_joins[&bottom_fork].contains(id) + })?; + edit = + edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?; + } + // Replace the bottom fork and join with the top fork and join. + edit = edit.replace_all_uses(bottom_fork, top_fork)?; + edit = edit.replace_all_uses(*bottom_join, top_join)?; + edit = edit.delete_node(bottom_fork)?; + edit = edit.delete_node(*bottom_join)?; + Ok(edit) + }) +} diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index e6961faf..8158bf0a 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -152,7 +152,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( let v = v1 * v1 + v2 * v2 + v3 * v3; l2_dist[cp] = sqrt!::<f32>(v); } - + @channel_loop for chan = 0 to CHAN { let chan_val : f32 = 0.0; for cp = 0 to num_ctrl_pts { diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 1ae1dc13..0b8869d5 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -101,7 +101,9 @@ fixpoint { simpl!(fuse4); fork-unroll(fuse4@channel_loop); simpl!(fuse4); -//fork-fusion(fuse4@channel_loop); +fixpoint { + fork-fusion(fuse4@channel_loop); +} simpl!(fuse4); array-slf(fuse4); simpl!(fuse4); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 912cc91f..7c92e00d 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -129,6 +129,7 @@ impl FromStr for Appliable { "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), + "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 8ad92324..205cd70b 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -13,6 +13,7 @@ pub enum Pass { ForkCoalesce, ForkDimMerge, ForkFissionBufferize, + ForkFusion, ForkGuardElim, ForkInterchange, ForkSplit, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index b26d1720..a4783a93 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2526,6 +2526,27 @@ fn run_pass( fields: new_fork_joins, }; } + Pass::ForkFusion => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + for ((func, fork_join_map), nodes_in_fork_joins) in + build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + fork_fusion_all_forks(&mut func, fork_join_map, nodes_in_fork_joins); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps(); -- GitLab From bb50794136a039f29236fb70385ac1abbdd4d1c3 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 17:31:03 -0600 Subject: [PATCH 14/15] Crop the small example image to round dimensions --- juno_samples/cava/src/image_proc.rs | 37 +++++++++++++++++++---------- juno_samples/cava/src/main.rs | 13 +++++++++- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/juno_samples/cava/src/image_proc.rs b/juno_samples/cava/src/image_proc.rs index 94a7c14f..7115bac8 100644 --- a/juno_samples/cava/src/image_proc.rs +++ b/juno_samples/cava/src/image_proc.rs @@ -17,41 +17,54 @@ fn read_bin_u32(f: &mut File) -> Result<u32, Error> { Ok(u32::from_le_bytes(buffer)) } -pub fn read_raw<P: AsRef<Path>>(image: P) -> Result<RawImage, Error> { +pub fn read_raw<P: AsRef<Path>>( + image: P, + crop_rows: Option<usize>, + crop_cols: Option<usize>, +) -> Result<RawImage, Error> { let mut file = File::open(image)?; - let rows = read_bin_u32(&mut file)? as usize; - let cols = read_bin_u32(&mut file)? as usize; + let in_rows = read_bin_u32(&mut file)? as usize; + let in_cols = read_bin_u32(&mut file)? as usize; + let out_rows = crop_rows.unwrap_or(in_rows); + let out_cols = crop_cols.unwrap_or(in_cols); let chans = read_bin_u32(&mut file)? as usize; assert!( chans == CHAN, "Channel size read from the binary file doesn't match the default value" ); + assert!(in_rows >= out_rows); + assert!(in_cols >= out_cols); - let size: usize = rows * cols * CHAN; - let mut buffer: Vec<u8> = vec![0; size]; + let in_size: usize = in_rows * in_cols * CHAN; + let mut buffer: Vec<u8> = vec![0; in_size]; // The file has pixels is a 3D array with dimensions height, width, channel file.read_exact(&mut buffer)?; let chunked_channels = buffer.chunks(CHAN).collect::<Vec<_>>(); - let chunked = chunked_channels.chunks(cols).collect::<Vec<_>>(); + let chunked = chunked_channels.chunks(in_cols).collect::<Vec<_>>(); let input: &[&[&[u8]]] = chunked.as_slice(); // We need the image in a 3D array with dimensions channel, height, width - let mut pixels: Vec<u8> = vec![0; size]; - let mut pixels_columns = pixels.chunks_mut(cols).collect::<Vec<_>>(); - let mut pixels_chunked = pixels_columns.chunks_mut(rows).collect::<Vec<_>>(); + let out_size: usize = out_rows * out_cols * CHAN; + let mut pixels: Vec<u8> = vec![0; out_size]; + let mut pixels_columns = pixels.chunks_mut(out_cols).collect::<Vec<_>>(); + let mut pixels_chunked = pixels_columns.chunks_mut(out_rows).collect::<Vec<_>>(); let result: &mut [&mut [&mut [u8]]] = pixels_chunked.as_mut_slice(); - for row in 0..rows { - for col in 0..cols { + for row in 0..out_rows { + for col in 0..out_cols { for chan in 0..CHAN { result[chan][row][col] = input[row][col][chan]; } } } - Ok(RawImage { rows, cols, pixels }) + Ok(RawImage { + rows: out_rows, + cols: out_cols, + pixels, + }) } pub fn extern_image(rows: usize, cols: usize, image: &[u8]) -> RgbImage { diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index 700b74e6..142ed703 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -147,6 +147,10 @@ struct CavaInputs { #[clap(long = "output-verify", value_name = "PATH")] output_verify: Option<String>, cam_model: String, + #[clap(short, long)] + crop_rows: Option<usize>, + #[clap(short, long)] + crop_cols: Option<usize>, } fn cava_harness(args: CavaInputs) { @@ -156,8 +160,11 @@ fn cava_harness(args: CavaInputs) { verify, output_verify, cam_model, + crop_rows, + crop_cols, } = args; - let RawImage { rows, cols, pixels } = read_raw(input).expect("Error loading image"); + let RawImage { rows, cols, pixels } = + read_raw(input, crop_rows, crop_cols).expect("Error loading image"); let CamModel { tstw, num_ctrl_pts, @@ -236,6 +243,8 @@ fn cava_test_small() { verify: true, output_verify: None, cam_model: "cam_models/NikonD7000".to_string(), + crop_rows: Some(144), + crop_cols: Some(192), }); } @@ -248,5 +257,7 @@ fn cava_test_full() { verify: true, output_verify: None, cam_model: "cam_models/NikonD7000".to_string(), + crop_rows: None, + crop_cols: None, }); } -- GitLab From a3ec1d56780b9afc79bbf8fadc7268620cef6e9a Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Fri, 14 Feb 2025 17:35:49 -0600 Subject: [PATCH 15/15] Reorganize cpu schedule --- juno_samples/cava/src/cpu.sch | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 0b8869d5..623dcfbc 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -38,8 +38,6 @@ fixpoint { } simpl!(fuse1); array-slf(fuse1); -simpl!(fuse1); -unforkify(fuse1); inline(fuse2); no-memset(fuse2@res); @@ -78,8 +76,6 @@ sroa(*); array-slf(fuse2); write-predication(fuse2); simpl!(fuse2); -fork-split(fuse2); -unforkify(fuse2); no-memset(fuse3@res); fixpoint { @@ -88,8 +84,6 @@ fixpoint { fork-coalesce(fuse3); } simpl!(fuse3); -fork-split(fuse3); -unforkify(fuse3); no-memset(fuse4@res); no-memset(fuse4@l2); @@ -107,8 +101,6 @@ fixpoint { simpl!(fuse4); array-slf(fuse4); simpl!(fuse4); -fork-split(fuse4); -unforkify(fuse4); no-memset(fuse5@res1); no-memset(fuse5@res2); @@ -120,9 +112,23 @@ fixpoint { simpl!(fuse5); array-slf(fuse5); simpl!(fuse5); + +delete-uncalled(*); +simpl!(*); +xdot[true](*); + +simpl!(fuse1); +unforkify(fuse1); +fork-split(fuse2); +unforkify(fuse2); +fork-split(fuse3); +unforkify(fuse3); +fork-split(fuse4); +unforkify(fuse4); fork-split(fuse5); unforkify(fuse5); simpl!(*); + delete-uncalled(*); gcm(*); -- GitLab