From 4d51c830e46d188f44e3816fe27824bc81f579af Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 19 Feb 2025 16:09:12 -0600 Subject: [PATCH] Fix task parallelism --- hercules_cg/src/rt.rs | 50 +++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index bd152074..88e8487c 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -416,6 +416,14 @@ impl<'a> RTContext<'a> { let succ = self.control_subgraph.succs(id).next().unwrap(); write!(epilogue, "{} => {{", id.idx())?; + // Await the empty futures for the fork-joins, waiting for them + // to complete. + write!( + epilogue, + "for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire);", + fork.idx(), + )?; + // Emit the assignments to the reduce variables in the // surrounding context. It's very unfortunate that we have to do // it while lowering the join node (rather than the reduce nodes @@ -562,7 +570,7 @@ impl<'a> RTContext<'a> { let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall); if is_async { for arg in args { - if let Some(arc) = self.clone_arc(*arg, false) { + if let Some(arc) = self.clone_arc(*arg) { write!(block, "{}", arc)?; } } @@ -573,16 +581,16 @@ impl<'a> RTContext<'a> { format!("{} = ", self.get_value(id, bb, true)) } (_, true) => format!( - "{}::async_std::task::spawn(async move {{ async_call_sender_{}.send(", - self.clone_arc(id, true).unwrap(), - id.idx() + "{}*async_call_{}.lock().await = Some(::async_std::task::spawn(async move {{ ", + self.clone_arc(id).unwrap(), + id.idx(), ), }; let postfix = match (device, is_async) { (Device::AsyncRust, false) => ".await", (_, false) => "", - (Device::AsyncRust, true) => ".await).await})", - (_, true) => ").await})", + (Device::AsyncRust, true) => ".await}))", + (_, true) => "}))", }; write!( block, @@ -1079,14 +1087,14 @@ impl<'a> RTContext<'a> { } // If the node is a call with an AsyncCall schedule, it should be - // lowered to a channel. + // lowered to a Arc<Mutex<>> over the future. let is_async_call = func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall); if is_async_call { write!( w, - "let mut async_call_channel_{} = ::async_std::channel::bounded(1);let async_call_sender_{} = ::std::sync::Arc::new(async_call_channel_{}.0);let async_call_receiver_{} = ::std::sync::Arc::new(async_call_channel_{}.1);", - idx, idx, idx, idx, idx + "let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(None));", + idx, )?; } else { write!( @@ -1353,40 +1361,30 @@ impl<'a> RTContext<'a> { fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String { let func = self.get_func(); if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() - && control == bb - { - format!("reduce_{}", id.idx()) - } else if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() && let fork = self.join_fork_map[&control] && !self.nodes_in_fork_joins[&fork].contains(&bb) { - // Before using the value of a reduction outside the fork-join, - // await the futures. - format!( - "{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}", - fork.idx(), - id.idx() - ) + format!("reduce_{}", id.idx()) } else if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall) { assert!(!lhs); - format!("async_call_receiver_{}.recv().await.unwrap()", id.idx(),) + format!( + "async_call_{}.lock().await.as_mut().unwrap().await", + id.idx(), + ) } else { format!("node_{}", id.idx()) } } - fn clone_arc(&self, id: NodeID, lhs: bool) -> Option<String> { + fn clone_arc(&self, id: NodeID) -> Option<String> { let func = self.get_func(); if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall) { - let kind = if lhs { "sender" } else { "receiver" }; Some(format!( - "let async_call_{}_{} = async_call_{}_{}.clone();", - kind, + "let async_call_{} = async_call_{}.clone();", id.idx(), - kind, id.idx() )) } else { -- GitLab