From 7a9bdc262e76dbedb1849b3e3d318829a8c1fe3a Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 10 Feb 2025 09:39:16 -0600
Subject: [PATCH] gpu

---
 hercules_opt/src/fork_transforms.rs                | 14 ++++++--------
 .../fork_join_tests/src/fork_join_tests.jn         |  6 +++---
 juno_samples/fork_join_tests/src/gpu.sch           | 13 ++++++++++++-
 juno_scheduler/src/pm.rs                           |  3 ---
 4 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs
index 80ff7b8f..b5d3bb28 100644
--- a/hercules_opt/src/fork_transforms.rs
+++ b/hercules_opt/src/fork_transforms.rs
@@ -322,8 +322,6 @@ pub fn fork_bufferize_fission_helper<'a, 'b>(
 where
     'a: 'b,
 {
-    // Returns the two forks that it generates.
-
     if bufferized_edges.is_empty() {
         return None;
     }
@@ -384,7 +382,6 @@ where
                 .replace_all_uses_where(old_node, map[&old_node], |node| *node == outside_user)?;
         }
 
-        let new_join = map[&join];
         let new_fork = map[&fork];
 
         // FIXME: Do this as part of copy subgraph?
@@ -422,12 +419,14 @@ where
             ));
             let empty_buffer = edit.add_zero_constant(empty_buffer);
             let empty_buffer = edit.add_node(Node::Constant { id: empty_buffer });
+            edit = edit.add_schedule(empty_buffer, Schedule::NoResetConstant)?;
             let reduce = Node::Reduce {
                 control: join,
                 init: empty_buffer,
                 reduct: write,
             };
             let reduce = edit.add_node(reduce);
+            edit = edit.add_schedule(reduce, Schedule::ParallelReduce)?;
 
             // Fix write node
             edit = edit.replace_all_uses_where(NodeID::new(0), reduce, |usee| *usee == write)?;
@@ -449,12 +448,11 @@ where
         Ok(edit)
     });
 
-    if edit_result == false {
-        todo!();
-        return None;
+    if edit_result {
+        Some((fork, new_fork_id))
+    } else {
+        None
     }
-
-    Some((fork, new_fork_id))
 }
 
 /** Split a 1D fork into a separate fork for each reduction. */
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index a765726f..886ab13b 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -97,8 +97,8 @@ fn test7(input: i32) -> i32 {
 
 #[entry]
 fn test8(input : i32) -> i32[8] {
-  let arr : i32[8];
-  let out : i32[8];
+  @const1 let arr : i32[8];
+  @const2 let out : i32[8];
 
   for i = 0 to 8 {
     arr[i] = i as i32;
@@ -116,4 +116,4 @@ fn test8(input : i32) -> i32[8] {
     out[i] = d;
   }
   return out;
-}
\ No newline at end of file
+}
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index bd1fd1d8..159fac94 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -5,18 +5,21 @@ no-memset(test3@const1);
 no-memset(test3@const2);
 no-memset(test3@const3);
 no-memset(test6@const);
+no-memset(test8@const1);
+no-memset(test8@const2);
 
 gvn(*);
 phi-elim(*);
 dce(*);
 
-let auto = auto-outline(test1, test2, test3, test4, test5, test7);
+let auto = auto-outline(test1, test2, test3, test4, test5, test7, test8);
 gpu(auto.test1);
 gpu(auto.test2);
 gpu(auto.test3);
 gpu(auto.test4);
 gpu(auto.test5);
 gpu(auto.test7);
+gpu(auto.test8);
 
 ip-sroa(*);
 sroa(*);
@@ -50,6 +53,14 @@ dce(auto.test7);
 simplify-cfg(auto.test7);
 dce(auto.test7);
 
+let fission = fork-fission-bufferize[test8@loop, test8@bufferize1](auto.test8);
+xdot[true](*);
+dce(auto.test8);
+ccp(auto.test8);
+dce(auto.test8);
+simplify-cfg(auto.test8);
+dce(auto.test8);
+
 ip-sroa(*);
 sroa(*);
 dce(*);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 58ff399f..20825c54 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -2191,9 +2191,6 @@ fn run_pass(
             result = Value::Record {
                 fields: new_fork_joins,
             };
-
-            println!("result: {:?}", result);
-
         }
         Pass::ForkDimMerge => {
             assert!(args.is_empty());
-- 
GitLab