Skip to content
Snippets Groups Projects

Fix parallel code gen in RT backend

Merged rarbore2 requested to merge fix_task_parallelism into main
+ 36
32
@@ -138,6 +138,7 @@ struct RTContext<'a> {
@@ -138,6 +138,7 @@ struct RTContext<'a> {
#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone, Default)]
struct RustBlock {
struct RustBlock {
prologue: String,
prologue: String,
 
arc_clones: String,
data: String,
data: String,
phi_tmp_assignments: String,
phi_tmp_assignments: String,
phi_assignments: String,
phi_assignments: String,
@@ -268,8 +269,9 @@ impl<'a> RTContext<'a> {
@@ -268,8 +269,9 @@ impl<'a> RTContext<'a> {
} else {
} else {
write!(
write!(
w,
w,
"{}{}{}{}{}",
"{}{}{}{}{}{}",
block.prologue,
block.prologue,
 
block.arc_clones,
block.data,
block.data,
block.phi_tmp_assignments,
block.phi_tmp_assignments,
block.phi_assignments,
block.phi_assignments,
@@ -416,6 +418,14 @@ impl<'a> RTContext<'a> {
@@ -416,6 +418,14 @@ impl<'a> RTContext<'a> {
let succ = self.control_subgraph.succs(id).next().unwrap();
let succ = self.control_subgraph.succs(id).next().unwrap();
write!(epilogue, "{} => {{", id.idx())?;
write!(epilogue, "{} => {{", id.idx())?;
 
// Await the empty futures for the fork-joins, waiting for them
 
// to complete.
 
write!(
 
epilogue,
 
"for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire);",
 
fork.idx(),
 
)?;
 
// Emit the assignments to the reduce variables in the
// Emit the assignments to the reduce variables in the
// surrounding context. It's very unfortunate that we have to do
// surrounding context. It's very unfortunate that we have to do
// it while lowering the join node (rather than the reduce nodes
// it while lowering the join node (rather than the reduce nodes
@@ -558,12 +568,14 @@ impl<'a> RTContext<'a> {
@@ -558,12 +568,14 @@ impl<'a> RTContext<'a> {
assert_eq!(control, bb);
assert_eq!(control, bb);
// The device backends ensure that device functions have the
// The device backends ensure that device functions have the
// same interface as AsyncRust functions.
// same interface as AsyncRust functions.
let block = &mut blocks.get_mut(&bb).unwrap().data;
let block = &mut blocks.get_mut(&bb).unwrap();
 
let arc_clones = &mut block.arc_clones;
 
let block = &mut block.data;
let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
if is_async {
if is_async {
for arg in args {
for arg in args {
if let Some(arc) = self.clone_arc(*arg, false) {
if let Some(arc) = self.clone_arc(*arg) {
write!(block, "{}", arc)?;
write!(arc_clones, "{}", arc)?;
}
}
}
}
}
}
@@ -572,17 +584,19 @@ impl<'a> RTContext<'a> {
@@ -572,17 +584,19 @@ impl<'a> RTContext<'a> {
(Device::AsyncRust, false) | (_, false) => {
(Device::AsyncRust, false) | (_, false) => {
format!("{} = ", self.get_value(id, bb, true))
format!("{} = ", self.get_value(id, bb, true))
}
}
(_, true) => format!(
(_, true) => {
"{}::async_std::task::spawn(async move {{ async_call_sender_{}.send(",
write!(arc_clones, "{}", self.clone_arc(id).unwrap())?;
self.clone_arc(id, true).unwrap(),
format!(
id.idx()
"*async_call_{}.lock().await = Some(::async_std::task::spawn(async move {{ ",
),
id.idx(),
 
)
 
}
};
};
let postfix = match (device, is_async) {
let postfix = match (device, is_async) {
(Device::AsyncRust, false) => ".await",
(Device::AsyncRust, false) => ".await",
(_, false) => "",
(_, false) => "",
(Device::AsyncRust, true) => ".await).await})",
(Device::AsyncRust, true) => ".await}))",
(_, true) => ").await})",
(_, true) => "}))",
};
};
write!(
write!(
block,
block,
@@ -1079,14 +1093,14 @@ impl<'a> RTContext<'a> {
@@ -1079,14 +1093,14 @@ impl<'a> RTContext<'a> {
}
}
// If the node is a call with an AsyncCall schedule, it should be
// If the node is a call with an AsyncCall schedule, it should be
// lowered to a channel.
// lowered to a Arc<Mutex<>> over the future.
let is_async_call =
let is_async_call =
func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
if is_async_call {
if is_async_call {
write!(
write!(
w,
w,
"let mut async_call_channel_{} = ::async_std::channel::bounded(1);let async_call_sender_{} = ::std::sync::Arc::new(async_call_channel_{}.0);let async_call_receiver_{} = ::std::sync::Arc::new(async_call_channel_{}.1);",
"let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(None));",
idx, idx, idx, idx, idx
idx,
)?;
)?;
} else {
} else {
write!(
write!(
@@ -1353,40 +1367,30 @@ impl<'a> RTContext<'a> {
@@ -1353,40 +1367,30 @@ impl<'a> RTContext<'a> {
fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
let func = self.get_func();
let func = self.get_func();
if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
&& control == bb
&& (control == bb
 
|| !self.nodes_in_fork_joins[&self.join_fork_map[&control]].contains(&bb))
{
{
format!("reduce_{}", id.idx())
format!("reduce_{}", id.idx())
} else if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
&& let fork = self.join_fork_map[&control]
&& !self.nodes_in_fork_joins[&fork].contains(&bb)
{
// Before using the value of a reduction outside the fork-join,
// await the futures.
format!(
"{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}",
fork.idx(),
id.idx()
)
} else if func.nodes[id.idx()].is_call()
} else if func.nodes[id.idx()].is_call()
&& func.schedules[id.idx()].contains(&Schedule::AsyncCall)
&& func.schedules[id.idx()].contains(&Schedule::AsyncCall)
{
{
assert!(!lhs);
assert!(!lhs);
format!("async_call_receiver_{}.recv().await.unwrap()", id.idx(),)
format!(
 
"async_call_{}.lock().await.as_mut().unwrap().await",
 
id.idx(),
 
)
} else {
} else {
format!("node_{}", id.idx())
format!("node_{}", id.idx())
}
}
}
}
fn clone_arc(&self, id: NodeID, lhs: bool) -> Option<String> {
fn clone_arc(&self, id: NodeID) -> Option<String> {
let func = self.get_func();
let func = self.get_func();
if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
{
{
let kind = if lhs { "sender" } else { "receiver" };
Some(format!(
Some(format!(
"let async_call_{}_{} = async_call_{}_{}.clone();",
"let async_call_{} = async_call_{}.clone();",
kind,
id.idx(),
id.idx(),
kind,
id.idx()
id.idx()
))
))
} else {
} else {
Loading