From bd0425ef1453533c0dd8fe78ed548066ee1958d3 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Wed, 12 Feb 2025 16:46:38 -0600
Subject: [PATCH] More flexible task parallelism

---
 hercules_cg/src/rt.rs                         | 134 +++++++++++-------
 hercules_ir/src/ir.rs                         |   2 +
 hercules_samples/call/build.rs                |  25 +---
 hercules_samples/call/src/call.hir            |  10 +-
 hercules_samples/call/src/main.rs             |   3 +-
 .../call/src/{gpu.sch => sched.sch}           |   4 +-
 juno_scheduler/src/compile.rs                 |   1 +
 7 files changed, 103 insertions(+), 76 deletions(-)
 rename hercules_samples/call/src/{gpu.sch => sched.sch} (73%)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 62f683ce..090253d4 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -300,7 +300,7 @@ impl<'a> RTContext<'a> {
                 write!(
                     epilogue,
                     "control_token = if {} {{{}}} else {{{}}};}}",
-                    self.get_value(cond, id),
+                    self.get_value(cond, id, false),
                     if succ1_is_true { succ1 } else { succ2 }.idx(),
                     if succ1_is_true { succ2 } else { succ1 }.idx(),
                 )?;
@@ -309,7 +309,7 @@ impl<'a> RTContext<'a> {
                 let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                 write!(prologue, "{} => {{", id.idx())?;
                 let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
-                write!(epilogue, "return {};}}", self.get_value(data, id))?;
+                write!(epilogue, "return {};}}", self.get_value(data, id, false))?;
             }
             // Fork nodes open a new environment for defining an async closure.
             Node::Fork {
@@ -399,8 +399,8 @@ impl<'a> RTContext<'a> {
                         write!(
                             epilogue,
                             "{} = {};",
-                            self.get_value(*user, id),
-                            self.get_value(init, id)
+                            self.get_value(*user, id, true),
+                            self.get_value(init, id, false)
                         )?;
                     }
                 }
@@ -427,11 +427,11 @@ impl<'a> RTContext<'a> {
         match func.nodes[id.idx()] {
             Node::Parameter { index } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
-                write!(block, "{} = p{};", self.get_value(id, bb), index)?
+                write!(block, "{} = p{};", self.get_value(id, bb, true), index)?
             }
             Node::Constant { id: cons_id } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
-                write!(block, "{} = ", self.get_value(id, bb))?;
+                write!(block, "{} = ", self.get_value(id, bb, true))?;
                 let mut size_and_device = None;
                 match self.module.constants[cons_id.idx()] {
                     Constant::Boolean(val) => write!(block, "{}bool", val)?,
@@ -468,18 +468,24 @@ impl<'a> RTContext<'a> {
                             block,
                             "::hercules_rt::__{}_zero_mem({}.0, {} as usize);",
                             device.name(),
-                            self.get_value(id, bb),
+                            self.get_value(id, bb, false),
                             size
                         )?;
                     }
                 }
             }
+            Node::DynamicConstant { id: dc_id } => {
+                let block = &mut blocks.get_mut(&bb).unwrap().data;
+                write!(block, "{} = ", self.get_value(id, bb, true))?;
+                self.codegen_dynamic_constant(dc_id, block)?;
+                write!(block, ";")?;
+            }
             Node::ThreadID { control, dimension } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 write!(
                     block,
                     "{} = tid_{}_{};",
-                    self.get_value(id, bb),
+                    self.get_value(id, bb, true),
                     control.idx(),
                     dimension
                 )?;
@@ -500,10 +506,25 @@ impl<'a> RTContext<'a> {
                 // The device backends ensure that device functions have the
                 // same interface as AsyncRust functions.
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
+                let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
+                let device = self.devices[callee_id.idx()];
+                let prefix = match (device, is_async) {
+                    (Device::AsyncRust, false) => "",
+                    (_, false) => "",
+                    (Device::AsyncRust, true) => "Some(::async_std::task::spawn(",
+                    (_, true) => "Some(::async_std::task::spawn(async move {",
+                };
+                let postfix = match (device, is_async) {
+                    (Device::AsyncRust, false) => ".await",
+                    (_, false) => "",
+                    (Device::AsyncRust, true) => "))",
+                    (_, true) => "}))",
+                };
                 write!(
                     block,
-                    "{} = {}(",
-                    self.get_value(id, bb),
+                    "{} = {}{}(",
+                    self.get_value(id, bb, true),
+                    prefix,
                     self.module.functions[callee_id.idx()].name
                 )?;
                 for (device, offset) in self.backing_allocations[&self.func_id]
@@ -519,14 +540,9 @@ impl<'a> RTContext<'a> {
                     write!(block, ", ")?;
                 }
                 for arg in args {
-                    write!(block, "{}, ", self.get_value(*arg, bb))?;
-                }
-                let device = self.devices[callee_id.idx()];
-                if device == Device::AsyncRust {
-                    write!(block, ").await;")?;
-                } else {
-                    write!(block, ");")?;
+                    write!(block, "{}, ", self.get_value(*arg, bb, false))?;
                 }
+                write!(block, "){};", postfix)?;
             }
             Node::Unary { op, input } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
