From 4d51c830e46d188f44e3816fe27824bc81f579af Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 19 Feb 2025 16:09:12 -0600
Subject: [PATCH] Fix task parallelism

---
 hercules_cg/src/rt.rs | 50 +++++++++++++++++++++----------------------
 1 file changed, 24 insertions(+), 26 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index bd152074..88e8487c 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -416,6 +416,14 @@ impl<'a> RTContext<'a> {
                 let succ = self.control_subgraph.succs(id).next().unwrap();
                 write!(epilogue, "{} => {{", id.idx())?;
 
+                // Await the empty futures for the fork-joins, waiting for them
+                // to complete.
+                write!(
+                    epilogue,
+                    "for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire);",
+                    fork.idx(),
+                )?;
+
                 // Emit the assignments to the reduce variables in the
                 // surrounding context. It's very unfortunate that we have to do
                 // it while lowering the join node (rather than the reduce nodes
@@ -562,7 +570,7 @@ impl<'a> RTContext<'a> {
                 let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
                 if is_async {
                     for arg in args {
-                        if let Some(arc) = self.clone_arc(*arg, false) {
+                        if let Some(arc) = self.clone_arc(*arg) {
                             write!(block, "{}", arc)?;
                         }
                     }
@@ -573,16 +581,16 @@ impl<'a> RTContext<'a> {
                         format!("{} = ", self.get_value(id, bb, true))
                     }
                     (_, true) => format!(
-                        "{}::async_std::task::spawn(async move {{ async_call_sender_{}.send(",
-                        self.clone_arc(id, true).unwrap(),
-                        id.idx()
+                        "{}*async_call_{}.lock().await = Some(::async_std::task::spawn(async move {{ ",
+                        self.clone_arc(id).unwrap(),
+                        id.idx(),
                     ),
                 };
                 let postfix = match (device, is_async) {
                     (Device::AsyncRust, false) => ".await",
                     (_, false) => "",
-                    (Device::AsyncRust, true) => ".await).await})",
-                    (_, true) => ").await})",
+                    (Device::AsyncRust, true) => ".await}))",
+                    (_, true) => "}))",
                 };
                 write!(
                     block,
@@ -1079,14 +1087,14 @@ impl<'a> RTContext<'a> {
             }
 
             // If the node is a call with an AsyncCall schedule, it should be
-            // lowered to a channel.
+            // lowered to a Arc<Mutex<>> over the future.
             let is_async_call =
                 func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
             if is_async_call {
                 write!(
                     w,
-                    "let mut async_call_channel_{} = ::async_std::channel::bounded(1);let async_call_sender_{} = ::std::sync::Arc::new(async_call_channel_{}.0);let async_call_receiver_{} = ::std::sync::Arc::new(async_call_channel_{}.1);",
-                    idx, idx, idx, idx, idx
+                    "let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(None));",
+                    idx,
                 )?;
             } else {
                 write!(
@@ -1353,40 +1361,30 @@ impl<'a> RTContext<'a> {
     fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
         let func = self.get_func();
         if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
-            && control == bb
-        {
-            format!("reduce_{}", id.idx())
-        } else if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
             && let fork = self.join_fork_map[&control]
             && !self.nodes_in_fork_joins[&fork].contains(&bb)
         {
-            // Before using the value of a reduction outside the fork-join,
-            // await the futures.
-            format!(
-                "{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}",
-                fork.idx(),
-                id.idx()
-            )
+            format!("reduce_{}", id.idx())
         } else if func.nodes[id.idx()].is_call()
             && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
         {
             assert!(!lhs);
-            format!("async_call_receiver_{}.recv().await.unwrap()", id.idx(),)
+            format!(
+                "async_call_{}.lock().await.as_mut().unwrap().await",
+                id.idx(),
+            )
         } else {
             format!("node_{}", id.idx())
         }
     }
 
-    fn clone_arc(&self, id: NodeID, lhs: bool) -> Option<String> {
+    fn clone_arc(&self, id: NodeID) -> Option<String> {
         let func = self.get_func();
         if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
         {
-            let kind = if lhs { "sender" } else { "receiver" };
             Some(format!(
-                "let async_call_{}_{} = async_call_{}_{}.clone();",
-                kind,
+                "let async_call_{} = async_call_{}.clone();",
                 id.idx(),
-                kind,
                 id.idx()
             ))
         } else {
-- 
GitLab