From 76338dee4402c64b9b8530898ff0864768b4cc13 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 20 Feb 2025 20:54:58 -0600
Subject: [PATCH] The reduction tree works!

---
 hercules_opt/src/outline.rs                     | 12 ++++++++++--
 .../edge_detection/src/edge_detection.jn        |  2 +-
 juno_samples/edge_detection/src/gpu.sch         | 17 ++++++++++++++---
 juno_scheduler/src/pm.rs                        |  2 +-
 4 files changed, 26 insertions(+), 7 deletions(-)

diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs
index 874e75e7..c6693336 100644
--- a/hercules_opt/src/outline.rs
+++ b/hercules_opt/src/outline.rs
@@ -23,7 +23,7 @@ pub fn outline(
     typing: &Vec<TypeID>,
     control_subgraph: &Subgraph,
     dom: &DomTree,
-    partition: &BTreeSet<NodeID>,
+    mut partition: BTreeSet<NodeID>,
     to_be_function_id: FunctionID,
 ) -> Option<Function> {
     // Step 1: do a whole bunch of analysis on the partition.
@@ -34,6 +34,14 @@ pub fn outline(
             .any(|id| nodes[id.idx()].is_start() || nodes[id.idx()].is_parameter() || nodes[id.idx()].is_return()),
         "PANIC: Can't outline a partition containing the start node, parameter nodes, or return nodes."
     );
+    for (idx, node) in nodes.into_iter().enumerate() {
+        if let Node::Constant { id } = node
+            && editor.get_constant(*id).is_scalar()
+        {
+            // Usually, you don't want to explicitly outline scalar constants.
+            partition.remove(&NodeID::new(idx));
+        }
+    }
     let mut top_nodes = partition.iter().filter(|id| {
         nodes[id.idx()].is_control()
             && control_subgraph
@@ -611,7 +619,7 @@ pub fn dumb_outline(
         typing,
         control_subgraph,
         dom,
-        &partition,
+        partition,
         to_be_function_id,
     )
 }
diff --git a/juno_samples/edge_detection/src/edge_detection.jn b/juno_samples/edge_detection/src/edge_detection.jn
index 3bc5bbfb..e1413488 100644
--- a/juno_samples/edge_detection/src/edge_detection.jn
+++ b/juno_samples/edge_detection/src/edge_detection.jn
@@ -189,7 +189,7 @@ fn gradient<n, m, sb: usize>(
 }
 
 fn max_gradient<n, m: usize>(gradient: f32[n, m]) -> f32 {
-  let max = gradient[0, 0];
+  let max = -1.0;
 
   for i = 0 to n {
     for j = 0 to m {
diff --git a/juno_samples/edge_detection/src/gpu.sch b/juno_samples/edge_detection/src/gpu.sch
index a3c804d5..7ee2904f 100644
--- a/juno_samples/edge_detection/src/gpu.sch
+++ b/juno_samples/edge_detection/src/gpu.sch
@@ -8,7 +8,7 @@ macro simpl!(X) {
   infer-schedules(X);
 }
 
-gpu(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, max_gradient, reject_zero_crossings);
+gpu(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, reject_zero_crossings);
 
 simpl!(*);
 
@@ -63,9 +63,21 @@ simpl!(max_gradient);
 fork-dim-merge(max_gradient);
 simpl!(max_gradient);
 fork-tile[32, 0, false, true](max_gradient);
-fork-split(max_gradient);
+let out = fork-split(max_gradient);
 clean-monoid-reduces(max_gradient);
 simpl!(max_gradient);
+let fission = fork-fission[out._4_max_gradient.fj0](max_gradient);
+simpl!(max_gradient);
+fork-tile[32, 0, false, true](fission._4_max_gradient.fj_bottom);
+let out = fork-split(fission._4_max_gradient.fj_bottom);
+clean-monoid-reduces(max_gradient);
+simpl!(max_gradient);
+let top = outline(fission._4_max_gradient.fj_top);
+let bottom = outline(out._4_max_gradient.fj0);
+gpu(top, bottom);
+ip-sroa(*);
+sroa(*);
+simpl!(*);
 
 no-memset(reject_zero_crossings@res);
 fixpoint {
@@ -82,4 +94,3 @@ simpl!(*);
 
 delete-uncalled(*);
 gcm(*);
-
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 392273d3..44b14257 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2158,7 +2158,7 @@ fn run_pass(
                 &typing[func.idx()],
                 &control_subgraphs[func.idx()],
                 &doms[func.idx()],
-                &nodes,
+                nodes,
                 new_func_id,
             );
             let Some(new_func) = new_func else {
-- 
GitLab