@@ -534,20 +550,20 @@ impl<'a> RTContext<'a> {
                     UnaryOperator::Not => write!(
                         block,
                         "{} = !{};",
-                        self.get_value(id, bb),
-                        self.get_value(input, bb)
+                        self.get_value(id, bb, true),
+                        self.get_value(input, bb, false)
                     )?,
                     UnaryOperator::Neg => write!(
                         block,
                         "{} = -{};",
-                        self.get_value(id, bb),
-                        self.get_value(input, bb)
+                        self.get_value(id, bb, true),
+                        self.get_value(input, bb, false)
                     )?,
                     UnaryOperator::Cast(ty) => write!(
                         block,
                         "{} = {} as {};",
-                        self.get_value(id, bb),
-                        self.get_value(input, bb),
+                        self.get_value(id, bb, true),
+                        self.get_value(input, bb, false),
                         self.get_type(ty)
                     )?,
                 };
@@ -576,10 +592,10 @@ impl<'a> RTContext<'a> {
                 write!(
                     block,
                     "{} = {} {} {};",
-                    self.get_value(id, bb),
-                    self.get_value(left, bb),
+                    self.get_value(id, bb, true),
+                    self.get_value(left, bb, false),
                     op,
-                    self.get_value(right, bb)
+                    self.get_value(right, bb, false)
                 )?;
             }
             Node::Ternary {
@@ -593,10 +609,10 @@ impl<'a> RTContext<'a> {
                     TernaryOperator::Select => write!(
                         block,
                         "{} = if {} {{{}}} else {{{}}};",
-                        self.get_value(id, bb),
-                        self.get_value(first, bb),
-                        self.get_value(second, bb),
-                        self.get_value(third, bb),
+                        self.get_value(id, bb, true),
+                        self.get_value(first, bb, false),
+                        self.get_value(second, bb, false),
+                        self.get_value(third, bb, false),
                     )?,
                 };
             }
