From 8734113e254ac8764a70237f0ce09293b2d62bfb Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 19 Feb 2025 10:43:42 -0600 Subject: [PATCH] Use channels to send values between async calls, wrap in arcs to share --- hercules_cg/src/rt.rs | 58 ++++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 5edddd86..bd152074 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -560,23 +560,33 @@ impl<'a> RTContext<'a> { // same interface as AsyncRust functions. let block = &mut blocks.get_mut(&bb).unwrap().data; let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall); + if is_async { + for arg in args { + if let Some(arc) = self.clone_arc(*arg, false) { + write!(block, "{}", arc)?; + } + } + } let device = self.devices[callee_id.idx()]; let prefix = match (device, is_async) { - (Device::AsyncRust, false) => "", - (_, false) => "", - (Device::AsyncRust, true) => "Some(::async_std::task::spawn(", - (_, true) => "Some(::async_std::task::spawn(async move {", + (Device::AsyncRust, false) | (_, false) => { + format!("{} = ", self.get_value(id, bb, true)) + } + (_, true) => format!( + "{}::async_std::task::spawn(async move {{ async_call_sender_{}.send(", + self.clone_arc(id, true).unwrap(), + id.idx() + ), }; let postfix = match (device, is_async) { (Device::AsyncRust, false) => ".await", (_, false) => "", - (Device::AsyncRust, true) => "))", - (_, true) => "}))", + (Device::AsyncRust, true) => ".await).await})", + (_, true) => ").await})", }; write!( block, - "{} = {}{}(", - self.get_value(id, bb, true), + "{}{}(", prefix, self.module.functions[callee_id.idx()].name )?; @@ -1069,11 +1079,15 @@ impl<'a> RTContext<'a> { } // If the node is a call with an AsyncCall schedule, it should be - // spawned as a task and awaited later. + // lowered to a channel. let is_async_call = func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall); if is_async_call { - write!(w, "let mut async_call_{} = None;", idx)?; + write!( + w, + "let mut async_call_channel_{} = ::async_std::channel::bounded(1);let async_call_sender_{} = ::std::sync::Arc::new(async_call_channel_{}.0);let async_call_receiver_{} = ::std::sync::Arc::new(async_call_channel_{}.1);", + idx, idx, idx, idx, idx + )?; } else { write!( w, @@ -1356,16 +1370,30 @@ impl<'a> RTContext<'a> { } else if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall) { - format!( - "async_call_{}{}", - id.idx(), - if lhs { "" } else { ".unwrap().await" } - ) + assert!(!lhs); + format!("async_call_receiver_{}.recv().await.unwrap()", id.idx(),) } else { format!("node_{}", id.idx()) } } + fn clone_arc(&self, id: NodeID, lhs: bool) -> Option<String> { + let func = self.get_func(); + if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall) + { + let kind = if lhs { "sender" } else { "receiver" }; + Some(format!( + "let async_call_{}_{} = async_call_{}_{}.clone();", + kind, + id.idx(), + kind, + id.idx() + )) + } else { + None + } + } + fn get_type(&self, id: TypeID) -> &'static str { convert_type(&self.module.types[id.idx()]) } -- GitLab