From 4d51c830e46d188f44e3816fe27824bc81f579af Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 19 Feb 2025 16:09:12 -0600
Subject: [PATCH 1/6] Fix task parallelism

---
 hercules_cg/src/rt.rs | 50 +++++++++++++++++++++----------------------
 1 file changed, 24 insertions(+), 26 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index bd152074..88e8487c 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -416,6 +416,14 @@ impl<'a> RTContext<'a> {
                 let succ = self.control_subgraph.succs(id).next().unwrap();
                 write!(epilogue, "{} => {{", id.idx())?;
 
+                // Await the empty futures for the fork-joins, waiting for them
+                // to complete.
+                write!(
+                    epilogue,
+                    "for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire);",
+                    fork.idx(),
+                )?;
+
                 // Emit the assignments to the reduce variables in the
                 // surrounding context. It's very unfortunate that we have to do
                 // it while lowering the join node (rather than the reduce nodes
@@ -562,7 +570,7 @@ impl<'a> RTContext<'a> {
                 let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
                 if is_async {
                     for arg in args {
-                        if let Some(arc) = self.clone_arc(*arg, false) {
+                        if let Some(arc) = self.clone_arc(*arg) {
                             write!(block, "{}", arc)?;
                         }
                     }
@@ -573,16 +581,16 @@ impl<'a> RTContext<'a> {
                         format!("{} = ", self.get_value(id, bb, true))
                     }
                     (_, true) => format!(
-                        "{}::async_std::task::spawn(async move {{ async_call_sender_{}.send(",
-                        self.clone_arc(id, true).unwrap(),
-                        id.idx()
+                        "{}*async_call_{}.lock().await = Some(::async_std::task::spawn(async move {{ ",
+                        self.clone_arc(id).unwrap(),
+                        id.idx(),
                     ),
                 };
                 let postfix = match (device, is_async) {
                     (Device::AsyncRust, false) => ".await",
                     (_, false) => "",
-                    (Device::AsyncRust, true) => ".await).await})",
-                    (_, true) => ").await})",
+                    (Device::AsyncRust, true) => ".await}))",
+                    (_, true) => "}))",
                 };
                 write!(
                     block,
@@ -1079,14 +1087,14 @@ impl<'a> RTContext<'a> {
             }
 
             // If the node is a call with an AsyncCall schedule, it should be
-            // lowered to a channel.
+            // lowered to a Arc<Mutex<>> over the future.
             let is_async_call =
                 func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
             if is_async_call {
                 write!(
                     w,
-                    "let mut async_call_channel_{} = ::async_std::channel::bounded(1);let async_call_sender_{} = ::std::sync::Arc::new(async_call_channel_{}.0);let async_call_receiver_{} = ::std::sync::Arc::new(async_call_channel_{}.1);",
-                    idx, idx, idx, idx, idx
+                    "let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(None));",
+                    idx,
                 )?;
             } else {
                 write!(
@@ -1353,40 +1361,30 @@ impl<'a> RTContext<'a> {
     fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
         let func = self.get_func();
         if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
-            && control == bb
-        {
-            format!("reduce_{}", id.idx())
-        } else if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
             && let fork = self.join_fork_map[&control]
             && !self.nodes_in_fork_joins[&fork].contains(&bb)
         {
-            // Before using the value of a reduction outside the fork-join,
-            // await the futures.
-            format!(
-                "{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}",
-                fork.idx(),
-                id.idx()
-            )
+            format!("reduce_{}", id.idx())
         } else if func.nodes[id.idx()].is_call()
             && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
         {
             assert!(!lhs);
-            format!("async_call_receiver_{}.recv().await.unwrap()", id.idx(),)
+            format!(
+                "async_call_{}.lock().await.as_mut().unwrap().await",
+                id.idx(),
+            )
         } else {
             format!("node_{}", id.idx())
         }
     }
 
-    fn clone_arc(&self, id: NodeID, lhs: bool) -> Option<String> {
+    fn clone_arc(&self, id: NodeID) -> Option<String> {
         let func = self.get_func();
         if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
         {
-            let kind = if lhs { "sender" } else { "receiver" };
             Some(format!(
-                "let async_call_{}_{} = async_call_{}_{}.clone();",
-                kind,
+                "let async_call_{} = async_call_{}.clone();",
                 id.idx(),
-                kind,
                 id.idx()
             ))
         } else {
-- 
GitLab


From 74df0c63781eb73525cccb68421367fe02bf02c9 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 19 Feb 2025 16:13:00 -0600
Subject: [PATCH 2/6] Fix emitting fork-joins in RT

---
 hercules_cg/src/rt.rs | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 88e8487c..313f7c10 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -1361,8 +1361,8 @@ impl<'a> RTContext<'a> {
     fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
         let func = self.get_func();
         if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
-            && let fork = self.join_fork_map[&control]
-            && !self.nodes_in_fork_joins[&fork].contains(&bb)
+            && (control == bb
+                || !self.nodes_in_fork_joins[&self.join_fork_map[&control]].contains(&bb))
         {
             format!("reduce_{}", id.idx())
         } else if func.nodes[id.idx()].is_call()
-- 
GitLab


From d9046f51a361fc0f070060a699573978c06d42c3 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 19 Feb 2025 16:23:43 -0600
Subject: [PATCH 3/6] Refactor arc clones into their own block

---
 hercules_cg/src/rt.rs | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 313f7c10..cbf70eec 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -138,6 +138,7 @@ struct RTContext<'a> {
 #[derive(Debug, Clone, Default)]
 struct RustBlock {
     prologue: String,
+    arc_clones: String,
     data: String,
     phi_tmp_assignments: String,
     phi_assignments: String,
@@ -268,8 +269,9 @@ impl<'a> RTContext<'a> {
             } else {
                 write!(
                     w,
-                    "{}{}{}{}{}",
+                    "{}{}{}{}{}{}",
                     block.prologue,
+                    block.arc_clones,
                     block.data,
                     block.phi_tmp_assignments,
                     block.phi_assignments,
@@ -566,12 +568,14 @@ impl<'a> RTContext<'a> {
                 assert_eq!(control, bb);
                 // The device backends ensure that device functions have the
                 // same interface as AsyncRust functions.
-                let block = &mut blocks.get_mut(&bb).unwrap().data;
+                let block = &mut blocks.get_mut(&bb).unwrap();
+                let arc_clones = &mut block.arc_clones;
+                let block = &mut block.data;
                 let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
                 if is_async {
                     for arg in args {
                         if let Some(arc) = self.clone_arc(*arg) {
-                            write!(block, "{}", arc)?;
+                            write!(arc_clones, "{}", arc)?;
                         }
                     }
                 }
@@ -580,11 +584,13 @@ impl<'a> RTContext<'a> {
                     (Device::AsyncRust, false) | (_, false) => {
                         format!("{} = ", self.get_value(id, bb, true))
                     }
-                    (_, true) => format!(
-                        "{}*async_call_{}.lock().await = Some(::async_std::task::spawn(async move {{ ",
-                        self.clone_arc(id).unwrap(),
-                        id.idx(),
-                    ),
+                    (_, true) => {
+                        write!(arc_clones, "{}", self.clone_arc(id).unwrap())?;
+                        format!(
+                            "*async_call_{}.lock().await = Some(::async_std::task::spawn(async move {{ ",
+                            id.idx(),
+                        )
+                    }
                 };
                 let postfix = match (device, is_async) {
                     (Device::AsyncRust, false) => ".await",
-- 
GitLab


From 6f997c9403d37b6c8df6029b7da2265f9ee7e945 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 19 Feb 2025 21:11:04 -0600
Subject: [PATCH 4/6] Rust futures are :):(

---
 hercules_cg/src/rt.rs                         | 19 +++++++++++++------
 juno_samples/fork_join_tests/src/cpu.sch      |  6 +++---
 .../fork_join_tests/src/fork_join_tests.jn    |  7 ++++++-
 3 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index cbf70eec..beb83f51 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -138,7 +138,6 @@ struct RTContext<'a> {
 #[derive(Debug, Clone, Default)]
 struct RustBlock {
     prologue: String,
-    arc_clones: String,
     data: String,
     phi_tmp_assignments: String,
     phi_assignments: String,
@@ -269,9 +268,8 @@ impl<'a> RTContext<'a> {
             } else {
                 write!(
                     w,
-                    "{}{}{}{}{}{}",
+                    "{}{}{}{}{}",
                     block.prologue,
-                    block.arc_clones,
                     block.data,
                     block.phi_tmp_assignments,
                     block.phi_assignments,
@@ -361,6 +359,16 @@ impl<'a> RTContext<'a> {
                     write!(prologue, " {{")?;
                 }
 
+                // Emit clones of arcs used inside the fork-join.
+                for other_id in (0..func.nodes.len()).map(NodeID::new) {
+                    if self.def_use.get_users(other_id).into_iter().any(|user_id| {
+                        self.nodes_in_fork_joins[&id].contains(&self.bbs.0[user_id.idx()])
+                    }) && let Some(arc) = self.clone_arc(other_id)
+                    {
+                        write!(prologue, "{}", arc)?;
+                    }
+                }
+
                 // Spawn an async closure and push its future to a Vec.
                 write!(
                     prologue,
@@ -569,13 +577,12 @@ impl<'a> RTContext<'a> {
                 // The device backends ensure that device functions have the
                 // same interface as AsyncRust functions.
                 let block = &mut blocks.get_mut(&bb).unwrap();
-                let arc_clones = &mut block.arc_clones;
                 let block = &mut block.data;
                 let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
                 if is_async {
                     for arg in args {
                         if let Some(arc) = self.clone_arc(*arg) {
-                            write!(arc_clones, "{}", arc)?;
+                            write!(block, "{}", arc)?;
                         }
                     }
                 }
@@ -585,7 +592,7 @@ impl<'a> RTContext<'a> {
                         format!("{} = ", self.get_value(id, bb, true))
                     }
                     (_, true) => {
-                        write!(arc_clones, "{}", self.clone_arc(id).unwrap())?;
+                        write!(block, "{}", self.clone_arc(id).unwrap())?;
                         format!(
                             "*async_call_{}.lock().await = Some(::async_std::task::spawn(async move {{ ",
                             id.idx(),
diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch
index c71aec11..76dcbdf6 100644
--- a/juno_samples/fork_join_tests/src/cpu.sch
+++ b/juno_samples/fork_join_tests/src/cpu.sch
@@ -1,5 +1,3 @@
-no-memset(test6@const);
-
 ccp(*);
 gvn(*);
 phi-elim(*);
@@ -63,7 +61,9 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-fork-tile[32, 0, false, true](test6@loop);
+async-call(test6@call);
+no-memset(test6@const);
+fork-tile[2, 0, false, false](test6@loop);
 let out = fork-split(test6@loop);
 let out = outline(out.test6.fj1);
 cpu(out);
diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn
index 51115f15..bfb5564b 100644
--- a/juno_samples/fork_join_tests/src/fork_join_tests.jn
+++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn
@@ -73,11 +73,16 @@ fn test5(input : i32) -> i32[4] {
   return arr1;
 }
 
+fn test6_helper(input: i32) -> i32 {
+  return input;
+}
+
 #[entry]
 fn test6(input: i32) -> i32[1024] {
+  @call let x = test6_helper(input);
   @const let arr : i32[1024];
   @loop for i = 0 to 1024 {
-    arr[i] = i as i32 + input;
+    arr[i] = i as i32 + x;
   }
   return arr;
 }
-- 
GitLab


From 875a17f032124202ee5fd542b6db560a55c87cf2 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 19 Feb 2025 21:22:12 -0600
Subject: [PATCH 5/6] Just add the type I want...

---
 hercules_cg/src/rt.rs  |  9 +++------
 hercules_rt/src/lib.rs | 42 +++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 44 insertions(+), 7 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index beb83f51..d3013239 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -594,7 +594,7 @@ impl<'a> RTContext<'a> {
                     (_, true) => {
                         write!(block, "{}", self.clone_arc(id).unwrap())?;
                         format!(
-                            "*async_call_{}.lock().await = Some(::async_std::task::spawn(async move {{ ",
+                            "*async_call_{}.lock().await = ::hercules_rt::__FutureSlotWrapper::new(::async_std::task::spawn(async move {{ ",
                             id.idx(),
                         )
                     }
@@ -1106,7 +1106,7 @@ impl<'a> RTContext<'a> {
             if is_async_call {
                 write!(
                     w,
-                    "let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(None));",
+                    "let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(::hercules_rt::__FutureSlotWrapper::empty()));",
                     idx,
                 )?;
             } else {
@@ -1382,10 +1382,7 @@ impl<'a> RTContext<'a> {
             && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
         {
             assert!(!lhs);
-            format!(
-                "async_call_{}.lock().await.as_mut().unwrap().await",
-                id.idx(),
-            )
+            format!("async_call_{}.lock().await.inspect().await", id.idx(),)
         } else {
             format!("node_{}", id.idx())
         }
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index 090a38a0..d19a0a5a 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -1,10 +1,10 @@
 #![feature(once_cell_try)]
 
 use std::alloc::{alloc, dealloc, GlobalAlloc, Layout, System};
+use std::future::Future;
 use std::marker::PhantomData;
 use std::ptr::{copy_nonoverlapping, write_bytes, NonNull};
 use std::slice::{from_raw_parts, from_raw_parts_mut};
-
 use std::sync::OnceLock;
 
 /*
@@ -426,6 +426,46 @@ impl __RawPtrSendSync {
 unsafe impl Send for __RawPtrSendSync {}
 unsafe impl Sync for __RawPtrSendSync {}
 
+#[derive(Clone, Debug)]
+pub struct __FutureSlotWrapper<T, F>
+where
+    T: Copy,
+    F: Future<Output = T>,
+{
+    future: Option<F>,
+    slot: Option<T>,
+}
+
+impl<T, F> __FutureSlotWrapper<T, F>
+where
+    T: Copy,
+    F: Future<Output = T>,
+{
+    pub fn empty() -> Self {
+        Self {
+            future: None,
+            slot: None,
+        }
+    }
+
+    pub fn new(f: F) -> Self {
+        Self {
+            future: Some(f),
+            slot: None,
+        }
+    }
+
+    pub async fn inspect(&mut self) -> T {
+        if let Some(slot) = self.slot {
+            slot
+        } else {
+            let result = self.future.take().unwrap().await;
+            self.slot = Some(result);
+            result
+        }
+    }
+}
+
 /*
  * A HerculesBox holds memory that can be on any device and provides a common interface to moving
  * data where it is needed.
-- 
GitLab


From c7aa94687b733c112aca66ca22cc09c035879d8a Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 19 Feb 2025 21:35:50 -0600
Subject: [PATCH 6/6] Fix inlining id functions

---
 hercules_opt/src/inline.rs               | 6 +++++-
 juno_samples/fork_join_tests/src/gpu.sch | 1 +
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs
index f01b2366..6e308274 100644
--- a/hercules_opt/src/inline.rs
+++ b/hercules_opt/src/inline.rs
@@ -210,7 +210,11 @@ fn inline_func(
             }
 
             // Finally, delete the call node.
-            edit = edit.replace_all_uses(id, old_id_to_new_id(called_return_data))?;
+            if let Node::Parameter { index } = called_func.nodes[called_return_data.idx()] {
+                edit = edit.replace_all_uses(id, args[index])?;
+            } else {
+                edit = edit.replace_all_uses(id, old_id_to_new_id(called_return_data))?;
+            }
             edit = edit.delete_node(control)?;
             edit = edit.delete_node(id)?;
 
diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch
index 91bd6c79..364673cd 100644
--- a/juno_samples/fork_join_tests/src/gpu.sch
+++ b/juno_samples/fork_join_tests/src/gpu.sch
@@ -52,6 +52,7 @@ slf(auto.test2);
 infer-schedules(auto.test2);
 fork-interchange[0, 1](auto.test2);
 
+inline(test6);
 fork-tile[32, 0, false, true](test6@loop);
 let out = fork-split(test6@loop);
 let out = auto-outline(test6);
-- 
GitLab