Skip to content
Snippets Groups Projects

Fix parallel code gen in RT backend

Merged rarbore2 requested to merge fix_task_parallelism into main
3 files
+ 22
10
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 38
30
@@ -359,6 +359,16 @@ impl<'a> RTContext<'a> {
write!(prologue, " {{")?;
}
// Emit clones of arcs used inside the fork-join.
for other_id in (0..func.nodes.len()).map(NodeID::new) {
if self.def_use.get_users(other_id).into_iter().any(|user_id| {
self.nodes_in_fork_joins[&id].contains(&self.bbs.0[user_id.idx()])
}) && let Some(arc) = self.clone_arc(other_id)
{
write!(prologue, "{}", arc)?;
}
}
// Spawn an async closure and push its future to a Vec.
write!(
prologue,
@@ -416,6 +426,14 @@ impl<'a> RTContext<'a> {
let succ = self.control_subgraph.succs(id).next().unwrap();
write!(epilogue, "{} => {{", id.idx())?;
// Await the empty futures for the fork-joins, waiting for them
// to complete.
write!(
epilogue,
"for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire);",
fork.idx(),
)?;
// Emit the assignments to the reduce variables in the
// surrounding context. It's very unfortunate that we have to do
// it while lowering the join node (rather than the reduce nodes
@@ -558,11 +576,12 @@ impl<'a> RTContext<'a> {
assert_eq!(control, bb);
// The device backends ensure that device functions have the
// same interface as AsyncRust functions.
let block = &mut blocks.get_mut(&bb).unwrap().data;
let block = &mut blocks.get_mut(&bb).unwrap();
let block = &mut block.data;
let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
if is_async {
for arg in args {
if let Some(arc) = self.clone_arc(*arg, false) {
if let Some(arc) = self.clone_arc(*arg) {
write!(block, "{}", arc)?;
}
}
@@ -572,17 +591,19 @@ impl<'a> RTContext<'a> {
(Device::AsyncRust, false) | (_, false) => {
format!("{} = ", self.get_value(id, bb, true))
}
(_, true) => format!(
"{}::async_std::task::spawn(async move {{ async_call_sender_{}.send(",
self.clone_arc(id, true).unwrap(),
id.idx()
),
(_, true) => {
write!(block, "{}", self.clone_arc(id).unwrap())?;
format!(
"*async_call_{}.lock().await = ::hercules_rt::__FutureSlotWrapper::new(::async_std::task::spawn(async move {{ ",
id.idx(),
)
}
};
let postfix = match (device, is_async) {
(Device::AsyncRust, false) => ".await",
(_, false) => "",
(Device::AsyncRust, true) => ".await).await})",
(_, true) => ").await})",
(Device::AsyncRust, true) => ".await}))",
(_, true) => "}))",
};
write!(
block,
@@ -1079,14 +1100,14 @@ impl<'a> RTContext<'a> {
}
// If the node is a call with an AsyncCall schedule, it should be
// lowered to a channel.
// lowered to a Arc<Mutex<>> over the future.
let is_async_call =
func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
if is_async_call {
write!(
w,
"let mut async_call_channel_{} = ::async_std::channel::bounded(1);let async_call_sender_{} = ::std::sync::Arc::new(async_call_channel_{}.0);let async_call_receiver_{} = ::std::sync::Arc::new(async_call_channel_{}.1);",
idx, idx, idx, idx, idx
"let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(::hercules_rt::__FutureSlotWrapper::empty()));",
idx,
)?;
} else {
write!(
@@ -1353,40 +1374,27 @@ impl<'a> RTContext<'a> {
fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
let func = self.get_func();
if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
&& control == bb
&& (control == bb
|| !self.nodes_in_fork_joins[&self.join_fork_map[&control]].contains(&bb))
{
format!("reduce_{}", id.idx())
} else if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
&& let fork = self.join_fork_map[&control]
&& !self.nodes_in_fork_joins[&fork].contains(&bb)
{
// Before using the value of a reduction outside the fork-join,
// await the futures.
format!(
"{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}",
fork.idx(),
id.idx()
)
} else if func.nodes[id.idx()].is_call()
&& func.schedules[id.idx()].contains(&Schedule::AsyncCall)
{
assert!(!lhs);
format!("async_call_receiver_{}.recv().await.unwrap()", id.idx(),)
format!("async_call_{}.lock().await.inspect().await", id.idx(),)
} else {
format!("node_{}", id.idx())
}
}
fn clone_arc(&self, id: NodeID, lhs: bool) -> Option<String> {
fn clone_arc(&self, id: NodeID) -> Option<String> {
let func = self.get_func();
if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
{
let kind = if lhs { "sender" } else { "receiver" };
Some(format!(
"let async_call_{}_{} = async_call_{}_{}.clone();",
kind,
"let async_call_{} = async_call_{}.clone();",
id.idx(),
kind,
id.idx()
))
} else {
Loading