From 46d62811a72ae590bea2dc7dc382f42e21e88a97 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Wed, 19 Feb 2025 15:13:59 -0600
Subject: [PATCH] Pass manager fixes and example

---
 Cargo.lock                                    | 10 ++++
 Cargo.toml                                    |  1 +
 juno_samples/multi_return/Cargo.toml          | 21 ++++++++
 juno_samples/multi_return/build.rs            | 15 ++++++
 juno_samples/multi_return/src/cpu.sch         | 31 +++++++++++
 juno_samples/multi_return/src/gpu.sch         | 26 ++++++++++
 juno_samples/multi_return/src/main.rs         | 22 ++++++++
 juno_samples/multi_return/src/multi_return.jn | 32 ++++++++++++
 juno_scheduler/src/pm.rs                      | 52 +++++++++++--------
 9 files changed, 189 insertions(+), 21 deletions(-)
 create mode 100644 juno_samples/multi_return/Cargo.toml
 create mode 100644 juno_samples/multi_return/build.rs
 create mode 100644 juno_samples/multi_return/src/cpu.sch
 create mode 100644 juno_samples/multi_return/src/gpu.sch
 create mode 100644 juno_samples/multi_return/src/main.rs
 create mode 100644 juno_samples/multi_return/src/multi_return.jn

