From 85ce14bde0ca675bfa967fb35dc746eea343ca86 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sun, 16 Feb 2025 10:15:36 -0600
Subject: [PATCH] enough rt backend stuff for multi-threaded dot

---
 hercules_cg/src/rt.rs        | 69 ++++++++++++++++++++++++++++++++----
 hercules_opt/src/gcm.rs      |  9 +++--
 juno_samples/dot/src/cpu.sch |  5 +--
 juno_samples/dot/src/main.rs |  2 +-
 4 files changed, 73 insertions(+), 12 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 8c5775d8..4d9a6cf6 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -139,7 +139,10 @@ struct RTContext<'a> {
 struct RustBlock {
     prologue: String,
     data: String,
+    phi_tmp_assignments: String,
+    phi_assignments: String,
     epilogue: String,
+    join_epilogue: String,
 }
 
 impl<'a> RTContext<'a> {
@@ -251,7 +254,28 @@ impl<'a> RTContext<'a> {
         // fork and join nodes open and close environments, respectively.
         for id in rev_po.iter() {
             let block = &blocks[id];
-            write!(w, "{}{}{}", block.prologue, block.data, block.epilogue)?;
+            if func.nodes[id.idx()].is_join() {
+                write!(
+                    w,
+                    "{}{}{}{}{}{}",
+                    block.prologue,
+                    block.data,
+                    block.epilogue,
+                    block.phi_tmp_assignments,
+                    block.phi_assignments,
+                    block.join_epilogue
+                )?;
+            } else {
+                write!(
+                    w,
+                    "{}{}{}{}{}",
+                    block.prologue,
+                    block.data,
+                    block.phi_tmp_assignments,
+                    block.phi_assignments,
+                    block.epilogue
+                )?;
+            }
         }
 
         // Close the root environment.
@@ -367,7 +391,10 @@ impl<'a> RTContext<'a> {
 
                 // Close the branch inside the async closure.
                 let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
-                write!(epilogue, "return;}}")?;
+                write!(
+                    epilogue,
+                    "::std::sync::atomic::fence(::std::sync::atomic::Ordering::Release);return;}}"
+                )?;
 
                 // Close the fork's environment.
                 self.codegen_close_environment(epilogue)?;
@@ -405,9 +432,10 @@ impl<'a> RTContext<'a> {
                     }
                 }
 
+                let join_epilogue = &mut blocks.get_mut(&id).unwrap().join_epilogue;
                 // Branch to the successor control node in the surrounding
                 // context, and close the branch for the join.
-                write!(epilogue, "control_token = {};}}", succ.idx())?;
+                write!(join_epilogue, "control_token = {};}}", succ.idx())?;
             }
             _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
         }
@@ -481,15 +509,39 @@ impl<'a> RTContext<'a> {
                 write!(block, ";")?;
             }
             Node::ThreadID { control, dimension } => {
+                assert_eq!(control, bb);
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
                 write!(
                     block,
                     "{} = tid_{}_{};",
                     self.get_value(id, bb, true),
-                    control.idx(),
+                    bb.idx(),
                     dimension
                 )?;
             }
