diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index bd1520745c72f9300ba750cf48e2c86c4bee978e..d3013239f5f78ce8e63181c181c0d5a8cbf77f81 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -359,6 +359,16 @@ impl<'a> RTContext<'a> { write!(prologue, " {{")?; } + // Emit clones of arcs used inside the fork-join. + for other_id in (0..func.nodes.len()).map(NodeID::new) { + if self.def_use.get_users(other_id).into_iter().any(|user_id| { + self.nodes_in_fork_joins[&id].contains(&self.bbs.0[user_id.idx()]) + }) && let Some(arc) = self.clone_arc(other_id) + { + write!(prologue, "{}", arc)?; + } + } + // Spawn an async closure and push its future to a Vec. write!( prologue, @@ -416,6 +426,14 @@ impl<'a> RTContext<'a> { let succ = self.control_subgraph.succs(id).next().unwrap(); write!(epilogue, "{} => {{", id.idx())?; + // Await the empty futures for the fork-joins, waiting for them + // to complete. + write!( + epilogue, + "for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire);", + fork.idx(), + )?; + // Emit the assignments to the reduce variables in the // surrounding context. It's very unfortunate that we have to do // it while lowering the join node (rather than the reduce nodes @@ -558,11 +576,12 @@ impl<'a> RTContext<'a> { assert_eq!(control, bb); // The device backends ensure that device functions have the // same interface as AsyncRust functions. - let block = &mut blocks.get_mut(&bb).unwrap().data; + let block = &mut blocks.get_mut(&bb).unwrap(); + let block = &mut block.data; let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall); if is_async { for arg in args { - if let Some(arc) = self.clone_arc(*arg, false) { + if let Some(arc) = self.clone_arc(*arg) { write!(block, "{}", arc)?; } } @@ -572,17 +591,19 @@ impl<'a> RTContext<'a> { (Device::AsyncRust, false) | (_, false) => { format!("{} = ", self.get_value(id, bb, true)) } - (_, true) => format!( - "{}::async_std::task::spawn(async move {{ async_call_sender_{}.send(", - self.clone_arc(id, true).unwrap(), - id.idx() - ), + (_, true) => { + write!(block, "{}", self.clone_arc(id).unwrap())?; + format!( + "*async_call_{}.lock().await = ::hercules_rt::__FutureSlotWrapper::new(::async_std::task::spawn(async move {{ ", + id.idx(), + ) + } }; let postfix = match (device, is_async) { (Device::AsyncRust, false) => ".await", (_, false) => "", - (Device::AsyncRust, true) => ".await).await})", - (_, true) => ").await})", + (Device::AsyncRust, true) => ".await}))", + (_, true) => "}))", }; write!( block, @@ -1079,14 +1100,14 @@ impl<'a> RTContext<'a> { } // If the node is a call with an AsyncCall schedule, it should be - // lowered to a channel. + // lowered to a Arc<Mutex<>> over the future. let is_async_call = func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall); if is_async_call { write!( w, - "let mut async_call_channel_{} = ::async_std::channel::bounded(1);let async_call_sender_{} = ::std::sync::Arc::new(async_call_channel_{}.0);let async_call_receiver_{} = ::std::sync::Arc::new(async_call_channel_{}.1);", - idx, idx, idx, idx, idx + "let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(::hercules_rt::__FutureSlotWrapper::empty()));", + idx, )?; } else { write!( @@ -1353,40 +1374,27 @@ impl<'a> RTContext<'a> { fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String { let func = self.get_func(); if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() - && control == bb + && (control == bb + || !self.nodes_in_fork_joins[&self.join_fork_map[&control]].contains(&bb)) { format!("reduce_{}", id.idx()) - } else if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() - && let fork = self.join_fork_map[&control] - && !self.nodes_in_fork_joins[&fork].contains(&bb) - { - // Before using the value of a reduction outside the fork-join, - // await the futures. - format!( - "{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}", - fork.idx(), - id.idx() - ) } else if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall) { assert!(!lhs); - format!("async_call_receiver_{}.recv().await.unwrap()", id.idx(),) + format!("async_call_{}.lock().await.inspect().await", id.idx(),) } else { format!("node_{}", id.idx()) } } - fn clone_arc(&self, id: NodeID, lhs: bool) -> Option<String> { + fn clone_arc(&self, id: NodeID) -> Option<String> { let func = self.get_func(); if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall) { - let kind = if lhs { "sender" } else { "receiver" }; Some(format!( - "let async_call_{}_{} = async_call_{}_{}.clone();", - kind, + "let async_call_{} = async_call_{}.clone();", id.idx(), - kind, id.idx() )) } else { diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index f01b2366e8da12be265b65ce29f6cb977ea7c618..6e308274c65e0fbf14bb938e297a188f8ebaf0f3 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -210,7 +210,11 @@ fn inline_func( } // Finally, delete the call node. - edit = edit.replace_all_uses(id, old_id_to_new_id(called_return_data))?; + if let Node::Parameter { index } = called_func.nodes[called_return_data.idx()] { + edit = edit.replace_all_uses(id, args[index])?; + } else { + edit = edit.replace_all_uses(id, old_id_to_new_id(called_return_data))?; + } edit = edit.delete_node(control)?; edit = edit.delete_node(id)?; diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 090a38a02cbbcd46253452f76a0b71681363c833..d19a0a5a16e4438e8746ca312428759ea28ed556 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -1,10 +1,10 @@ #![feature(once_cell_try)] use std::alloc::{alloc, dealloc, GlobalAlloc, Layout, System}; +use std::future::Future; use std::marker::PhantomData; use std::ptr::{copy_nonoverlapping, write_bytes, NonNull}; use std::slice::{from_raw_parts, from_raw_parts_mut}; - use std::sync::OnceLock; /* @@ -426,6 +426,46 @@ impl __RawPtrSendSync { unsafe impl Send for __RawPtrSendSync {} unsafe impl Sync for __RawPtrSendSync {} +#[derive(Clone, Debug)] +pub struct __FutureSlotWrapper<T, F> +where + T: Copy, + F: Future<Output = T>, +{ + future: Option<F>, + slot: Option<T>, +} + +impl<T, F> __FutureSlotWrapper<T, F> +where + T: Copy, + F: Future<Output = T>, +{ + pub fn empty() -> Self { + Self { + future: None, + slot: None, + } + } + + pub fn new(f: F) -> Self { + Self { + future: Some(f), + slot: None, + } + } + + pub async fn inspect(&mut self) -> T { + if let Some(slot) = self.slot { + slot + } else { + let result = self.future.take().unwrap().await; + self.slot = Some(result); + result + } + } +} + /* * A HerculesBox holds memory that can be on any device and provides a common interface to moving * data where it is needed. diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index c71aec111f42aad90c383f8b42f622037efac60c..76dcbdf68ce0ac88f8a2a22bda364f60a88e78bb 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -1,5 +1,3 @@ -no-memset(test6@const); - ccp(*); gvn(*); phi-elim(*); @@ -63,7 +61,9 @@ gvn(*); phi-elim(*); dce(*); -fork-tile[32, 0, false, true](test6@loop); +async-call(test6@call); +no-memset(test6@const); +fork-tile[2, 0, false, false](test6@loop); let out = fork-split(test6@loop); let out = outline(out.test6.fj1); cpu(out); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 51115f1576edd1d555395717bdc1dcb4e82a2529..bfb5564be29e444d065c3caaa40afb760d678aa7 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -73,11 +73,16 @@ fn test5(input : i32) -> i32[4] { return arr1; } +fn test6_helper(input: i32) -> i32 { + return input; +} + #[entry] fn test6(input: i32) -> i32[1024] { + @call let x = test6_helper(input); @const let arr : i32[1024]; @loop for i = 0 to 1024 { - arr[i] = i as i32 + input; + arr[i] = i as i32 + x; } return arr; } diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index 91bd6c79103eeb03468ea39a38995a9222836a20..364673cd2aedf323ca7727dc1ea8515188373a8d 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -52,6 +52,7 @@ slf(auto.test2); infer-schedules(auto.test2); fork-interchange[0, 1](auto.test2); +inline(test6); fork-tile[32, 0, false, true](test6@loop); let out = fork-split(test6@loop); let out = auto-outline(test6);