diff --git a/Cargo.lock b/Cargo.lock
index c438e846..32dc6a0e 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1347,6 +1347,16 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_multi_return"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_patterns"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index 42d28135..01f8cc13 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -26,6 +26,7 @@ members = [
 	"juno_samples/matmul",
 	"juno_samples/median_window",
 	"juno_samples/multi_device",
+	"juno_samples/multi_return",
 	"juno_samples/patterns",
 	"juno_samples/products",
 	"juno_samples/rodinia/backprop",
diff --git a/juno_samples/multi_return/Cargo.toml b/juno_samples/multi_return/Cargo.toml
new file mode 100644
index 00000000..0fb3de94
--- /dev/null
+++ b/juno_samples/multi_return/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "juno_multi_return"
+version = "0.1.0"
+authors = ["Aaron Councilman <aaronjc4@illinois.edu>"]
+edition = "2021"
+
+[[bin]]
+name = "juno_multi_return"
+path = "src/main.rs"
+
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
+[build-dependencies]
+juno_build = { path = "../../juno_build" }
+
+[dependencies]
+juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
+with_builtin_macros = "0.1.0"
+async-std = "*"
diff --git a/juno_samples/multi_return/build.rs b/juno_samples/multi_return/build.rs
new file mode 100644
index 00000000..3a8f9b1c
--- /dev/null
+++ b/juno_samples/multi_return/build.rs
@@ -0,0 +1,15 @@
+use juno_build::JunoCompiler;
+
+fn main() {
+    JunoCompiler::new()
+        .file_in_src("multi_return.jn")
+        .unwrap()
+        .schedule_in_src(if cfg!(feature = "cuda") {
+            "gpu.sch"
+        } else {
+            "cpu.sch"
+        })
+        .unwrap()
+        .build()
+        .unwrap();
+}
diff --git a/juno_samples/multi_return/src/cpu.sch b/juno_samples/multi_return/src/cpu.sch
new file mode 100644
index 00000000..03fb2585
--- /dev/null
+++ b/juno_samples/multi_return/src/cpu.sch
@@ -0,0 +1,31 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+ip-sroa(*);
+sroa(*);
+dce(*);
+
+forkify(*);
+fork-guard-elim(*);
+gvn(*);
+dce(*);
+
+inline(*);
+delete-uncalled(*);
+
+let out = auto-outline(*);
+cpu(out.rolling_sum_prod);
+
+fork-fusion(out.rolling_sum_prod);
+gvn(*);
+dce(*);
+
+float-collections(*);
+
+unforkify(*);
+gvn(*);
+ccp(*);
+dce(*);
+
+gcm(*);
diff --git a/juno_samples/multi_return/src/gpu.sch b/juno_samples/multi_return/src/gpu.sch
new file mode 100644
index 00000000..e733551d
--- /dev/null
+++ b/juno_samples/multi_return/src/gpu.sch
@@ -0,0 +1,26 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+
+ip-sroa(*);
+sroa(*);
+dce(*);
+
+forkify(*);
+fork-guard-elim(*);
+gvn(*);
+dce(*);
+
+inline(*);
+delete-uncalled(*);
+
+let out = auto-outline(*);
+gpu(out.rolling_sum_prod);
+
+fork-fusion(out.rolling_sum_prod);
+gvn(*);
+dce(*);
+
+float-collections(*);
+
+gcm(*);
diff --git a/juno_samples/multi_return/src/main.rs b/juno_samples/multi_return/src/main.rs
new file mode 100644
index 00000000..63479dba
--- /dev/null
+++ b/juno_samples/multi_return/src/main.rs
@@ -0,0 +1,22 @@
+#![feature(concat_idents)]
+
+juno_build::juno!("median");
+
+use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo};
+
+fn main() {
+    let m = vec![
+        86, 72, 14, 5, 55, 25, 98, 89, 3, 66, 44, 81, 27, 3, 40, 18, 4, 57, 93, 34, 70, 50, 50, 18,
+        34,
+    ];
+    let m = HerculesImmBox::from(m.as_slice());
+
+    let mut r = runner!(median_window);
+    let res = async_std::task::block_on(async { r.run(m.to()).await });
+    assert_eq!(res, 57);
+}
+
+#[test]
+fn test_median_window() {
+    main()
+}
diff --git a/juno_samples/multi_return/src/multi_return.jn b/juno_samples/multi_return/src/multi_return.jn
new file mode 100644
index 00000000..a49df91c
--- /dev/null
+++ b/juno_samples/multi_return/src/multi_return.jn
@@ -0,0 +1,32 @@
+fn rolling_sum<t: number, n: usize>(x: t[n]) -> t, t[n + 1] {
+  let rolling_sum: t[n + 1];
+  let sum = 0;
+
+  for i in 0..n {
+    rolling_sum[i] = sum;
+    sum += x[i];
+  }
+  rolling_sum[n] = sum;
+
+  return (sum, rolling_sum);
+}
+
+fn rolling_prod<t: number, n: usize>(x: t[n]) -> t, t[n + 1] {
+  let rolling_prod: t[n + 1];
+  let prod = 1;
+
+  for i in 0..n {
+    rolling_prod[i] = prod;
+    prod *= x[i];
+  }
+  rolling_prod[n] = prod;
+
+  return prod, rolling_prod;
+}
+
+#[entry]
+fn rolling_sum_prod<n: usize>(x: f32[n]) -> f32[n + 1], f32[n + 1] {
+  let rsum = rolling_sum::<_, n>(x).1;
+  let _, rprod = rolling_prod::<_, n>(x);
+  return rsum, rprod;
+}
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index d5c0af27..94f90048 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2077,16 +2077,32 @@ fn run_pass(
             pm.clear_analyses();
         }
         Pass::InterproceduralSROA => {
-            assert!(args.is_empty());
-            if let Some(_) = selection {
-                return Err(SchedulerError::PassError {
-                    pass: "interproceduralSROA".to_string(),
-                    error: "must be applied to the entire module".to_string(),
-                });
-            }
+            let sroa_with_arrays = match args.get(0) {
+                Some(Value::Boolean { val }) => *val,
+                Some(_) => {
+                    return Err(SchedulerError::PassError {
+                        pass: "sroa".to_string(),
+                        error: "expected boolean argument".to_string(),
+                    });
+                }
+                None => false,
+            };
+
+            let selection = selection_of_functions(pm, selection)
+                .ok_or_else(|| {
+                    SchedulerError::PassError {
+                        pass: "xdot".to_string(),
+                        error: "expected coarse-grained selection (can't partially xdot a function)".to_string(),
+                    }
+                })?;
+            let mut bool_selection = vec![false; pm.functions.len()];
+            selection.into_iter().for_each(|func| bool_selection[func.idx()] = true);
+
+            pm.make_typing();
+            let typing = pm.typing.take().unwrap();
 
             let mut editors = build_editors(pm);
-            interprocedural_sroa(&mut editors);
+            interprocedural_sroa(&mut editors, &typing, &bool_selection, sroa_with_arrays);
 
             for func in editors {
                 changed |= func.modified();
@@ -2720,21 +2736,15 @@ fn run_pass(
                 None => true,
             };
 
-            let mut bool_selection = vec![];
-            if let Some(selection) = selection {
-                bool_selection = vec![false; pm.functions.len()];
-                for loc in selection {
-                    let CodeLocation::Function(id) = loc else {
-                        return Err(SchedulerError::PassError {
+            let selection = selection_of_functions(pm, selection)
+                .ok_or_else(|| {
+                    SchedulerError::PassError {
                         pass: "xdot".to_string(),
                         error: "expected coarse-grained selection (can't partially xdot a function)".to_string(),
-                    });
-                    };
-                    bool_selection[id.idx()] = true;
-                }
-            } else {
-                bool_selection = vec![true; pm.functions.len()];
-            }
+                    }
+                })?;
+            let mut bool_selection = vec![false; pm.functions.len()];
+            selection.into_iter().for_each(|func| bool_selection[func.idx()] = true);
 
             pm.make_reverse_postorders();
             if force_analyses {
-- 
GitLab