From bd0425ef1453533c0dd8fe78ed548066ee1958d3 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Wed, 12 Feb 2025 16:46:38 -0600 Subject: [PATCH] More flexible task parallelism --- hercules_cg/src/rt.rs | 134 +++++++++++------- hercules_ir/src/ir.rs | 2 + hercules_samples/call/build.rs | 25 +--- hercules_samples/call/src/call.hir | 10 +- hercules_samples/call/src/main.rs | 3 +- .../call/src/{gpu.sch => sched.sch} | 4 +- juno_scheduler/src/compile.rs | 1 + 7 files changed, 103 insertions(+), 76 deletions(-) rename hercules_samples/call/src/{gpu.sch => sched.sch} (73%) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 62f683ce..090253d4 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -300,7 +300,7 @@ impl<'a> RTContext<'a> { write!( epilogue, "control_token = if {} {{{}}} else {{{}}};}}", - self.get_value(cond, id), + self.get_value(cond, id, false), if succ1_is_true { succ1 } else { succ2 }.idx(), if succ1_is_true { succ2 } else { succ1 }.idx(), )?; @@ -309,7 +309,7 @@ impl<'a> RTContext<'a> { let prologue = &mut blocks.get_mut(&id).unwrap().prologue; write!(prologue, "{} => {{", id.idx())?; let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; - write!(epilogue, "return {};}}", self.get_value(data, id))?; + write!(epilogue, "return {};}}", self.get_value(data, id, false))?; } // Fork nodes open a new environment for defining an async closure. Node::Fork { @@ -399,8 +399,8 @@ impl<'a> RTContext<'a> { write!( epilogue, "{} = {};", - self.get_value(*user, id), - self.get_value(init, id) + self.get_value(*user, id, true), + self.get_value(init, id, false) )?; } } @@ -427,11 +427,11 @@ impl<'a> RTContext<'a> { match func.nodes[id.idx()] { Node::Parameter { index } => { let block = &mut blocks.get_mut(&bb).unwrap().data; - write!(block, "{} = p{};", self.get_value(id, bb), index)? + write!(block, "{} = p{};", self.get_value(id, bb, true), index)? } Node::Constant { id: cons_id } => { let block = &mut blocks.get_mut(&bb).unwrap().data; - write!(block, "{} = ", self.get_value(id, bb))?; + write!(block, "{} = ", self.get_value(id, bb, true))?; let mut size_and_device = None; match self.module.constants[cons_id.idx()] { Constant::Boolean(val) => write!(block, "{}bool", val)?, @@ -468,18 +468,24 @@ impl<'a> RTContext<'a> { block, "::hercules_rt::__{}_zero_mem({}.0, {} as usize);", device.name(), - self.get_value(id, bb), + self.get_value(id, bb, false), size )?; } } } + Node::DynamicConstant { id: dc_id } => { + let block = &mut blocks.get_mut(&bb).unwrap().data; + write!(block, "{} = ", self.get_value(id, bb, true))?; + self.codegen_dynamic_constant(dc_id, block)?; + write!(block, ";")?; + } Node::ThreadID { control, dimension } => { let block = &mut blocks.get_mut(&bb).unwrap().data; write!( block, "{} = tid_{}_{};", - self.get_value(id, bb), + self.get_value(id, bb, true), control.idx(), dimension )?; @@ -500,10 +506,25 @@ impl<'a> RTContext<'a> { // The device backends ensure that device functions have the // same interface as AsyncRust functions. let block = &mut blocks.get_mut(&bb).unwrap().data; + let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall); + let device = self.devices[callee_id.idx()]; + let prefix = match (device, is_async) { + (Device::AsyncRust, false) => "", + (_, false) => "", + (Device::AsyncRust, true) => "Some(::async_std::task::spawn(", + (_, true) => "Some(::async_std::task::spawn(async move {", + }; + let postfix = match (device, is_async) { + (Device::AsyncRust, false) => ".await", + (_, false) => "", + (Device::AsyncRust, true) => "))", + (_, true) => "}))", + }; write!( block, - "{} = {}(", - self.get_value(id, bb), + "{} = {}{}(", + self.get_value(id, bb, true), + prefix, self.module.functions[callee_id.idx()].name )?; for (device, offset) in self.backing_allocations[&self.func_id] @@ -519,14 +540,9 @@ impl<'a> RTContext<'a> { write!(block, ", ")?; } for arg in args { - write!(block, "{}, ", self.get_value(*arg, bb))?; - } - let device = self.devices[callee_id.idx()]; - if device == Device::AsyncRust { - write!(block, ").await;")?; - } else { - write!(block, ");")?; + write!(block, "{}, ", self.get_value(*arg, bb, false))?; } + write!(block, "){};", postfix)?; } Node::Unary { op, input } => { let block = &mut blocks.get_mut(&bb).unwrap().data; @@ -534,20 +550,20 @@ impl<'a> RTContext<'a> { UnaryOperator::Not => write!( block, "{} = !{};", - self.get_value(id, bb), - self.get_value(input, bb) + self.get_value(id, bb, true), + self.get_value(input, bb, false) )?, UnaryOperator::Neg => write!( block, "{} = -{};", - self.get_value(id, bb), - self.get_value(input, bb) + self.get_value(id, bb, true), + self.get_value(input, bb, false) )?, UnaryOperator::Cast(ty) => write!( block, "{} = {} as {};", - self.get_value(id, bb), - self.get_value(input, bb), + self.get_value(id, bb, true), + self.get_value(input, bb, false), self.get_type(ty) )?, }; @@ -576,10 +592,10 @@ impl<'a> RTContext<'a> { write!( block, "{} = {} {} {};", - self.get_value(id, bb), - self.get_value(left, bb), + self.get_value(id, bb, true), + self.get_value(left, bb, false), op, - self.get_value(right, bb) + self.get_value(right, bb, false) )?; } Node::Ternary { @@ -593,10 +609,10 @@ impl<'a> RTContext<'a> { TernaryOperator::Select => write!( block, "{} = if {} {{{}}} else {{{}}};", - self.get_value(id, bb), - self.get_value(first, bb), - self.get_value(second, bb), - self.get_value(third, bb), + self.get_value(id, bb, true), + self.get_value(first, bb, false), + self.get_value(second, bb, false), + self.get_value(third, bb, false), )?, }; } @@ -635,17 +651,17 @@ impl<'a> RTContext<'a> { "::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {});", src_device.name(), dst_device.name(), - self.get_value(collect, bb), + self.get_value(collect, bb, false), offset, - self.get_value(data, bb), + self.get_value(data, bb, false), data_size, )?; } write!( block, "{} = {};", - self.get_value(id, bb), - self.get_value(collect, bb) + self.get_value(id, bb, true), + self.get_value(collect, bb, false) )?; } _ => panic!( @@ -811,7 +827,7 @@ impl<'a> RTContext<'a> { // ((0 * s1 + p1) * s2 + p2) * s3 + p3 ... let elem_size = self.codegen_type_size(elem); for (p, s) in zip(pos, dims) { - let p = self.get_value(*p, bb); + let p = self.get_value(*p, bb, false); acc_offset = format!("{} * ", acc_offset); self.codegen_dynamic_constant(*s, &mut acc_offset)?; acc_offset = format!("({} + {})", acc_offset, p); @@ -922,23 +938,31 @@ impl<'a> RTContext<'a> { continue; } - write!( - w, - "let mut {}_{}: {} = {};", - if is_reduce_on_child { "reduce" } else { "node" }, - idx, - self.get_type(self.typing[idx]), - if self.module.types[self.typing[idx].idx()].is_integer() { - "0" - } else if self.module.types[self.typing[idx].idx()].is_float() { - "0.0" - } else { - "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" - } - )?; + // If the node is a call with an AsyncCall schedule, it should be + // spawned as a task and awaited later. + let is_async_call = + func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall); + if is_async_call { + write!(w, "let mut async_call_{} = None;", idx)?; + } else { + write!( + w, + "let mut {}_{}: {} = {};", + if is_reduce_on_child { "reduce" } else { "node" }, + idx, + self.get_type(self.typing[idx]), + if self.module.types[self.typing[idx].idx()].is_integer() { + "0" + } else if self.module.types[self.typing[idx].idx()].is_float() { + "0.0" + } else { + "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" + } + )?; + } } - // Declare Vec for storing futures of fork-joins. + // Declare Vecs for storing futures of fork-joins. for fork in self.fork_tree[&root].iter() { write!(w, "let mut fork_{} = vec![];", fork.idx())?; } @@ -1180,7 +1204,7 @@ impl<'a> RTContext<'a> { &self.module.functions[self.func_id.idx()] } - fn get_value(&self, id: NodeID, bb: NodeID) -> String { + fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String { let func = self.get_func(); if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() && control == bb @@ -1197,6 +1221,14 @@ impl<'a> RTContext<'a> { fork.idx(), id.idx() ) + } else if func.nodes[id.idx()].is_call() + && func.schedules[id.idx()].contains(&Schedule::AsyncCall) + { + format!( + "async_call_{}{}", + id.idx(), + if lhs { "" } else { ".unwrap().await" } + ) } else { format!("node_{}", id.idx()) } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 1dce5cfc..eb008904 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -331,6 +331,8 @@ pub enum Schedule { TightAssociative, // This constant node doesn't need to be memset to zero. NoResetConstant, + // This call should be called in a spawned future. + AsyncCall, } /* diff --git a/hercules_samples/call/build.rs b/hercules_samples/call/build.rs index e7b6dee9..3ecbb221 100644 --- a/hercules_samples/call/build.rs +++ b/hercules_samples/call/build.rs @@ -1,22 +1,11 @@ use juno_build::JunoCompiler; fn main() { - #[cfg(not(feature = "cuda"))] - { - JunoCompiler::new() - .ir_in_src("call.hir") - .unwrap() - .build() - .unwrap(); - } - #[cfg(feature = "cuda")] - { - JunoCompiler::new() - .ir_in_src("call.hir") - .unwrap() - .schedule_in_src("gpu.sch") - .unwrap() - .build() - .unwrap(); - } + JunoCompiler::new() + .ir_in_src("call.hir") + .unwrap() + .schedule_in_src("sched.sch") + .unwrap() + .build() + .unwrap(); } diff --git a/hercules_samples/call/src/call.hir b/hercules_samples/call/src/call.hir index 937ce1ef..cecee343 100644 --- a/hercules_samples/call/src/call.hir +++ b/hercules_samples/call/src/call.hir @@ -1,7 +1,11 @@ fn myfunc(x: u64) -> u64 - cr = region(start) - y = call<16>(add, cr, x, x) - r = return(cr, y) + cr1 = region(start) + cr2 = region(cr1) + c = constant(u64, 24) + y = call<16>(add, cr1, x, x) + z = call<10>(add, cr2, x, c) + w = add(y, z) + r = return(cr2, w) fn add<1>(x: u64, y: u64) -> u64 w = add(x, y) diff --git a/hercules_samples/call/src/main.rs b/hercules_samples/call/src/main.rs index ff4b6f4a..ea83a1df 100644 --- a/hercules_samples/call/src/main.rs +++ b/hercules_samples/call/src/main.rs @@ -8,9 +8,10 @@ fn main() { async_std::task::block_on(async { let mut r = runner!(myfunc); let x = r.run(7).await; + assert_eq!(x, 71); let mut r = runner!(add); let y = r.run(10, 2, 18).await; - assert_eq!(x, y); + assert_eq!(y, 30); }); } diff --git a/hercules_samples/call/src/gpu.sch b/hercules_samples/call/src/sched.sch similarity index 73% rename from hercules_samples/call/src/gpu.sch rename to hercules_samples/call/src/sched.sch index cc4ef88f..7d4172fb 100644 --- a/hercules_samples/call/src/gpu.sch +++ b/hercules_samples/call/src/sched.sch @@ -2,9 +2,6 @@ gvn(*); phi-elim(*); dce(*); -let out = auto-outline(*); -gpu(out.add); - ip-sroa(*); sroa(*); dce(*); @@ -13,5 +10,6 @@ phi-elim(*); dce(*); infer-schedules(*); +async-call(myfunc@y); gcm(*); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 4ea8dfb5..1aaa10cd 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -137,6 +137,7 @@ impl FromStr for Appliable { "parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)), "vectorize" => Ok(Appliable::Schedule(Schedule::Vectorizable)), "no-memset" | "no-reset" => Ok(Appliable::Schedule(Schedule::NoResetConstant)), + "task-parallel" | "async-call" => Ok(Appliable::Schedule(Schedule::AsyncCall)), _ => Err(s.to_string()), } -- GitLab