From 5ab36921092a378d02f1efb6791944e6347b6085 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 20 Feb 2025 11:43:05 -0600 Subject: [PATCH] Speed up max gradient with tiling + cooperative groups --- hercules_cg/src/gpu.rs | 6 +++--- juno_samples/edge_detection/src/gpu.sch | 10 ++++++++-- juno_samples/edge_detection/src/lib.rs | 5 +++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 17f0f893..73dcf528 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -1383,15 +1383,15 @@ extern \"C\" {} {}(", let cg_tile = self.get_cg_tile(nesting_fork.unwrap(), CGType::Use); #[allow(unreachable_patterns)] let cg_op = match intrinsic { - Intrinsic::Max => "max", - Intrinsic::Min => "min", + Intrinsic::Max => "greater", + Intrinsic::Min => "less", _ => unreachable!(), }; let id_type_name = self.get_type(id_type, false); write!( w, "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n", - tabs, define_variable, non_reduce_arg, cg_tile, cg_op, id_type_name + tabs, define_variable, cg_tile, non_reduce_arg, cg_op, id_type_name )?; } else { let ty = &self.types[id_type.idx()]; diff --git a/juno_samples/edge_detection/src/gpu.sch b/juno_samples/edge_detection/src/gpu.sch index 3da40fd3..ad3ec65c 100644 --- a/juno_samples/edge_detection/src/gpu.sch +++ b/juno_samples/edge_detection/src/gpu.sch @@ -8,6 +8,8 @@ macro simpl!(X) { infer-schedules(X); } +gpu(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, max_gradient, reject_zero_crossings); + simpl!(*); ip-sroa(*); @@ -58,6 +60,12 @@ fixpoint { fork-coalesce(max_gradient); } simpl!(max_gradient); +fork-dim-merge(max_gradient); +simpl!(max_gradient); +fork-tile[32, 0, false, true](max_gradient); +simpl!(max_gradient); +fork-split(max_gradient); +simpl!(max_gradient); no-memset(reject_zero_crossings@res); fixpoint { @@ -70,8 +78,6 @@ simpl!(reject_zero_crossings); async-call(edge_detection@le, edge_detection@zc); -gpu(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, max_gradient, reject_zero_crossings); - simpl!(*); delete-uncalled(*); diff --git a/juno_samples/edge_detection/src/lib.rs b/juno_samples/edge_detection/src/lib.rs index 6c2a15bd..dab84cf6 100644 --- a/juno_samples/edge_detection/src/lib.rs +++ b/juno_samples/edge_detection/src/lib.rs @@ -143,6 +143,11 @@ pub fn edge_detection_harness(args: EdgeDetectionInputs) { num_frames }; + println!( + "Running edge with {} rows, {} columns, {} gs, {} sz, and {} sb.", + height, width, gs, sz, sb, + ); + let mut r = runner!(edge_detection); let mut output = output.map(|filename| { -- GitLab