From 2878e698334761642ef92365af1f4f99f66d09d6 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 12 Feb 2025 16:42:45 -0600 Subject: [PATCH] Lower AsyncCall to task parallel code --- hercules_cg/src/rt.rs | 130 +++++++++++++++++++++++++----------------- 1 file changed, 78 insertions(+), 52 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index fd44d777..090253d4 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -300,7 +300,7 @@ impl<'a> RTContext<'a> { write!( epilogue, "control_token = if {} {{{}}} else {{{}}};}}", - self.get_value(cond, id), + self.get_value(cond, id, false), if succ1_is_true { succ1 } else { succ2 }.idx(), if succ1_is_true { succ2 } else { succ1 }.idx(), )?; @@ -309,7 +309,7 @@ impl<'a> RTContext<'a> { let prologue = &mut blocks.get_mut(&id).unwrap().prologue; write!(prologue, "{} => {{", id.idx())?; let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; - write!(epilogue, "return {};}}", self.get_value(data, id))?; + write!(epilogue, "return {};}}", self.get_value(data, id, false))?; } // Fork nodes open a new environment for defining an async closure. Node::Fork { @@ -399,8 +399,8 @@ impl<'a> RTContext<'a> { write!( epilogue, "{} = {};", - self.get_value(*user, id), - self.get_value(init, id) + self.get_value(*user, id, true), + self.get_value(init, id, false) )?; } } @@ -427,11 +427,11 @@ impl<'a> RTContext<'a> { match func.nodes[id.idx()] { Node::Parameter { index } => { let block = &mut blocks.get_mut(&bb).unwrap().data; - write!(block, "{} = p{};", self.get_value(id, bb), index)? + write!(block, "{} = p{};", self.get_value(id, bb, true), index)? } Node::Constant { id: cons_id } => { let block = &mut blocks.get_mut(&bb).unwrap().data; - write!(block, "{} = ", self.get_value(id, bb))?; + write!(block, "{} = ", self.get_value(id, bb, true))?; let mut size_and_device = None; match self.module.constants[cons_id.idx()] { Constant::Boolean(val) => write!(block, "{}bool", val)?, @@ -468,7 +468,7 @@ impl<'a> RTContext<'a> { block, "::hercules_rt::__{}_zero_mem({}.0, {} as usize);", device.name(), - self.get_value(id, bb), + self.get_value(id, bb, false), size )?; } @@ -476,7 +476,7 @@ impl<'a> RTContext<'a> { } Node::DynamicConstant { id: dc_id } => { let block = &mut blocks.get_mut(&bb).unwrap().data; - write!(block, "{} = ", self.get_value(id, bb))?; + write!(block, "{} = ", self.get_value(id, bb, true))?; self.codegen_dynamic_constant(dc_id, block)?; write!(block, ";")?; } @@ -485,7 +485,7 @@ impl<'a> RTContext<'a> { write!( block, "{} = tid_{}_{};", - self.get_value(id, bb), + self.get_value(id, bb, true), control.idx(), dimension )?; @@ -506,10 +506,25 @@ impl<'a> RTContext<'a> { // The device backends ensure that device functions have the // same interface as AsyncRust functions. let block = &mut blocks.get_mut(&bb).unwrap().data; + let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall); + let device = self.devices[callee_id.idx()]; + let prefix = match (device, is_async) { + (Device::AsyncRust, false) => "", + (_, false) => "", + (Device::AsyncRust, true) => "Some(::async_std::task::spawn(", + (_, true) => "Some(::async_std::task::spawn(async move {", + }; + let postfix = match (device, is_async) { + (Device::AsyncRust, false) => ".await", + (_, false) => "", + (Device::AsyncRust, true) => "))", + (_, true) => "}))", + }; write!( block, - "{} = {}(", - self.get_value(id, bb), + "{} = {}{}(", + self.get_value(id, bb, true), + prefix, self.module.functions[callee_id.idx()].name )?; for (device, offset) in self.backing_allocations[&self.func_id] @@ -525,14 +540,9 @@ impl<'a> RTContext<'a> { write!(block, ", ")?; } for arg in args { - write!(block, "{}, ", self.get_value(*arg, bb))?; - } - let device = self.devices[callee_id.idx()]; - if device == Device::AsyncRust { - write!(block, ").await;")?; - } else { - write!(block, ");")?; + write!(block, "{}, ", self.get_value(*arg, bb, false))?; } + write!(block, "){};", postfix)?; } Node::Unary { op, input } => { let block = &mut blocks.get_mut(&bb).unwrap().data; @@ -540,20 +550,20 @@ impl<'a> RTContext<'a> { UnaryOperator::Not => write!( block, "{} = !{};", - self.get_value(id, bb), - self.get_value(input, bb) + self.get_value(id, bb, true), + self.get_value(input, bb, false) )?, UnaryOperator::Neg => write!( block, "{} = -{};", - self.get_value(id, bb), - self.get_value(input, bb) + self.get_value(id, bb, true), + self.get_value(input, bb, false) )?, UnaryOperator::Cast(ty) => write!( block, "{} = {} as {};", - self.get_value(id, bb), - self.get_value(input, bb), + self.get_value(id, bb, true), + self.get_value(input, bb, false), self.get_type(ty) )?, }; @@ -582,10 +592,10 @@ impl<'a> RTContext<'a> { write!( block, "{} = {} {} {};", - self.get_value(id, bb), - self.get_value(left, bb), + self.get_value(id, bb, true), + self.get_value(left, bb, false), op, - self.get_value(right, bb) + self.get_value(right, bb, false) )?; } Node::Ternary { @@ -599,10 +609,10 @@ impl<'a> RTContext<'a> { TernaryOperator::Select => write!( block, "{} = if {} {{{}}} else {{{}}};", - self.get_value(id, bb), - self.get_value(first, bb), - self.get_value(second, bb), - self.get_value(third, bb), + self.get_value(id, bb, true), + self.get_value(first, bb, false), + self.get_value(second, bb, false), + self.get_value(third, bb, false), )?, }; } @@ -641,17 +651,17 @@ impl<'a> RTContext<'a> { "::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {});", src_device.name(), dst_device.name(), - self.get_value(collect, bb), + self.get_value(collect, bb, false), offset, - self.get_value(data, bb), + self.get_value(data, bb, false), data_size, )?; } write!( block, "{} = {};", - self.get_value(id, bb), - self.get_value(collect, bb) + self.get_value(id, bb, true), + self.get_value(collect, bb, false) )?; } _ => panic!( @@ -817,7 +827,7 @@ impl<'a> RTContext<'a> { // ((0 * s1 + p1) * s2 + p2) * s3 + p3 ... let elem_size = self.codegen_type_size(elem); for (p, s) in zip(pos, dims) { - let p = self.get_value(*p, bb); + let p = self.get_value(*p, bb, false); acc_offset = format!("{} * ", acc_offset); self.codegen_dynamic_constant(*s, &mut acc_offset)?; acc_offset = format!("({} + {})", acc_offset, p); @@ -928,23 +938,31 @@ impl<'a> RTContext<'a> { continue; } - write!( - w, - "let mut {}_{}: {} = {};", - if is_reduce_on_child { "reduce" } else { "node" }, - idx, - self.get_type(self.typing[idx]), - if self.module.types[self.typing[idx].idx()].is_integer() { - "0" - } else if self.module.types[self.typing[idx].idx()].is_float() { - "0.0" - } else { - "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" - } - )?; + // If the node is a call with an AsyncCall schedule, it should be + // spawned as a task and awaited later. + let is_async_call = + func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall); + if is_async_call { + write!(w, "let mut async_call_{} = None;", idx)?; + } else { + write!( + w, + "let mut {}_{}: {} = {};", + if is_reduce_on_child { "reduce" } else { "node" }, + idx, + self.get_type(self.typing[idx]), + if self.module.types[self.typing[idx].idx()].is_integer() { + "0" + } else if self.module.types[self.typing[idx].idx()].is_float() { + "0.0" + } else { + "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" + } + )?; + } } - // Declare Vec for storing futures of fork-joins. + // Declare Vecs for storing futures of fork-joins. for fork in self.fork_tree[&root].iter() { write!(w, "let mut fork_{} = vec![];", fork.idx())?; } @@ -1186,7 +1204,7 @@ impl<'a> RTContext<'a> { &self.module.functions[self.func_id.idx()] } - fn get_value(&self, id: NodeID, bb: NodeID) -> String { + fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String { let func = self.get_func(); if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() && control == bb @@ -1203,6 +1221,14 @@ impl<'a> RTContext<'a> { fork.idx(), id.idx() ) + } else if func.nodes[id.idx()].is_call() + && func.schedules[id.idx()].contains(&Schedule::AsyncCall) + { + format!( + "async_call_{}{}", + id.idx(), + if lhs { "" } else { ".unwrap().await" } + ) } else { format!("node_{}", id.idx()) } -- GitLab