From 8734113e254ac8764a70237f0ce09293b2d62bfb Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 19 Feb 2025 10:43:42 -0600
Subject: [PATCH] Use channels to send values between async calls, wrap in arcs
 to share

---
 hercules_cg/src/rt.rs | 58 ++++++++++++++++++++++++++++++++-----------
 1 file changed, 43 insertions(+), 15 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 5edddd86..bd152074 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -560,23 +560,33 @@ impl<'a> RTContext<'a> {
                 // same interface as AsyncRust functions.
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
+                if is_async {
+                    for arg in args {
+                        if let Some(arc) = self.clone_arc(*arg, false) {
+                            write!(block, "{}", arc)?;
+                        }
+                    }
+                }
                 let device = self.devices[callee_id.idx()];
                 let prefix = match (device, is_async) {
-                    (Device::AsyncRust, false) => "",
-                    (_, false) => "",
-                    (Device::AsyncRust, true) => "Some(::async_std::task::spawn(",
-                    (_, true) => "Some(::async_std::task::spawn(async move {",
+                    (Device::AsyncRust, false) | (_, false) => {
+                        format!("{} = ", self.get_value(id, bb, true))
+                    }
+                    (_, true) => format!(
+                        "{}::async_std::task::spawn(async move {{ async_call_sender_{}.send(",
+                        self.clone_arc(id, true).unwrap(),
+                        id.idx()
+                    ),
                 };
                 let postfix = match (device, is_async) {
                     (Device::AsyncRust, false) => ".await",
                     (_, false) => "",
-                    (Device::AsyncRust, true) => "))",
-                    (_, true) => "}))",
+                    (Device::AsyncRust, true) => ".await).await})",
+                    (_, true) => ").await})",
                 };
                 write!(
                     block,
-                    "{} = {}{}(",
-                    self.get_value(id, bb, true),
+                    "{}{}(",
                     prefix,
                     self.module.functions[callee_id.idx()].name
                 )?;
@@ -1069,11 +1079,15 @@ impl<'a> RTContext<'a> {
             }
 
             // If the node is a call with an AsyncCall schedule, it should be
-            // spawned as a task and awaited later.
+            // lowered to a channel.
             let is_async_call =
                 func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
             if is_async_call {
-                write!(w, "let mut async_call_{} = None;", idx)?;
+                write!(
+                    w,
+                    "let mut async_call_channel_{} = ::async_std::channel::bounded(1);let async_call_sender_{} = ::std::sync::Arc::new(async_call_channel_{}.0);let async_call_receiver_{} = ::std::sync::Arc::new(async_call_channel_{}.1);",
+                    idx, idx, idx, idx, idx
+                )?;
             } else {
                 write!(
                     w,
@@ -1356,16 +1370,30 @@ impl<'a> RTContext<'a> {
         } else if func.nodes[id.idx()].is_call()
             && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
         {
-            format!(
-                "async_call_{}{}",
-                id.idx(),
-                if lhs { "" } else { ".unwrap().await" }
-            )
+            assert!(!lhs);
+            format!("async_call_receiver_{}.recv().await.unwrap()", id.idx(),)
         } else {
             format!("node_{}", id.idx())
         }
     }
 
+    fn clone_arc(&self, id: NodeID, lhs: bool) -> Option<String> {
+        let func = self.get_func();
+        if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
+        {
+            let kind = if lhs { "sender" } else { "receiver" };
+            Some(format!(
+                "let async_call_{}_{} = async_call_{}_{}.clone();",
+                kind,
+                id.idx(),
+                kind,
+                id.idx()
+            ))
+        } else {
+            None
+        }
+    }
+
     fn get_type(&self, id: TypeID) -> &'static str {
         convert_type(&self.module.types[id.idx()])
     }
-- 
GitLab