Skip to content
Snippets Groups Projects
Commit 4d51c830 authored by Russel Arbore's avatar Russel Arbore
Browse files

Fix task parallelism

parent 7ab06905
No related branches found
No related tags found
1 merge request!193Fix parallel code gen in RT backend
Pipeline #201784 failed
...@@ -416,6 +416,14 @@ impl<'a> RTContext<'a> { ...@@ -416,6 +416,14 @@ impl<'a> RTContext<'a> {
let succ = self.control_subgraph.succs(id).next().unwrap(); let succ = self.control_subgraph.succs(id).next().unwrap();
write!(epilogue, "{} => {{", id.idx())?; write!(epilogue, "{} => {{", id.idx())?;
// Await the empty futures for the fork-joins, waiting for them
// to complete.
write!(
epilogue,
"for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire);",
fork.idx(),
)?;
// Emit the assignments to the reduce variables in the // Emit the assignments to the reduce variables in the
// surrounding context. It's very unfortunate that we have to do // surrounding context. It's very unfortunate that we have to do
// it while lowering the join node (rather than the reduce nodes // it while lowering the join node (rather than the reduce nodes
...@@ -562,7 +570,7 @@ impl<'a> RTContext<'a> { ...@@ -562,7 +570,7 @@ impl<'a> RTContext<'a> {
let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall); let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
if is_async { if is_async {
for arg in args { for arg in args {
if let Some(arc) = self.clone_arc(*arg, false) { if let Some(arc) = self.clone_arc(*arg) {
write!(block, "{}", arc)?; write!(block, "{}", arc)?;
} }
} }
...@@ -573,16 +581,16 @@ impl<'a> RTContext<'a> { ...@@ -573,16 +581,16 @@ impl<'a> RTContext<'a> {
format!("{} = ", self.get_value(id, bb, true)) format!("{} = ", self.get_value(id, bb, true))
} }
(_, true) => format!( (_, true) => format!(
"{}::async_std::task::spawn(async move {{ async_call_sender_{}.send(", "{}*async_call_{}.lock().await = Some(::async_std::task::spawn(async move {{ ",
self.clone_arc(id, true).unwrap(), self.clone_arc(id).unwrap(),
id.idx() id.idx(),
), ),
}; };
let postfix = match (device, is_async) { let postfix = match (device, is_async) {
(Device::AsyncRust, false) => ".await", (Device::AsyncRust, false) => ".await",
(_, false) => "", (_, false) => "",
(Device::AsyncRust, true) => ".await).await})", (Device::AsyncRust, true) => ".await}))",
(_, true) => ").await})", (_, true) => "}))",
}; };
write!( write!(
block, block,
...@@ -1079,14 +1087,14 @@ impl<'a> RTContext<'a> { ...@@ -1079,14 +1087,14 @@ impl<'a> RTContext<'a> {
} }
// If the node is a call with an AsyncCall schedule, it should be // If the node is a call with an AsyncCall schedule, it should be
// lowered to a channel. // lowered to a Arc<Mutex<>> over the future.
let is_async_call = let is_async_call =
func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall); func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
if is_async_call { if is_async_call {
write!( write!(
w, w,
"let mut async_call_channel_{} = ::async_std::channel::bounded(1);let async_call_sender_{} = ::std::sync::Arc::new(async_call_channel_{}.0);let async_call_receiver_{} = ::std::sync::Arc::new(async_call_channel_{}.1);", "let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(None));",
idx, idx, idx, idx, idx idx,
)?; )?;
} else { } else {
write!( write!(
...@@ -1353,40 +1361,30 @@ impl<'a> RTContext<'a> { ...@@ -1353,40 +1361,30 @@ impl<'a> RTContext<'a> {
fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String { fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
let func = self.get_func(); let func = self.get_func();
if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
&& control == bb
{
format!("reduce_{}", id.idx())
} else if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
&& let fork = self.join_fork_map[&control] && let fork = self.join_fork_map[&control]
&& !self.nodes_in_fork_joins[&fork].contains(&bb) && !self.nodes_in_fork_joins[&fork].contains(&bb)
{ {
// Before using the value of a reduction outside the fork-join, format!("reduce_{}", id.idx())
// await the futures.
format!(
"{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}",
fork.idx(),
id.idx()
)
} else if func.nodes[id.idx()].is_call() } else if func.nodes[id.idx()].is_call()
&& func.schedules[id.idx()].contains(&Schedule::AsyncCall) && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
{ {
assert!(!lhs); assert!(!lhs);
format!("async_call_receiver_{}.recv().await.unwrap()", id.idx(),) format!(
"async_call_{}.lock().await.as_mut().unwrap().await",
id.idx(),
)
} else { } else {
format!("node_{}", id.idx()) format!("node_{}", id.idx())
} }
} }
fn clone_arc(&self, id: NodeID, lhs: bool) -> Option<String> { fn clone_arc(&self, id: NodeID) -> Option<String> {
let func = self.get_func(); let func = self.get_func();
if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall) if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
{ {
let kind = if lhs { "sender" } else { "receiver" };
Some(format!( Some(format!(
"let async_call_{}_{} = async_call_{}_{}.clone();", "let async_call_{} = async_call_{}.clone();",
kind,
id.idx(), id.idx(),
kind,
id.idx() id.idx()
)) ))
} else { } else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment