From 76338dee4402c64b9b8530898ff0864768b4cc13 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 20 Feb 2025 20:54:58 -0600 Subject: [PATCH] The reduction tree works! --- hercules_opt/src/outline.rs | 12 ++++++++++-- .../edge_detection/src/edge_detection.jn | 2 +- juno_samples/edge_detection/src/gpu.sch | 17 ++++++++++++++--- juno_scheduler/src/pm.rs | 2 +- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs index 874e75e7..c6693336 100644 --- a/hercules_opt/src/outline.rs +++ b/hercules_opt/src/outline.rs @@ -23,7 +23,7 @@ pub fn outline( typing: &Vec<TypeID>, control_subgraph: &Subgraph, dom: &DomTree, - partition: &BTreeSet<NodeID>, + mut partition: BTreeSet<NodeID>, to_be_function_id: FunctionID, ) -> Option<Function> { // Step 1: do a whole bunch of analysis on the partition. @@ -34,6 +34,14 @@ pub fn outline( .any(|id| nodes[id.idx()].is_start() || nodes[id.idx()].is_parameter() || nodes[id.idx()].is_return()), "PANIC: Can't outline a partition containing the start node, parameter nodes, or return nodes." ); + for (idx, node) in nodes.into_iter().enumerate() { + if let Node::Constant { id } = node + && editor.get_constant(*id).is_scalar() + { + // Usually, you don't want to explicitly outline scalar constants. + partition.remove(&NodeID::new(idx)); + } + } let mut top_nodes = partition.iter().filter(|id| { nodes[id.idx()].is_control() && control_subgraph @@ -611,7 +619,7 @@ pub fn dumb_outline( typing, control_subgraph, dom, - &partition, + partition, to_be_function_id, ) } diff --git a/juno_samples/edge_detection/src/edge_detection.jn b/juno_samples/edge_detection/src/edge_detection.jn index 3bc5bbfb..e1413488 100644 --- a/juno_samples/edge_detection/src/edge_detection.jn +++ b/juno_samples/edge_detection/src/edge_detection.jn @@ -189,7 +189,7 @@ fn gradient<n, m, sb: usize>( } fn max_gradient<n, m: usize>(gradient: f32[n, m]) -> f32 { - let max = gradient[0, 0]; + let max = -1.0; for i = 0 to n { for j = 0 to m { diff --git a/juno_samples/edge_detection/src/gpu.sch b/juno_samples/edge_detection/src/gpu.sch index a3c804d5..7ee2904f 100644 --- a/juno_samples/edge_detection/src/gpu.sch +++ b/juno_samples/edge_detection/src/gpu.sch @@ -8,7 +8,7 @@ macro simpl!(X) { infer-schedules(X); } -gpu(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, max_gradient, reject_zero_crossings); +gpu(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, reject_zero_crossings); simpl!(*); @@ -63,9 +63,21 @@ simpl!(max_gradient); fork-dim-merge(max_gradient); simpl!(max_gradient); fork-tile[32, 0, false, true](max_gradient); -fork-split(max_gradient); +let out = fork-split(max_gradient); clean-monoid-reduces(max_gradient); simpl!(max_gradient); +let fission = fork-fission[out._4_max_gradient.fj0](max_gradient); +simpl!(max_gradient); +fork-tile[32, 0, false, true](fission._4_max_gradient.fj_bottom); +let out = fork-split(fission._4_max_gradient.fj_bottom); +clean-monoid-reduces(max_gradient); +simpl!(max_gradient); +let top = outline(fission._4_max_gradient.fj_top); +let bottom = outline(out._4_max_gradient.fj0); +gpu(top, bottom); +ip-sroa(*); +sroa(*); +simpl!(*); no-memset(reject_zero_crossings@res); fixpoint { @@ -82,4 +94,3 @@ simpl!(*); delete-uncalled(*); gcm(*); - diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 392273d3..44b14257 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2158,7 +2158,7 @@ fn run_pass( &typing[func.idx()], &control_subgraphs[func.idx()], &doms[func.idx()], - &nodes, + nodes, new_func_id, ); let Some(new_func) = new_func else { -- GitLab