+            Node::Phi { control, ref data } => {
+                assert_eq!(control, bb);
+                // Phis aren't executable in their own basic block - predecessor
+                // blocks assign the to-be phi values themselves. Assign
+                // temporary values first before assigning the phi itself, since
+                // there may be simultaneous inter-dependent phis.
+                for (data, pred) in zip(data.into_iter(), self.control_subgraph.preds(bb)) {
+                    let block = &mut blocks.get_mut(&pred).unwrap().phi_tmp_assignments;
+                    write!(
+                        block,
+                        "let {}_tmp = {};",
+                        self.get_value(id, pred, true),
+                        self.get_value(*data, pred, false),
+                    )?;
+                    let block = &mut blocks.get_mut(&pred).unwrap().phi_assignments;
+                    write!(
+                        block,
+                        "{} = {}_tmp;",
+                        self.get_value(id, pred, true),
+                        self.get_value(id, pred, false),
+                    )?;
+                }
+            }
             Node::Reduce {
                 control: _,
                 init: _,
@@ -498,11 +550,12 @@ impl<'a> RTContext<'a> {
                 assert!(func.schedules[id.idx()].contains(&Schedule::ParallelReduce));
             }
             Node::Call {
-                control: _,
+                control,
                 function: callee_id,
                 ref dynamic_constants,
                 ref args,
             } => {
+                assert_eq!(control, bb);
                 // The device backends ensure that device functions have the
                 // same interface as AsyncRust functions.
                 let block = &mut blocks.get_mut(&bb).unwrap().data;
@@ -975,7 +1028,9 @@ impl<'a> RTContext<'a> {
                     if is_reduce_on_child { "reduce" } else { "node" },
                     idx,
                     self.get_type(self.typing[idx]),
-                    if self.module.types[self.typing[idx].idx()].is_integer() {
+                    if self.module.types[self.typing[idx].idx()].is_bool() {
+                        "false"
+                    } else if self.module.types[self.typing[idx].idx()].is_integer() {
                         "0"
                     } else if self.module.types[self.typing[idx].idx()].is_float() {
                         "0.0"
@@ -1241,7 +1296,7 @@ impl<'a> RTContext<'a> {
             // Before using the value of a reduction outside the fork-join,
             // await the futures.
             format!(
-                "{{for fut in fork_{}.drain(..) {{ fut.await; }}; reduce_{}}}",
+                "{{for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire); reduce_{}}}",
                 fork.idx(),
                 id.idx()
             )
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 99c44d52..821d02ea 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -95,9 +95,11 @@ pub fn gcm(
 
     let bbs = basic_blocks(
         editor.func(),
+        editor.get_types(),
         editor.func_id(),
         def_use,
         reverse_postorder,
+        typing,
         dom,
         loops,
         reduce_cycles,
@@ -218,9 +220,11 @@ fn preliminary_fixups(
  */
 fn basic_blocks(
     function: &Function,
+    types: Ref<Vec<Type>>,
     func_id: FunctionID,
     def_use: &ImmutableDefUseMap,
     reverse_postorder: &Vec<NodeID>,
+    typing: &Vec<TypeID>,
     dom: &DomTree,
     loops: &LoopTree,
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
@@ -498,8 +502,9 @@ fn basic_blocks(
                 // control dependent as possible, even inside loops. In GPU
                 // functions specifically, lift constants that may be returned
                 // outside fork-joins.
-                let is_constant_or_undef =
-                    function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef();
+                let is_constant_or_undef = (function.nodes[id.idx()].is_constant()
+                    || function.nodes[id.idx()].is_undef())
+                    && !types[typing[id.idx()].idx()].is_primitive();
                 let is_gpu_returned = devices[func_id.idx()] == Device::CUDA
                     && objects[&func_id]
                         .objects(id)
diff --git a/juno_samples/dot/src/cpu.sch b/juno_samples/dot/src/cpu.sch
index aa87972e..1f8953d9 100644
--- a/juno_samples/dot/src/cpu.sch
+++ b/juno_samples/dot/src/cpu.sch
@@ -8,7 +8,7 @@ fork-guard-elim(dot);
 dce(dot);
 
 fork-tile[8, 0, false, true](dot);
-fork-tile[32, 0, false, false](dot);
+fork-tile[8, 0, false, false](dot);
 let split_out = fork-split(dot);
 infer-schedules(*);
 clean-monoid-reduces(*);
@@ -29,8 +29,9 @@ ccp(dot);
 simplify-cfg(dot);
 gvn(dot);
 dce(dot);
+infer-schedules(dot);
 
-unforkify-one(out);
+unforkify(out);
 ccp(out);
 simplify-cfg(out);
 gvn(out);
diff --git a/juno_samples/dot/src/main.rs b/juno_samples/dot/src/main.rs
index b73f8710..bd887194 100644
--- a/juno_samples/dot/src/main.rs
+++ b/juno_samples/dot/src/main.rs
@@ -9,7 +9,7 @@ juno_build::juno!("dot");
 
 fn main() {
     async_std::task::block_on(async {
-        const N: u64 = 1024 * 1024;
+        const N: u64 = 1024 * 8;
         let a: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect();
         let b: Box<[i64]> = (0..N).map(|_| random::<i64>() % 100).collect();
         let a_herc = HerculesImmBox::from(&a as &[i64]);
-- 
GitLab