@@ -635,17 +651,17 @@ impl<'a> RTContext<'a> {
                         "::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {});",
                         src_device.name(),
                         dst_device.name(),
-                        self.get_value(collect, bb),
+                        self.get_value(collect, bb, false),
                         offset,
-                        self.get_value(data, bb),
+                        self.get_value(data, bb, false),
                         data_size,
                     )?;
                 }
                 write!(
                     block,
                     "{} = {};",
-                    self.get_value(id, bb),
-                    self.get_value(collect, bb)
+                    self.get_value(id, bb, true),
+                    self.get_value(collect, bb, false)
                 )?;
             }
             _ => panic!(
@@ -811,7 +827,7 @@ impl<'a> RTContext<'a> {
                     //     ((0 * s1 + p1) * s2 + p2) * s3 + p3 ...
                     let elem_size = self.codegen_type_size(elem);
                     for (p, s) in zip(pos, dims) {
-                        let p = self.get_value(*p, bb);
+                        let p = self.get_value(*p, bb, false);
                         acc_offset = format!("{} * ", acc_offset);
                         self.codegen_dynamic_constant(*s, &mut acc_offset)?;
                         acc_offset = format!("({} + {})", acc_offset, p);
@@ -922,23 +938,31 @@ impl<'a> RTContext<'a> {
                 continue;
             }
 
-            write!(
-                w,
-                "let mut {}_{}: {} = {};",
-                if is_reduce_on_child { "reduce" } else { "node" },
-                idx,
-                self.get_type(self.typing[idx]),
-                if self.module.types[self.typing[idx].idx()].is_integer() {
-                    "0"
-                } else if self.module.types[self.typing[idx].idx()].is_float() {
-                    "0.0"
-                } else {
-                    "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())"
-                }
-            )?;
+            // If the node is a call with an AsyncCall schedule, it should be
+            // spawned as a task and awaited later.
+            let is_async_call =
+                func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
+            if is_async_call {
+                write!(w, "let mut async_call_{} = None;", idx)?;
+            } else {
+                write!(
+                    w,
+                    "let mut {}_{}: {} = {};",
+                    if is_reduce_on_child { "reduce" } else { "node" },
+                    idx,
+                    self.get_type(self.typing[idx]),
+                    if self.module.types[self.typing[idx].idx()].is_integer() {
+                        "0"
+                    } else if self.module.types[self.typing[idx].idx()].is_float() {
+                        "0.0"
+                    } else {
+                        "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())"
+                    }
+                )?;
+            }
         }
 
-        // Declare Vec for storing futures of fork-joins.
+        // Declare Vecs for storing futures of fork-joins.
         for fork in self.fork_tree[&root].iter() {
             write!(w, "let mut fork_{} = vec![];", fork.idx())?;
         }
@@ -1180,7 +1204,7 @@ impl<'a> RTContext<'a> {
         &self.module.functions[self.func_id.idx()]
     }
 
-    fn get_value(&self, id: NodeID, bb: NodeID) -> String {
+    fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
         let func = self.get_func();
         if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
             && control == bb
@@ -1197,6 +1221,14 @@ impl<'a> RTContext<'a> {
                 fork.idx(),
                 id.idx()
             )
+        } else if func.nodes[id.idx()].is_call()
+            && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
+        {
+            format!(
+                "async_call_{}{}",
+                id.idx(),
+                if lhs { "" } else { ".unwrap().await" }
+            )
         } else {
             format!("node_{}", id.idx())
         }
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 1dce5cfc..eb008904 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -331,6 +331,8 @@ pub enum Schedule {
     TightAssociative,
     // This constant node doesn't need to be memset to zero.
     NoResetConstant,
+    // This call should be called in a spawned future.
+    AsyncCall,
 }
 
 /*
diff --git a/hercules_samples/call/build.rs b/hercules_samples/call/build.rs
index e7b6dee9..3ecbb221 100644
--- a/hercules_samples/call/build.rs
+++ b/hercules_samples/call/build.rs
@@ -1,22 +1,11 @@
 use juno_build::JunoCompiler;
 
 fn main() {
-    #[cfg(not(feature = "cuda"))]
-    {
-        JunoCompiler::new()
-            .ir_in_src("call.hir")
-            .unwrap()
-            .build()
-            .unwrap();
-    }
-    #[cfg(feature = "cuda")]
-    {
-        JunoCompiler::new()
-            .ir_in_src("call.hir")
-            .unwrap()
-            .schedule_in_src("gpu.sch")
-            .unwrap()
-            .build()
-            .unwrap();
-    }
+    JunoCompiler::new()
+        .ir_in_src("call.hir")
+        .unwrap()
+        .schedule_in_src("sched.sch")
+        .unwrap()
+        .build()
+        .unwrap();
 }
diff --git a/hercules_samples/call/src/call.hir b/hercules_samples/call/src/call.hir
index 937ce1ef..cecee343 100644
--- a/hercules_samples/call/src/call.hir
+++ b/hercules_samples/call/src/call.hir
@@ -1,7 +1,11 @@
 fn myfunc(x: u64) -> u64
-  cr = region(start)
-  y = call<16>(add, cr, x, x)
-  r = return(cr, y)
+  cr1 = region(start)
+  cr2 = region(cr1)
+  c = constant(u64, 24)
+  y = call<16>(add, cr1, x, x)
+  z = call<10>(add, cr2, x, c)
+  w = add(y, z)
+  r = return(cr2, w)
 
 fn add<1>(x: u64, y: u64) -> u64
   w = add(x, y)
diff --git a/hercules_samples/call/src/main.rs b/hercules_samples/call/src/main.rs
index ff4b6f4a..ea83a1df 100644
--- a/hercules_samples/call/src/main.rs
+++ b/hercules_samples/call/src/main.rs
@@ -8,9 +8,10 @@ fn main() {
     async_std::task::block_on(async {
         let mut r = runner!(myfunc);
         let x = r.run(7).await;
+        assert_eq!(x, 71);
         let mut r = runner!(add);
         let y = r.run(10, 2, 18).await;
-        assert_eq!(x, y);
+        assert_eq!(y, 30);
     });
 }
 
diff --git a/hercules_samples/call/src/gpu.sch b/hercules_samples/call/src/sched.sch
similarity index 73%
rename from hercules_samples/call/src/gpu.sch
rename to hercules_samples/call/src/sched.sch
index cc4ef88f..7d4172fb 100644
--- a/hercules_samples/call/src/gpu.sch
+++ b/hercules_samples/call/src/sched.sch
@@ -2,9 +2,6 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-let out = auto-outline(*);
-gpu(out.add);
-
 ip-sroa(*);
 sroa(*);
 dce(*);
@@ -13,5 +10,6 @@ phi-elim(*);
 dce(*);
 
 infer-schedules(*);
+async-call(myfunc@y);
 
 gcm(*);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 4ea8dfb5..1aaa10cd 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -137,6 +137,7 @@ impl FromStr for Appliable {
             "parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)),
             "vectorize" => Ok(Appliable::Schedule(Schedule::Vectorizable)),
             "no-memset" | "no-reset" => Ok(Appliable::Schedule(Schedule::NoResetConstant)),
+            "task-parallel" | "async-call" => Ok(Appliable::Schedule(Schedule::AsyncCall)),
 
             _ => Err(s.to_string()),
         }
-- 
GitLab