From ed06da2d1c9811b28a4a78d6aa5e79acf3877825 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 23 Dec 2024 16:57:46 -0800
Subject: [PATCH] Refactor device placement out

---
 hercules_cg/src/device.rs | 24 ++++++++++++++++++++++++
 hercules_cg/src/lib.rs    |  2 ++
 hercules_ir/src/ir.rs     |  5 ++---
 hercules_opt/src/pass.rs  | 37 ++++++++++++++++++-------------------
 4 files changed, 46 insertions(+), 22 deletions(-)
 create mode 100644 hercules_cg/src/device.rs

diff --git a/hercules_cg/src/device.rs b/hercules_cg/src/device.rs
new file mode 100644
index 00000000..7dbeeeda
--- /dev/null
+++ b/hercules_cg/src/device.rs
@@ -0,0 +1,24 @@
+extern crate hercules_ir;
+
+use self::hercules_ir::*;
+
+/*
+ * Top level function to definitively place functions onto devices. A function
+ * may store a device placement, but only optionally - this function assigns
+ * devices to the rest of the functions.
+ */
+pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec<Device> {
+    let mut devices = vec![];
+
+    for (idx, function) in functions.into_iter().enumerate() {
+        if let Some(device) = function.device {
+            devices.push(device);
+        } else if function.entry || callgraph.num_callees(FunctionID::new(idx)) != 0 {
+            devices.push(Device::AsyncRust);
+        } else {
+            devices.push(Device::LLVM);
+        }
+    }
+
+    devices
+}
diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index 9013eff7..952ce368 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -1,9 +1,11 @@
 #![feature(if_let_guard, let_chains)]
 
 pub mod cpu;
+pub mod device;
 pub mod mem;
 pub mod rt;
 
 pub use crate::cpu::*;
+pub use crate::device::*;
 pub use crate::mem::*;
 pub use crate::rt::*;
diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs
index f6356dfd..c0faec59 100644
--- a/hercules_ir/src/ir.rs
+++ b/hercules_ir/src/ir.rs
@@ -333,9 +333,8 @@ pub enum Schedule {
 pub enum Device {
     LLVM,
     NVVM,
-    // Internal nodes in the call graph are lowered to async Rust code that
-    // calls device functions (leaf nodes in the call graph), possibly
-    // concurrently.
+    // Entry functions are lowered to async Rust code that calls device
+    // functions (leaf nodes in the call graph), possibly concurrently.
     AsyncRust,
 }
 
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index ce53916b..b85835cf 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -864,13 +864,25 @@ impl PassManager {
                     let memory_objects_mutable =
                         memory_objects_mutability(&self.module, &callgraph, &memory_objects);
 
+                    let devices = device_placement(&self.module.functions, &callgraph);
+
                     let mut rust_rt = String::new();
                     let mut llvm_ir = String::new();
                     for idx in 0..self.module.functions.len() {
-                        if self.module.functions[idx].entry
-                            || callgraph.num_callees(FunctionID::new(idx)) != 0
-                        {
-                            rt_codegen(
+                        match devices[idx] {
+                            Device::LLVM => cpu_codegen(
+                                &self.module.functions[idx],
+                                &self.module.types,
+                                &self.module.constants,
+                                &self.module.dynamic_constants,
+                                &reverse_postorders[idx],
+                                &typing[idx],
+                                &control_subgraphs[idx],
+                                &bbs[idx],
+                                &mut llvm_ir,
+                            )
+                            .unwrap(),
+                            Device::AsyncRust => rt_codegen(
                                 FunctionID::new(idx),
                                 &self.module,
                                 &reverse_postorders[idx],
@@ -882,21 +894,8 @@ impl PassManager {
                                 &memory_objects_mutable,
                                 &mut rust_rt,
                             )
-                            .unwrap();
-                        } else {
-                            // TODO: determine which backend to use for function.
-                            cpu_codegen(
-                                &self.module.functions[idx],
-                                &self.module.types,
-                                &self.module.constants,
-                                &self.module.dynamic_constants,
-                                &reverse_postorders[idx],
-                                &typing[idx],
-                                &control_subgraphs[idx],
-                                &bbs[idx],
-                                &mut llvm_ir,
-                            )
-                            .unwrap();
+                            .unwrap(),
+                            _ => todo!(),
                         }
                     }
                     println!("{}", llvm_ir);
-- 
GitLab