Skip to content
Snippets Groups Projects
Commit 2878e698 authored by Russel Arbore's avatar Russel Arbore
Browse files

Lower AsyncCall to task parallel code

parent a43251a4
No related branches found
No related tags found
1 merge request!164More flexible task parallelism
Pipeline #201567 passed
...@@ -300,7 +300,7 @@ impl<'a> RTContext<'a> { ...@@ -300,7 +300,7 @@ impl<'a> RTContext<'a> {
write!( write!(
epilogue, epilogue,
"control_token = if {} {{{}}} else {{{}}};}}", "control_token = if {} {{{}}} else {{{}}};}}",
self.get_value(cond, id), self.get_value(cond, id, false),
if succ1_is_true { succ1 } else { succ2 }.idx(), if succ1_is_true { succ1 } else { succ2 }.idx(),
if succ1_is_true { succ2 } else { succ1 }.idx(), if succ1_is_true { succ2 } else { succ1 }.idx(),
)?; )?;
...@@ -309,7 +309,7 @@ impl<'a> RTContext<'a> { ...@@ -309,7 +309,7 @@ impl<'a> RTContext<'a> {
let prologue = &mut blocks.get_mut(&id).unwrap().prologue; let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
write!(prologue, "{} => {{", id.idx())?; write!(prologue, "{} => {{", id.idx())?;
let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
write!(epilogue, "return {};}}", self.get_value(data, id))?; write!(epilogue, "return {};}}", self.get_value(data, id, false))?;
} }
// Fork nodes open a new environment for defining an async closure. // Fork nodes open a new environment for defining an async closure.
Node::Fork { Node::Fork {
...@@ -399,8 +399,8 @@ impl<'a> RTContext<'a> { ...@@ -399,8 +399,8 @@ impl<'a> RTContext<'a> {
write!( write!(
epilogue, epilogue,
"{} = {};", "{} = {};",
self.get_value(*user, id), self.get_value(*user, id, true),
self.get_value(init, id) self.get_value(init, id, false)
)?; )?;
} }
} }
...@@ -427,11 +427,11 @@ impl<'a> RTContext<'a> { ...@@ -427,11 +427,11 @@ impl<'a> RTContext<'a> {
match func.nodes[id.idx()] { match func.nodes[id.idx()] {
Node::Parameter { index } => { Node::Parameter { index } => {
let block = &mut blocks.get_mut(&bb).unwrap().data; let block = &mut blocks.get_mut(&bb).unwrap().data;
write!(block, "{} = p{};", self.get_value(id, bb), index)? write!(block, "{} = p{};", self.get_value(id, bb, true), index)?
} }
Node::Constant { id: cons_id } => { Node::Constant { id: cons_id } => {
let block = &mut blocks.get_mut(&bb).unwrap().data; let block = &mut blocks.get_mut(&bb).unwrap().data;
write!(block, "{} = ", self.get_value(id, bb))?; write!(block, "{} = ", self.get_value(id, bb, true))?;
let mut size_and_device = None; let mut size_and_device = None;
match self.module.constants[cons_id.idx()] { match self.module.constants[cons_id.idx()] {
Constant::Boolean(val) => write!(block, "{}bool", val)?, Constant::Boolean(val) => write!(block, "{}bool", val)?,
...@@ -468,7 +468,7 @@ impl<'a> RTContext<'a> { ...@@ -468,7 +468,7 @@ impl<'a> RTContext<'a> {
block, block,
"::hercules_rt::__{}_zero_mem({}.0, {} as usize);", "::hercules_rt::__{}_zero_mem({}.0, {} as usize);",
device.name(), device.name(),
self.get_value(id, bb), self.get_value(id, bb, false),
size size
)?; )?;
} }
...@@ -476,7 +476,7 @@ impl<'a> RTContext<'a> { ...@@ -476,7 +476,7 @@ impl<'a> RTContext<'a> {
} }
Node::DynamicConstant { id: dc_id } => { Node::DynamicConstant { id: dc_id } => {
let block = &mut blocks.get_mut(&bb).unwrap().data; let block = &mut blocks.get_mut(&bb).unwrap().data;
write!(block, "{} = ", self.get_value(id, bb))?; write!(block, "{} = ", self.get_value(id, bb, true))?;
self.codegen_dynamic_constant(dc_id, block)?; self.codegen_dynamic_constant(dc_id, block)?;
write!(block, ";")?; write!(block, ";")?;
} }
...@@ -485,7 +485,7 @@ impl<'a> RTContext<'a> { ...@@ -485,7 +485,7 @@ impl<'a> RTContext<'a> {
write!( write!(
block, block,
"{} = tid_{}_{};", "{} = tid_{}_{};",
self.get_value(id, bb), self.get_value(id, bb, true),
control.idx(), control.idx(),
dimension dimension
)?; )?;
...@@ -506,10 +506,25 @@ impl<'a> RTContext<'a> { ...@@ -506,10 +506,25 @@ impl<'a> RTContext<'a> {
// The device backends ensure that device functions have the // The device backends ensure that device functions have the
// same interface as AsyncRust functions. // same interface as AsyncRust functions.
let block = &mut blocks.get_mut(&bb).unwrap().data; let block = &mut blocks.get_mut(&bb).unwrap().data;
let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
let device = self.devices[callee_id.idx()];
let prefix = match (device, is_async) {
(Device::AsyncRust, false) => "",
(_, false) => "",
(Device::AsyncRust, true) => "Some(::async_std::task::spawn(",
(_, true) => "Some(::async_std::task::spawn(async move {",
};
let postfix = match (device, is_async) {
(Device::AsyncRust, false) => ".await",
(_, false) => "",
(Device::AsyncRust, true) => "))",
(_, true) => "}))",
};
write!( write!(
block, block,
"{} = {}(", "{} = {}{}(",
self.get_value(id, bb), self.get_value(id, bb, true),
prefix,
self.module.functions[callee_id.idx()].name self.module.functions[callee_id.idx()].name
)?; )?;
for (device, offset) in self.backing_allocations[&self.func_id] for (device, offset) in self.backing_allocations[&self.func_id]
...@@ -525,14 +540,9 @@ impl<'a> RTContext<'a> { ...@@ -525,14 +540,9 @@ impl<'a> RTContext<'a> {
write!(block, ", ")?; write!(block, ", ")?;
} }
for arg in args { for arg in args {
write!(block, "{}, ", self.get_value(*arg, bb))?; write!(block, "{}, ", self.get_value(*arg, bb, false))?;
}
let device = self.devices[callee_id.idx()];
if device == Device::AsyncRust {
write!(block, ").await;")?;
} else {
write!(block, ");")?;
} }
write!(block, "){};", postfix)?;
} }
Node::Unary { op, input } => { Node::Unary { op, input } => {
let block = &mut blocks.get_mut(&bb).unwrap().data; let block = &mut blocks.get_mut(&bb).unwrap().data;
...@@ -540,20 +550,20 @@ impl<'a> RTContext<'a> { ...@@ -540,20 +550,20 @@ impl<'a> RTContext<'a> {
UnaryOperator::Not => write!( UnaryOperator::Not => write!(
block, block,
"{} = !{};", "{} = !{};",
self.get_value(id, bb), self.get_value(id, bb, true),
self.get_value(input, bb) self.get_value(input, bb, false)
)?, )?,
UnaryOperator::Neg => write!( UnaryOperator::Neg => write!(
block, block,
"{} = -{};", "{} = -{};",
self.get_value(id, bb), self.get_value(id, bb, true),
self.get_value(input, bb) self.get_value(input, bb, false)
)?, )?,
UnaryOperator::Cast(ty) => write!( UnaryOperator::Cast(ty) => write!(
block, block,
"{} = {} as {};", "{} = {} as {};",
self.get_value(id, bb), self.get_value(id, bb, true),
self.get_value(input, bb), self.get_value(input, bb, false),
self.get_type(ty) self.get_type(ty)
)?, )?,
}; };
...@@ -582,10 +592,10 @@ impl<'a> RTContext<'a> { ...@@ -582,10 +592,10 @@ impl<'a> RTContext<'a> {
write!( write!(
block, block,
"{} = {} {} {};", "{} = {} {} {};",
self.get_value(id, bb), self.get_value(id, bb, true),
self.get_value(left, bb), self.get_value(left, bb, false),
op, op,
self.get_value(right, bb) self.get_value(right, bb, false)
)?; )?;
} }
Node::Ternary { Node::Ternary {
...@@ -599,10 +609,10 @@ impl<'a> RTContext<'a> { ...@@ -599,10 +609,10 @@ impl<'a> RTContext<'a> {
TernaryOperator::Select => write!( TernaryOperator::Select => write!(
block, block,
"{} = if {} {{{}}} else {{{}}};", "{} = if {} {{{}}} else {{{}}};",
self.get_value(id, bb), self.get_value(id, bb, true),
self.get_value(first, bb), self.get_value(first, bb, false),
self.get_value(second, bb), self.get_value(second, bb, false),
self.get_value(third, bb), self.get_value(third, bb, false),
)?, )?,
}; };
} }
...@@ -641,17 +651,17 @@ impl<'a> RTContext<'a> { ...@@ -641,17 +651,17 @@ impl<'a> RTContext<'a> {
"::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {});", "::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {});",
src_device.name(), src_device.name(),
dst_device.name(), dst_device.name(),
self.get_value(collect, bb), self.get_value(collect, bb, false),
offset, offset,
self.get_value(data, bb), self.get_value(data, bb, false),
data_size, data_size,
)?; )?;
} }
write!( write!(
block, block,
"{} = {};", "{} = {};",
self.get_value(id, bb), self.get_value(id, bb, true),
self.get_value(collect, bb) self.get_value(collect, bb, false)
)?; )?;
} }
_ => panic!( _ => panic!(
...@@ -817,7 +827,7 @@ impl<'a> RTContext<'a> { ...@@ -817,7 +827,7 @@ impl<'a> RTContext<'a> {
// ((0 * s1 + p1) * s2 + p2) * s3 + p3 ... // ((0 * s1 + p1) * s2 + p2) * s3 + p3 ...
let elem_size = self.codegen_type_size(elem); let elem_size = self.codegen_type_size(elem);
for (p, s) in zip(pos, dims) { for (p, s) in zip(pos, dims) {
let p = self.get_value(*p, bb); let p = self.get_value(*p, bb, false);
acc_offset = format!("{} * ", acc_offset); acc_offset = format!("{} * ", acc_offset);
self.codegen_dynamic_constant(*s, &mut acc_offset)?; self.codegen_dynamic_constant(*s, &mut acc_offset)?;
acc_offset = format!("({} + {})", acc_offset, p); acc_offset = format!("({} + {})", acc_offset, p);
...@@ -928,23 +938,31 @@ impl<'a> RTContext<'a> { ...@@ -928,23 +938,31 @@ impl<'a> RTContext<'a> {
continue; continue;
} }
write!( // If the node is a call with an AsyncCall schedule, it should be
w, // spawned as a task and awaited later.
"let mut {}_{}: {} = {};", let is_async_call =
if is_reduce_on_child { "reduce" } else { "node" }, func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
idx, if is_async_call {
self.get_type(self.typing[idx]), write!(w, "let mut async_call_{} = None;", idx)?;
if self.module.types[self.typing[idx].idx()].is_integer() { } else {
"0" write!(
} else if self.module.types[self.typing[idx].idx()].is_float() { w,
"0.0" "let mut {}_{}: {} = {};",
} else { if is_reduce_on_child { "reduce" } else { "node" },
"::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" idx,
} self.get_type(self.typing[idx]),
)?; if self.module.types[self.typing[idx].idx()].is_integer() {
"0"
} else if self.module.types[self.typing[idx].idx()].is_float() {
"0.0"
} else {
"::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())"
}
)?;
}
} }
// Declare Vec for storing futures of fork-joins. // Declare Vecs for storing futures of fork-joins.
for fork in self.fork_tree[&root].iter() { for fork in self.fork_tree[&root].iter() {
write!(w, "let mut fork_{} = vec![];", fork.idx())?; write!(w, "let mut fork_{} = vec![];", fork.idx())?;
} }
...@@ -1186,7 +1204,7 @@ impl<'a> RTContext<'a> { ...@@ -1186,7 +1204,7 @@ impl<'a> RTContext<'a> {
&self.module.functions[self.func_id.idx()] &self.module.functions[self.func_id.idx()]
} }
fn get_value(&self, id: NodeID, bb: NodeID) -> String { fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
let func = self.get_func(); let func = self.get_func();
if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
&& control == bb && control == bb
...@@ -1203,6 +1221,14 @@ impl<'a> RTContext<'a> { ...@@ -1203,6 +1221,14 @@ impl<'a> RTContext<'a> {
fork.idx(), fork.idx(),
id.idx() id.idx()
) )
} else if func.nodes[id.idx()].is_call()
&& func.schedules[id.idx()].contains(&Schedule::AsyncCall)
{
format!(
"async_call_{}{}",
id.idx(),
if lhs { "" } else { ".unwrap().await" }
)
} else { } else {
format!("node_{}", id.idx()) format!("node_{}", id.idx())
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment