Skip to content
Snippets Groups Projects
Commit 2fe49fc2 authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'asynccall' into 'main'

More flexible task parallelism

See merge request !164
parents 40f1eec2 bd0425ef
No related branches found
No related tags found
1 merge request!164More flexible task parallelism
Pipeline #201568 passed
......@@ -300,7 +300,7 @@ impl<'a> RTContext<'a> {
write!(
epilogue,
"control_token = if {} {{{}}} else {{{}}};}}",
self.get_value(cond, id),
self.get_value(cond, id, false),
if succ1_is_true { succ1 } else { succ2 }.idx(),
if succ1_is_true { succ2 } else { succ1 }.idx(),
)?;
......@@ -309,7 +309,7 @@ impl<'a> RTContext<'a> {
let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
write!(prologue, "{} => {{", id.idx())?;
let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
write!(epilogue, "return {};}}", self.get_value(data, id))?;
write!(epilogue, "return {};}}", self.get_value(data, id, false))?;
}
// Fork nodes open a new environment for defining an async closure.
Node::Fork {
......@@ -399,8 +399,8 @@ impl<'a> RTContext<'a> {
write!(
epilogue,
"{} = {};",
self.get_value(*user, id),
self.get_value(init, id)
self.get_value(*user, id, true),
self.get_value(init, id, false)
)?;
}
}
......@@ -427,11 +427,11 @@ impl<'a> RTContext<'a> {
match func.nodes[id.idx()] {
Node::Parameter { index } => {
let block = &mut blocks.get_mut(&bb).unwrap().data;
write!(block, "{} = p{};", self.get_value(id, bb), index)?
write!(block, "{} = p{};", self.get_value(id, bb, true), index)?
}
Node::Constant { id: cons_id } => {
let block = &mut blocks.get_mut(&bb).unwrap().data;
write!(block, "{} = ", self.get_value(id, bb))?;
write!(block, "{} = ", self.get_value(id, bb, true))?;
let mut size_and_device = None;
match self.module.constants[cons_id.idx()] {
Constant::Boolean(val) => write!(block, "{}bool", val)?,
......@@ -468,18 +468,24 @@ impl<'a> RTContext<'a> {
block,
"::hercules_rt::__{}_zero_mem({}.0, {} as usize);",
device.name(),
self.get_value(id, bb),
self.get_value(id, bb, false),
size
)?;
}
}
}
Node::DynamicConstant { id: dc_id } => {
let block = &mut blocks.get_mut(&bb).unwrap().data;
write!(block, "{} = ", self.get_value(id, bb, true))?;
self.codegen_dynamic_constant(dc_id, block)?;
write!(block, ";")?;
}
Node::ThreadID { control, dimension } => {
let block = &mut blocks.get_mut(&bb).unwrap().data;
write!(
block,
"{} = tid_{}_{};",
self.get_value(id, bb),
self.get_value(id, bb, true),
control.idx(),
dimension
)?;
......@@ -500,10 +506,25 @@ impl<'a> RTContext<'a> {
// The device backends ensure that device functions have the
// same interface as AsyncRust functions.
let block = &mut blocks.get_mut(&bb).unwrap().data;
let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
let device = self.devices[callee_id.idx()];
let prefix = match (device, is_async) {
(Device::AsyncRust, false) => "",
(_, false) => "",
(Device::AsyncRust, true) => "Some(::async_std::task::spawn(",
(_, true) => "Some(::async_std::task::spawn(async move {",
};
let postfix = match (device, is_async) {
(Device::AsyncRust, false) => ".await",
(_, false) => "",
(Device::AsyncRust, true) => "))",
(_, true) => "}))",
};
write!(
block,
"{} = {}(",
self.get_value(id, bb),
"{} = {}{}(",
self.get_value(id, bb, true),
prefix,
self.module.functions[callee_id.idx()].name
)?;
for (device, offset) in self.backing_allocations[&self.func_id]
......@@ -519,14 +540,9 @@ impl<'a> RTContext<'a> {
write!(block, ", ")?;
}
for arg in args {
write!(block, "{}, ", self.get_value(*arg, bb))?;
}
let device = self.devices[callee_id.idx()];
if device == Device::AsyncRust {
write!(block, ").await;")?;
} else {
write!(block, ");")?;
write!(block, "{}, ", self.get_value(*arg, bb, false))?;
}
write!(block, "){};", postfix)?;
}
Node::Unary { op, input } => {
let block = &mut blocks.get_mut(&bb).unwrap().data;
......@@ -534,20 +550,20 @@ impl<'a> RTContext<'a> {
UnaryOperator::Not => write!(
block,
"{} = !{};",
self.get_value(id, bb),
self.get_value(input, bb)
self.get_value(id, bb, true),
self.get_value(input, bb, false)
)?,
UnaryOperator::Neg => write!(
block,
"{} = -{};",
self.get_value(id, bb),
self.get_value(input, bb)
self.get_value(id, bb, true),
self.get_value(input, bb, false)
)?,
UnaryOperator::Cast(ty) => write!(
block,
"{} = {} as {};",
self.get_value(id, bb),
self.get_value(input, bb),
self.get_value(id, bb, true),
self.get_value(input, bb, false),
self.get_type(ty)
)?,
};
......@@ -576,10 +592,10 @@ impl<'a> RTContext<'a> {
write!(
block,
"{} = {} {} {};",
self.get_value(id, bb),
self.get_value(left, bb),
self.get_value(id, bb, true),
self.get_value(left, bb, false),
op,
self.get_value(right, bb)
self.get_value(right, bb, false)
)?;
}
Node::Ternary {
......@@ -593,10 +609,10 @@ impl<'a> RTContext<'a> {
TernaryOperator::Select => write!(
block,
"{} = if {} {{{}}} else {{{}}};",
self.get_value(id, bb),
self.get_value(first, bb),
self.get_value(second, bb),
self.get_value(third, bb),
self.get_value(id, bb, true),
self.get_value(first, bb, false),
self.get_value(second, bb, false),
self.get_value(third, bb, false),
)?,
};
}
......@@ -635,17 +651,17 @@ impl<'a> RTContext<'a> {
"::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {});",
src_device.name(),
dst_device.name(),
self.get_value(collect, bb),
self.get_value(collect, bb, false),
offset,
self.get_value(data, bb),
self.get_value(data, bb, false),
data_size,
)?;
}
write!(
block,
"{} = {};",
self.get_value(id, bb),
self.get_value(collect, bb)
self.get_value(id, bb, true),
self.get_value(collect, bb, false)
)?;
}
_ => panic!(
......@@ -811,7 +827,7 @@ impl<'a> RTContext<'a> {
// ((0 * s1 + p1) * s2 + p2) * s3 + p3 ...
let elem_size = self.codegen_type_size(elem);
for (p, s) in zip(pos, dims) {
let p = self.get_value(*p, bb);
let p = self.get_value(*p, bb, false);
acc_offset = format!("{} * ", acc_offset);
self.codegen_dynamic_constant(*s, &mut acc_offset)?;
acc_offset = format!("({} + {})", acc_offset, p);
......@@ -922,23 +938,31 @@ impl<'a> RTContext<'a> {
continue;
}
write!(
w,
"let mut {}_{}: {} = {};",
if is_reduce_on_child { "reduce" } else { "node" },
idx,
self.get_type(self.typing[idx]),
if self.module.types[self.typing[idx].idx()].is_integer() {
"0"
} else if self.module.types[self.typing[idx].idx()].is_float() {
"0.0"
} else {
"::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())"
}
)?;
// If the node is a call with an AsyncCall schedule, it should be
// spawned as a task and awaited later.
let is_async_call =
func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
if is_async_call {
write!(w, "let mut async_call_{} = None;", idx)?;
} else {
write!(
w,
"let mut {}_{}: {} = {};",
if is_reduce_on_child { "reduce" } else { "node" },
idx,
self.get_type(self.typing[idx]),
if self.module.types[self.typing[idx].idx()].is_integer() {
"0"
} else if self.module.types[self.typing[idx].idx()].is_float() {
"0.0"
} else {
"::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())"
}
)?;
}
}
// Declare Vec for storing futures of fork-joins.
// Declare Vecs for storing futures of fork-joins.
for fork in self.fork_tree[&root].iter() {
write!(w, "let mut fork_{} = vec![];", fork.idx())?;
}
......@@ -1180,7 +1204,7 @@ impl<'a> RTContext<'a> {
&self.module.functions[self.func_id.idx()]
}
fn get_value(&self, id: NodeID, bb: NodeID) -> String {
fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
let func = self.get_func();
if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
&& control == bb
......@@ -1197,6 +1221,14 @@ impl<'a> RTContext<'a> {
fork.idx(),
id.idx()
)
} else if func.nodes[id.idx()].is_call()
&& func.schedules[id.idx()].contains(&Schedule::AsyncCall)
{
format!(
"async_call_{}{}",
id.idx(),
if lhs { "" } else { ".unwrap().await" }
)
} else {
format!("node_{}", id.idx())
}
......
......@@ -331,6 +331,8 @@ pub enum Schedule {
TightAssociative,
// This constant node doesn't need to be memset to zero.
NoResetConstant,
// This call should be called in a spawned future.
AsyncCall,
}
/*
......
use juno_build::JunoCompiler;
fn main() {
#[cfg(not(feature = "cuda"))]
{
JunoCompiler::new()
.ir_in_src("call.hir")
.unwrap()
.build()
.unwrap();
}
#[cfg(feature = "cuda")]
{
JunoCompiler::new()
.ir_in_src("call.hir")
.unwrap()
.schedule_in_src("gpu.sch")
.unwrap()
.build()
.unwrap();
}
JunoCompiler::new()
.ir_in_src("call.hir")
.unwrap()
.schedule_in_src("sched.sch")
.unwrap()
.build()
.unwrap();
}
fn myfunc(x: u64) -> u64
cr = region(start)
y = call<16>(add, cr, x, x)
r = return(cr, y)
cr1 = region(start)
cr2 = region(cr1)
c = constant(u64, 24)
y = call<16>(add, cr1, x, x)
z = call<10>(add, cr2, x, c)
w = add(y, z)
r = return(cr2, w)
fn add<1>(x: u64, y: u64) -> u64
w = add(x, y)
......
......@@ -8,9 +8,10 @@ fn main() {
async_std::task::block_on(async {
let mut r = runner!(myfunc);
let x = r.run(7).await;
assert_eq!(x, 71);
let mut r = runner!(add);
let y = r.run(10, 2, 18).await;
assert_eq!(x, y);
assert_eq!(y, 30);
});
}
......
......@@ -2,9 +2,6 @@ gvn(*);
phi-elim(*);
dce(*);
let out = auto-outline(*);
gpu(out.add);
ip-sroa(*);
sroa(*);
dce(*);
......@@ -13,5 +10,6 @@ phi-elim(*);
dce(*);
infer-schedules(*);
async-call(myfunc@y);
gcm(*);
......@@ -137,6 +137,7 @@ impl FromStr for Appliable {
"parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)),
"vectorize" => Ok(Appliable::Schedule(Schedule::Vectorizable)),
"no-memset" | "no-reset" => Ok(Appliable::Schedule(Schedule::NoResetConstant)),
"task-parallel" | "async-call" => Ok(Appliable::Schedule(Schedule::AsyncCall)),
_ => Err(s.to_string()),
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment