From a43251a4b952ac1ffc2c139fe8a316ba101e06e7 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Wed, 12 Feb 2025 15:09:45 -0600
Subject: [PATCH] AsyncCall schedule, lower DynamicConstant in RT backend,
 simplify call test

---
 hercules_cg/src/rt.rs                         |  6 +++++
 hercules_ir/src/ir.rs                         |  2 ++
 hercules_samples/call/build.rs                | 25 ++++++-------------
 hercules_samples/call/src/call.hir            | 10 +++++---
 hercules_samples/call/src/main.rs             |  3 ++-
 .../call/src/{gpu.sch => sched.sch}           |  4 +--
 juno_scheduler/src/compile.rs                 |  1 +
 7 files changed, 26 insertions(+), 25 deletions(-)
 rename hercules_samples/call/src/{gpu.sch => sched.sch} (73%)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 62f683ce..fd44d777 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -474,6 +474,12 @@ impl<'a> RTContext<'a> {
                     }
                 }
             }
+            Node::DynamicConstant { id: dc_id } => {
+                let block = &mut blocks.get_mut(&bb).unwrap().data;
+                write!(block, "{} = ", self.get_value(id, bb))?;
+                self.codegen_dynamic_constant(dc_id, block)?;
+                write!(block, ";")?;
+            }
             Node::ThreadID { control, dimension } => {
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 write!(
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index 1dce5cfc..eb008904 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -331,6 +331,8 @@ pub enum Schedule {
     TightAssociative,
     // This constant node doesn't need to be memset to zero.
     NoResetConstant,
+    // This call should be called in a spawned future.
+    AsyncCall,
 }
 
 /*
diff --git a/hercules_samples/call/build.rs b/hercules_samples/call/build.rs
index e7b6dee9..3ecbb221 100644
--- a/hercules_samples/call/build.rs
+++ b/hercules_samples/call/build.rs
@@ -1,22 +1,11 @@
 use juno_build::JunoCompiler;
 
 fn main() {
-    #[cfg(not(feature = "cuda"))]
-    {
-        JunoCompiler::new()
-            .ir_in_src("call.hir")
-            .unwrap()
-            .build()
-            .unwrap();
-    }
-    #[cfg(feature = "cuda")]
-    {
-        JunoCompiler::new()
-            .ir_in_src("call.hir")
-            .unwrap()
-            .schedule_in_src("gpu.sch")
-            .unwrap()
-            .build()
-            .unwrap();
-    }
+    JunoCompiler::new()
+        .ir_in_src("call.hir")
+        .unwrap()
+        .schedule_in_src("sched.sch")
+        .unwrap()
+        .build()
+        .unwrap();
 }
diff --git a/hercules_samples/call/src/call.hir b/hercules_samples/call/src/call.hir
index 937ce1ef..cecee343 100644
--- a/hercules_samples/call/src/call.hir
+++ b/hercules_samples/call/src/call.hir
@@ -1,7 +1,11 @@
 fn myfunc(x: u64) -> u64
-  cr = region(start)
-  y = call<16>(add, cr, x, x)
-  r = return(cr, y)
+  cr1 = region(start)
+  cr2 = region(cr1)
+  c = constant(u64, 24)
+  y = call<16>(add, cr1, x, x)
+  z = call<10>(add, cr2, x, c)
+  w = add(y, z)
+  r = return(cr2, w)
 
 fn add<1>(x: u64, y: u64) -> u64
   w = add(x, y)
diff --git a/hercules_samples/call/src/main.rs b/hercules_samples/call/src/main.rs
index ff4b6f4a..ea83a1df 100644
--- a/hercules_samples/call/src/main.rs
+++ b/hercules_samples/call/src/main.rs
@@ -8,9 +8,10 @@ fn main() {
     async_std::task::block_on(async {
         let mut r = runner!(myfunc);
         let x = r.run(7).await;
+        assert_eq!(x, 71);
         let mut r = runner!(add);
         let y = r.run(10, 2, 18).await;
-        assert_eq!(x, y);
+        assert_eq!(y, 30);
     });
 }
 
diff --git a/hercules_samples/call/src/gpu.sch b/hercules_samples/call/src/sched.sch
similarity index 73%
rename from hercules_samples/call/src/gpu.sch
rename to hercules_samples/call/src/sched.sch
index cc4ef88f..7d4172fb 100644
--- a/hercules_samples/call/src/gpu.sch
+++ b/hercules_samples/call/src/sched.sch
@@ -2,9 +2,6 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
-let out = auto-outline(*);
-gpu(out.add);
-
 ip-sroa(*);
 sroa(*);
 dce(*);
@@ -13,5 +10,6 @@ phi-elim(*);
 dce(*);
 
 infer-schedules(*);
+async-call(myfunc@y);
 
 gcm(*);
diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs
index 4ea8dfb5..1aaa10cd 100644
--- a/juno_scheduler/src/compile.rs
+++ b/juno_scheduler/src/compile.rs
@@ -137,6 +137,7 @@ impl FromStr for Appliable {
             "parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)),
             "vectorize" => Ok(Appliable::Schedule(Schedule::Vectorizable)),
             "no-memset" | "no-reset" => Ok(Appliable::Schedule(Schedule::NoResetConstant)),
+            "task-parallel" | "async-call" => Ok(Appliable::Schedule(Schedule::AsyncCall)),
 
             _ => Err(s.to_string()),
         }
-- 
GitLab