From 17ba1e7b13dce45cb772d3ab8ea8314da329914b Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 7 Feb 2025 10:43:26 -0600
Subject: [PATCH 01/13] multi-device test

---
 Cargo.lock                                    | 10 ++++++
 Cargo.toml                                    |  1 +
 juno_samples/multi_device/Cargo.toml          | 21 ++++++++++++
 juno_samples/multi_device/build.rs            | 14 ++++++++
 juno_samples/multi_device/src/main.rs         | 22 +++++++++++++
 juno_samples/multi_device/src/multi_device.jn | 13 ++++++++
 .../multi_device/src/multi_device.sch         | 32 +++++++++++++++++++
 7 files changed, 113 insertions(+)
 create mode 100644 juno_samples/multi_device/Cargo.toml
 create mode 100644 juno_samples/multi_device/build.rs
 create mode 100644 juno_samples/multi_device/src/main.rs
 create mode 100644 juno_samples/multi_device/src/multi_device.jn
 create mode 100644 juno_samples/multi_device/src/multi_device.sch

diff --git a/Cargo.lock b/Cargo.lock
index 06ee00ff..4a9b8891 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1189,6 +1189,16 @@ dependencies = [
  "with_builtin_macros",
 ]
 
+[[package]]
+name = "juno_multi_device"
+version = "0.1.0"
+dependencies = [
+ "async-std",
+ "hercules_rt",
+ "juno_build",
+ "with_builtin_macros",
+]
+
 [[package]]
 name = "juno_patterns"
 version = "0.1.0"
diff --git a/Cargo.toml b/Cargo.toml
index 54cfc512..eeb5e69d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -32,4 +32,5 @@ members = [
 	"juno_samples/schedule_test",
 	"juno_samples/edge_detection",
 	"juno_samples/fork_join_tests",
+	"juno_samples/multi_device",
 ]
diff --git a/juno_samples/multi_device/Cargo.toml b/juno_samples/multi_device/Cargo.toml
new file mode 100644
index 00000000..e87f6dd1
--- /dev/null
+++ b/juno_samples/multi_device/Cargo.toml
@@ -0,0 +1,21 @@
+[package]
+name = "juno_multi_device"
+version = "0.1.0"
+authors = ["Russel Arbore <rarbore2@illinois.edu>"]
+edition = "2021"
+
+[[bin]]
+name = "juno_multi_device"
+path = "src/main.rs"
+
+[features]
+cuda = ["juno_build/cuda", "hercules_rt/cuda"]
+
+[build-dependencies]
+juno_build = { path = "../../juno_build" }
+
+[dependencies]
+juno_build = { path = "../../juno_build" }
+hercules_rt = { path = "../../hercules_rt" }
+with_builtin_macros = "0.1.0"
+async-std = "*"
diff --git a/juno_samples/multi_device/build.rs b/juno_samples/multi_device/build.rs
new file mode 100644
index 00000000..f3753192
--- /dev/null
+++ b/juno_samples/multi_device/build.rs
@@ -0,0 +1,14 @@
+use juno_build::JunoCompiler;
+
+fn main() {
+    #[cfg(feature = "cuda")]
+    {
+        JunoCompiler::new()
+            .file_in_src("multi_device.jn")
+            .unwrap()
+            .schedule_in_src("multi_device.sch")
+            .unwrap()
+            .build()
+            .unwrap();
+    }
+}
diff --git a/juno_samples/multi_device/src/main.rs b/juno_samples/multi_device/src/main.rs
new file mode 100644
index 00000000..f62de56c
--- /dev/null
+++ b/juno_samples/multi_device/src/main.rs
@@ -0,0 +1,22 @@
+#![feature(concat_idents)]
+
+#[cfg(feature = "cuda")]
+use hercules_rt::runner;
+
+#[cfg(feature = "cuda")]
+juno_build::juno!("multi_device");
+
+#[cfg(feature = "cuda")]
+fn main() {
+    async_std::task::block_on(async {
+        let mut r = runner!(multi_device_1);
+        let out = r.run(42).await;
+        assert_eq!(out, 42 * 16);
+    });
+}
+
+#[test]
+fn multi_device_test() {
+    #[cfg(feature = "cuda")]
+    main();
+}
diff --git a/juno_samples/multi_device/src/multi_device.jn b/juno_samples/multi_device/src/multi_device.jn
new file mode 100644
index 00000000..4713363f
--- /dev/null
+++ b/juno_samples/multi_device/src/multi_device.jn
@@ -0,0 +1,13 @@
+#[entry]
+fn multi_device_1(input : i32) -> i32 {
+  @loop1 @cons let arr : i32[16];
+  @loop1 for i = 0 to 16 {
+    arr[i] = input;
+  }
+
+  @loop2 let sum : i32;
+  @loop2 for i = 0 to 16 {
+    sum += arr[i];
+  }
+  return sum;
+}
diff --git a/juno_samples/multi_device/src/multi_device.sch b/juno_samples/multi_device/src/multi_device.sch
new file mode 100644
index 00000000..e5029a10
--- /dev/null
+++ b/juno_samples/multi_device/src/multi_device.sch
@@ -0,0 +1,32 @@
+gvn(*);
+phi-elim(*);
+dce(*);
+ip-sroa(*);
+sroa(*);
+dce(*);
+gvn(*);
+dce(*);
+phi-elim(*);
+dce(*);
+ccp(*);
+dce(*);
+simplify-cfg(*);
+dce(*);
+forkify(*);
+fork-guard-elim(*);
+fork-coalesce(*);
+dce(*);
+
+no-memset(multi_device_1@cons);
+let l1 = outline(multi_device_1@loop1);
+let l2 = outline(multi_device_1@loop2);
+gpu(l1);
+cpu(l2);
+ip-sroa(*);
+sroa(*);
+dce(*);
+
+infer-schedules(*);
+xdot[true](*);
+
+gcm(*);
-- 
GitLab


From ca719435eca2c6ded8af3a735a63875cc48ede5e Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 7 Feb 2025 17:18:38 -0600
Subject: [PATCH 02/13] Remove object_device_demands, get 1/2 of analysis for
 node colors done

---
 hercules_cg/src/lib.rs   |  16 ++++-
 hercules_cg/src/rt.rs    |  34 +++------
 hercules_opt/src/gcm.rs  | 145 ++++++++++++++++++++++++++-------------
 juno_scheduler/src/pm.rs |  34 ++-------
 4 files changed, 123 insertions(+), 106 deletions(-)

diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index dab4dbac..52bf6177 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -39,10 +39,11 @@ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
 
 /*
  * Nodes producing collection values are "colored" with what device their
- * underlying memory lives on.
+ * underlying memory lives on. Also explicitly store the device of the
+ * parameters and return of each function.
  */
-pub type FunctionNodeColors = BTreeMap<NodeID, Device>;
-pub type NodeColors = Vec<FunctionNodeColors>;
+pub type FunctionNodeColors = (BTreeMap<NodeID, Device>, Vec<Device>, Device);
+pub type NodeColors = BTreeMap<FunctionID, FunctionNodeColors>;
 
 /*
  * The allocation information of each function is a size of the backing memory
@@ -53,3 +54,12 @@ pub type FunctionBackingAllocation =
     BTreeMap<Device, (DynamicConstantID, BTreeMap<NodeID, DynamicConstantID>)>;
 pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>;
 pub const BACKED_DEVICES: [Device; 2] = [Device::LLVM, Device::CUDA];
+
+pub fn backing_device(device: Device) -> Device {
+    match device {
+        Device::LLVM => Device::LLVM,
+        Device::CUDA => Device::CUDA,
+        // Memory loads and stores in AsyncRust code execute on the CPU.
+        Device::AsyncRust => Device::LLVM,
+    }
+}
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 2c5f7c35..19db630d 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -941,23 +941,8 @@ impl<'a> RTContext<'a> {
         // they are collections and whether they should be immutable or mutable
         // references.
         let func = self.get_func();
-        let mut param_devices = vec![None; func.param_types.len()];
-        let mut return_device = None;
-        for idx in 0..func.nodes.len() {
-            match func.nodes[idx] {
-                Node::Parameter { index } => {
-                    let device = self.node_colors.get(&NodeID::new(idx));
-                    assert!(param_devices[index].is_none() || param_devices[index] == device);
-                    param_devices[index] = device;
-                }
-                Node::Return { control: _, data } => {
-                    let device = self.node_colors.get(&data);
-                    assert!(return_device.is_none() || return_device == device);
-                    return_device = device;
-                }
-                _ => {}
-            }
-        }
+        let mut param_devices = &self.node_colors.1;
+        let mut return_device = self.node_colors.2;
         let mut param_muts = vec![false; func.param_types.len()];
         let mut return_mut = true;
         let objects = &self.collection_objects[&self.func_id];
@@ -1010,11 +995,8 @@ impl<'a> RTContext<'a> {
                 write!(w, ", p{}: {}", idx, self.get_type(func.param_types[idx]))?;
             } else {
                 let device = match param_devices[idx] {
-                    Some(Device::LLVM) => "CPU",
-                    Some(Device::CUDA) => "CUDA",
-                    // For parameters that are unused, it doesn't really matter
-                    // what device is required, so just pick CPU for now.
-                    None => "CPU",
+                    Device::LLVM => "CPU",
+                    Device::CUDA => "CUDA",
                     _ => panic!(),
                 };
                 let mutability = if param_muts[idx] { "Mut" } else { "" };
@@ -1029,8 +1011,8 @@ impl<'a> RTContext<'a> {
             write!(w, ") -> {} {{", self.get_type(func.return_type))?;
         } else {
             let device = match return_device {
-                Some(Device::LLVM) => "CPU",
-                Some(Device::CUDA) => "CUDA",
+                Device::LLVM => "CPU",
+                Device::CUDA => "CUDA",
                 _ => panic!(),
             };
             let mutability = if return_mut { "Mut" } else { "" };
@@ -1094,8 +1076,8 @@ impl<'a> RTContext<'a> {
             write!(w, "        ret")?;
         } else {
             let device = match return_device {
-                Some(Device::LLVM) => "CPU",
-                Some(Device::CUDA) => "CUDA",
+                Device::LLVM => "CPU",
+                Device::CUDA => "CUDA",
                 _ => panic!(),
             };
             let mutability = if return_mut { "Mut" } else { "" };
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index ae8c813d..a4f76b75 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -86,7 +86,7 @@ pub fn gcm(
     reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
     objects: &CollectionObjects,
     devices: &Vec<Device>,
-    object_device_demands: &FunctionObjectDeviceDemands,
+    node_colors: &NodeColors,
     backing_allocations: &BackingAllocations,
 ) -> Option<(BasicBlocks, FunctionNodeColors, FunctionBackingAllocation)> {
     if preliminary_fixups(editor, fork_join_map, loops, reduce_cycles) {
@@ -121,31 +121,16 @@ pub fn gcm(
     let func_id = editor.func_id();
     let Some(node_colors) = color_nodes(
         editor,
+        def_use,
         reverse_postorder,
+        typing,
         &objects[&func_id],
-        &object_device_demands,
+        &devices,
+        node_colors,
     ) else {
         return None;
     };
 
-    let device = devices[func_id.idx()];
-    match device {
-        Device::LLVM | Device::CUDA => {
-            // Check that every object that has a demand in this function are
-            // only demanded on this device.
-            for demands in object_device_demands {
-                assert!(demands.is_empty() || (demands.len() == 1 && demands.contains(&device)))
-            }
-        }
-        Device::AsyncRust => {
-            // Check that every object that has a demand in this function only
-            // has a demand from one device.
-            for demands in object_device_demands {
-                assert!(demands.len() <= 1);
-            }
-        }
-    }
-
     let mut alignments = vec![];
     Ref::map(editor.get_types(), |types| {
         for idx in 0..types.len() {
@@ -1148,43 +1133,109 @@ fn liveness_dataflow(
     }
 }
 
+#[derive(Debug, Clone, PartialEq, Eq)]
+struct Colors {
+    colors: BTreeSet<Device>,
+}
+
+impl Semilattice for Colors {
+    fn meet(a: &Self, b: &Self) -> Self {
+        Colors {
+            colors: a.colors.union(&b.colors).map(|device| *device).collect(),
+        }
+    }
+
+    fn bottom() -> Self {
+        Colors {
+            colors: BACKED_DEVICES.into_iter().collect(),
+        }
+    }
+
+    fn top() -> Self {
+        Colors {
+            colors: BTreeSet::new(),
+        }
+    }
+}
+
 /*
  * Determine what device each node produces a collection onto. Insert inter-
  * device clones when a single node may potentially be on different devices.
  */
 fn color_nodes(
-    _editor: &mut FunctionEditor,
+    editor: &mut FunctionEditor,
+    def_use: &ImmutableDefUseMap,
     reverse_postorder: &Vec<NodeID>,
+    typing: &Vec<TypeID>,
     objects: &FunctionCollectionObjects,
-    object_device_demands: &FunctionObjectDeviceDemands,
+    devices: &Vec<Device>,
+    node_colors: &NodeColors,
 ) -> Option<FunctionNodeColors> {
-    // First, try to give each node a single color.
-    let mut colors = BTreeMap::new();
-    let mut bad_node = None;
-    'nodes: for id in reverse_postorder {
-        let mut device = None;
-        for object in objects.objects(*id) {
-            for demand in object_device_demands[object.idx()].iter() {
-                if let Some(device) = device
-                    && device != *demand
-                {
-                    bad_node = Some(id);
-                    break 'nodes;
+    // First, do a backward dataflow analysis to figure out which device each
+    // node is "demanded" on - these are the devices the collection will be used
+    // on by users.
+    let nodes = &editor.func().nodes;
+    let func_id = editor.func_id();
+    let demands = backward_dataflow(editor.func(), def_use, reverse_postorder, |inputs, id| {
+        let mut demands = Colors {
+            colors: BTreeSet::new(),
+        };
+        for user in editor.get_users(id) {
+            match nodes[user.idx()] {
+                Node::Read {
+                    collect,
+                    indices: _,
+                } if user == collect && editor.get_type(typing[user.idx()]).is_primitive() => {
+                    let device = backing_device(devices[func_id.idx()]);
+                    demands.colors.insert(device);
                 }
-                device = Some(*demand);
-            }
-        }
-        if let Some(device) = device {
-            colors.insert(*id, device);
-        } else {
-            assert!(objects.objects(*id).is_empty(), "PANIC: Found an object with no device demands. This is technically possible and is easily supported by just picking an arbitrary device for this object. This assert exists because I'm curious to see where this will be needed first, and if that use is frivolous or not.");
+                Node::Write {
+                    collect,
+                    data,
+                    indices: _,
+                } if user == collect && editor.get_type(typing[data.idx()]).is_primitive() => {
+                    let device = backing_device(devices[func_id.idx()]);
+                    demands.colors.insert(device);
+                }
+                Node::Call {
+                    control: _,
+                    function: callee,
+                    dynamic_constants: _,
+                    ref args,
+                } => {
+                    for (idx, arg) in args.into_iter().enumerate() {
+                        if *arg == user {
+                            let device = node_colors[&callee].1[idx];
+                            demands.colors.insert(device);
+                        }
+                    }
+                }
+                _ if !editor.get_type(typing[id.idx()]).is_primitive() => {
+                    for (pos, other_user) in def_use.get_users(id).into_iter().enumerate() {
+                        if user == *other_user {
+                            demands = Colors::meet(&demands, inputs[pos]);
+                        }
+                    }
+                }
+                _ => {}
+            };
         }
-    }
-    if bad_node.is_some() {
-        todo!("Deal with inter-device demands.")
-    }
+        demands
+    });
 
-    Some(colors)
+    // Second, do a forward dataflow analysis to figure out which device each
+    // node is "produced" on - these are the devices the collection will be put
+    // on by uses.
+
+    // Third, reconcile nodes that are produced on or demanded on multiple
+    // devices. There are the following cases (checked in this order):
+    // 1. This is a phi, reduce, or select node where there are multiple
+    //    producer devices, but each individual input has at most one input
+    //    device. Pick a demanded device (or if there is none, just the CPU) and
+    //    add copies on the inputs that don't produce on that device.
+    // 2. This is a node with multiple demanded devices, but a single produced
+    //    device. Replace uses of this node with copies to other devices.
+    todo!()
 }
 
 fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> DynamicConstantID {
@@ -1282,7 +1333,7 @@ fn object_allocation(
             match *edit.get_node(id) {
                 Node::Constant { id: _ } => {
                     if !edit.get_type(typing[id.idx()]).is_primitive() {
-                        let device = node_colors[&id];
+                        let device = node_colors.0[&id];
                         let (total, offsets) =
                             fba.entry(device).or_insert_with(|| (zero, BTreeMap::new()));
                         *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]);
diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs
index 5f445217..f7dd102d 100644
--- a/juno_scheduler/src/pm.rs
+++ b/juno_scheduler/src/pm.rs
@@ -188,7 +188,6 @@ pub struct PassManager {
     pub collection_objects: Option<CollectionObjects>,
     pub callgraph: Option<CallGraph>,
     pub devices: Option<Vec<Device>>,
-    pub object_device_demands: Option<ObjectDeviceDemands>,
     pub bbs: Option<Vec<BasicBlocks>>,
     pub node_colors: Option<NodeColors>,
     pub backing_allocations: Option<BackingAllocations>,
@@ -227,7 +226,6 @@ impl PassManager {
             collection_objects: None,
             callgraph: None,
             devices: None,
-            object_device_demands: None,
             bbs: None,
             node_colors: None,
             backing_allocations: None,
@@ -539,27 +537,6 @@ impl PassManager {
         }
     }
 
-    pub fn make_object_device_demands(&mut self) {
-        if self.object_device_demands.is_none() {
-            self.make_typing();
-            self.make_callgraph();
-            self.make_collection_objects();
-            self.make_devices();
-            let typing = self.typing.as_ref().unwrap();
-            let callgraph = self.callgraph.as_ref().unwrap();
-            let collection_objects = self.collection_objects.as_ref().unwrap();
-            let devices = self.devices.as_ref().unwrap();
-            self.object_device_demands = Some(object_device_demands(
-                &self.functions,
-                &self.types.borrow(),
-                typing,
-                callgraph,
-                collection_objects,
-                devices,
-            ));
-        }
-    }
-
     pub fn delete_gravestones(&mut self) {
         for func in self.functions.iter_mut() {
             func.delete_gravestones();
@@ -585,7 +562,6 @@ impl PassManager {
         self.collection_objects = None;
         self.callgraph = None;
         self.devices = None;
-        self.object_device_demands = None;
         self.bbs = None;
         self.node_colors = None;
         self.backing_allocations = None;
@@ -748,7 +724,7 @@ impl PassManager {
                     &callgraph,
                     &devices,
                     &bbs[idx],
-                    &node_colors[idx],
+                    &node_colors[&FunctionID::new(idx)],
                     &backing_allocations,
                     &mut rust_rt,
                 )
@@ -1702,7 +1678,6 @@ fn run_pass(
                 pm.make_reduce_cycles();
                 pm.make_collection_objects();
                 pm.make_devices();
-                pm.make_object_device_demands();
 
                 let def_uses = pm.def_uses.take().unwrap();
                 let reverse_postorders = pm.reverse_postorders.take().unwrap();
@@ -1714,10 +1689,9 @@ fn run_pass(
                 let control_subgraphs = pm.control_subgraphs.take().unwrap();
                 let collection_objects = pm.collection_objects.take().unwrap();
                 let devices = pm.devices.take().unwrap();
-                let object_device_demands = pm.object_device_demands.take().unwrap();
 
                 let mut bbs = vec![(vec![], vec![]); topo.len()];
-                let mut node_colors = vec![BTreeMap::new(); topo.len()];
+                let mut node_colors = BTreeMap::new();
                 let mut backing_allocations = BTreeMap::new();
                 let mut editors = build_editors(pm);
                 let mut any_failed = false;
@@ -1735,11 +1709,11 @@ fn run_pass(
                         &reduce_cycles[id.idx()],
                         &collection_objects,
                         &devices,
-                        &object_device_demands[id.idx()],
+                        &node_colors,
                         &backing_allocations,
                     ) {
                         bbs[id.idx()] = bb;
-                        node_colors[id.idx()] = function_node_colors;
+                        node_colors.insert(*id, function_node_colors);
                         backing_allocations.insert(*id, backing_allocation);
                     } else {
                         any_failed = true;
-- 
GitLab


From ba414f26a4851ae537dfdc873158514bbf07a3e1 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Fri, 7 Feb 2025 20:35:03 -0600
Subject: [PATCH 03/13] produces analysis for node coloring

---
 hercules_opt/src/gcm.rs | 40 ++++++++++++++++++++++++++++++++++------
 1 file changed, 34 insertions(+), 6 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index a4f76b75..ede71c7c 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1177,15 +1177,13 @@ fn color_nodes(
     let nodes = &editor.func().nodes;
     let func_id = editor.func_id();
     let demands = backward_dataflow(editor.func(), def_use, reverse_postorder, |inputs, id| {
-        let mut demands = Colors {
-            colors: BTreeSet::new(),
-        };
+        let mut demands = Colors::top();
         for user in editor.get_users(id) {
             match nodes[user.idx()] {
                 Node::Read {
                     collect,
                     indices: _,
-                } if user == collect && editor.get_type(typing[user.idx()]).is_primitive() => {
+                } if id == collect && editor.get_type(typing[user.idx()]).is_primitive() => {
                     let device = backing_device(devices[func_id.idx()]);
                     demands.colors.insert(device);
                 }
@@ -1193,7 +1191,7 @@ fn color_nodes(
                     collect,
                     data,
                     indices: _,
-                } if user == collect && editor.get_type(typing[data.idx()]).is_primitive() => {
+                } if id == collect && editor.get_type(typing[data.idx()]).is_primitive() => {
                     let device = backing_device(devices[func_id.idx()]);
                     demands.colors.insert(device);
                 }
@@ -1210,7 +1208,9 @@ fn color_nodes(
                         }
                     }
                 }
-                _ if !editor.get_type(typing[id.idx()]).is_primitive() => {
+                _ if !editor.get_type(typing[id.idx()]).is_primitive()
+                    && !editor.get_type(typing[id.idx()]).is_control() =>
+                {
                     for (pos, other_user) in def_use.get_users(id).into_iter().enumerate() {
                         if user == *other_user {
                             demands = Colors::meet(&demands, inputs[pos]);
@@ -1226,6 +1226,33 @@ fn color_nodes(
     // Second, do a forward dataflow analysis to figure out which device each
     // node is "produced" on - these are the devices the collection will be put
     // on by uses.
+    let produces = forward_dataflow(editor.func(), reverse_postorder, |inputs, id| {
+        match nodes[id.idx()] {
+            Node::Constant { id: _ } | Node::Undef { ty: _ }
+                if !editor.get_type(typing[id.idx()]).is_primitive() =>
+            {
+                Colors {
+                    colors: once(devices[func_id.idx()]).collect(),
+                }
+            }
+            Node::Call {
+                control: _,
+                function: callee,
+                dynamic_constants: _,
+                args: _,
+            } if !editor.get_type(typing[id.idx()]).is_primitive() => Colors {
+                colors: once(node_colors[&callee].2).collect(),
+            },
+            _ if !editor.get_type(typing[id.idx()]).is_primitive()
+                && !editor.get_type(typing[id.idx()]).is_control() =>
+            {
+                inputs
+                    .into_iter()
+                    .fold(Colors::top(), |acc, input| Colors::meet(&acc, input))
+            }
+            _ => Colors::top(),
+        }
+    });
 
     // Third, reconcile nodes that are produced on or demanded on multiple
     // devices. There are the following cases (checked in this order):
@@ -1235,6 +1262,7 @@ fn color_nodes(
     //    add copies on the inputs that don't produce on that device.
     // 2. This is a node with multiple demanded devices, but a single produced
     //    device. Replace uses of this node with copies to other devices.
+    println!("{}:\n{:#?}\n{:#?}", editor.func().name, demands, produces);
     todo!()
 }
 
-- 
GitLab


From 443085ca8760adf3a632b0201f4fac2e992023df Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 11:48:46 -0600
Subject: [PATCH 04/13] why didn't i think of this before

---
 hercules_cg/src/lib.rs  |   6 +-
 hercules_cg/src/rt.rs   |   6 +-
 hercules_opt/src/gcm.rs | 209 +++++++++++++++++++++-------------------
 3 files changed, 117 insertions(+), 104 deletions(-)

diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index 52bf6177..217f5879 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -42,7 +42,11 @@ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize {
  * underlying memory lives on. Also explicitly store the device of the
  * parameters and return of each function.
  */
-pub type FunctionNodeColors = (BTreeMap<NodeID, Device>, Vec<Device>, Device);
+pub type FunctionNodeColors = (
+    BTreeMap<NodeID, Device>,
+    Vec<Option<Device>>,
+    Option<Device>,
+);
 pub type NodeColors = BTreeMap<FunctionID, FunctionNodeColors>;
 
 /*
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 19db630d..e7fb1b60 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -994,7 +994,7 @@ impl<'a> RTContext<'a> {
             if self.module.types[func.param_types[idx].idx()].is_primitive() {
                 write!(w, ", p{}: {}", idx, self.get_type(func.param_types[idx]))?;
             } else {
-                let device = match param_devices[idx] {
+                let device = match param_devices[idx].unwrap() {
                     Device::LLVM => "CPU",
                     Device::CUDA => "CUDA",
                     _ => panic!(),
@@ -1010,7 +1010,7 @@ impl<'a> RTContext<'a> {
         if self.module.types[func.return_type.idx()].is_primitive() {
             write!(w, ") -> {} {{", self.get_type(func.return_type))?;
         } else {
-            let device = match return_device {
+            let device = match return_device.unwrap() {
                 Device::LLVM => "CPU",
                 Device::CUDA => "CUDA",
                 _ => panic!(),
@@ -1075,7 +1075,7 @@ impl<'a> RTContext<'a> {
         if self.module.types[func.return_type.idx()].is_primitive() {
             write!(w, "        ret")?;
         } else {
-            let device = match return_device {
+            let device = match return_device.unwrap() {
                 Device::LLVM => "CPU",
                 Device::CUDA => "CUDA",
                 _ => panic!(),
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index ede71c7c..8ec61ab2 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1134,28 +1134,13 @@ fn liveness_dataflow(
 }
 
 #[derive(Debug, Clone, PartialEq, Eq)]
-struct Colors {
-    colors: BTreeSet<Device>,
+enum UnificationTerm {
+    Node(NodeID),
+    Device(Device),
 }
 
-impl Semilattice for Colors {
-    fn meet(a: &Self, b: &Self) -> Self {
-        Colors {
-            colors: a.colors.union(&b.colors).map(|device| *device).collect(),
-        }
-    }
-
-    fn bottom() -> Self {
-        Colors {
-            colors: BACKED_DEVICES.into_iter().collect(),
-        }
-    }
-
-    fn top() -> Self {
-        Colors {
-            colors: BTreeSet::new(),
-        }
-    }
+fn unify(equations: &Vec<(UnificationTerm, UnificationTerm)>) -> BTreeMap<NodeID, Device> {
+    todo!()
 }
 
 /*
@@ -1171,99 +1156,123 @@ fn color_nodes(
     devices: &Vec<Device>,
     node_colors: &NodeColors,
 ) -> Option<FunctionNodeColors> {
-    // First, do a backward dataflow analysis to figure out which device each
-    // node is "demanded" on - these are the devices the collection will be used
-    // on by users.
     let nodes = &editor.func().nodes;
     let func_id = editor.func_id();
-    let demands = backward_dataflow(editor.func(), def_use, reverse_postorder, |inputs, id| {
-        let mut demands = Colors::top();
-        for user in editor.get_users(id) {
-            match nodes[user.idx()] {
-                Node::Read {
-                    collect,
-                    indices: _,
-                } if id == collect && editor.get_type(typing[user.idx()]).is_primitive() => {
-                    let device = backing_device(devices[func_id.idx()]);
-                    demands.colors.insert(device);
-                }
-                Node::Write {
-                    collect,
-                    data,
-                    indices: _,
-                } if id == collect && editor.get_type(typing[data.idx()]).is_primitive() => {
-                    let device = backing_device(devices[func_id.idx()]);
-                    demands.colors.insert(device);
-                }
-                Node::Call {
-                    control: _,
-                    function: callee,
-                    dynamic_constants: _,
-                    ref args,
-                } => {
-                    for (idx, arg) in args.into_iter().enumerate() {
-                        if *arg == user {
-                            let device = node_colors[&callee].1[idx];
-                            demands.colors.insert(device);
-                        }
-                    }
-                }
-                _ if !editor.get_type(typing[id.idx()]).is_primitive()
-                    && !editor.get_type(typing[id.idx()]).is_control() =>
-                {
-                    for (pos, other_user) in def_use.get_users(id).into_iter().enumerate() {
-                        if user == *other_user {
-                            demands = Colors::meet(&demands, inputs[pos]);
-                        }
-                    }
-                }
-                _ => {}
-            };
-        }
-        demands
-    });
+    let func_device = devices[func_id.idx()];
+    let mut func_colors = (
+        BTreeMap::new(),
+        vec![None; editor.func().param_types.len()],
+        None,
+    );
 
-    // Second, do a forward dataflow analysis to figure out which device each
-    // node is "produced" on - these are the devices the collection will be put
-    // on by uses.
-    let produces = forward_dataflow(editor.func(), reverse_postorder, |inputs, id| {
+    // Assigning nodes to devices is tricky due to function calls. Technically,
+    // all the information is there to decide what device to place nodes on, but
+    // coherently expressing the constraints and deriving the devices is not
+    // obvious. Express this as a unification problem, where we need to assign
+    // types (devices) to each node. Each unification term is either a node ID
+    // or a concrete device. Assemble a list of unification equations to solve.
+    let mut equations = vec![];
+    for id in editor.node_ids() {
         match nodes[id.idx()] {
-            Node::Constant { id: _ } | Node::Undef { ty: _ }
-                if !editor.get_type(typing[id.idx()]).is_primitive() =>
+            Node::Phi {
+                control: _,
+                ref data,
+            } if !editor.get_type(typing[id.idx()]).is_primitive() => {
+                // Every input to a phi needs to be on the same device. The
+                // phi itself is also on this device.
+                for (l, r) in zip(data.into_iter(), data.into_iter().skip(1).chain(once(&id))) {
+                    equations.push((UnificationTerm::Node(*l), UnificationTerm::Node(*r)));
+                }
+            }
+            Node::Reduce {
+                control: _,
+                init: first,
+                reduct: second,
+            }
+            | Node::Ternary {
+                op: TernaryOperator::Select,
+                first: _,
+                second: first,
+                third: second,
+            } if !editor.get_type(typing[id.idx()]).is_primitive() => {
+                // Every input to the reduce, and the reduce itself, are on
+                // the same device.
+                equations.push((UnificationTerm::Node(first), UnificationTerm::Node(second)));
+                equations.push((UnificationTerm::Node(second), UnificationTerm::Node(id)));
+            }
+            Node::Constant { id: _ }
+                if !editor.get_type(typing[id.idx()]).is_primitive()
+                    && func_device != Device::AsyncRust =>
             {
-                Colors {
-                    colors: once(devices[func_id.idx()]).collect(),
+                // Constants inside device functions are allocated on that
+                // device.
+                equations.push((
+                    UnificationTerm::Node(id),
+                    UnificationTerm::Device(func_device),
+                ));
+            }
+            Node::Read {
+                collect,
+                indices: _,
+            } => {
+                if editor.get_type(typing[id.idx()]).is_primitive() {
+                    // If this reads a primitive, then the collection needs to
+                    // be on the device of this function.
+                    equations.push((
+                        UnificationTerm::Node(collect),
+                        UnificationTerm::Device(backing_device(func_device)),
+                    ));
+                } else {
+                    // If this read just reads a sub-collection, then `collect`
+                    // and the read itself need to be on the same device.
+                    equations.push((UnificationTerm::Node(collect), UnificationTerm::Node(id)));
+                }
+            }
+            Node::Write {
+                collect,
+                data,
+                indices: _,
+            } => {
+                if editor.get_type(typing[data.idx()]).is_primitive() {
+                    // If this writes a primitive, then the collection needs to
+                    // be on the device of this function. Since we can do inter-
+                    // device copies, no constraint is needed with respect to
+                    // the device of `data`.
+                    equations.push((
+                        UnificationTerm::Node(collect),
+                        UnificationTerm::Device(backing_device(func_device)),
+                    ));
                 }
+                equations.push((UnificationTerm::Node(collect), UnificationTerm::Node(id)));
             }
             Node::Call {
                 control: _,
                 function: callee,
                 dynamic_constants: _,
-                args: _,
-            } if !editor.get_type(typing[id.idx()]).is_primitive() => Colors {
-                colors: once(node_colors[&callee].2).collect(),
-            },
-            _ if !editor.get_type(typing[id.idx()]).is_primitive()
-                && !editor.get_type(typing[id.idx()]).is_control() =>
-            {
-                inputs
-                    .into_iter()
-                    .fold(Colors::top(), |acc, input| Colors::meet(&acc, input))
-            }
-            _ => Colors::top(),
+                ref args,
+            } => {}
+            _ => {}
         }
-    });
+    }
 
-    // Third, reconcile nodes that are produced on or demanded on multiple
-    // devices. There are the following cases (checked in this order):
-    // 1. This is a phi, reduce, or select node where there are multiple
-    //    producer devices, but each individual input has at most one input
-    //    device. Pick a demanded device (or if there is none, just the CPU) and
-    //    add copies on the inputs that don't produce on that device.
-    // 2. This is a node with multiple demanded devices, but a single produced
-    //    device. Replace uses of this node with copies to other devices.
-    println!("{}:\n{:#?}\n{:#?}", editor.func().name, demands, produces);
-    todo!()
+    // Solve the unification problem. I couldn't find a simple enough crate for
+    // this, and the problems are usually pretty small, so just use a hand-
+    // rolled implementation for now.
+    println!("{:?}", equations);
+    let solve = unify(&equations);
+    func_colors.0 = solve;
+    for id in editor.node_ids() {
+        if let Node::Parameter { index } = nodes[id.idx()]
+            && let Some(device) = func_colors.0.get(&id)
+        {
+            func_colors.1[index] = Some(*device);
+        } else if let Node::Return { control: _, data } = nodes[id.idx()]
+            && let Some(device) = func_colors.0.get(&data)
+        {
+            func_colors.2 = Some(*device);
+        }
+    }
+    Some(func_colors)
 }
 
 fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> DynamicConstantID {
-- 
GitLab


From d28b293aee44ac3c5db1b9d7b7cd65fbaee29755 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 14:46:59 -0600
Subject: [PATCH 05/13] do call nodes

---
 hercules_opt/src/gcm.rs | 31 ++++++++++++++++++++++++++++---
 1 file changed, 28 insertions(+), 3 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 8ec61ab2..fc60d0c8 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -124,7 +124,7 @@ pub fn gcm(
         def_use,
         reverse_postorder,
         typing,
-        &objects[&func_id],
+        &objects,
         &devices,
         node_colors,
     ) else {
@@ -1152,7 +1152,7 @@ fn color_nodes(
     def_use: &ImmutableDefUseMap,
     reverse_postorder: &Vec<NodeID>,
     typing: &Vec<TypeID>,
-    objects: &FunctionCollectionObjects,
+    objects: &CollectionObjects,
     devices: &Vec<Device>,
     node_colors: &NodeColors,
 ) -> Option<FunctionNodeColors> {
@@ -1250,7 +1250,32 @@ fn color_nodes(
                 function: callee,
                 dynamic_constants: _,
                 ref args,
-            } => {}
+            } => {
+                // If the callee has a definite device for a parameter, add an
+                // equation for the corresponding argument.
+                for (idx, arg) in args.into_iter().enumerate() {
+                    if let Some(device) = node_colors[&callee].1[idx] {
+                        equations
+                            .push((UnificationTerm::Node(*arg), UnificationTerm::Device(device)));
+                    }
+                }
+
+                // If the callee has a definite device for the returned value,
+                // add an equation for the call node itself.
+                if let Some(device) = node_colors[&callee].2 {
+                    equations.push((UnificationTerm::Node(id), UnificationTerm::Device(device)));
+                }
+
+                // For any object that may be returned by the callee that
+                // originates as a parameter in the callee, the device of the
+                // corresponding argument and call node itself must be equal.
+                for ret in objects[&callee].returned_objects() {
+                    if let Some(idx) = objects[&callee].origin(*ret).try_parameter() {
+                        equations
+                            .push((UnificationTerm::Node(args[idx]), UnificationTerm::Node(id)));
+                    }
+                }
+            }
             _ => {}
         }
     }
-- 
GitLab


From 75599c982d3e5bf7d5d74fce554152b85fd25b68 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 17:57:26 -0600
Subject: [PATCH 06/13] Unification algo to assign devices

---
 hercules_opt/src/gcm.rs | 107 +++++++++++++++++++++++++++-------------
 1 file changed, 72 insertions(+), 35 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index fc60d0c8..d3075a04 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1133,14 +1133,50 @@ fn liveness_dataflow(
     }
 }
 
-#[derive(Debug, Clone, PartialEq, Eq)]
-enum UnificationTerm {
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum UTerm {
     Node(NodeID),
     Device(Device),
 }
 
-fn unify(equations: &Vec<(UnificationTerm, UnificationTerm)>) -> BTreeMap<NodeID, Device> {
-    todo!()
+fn unify(
+    mut equations: VecDeque<(UTerm, UTerm)>,
+) -> Result<BTreeMap<NodeID, Device>, BTreeMap<NodeID, Device>> {
+    let mut theta = BTreeMap::new();
+
+    let mut no_progress_iters = 0;
+    while no_progress_iters <= equations.len()
+        && let Some((l, r)) = equations.pop_front()
+    {
+        match (l, r) {
+            (UTerm::Node(_), UTerm::Node(_)) => {
+                if l != r {
+                    equations.push_back((l, r));
+                }
+                no_progress_iters += 1;
+            }
+            (UTerm::Node(n), UTerm::Device(d)) | (UTerm::Device(d), UTerm::Node(n)) => {
+                theta.insert(n, d);
+                for (l, r) in equations.iter_mut() {
+                    if *l == UTerm::Node(n) {
+                        *l = UTerm::Device(d);
+                    }
+                    if *r == UTerm::Node(n) {
+                        *r = UTerm::Device(d);
+                    }
+                }
+                no_progress_iters = 0;
+            }
+            (UTerm::Device(d1), UTerm::Device(d2)) if d1 == d2 => {
+                no_progress_iters = 0;
+            }
+            _ => {
+                return Err(theta);
+            }
+        }
+    }
+
+    Ok(theta)
 }
 
 /*
@@ -1181,7 +1217,7 @@ fn color_nodes(
                 // Every input to a phi needs to be on the same device. The
                 // phi itself is also on this device.
                 for (l, r) in zip(data.into_iter(), data.into_iter().skip(1).chain(once(&id))) {
-                    equations.push((UnificationTerm::Node(*l), UnificationTerm::Node(*r)));
+                    equations.push((UTerm::Node(*l), UTerm::Node(*r)));
                 }
             }
             Node::Reduce {
@@ -1197,8 +1233,8 @@ fn color_nodes(
             } if !editor.get_type(typing[id.idx()]).is_primitive() => {
                 // Every input to the reduce, and the reduce itself, are on
                 // the same device.
-                equations.push((UnificationTerm::Node(first), UnificationTerm::Node(second)));
-                equations.push((UnificationTerm::Node(second), UnificationTerm::Node(id)));
+                equations.push((UTerm::Node(first), UTerm::Node(second)));
+                equations.push((UTerm::Node(second), UTerm::Node(id)));
             }
             Node::Constant { id: _ }
                 if !editor.get_type(typing[id.idx()]).is_primitive()
@@ -1206,10 +1242,7 @@ fn color_nodes(
             {
                 // Constants inside device functions are allocated on that
                 // device.
-                equations.push((
-                    UnificationTerm::Node(id),
-                    UnificationTerm::Device(func_device),
-                ));
+                equations.push((UTerm::Node(id), UTerm::Device(func_device)));
             }
             Node::Read {
                 collect,
@@ -1219,13 +1252,13 @@ fn color_nodes(
                     // If this reads a primitive, then the collection needs to
                     // be on the device of this function.
                     equations.push((
-                        UnificationTerm::Node(collect),
-                        UnificationTerm::Device(backing_device(func_device)),
+                        UTerm::Node(collect),
+                        UTerm::Device(backing_device(func_device)),
                     ));
                 } else {
                     // If this read just reads a sub-collection, then `collect`
                     // and the read itself need to be on the same device.
-                    equations.push((UnificationTerm::Node(collect), UnificationTerm::Node(id)));
+                    equations.push((UTerm::Node(collect), UTerm::Node(id)));
                 }
             }
             Node::Write {
@@ -1239,11 +1272,11 @@ fn color_nodes(
                     // device copies, no constraint is needed with respect to
                     // the device of `data`.
                     equations.push((
-                        UnificationTerm::Node(collect),
-                        UnificationTerm::Device(backing_device(func_device)),
+                        UTerm::Node(collect),
+                        UTerm::Device(backing_device(func_device)),
                     ));
                 }
-                equations.push((UnificationTerm::Node(collect), UnificationTerm::Node(id)));
+                equations.push((UTerm::Node(collect), UTerm::Node(id)));
             }
             Node::Call {
                 control: _,
@@ -1255,15 +1288,14 @@ fn color_nodes(
                 // equation for the corresponding argument.
                 for (idx, arg) in args.into_iter().enumerate() {
                     if let Some(device) = node_colors[&callee].1[idx] {
-                        equations
-                            .push((UnificationTerm::Node(*arg), UnificationTerm::Device(device)));
+                        equations.push((UTerm::Node(*arg), UTerm::Device(device)));
                     }
                 }
 
                 // If the callee has a definite device for the returned value,
                 // add an equation for the call node itself.
                 if let Some(device) = node_colors[&callee].2 {
-                    equations.push((UnificationTerm::Node(id), UnificationTerm::Device(device)));
+                    equations.push((UTerm::Node(id), UTerm::Device(device)));
                 }
 
                 // For any object that may be returned by the callee that
@@ -1271,8 +1303,7 @@ fn color_nodes(
                 // corresponding argument and call node itself must be equal.
                 for ret in objects[&callee].returned_objects() {
                     if let Some(idx) = objects[&callee].origin(*ret).try_parameter() {
-                        equations
-                            .push((UnificationTerm::Node(args[idx]), UnificationTerm::Node(id)));
+                        equations.push((UTerm::Node(args[idx]), UTerm::Node(id)));
                     }
                 }
             }
@@ -1283,21 +1314,27 @@ fn color_nodes(
     // Solve the unification problem. I couldn't find a simple enough crate for
     // this, and the problems are usually pretty small, so just use a hand-
     // rolled implementation for now.
-    println!("{:?}", equations);
-    let solve = unify(&equations);
-    func_colors.0 = solve;
-    for id in editor.node_ids() {
-        if let Node::Parameter { index } = nodes[id.idx()]
-            && let Some(device) = func_colors.0.get(&id)
-        {
-            func_colors.1[index] = Some(*device);
-        } else if let Node::Return { control: _, data } = nodes[id.idx()]
-            && let Some(device) = func_colors.0.get(&data)
-        {
-            func_colors.2 = Some(*device);
+    match unify(VecDeque::from(equations.clone())) {
+        Ok(solve) => {
+            func_colors.0 = solve;
+            for id in editor.node_ids() {
+                if let Node::Parameter { index } = nodes[id.idx()]
+                    && let Some(device) = func_colors.0.get(&id)
+                {
+                    func_colors.1[index] = Some(*device);
+                } else if let Node::Return { control: _, data } = nodes[id.idx()]
+                    && let Some(device) = func_colors.0.get(&data)
+                {
+                    func_colors.2 = Some(*device);
+                }
+            }
+            Some(func_colors)
+        }
+        Err(progress) => {
+            println!("{:?}", progress);
+            todo!()
         }
     }
-    Some(func_colors)
 }
 
 fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> DynamicConstantID {
-- 
GitLab


From f1abcf98b2e5670e15905d782acfef62af27d4c7 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 18:15:06 -0600
Subject: [PATCH 07/13] Add in inter-device copy

---
 hercules_opt/src/gcm.rs                       | 30 +++++++++++++++++--
 .../multi_device/src/multi_device.sch         |  1 +
 2 files changed, 28 insertions(+), 3 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index d3075a04..a7f2da44 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1314,25 +1314,49 @@ fn color_nodes(
     // Solve the unification problem. I couldn't find a simple enough crate for
     // this, and the problems are usually pretty small, so just use a hand-
     // rolled implementation for now.
-    match unify(VecDeque::from(equations.clone())) {
+    match unify(VecDeque::from(equations)) {
         Ok(solve) => {
             func_colors.0 = solve;
+            // Look at parameter and return nodes to get the device signature of
+            // the function.
             for id in editor.node_ids() {
                 if let Node::Parameter { index } = nodes[id.idx()]
                     && let Some(device) = func_colors.0.get(&id)
                 {
+                    assert!(func_colors.1[index].is_none(), "PANIC: Found multiple parameter nodes for the same index in GCM. Please just run GVN first.");
                     func_colors.1[index] = Some(*device);
                 } else if let Node::Return { control: _, data } = nodes[id.idx()]
                     && let Some(device) = func_colors.0.get(&data)
                 {
+                    assert!(func_colors.2.is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix.");
                     func_colors.2 = Some(*device);
                 }
             }
             Some(func_colors)
         }
         Err(progress) => {
-            println!("{:?}", progress);
-            todo!()
+            // If unification failed, then there's some node using a node in
+            // `progress` that's expecting a different type than what it got.
+            // Pick one and add potentially inter-device copies on each def-use
+            // edge. We'll clean these up later.
+            let (id, _) = progress.into_iter().next().unwrap();
+            let users: Vec<_> = editor.get_users(id).collect();
+            let success = editor.edit(|mut edit| {
+                let cons = edit.add_zero_constant(typing[id.idx()]);
+                for user in users {
+                    let cons = edit.add_node(Node::Constant { id: cons });
+                    edit = edit.add_schedule(cons, Schedule::NoResetConstant)?;
+                    let copy = edit.add_node(Node::Write {
+                        collect: cons,
+                        data: id,
+                        indices: Box::new([]),
+                    });
+                    edit = edit.replace_all_uses_where(id, copy, |id| *id == user)?;
+                }
+                Ok(edit)
+            });
+            assert!(success);
+            None
         }
     }
 }
diff --git a/juno_samples/multi_device/src/multi_device.sch b/juno_samples/multi_device/src/multi_device.sch
index e5029a10..12ca53e4 100644
--- a/juno_samples/multi_device/src/multi_device.sch
+++ b/juno_samples/multi_device/src/multi_device.sch
@@ -30,3 +30,4 @@ infer-schedules(*);
 xdot[true](*);
 
 gcm(*);
+xdot[true](*);
-- 
GitLab


From da4d44729f9000c9136cc5cc2d62da83f43833cf Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 18:29:29 -0600
Subject: [PATCH 08/13] it works

---
 hercules_cg/src/rt.rs                         | 29 ++++++++++++++++++-
 hercules_opt/src/gcm.rs                       | 12 +-------
 .../multi_device/src/multi_device.sch         |  5 ++--
 3 files changed, 32 insertions(+), 14 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index e7fb1b60..711496bf 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -619,7 +619,34 @@ impl<'a> RTContext<'a> {
                 let collect_ty = self.typing[collect.idx()];
                 let data_size = self.codegen_type_size(self.typing[data.idx()]);
                 let offset = self.codegen_index_math(collect_ty, indices, bb)?;
-                todo!();
+                let data_ty = self.typing[data.idx()];
+                if self.module.types[data_ty.idx()].is_primitive() {
+                    todo!();
+                } else {
+                    // If the data item being written is not a primitive type,
+                    // then perform a memcpy from the data collection to the
+                    // destination collection. Look at the colors of the
+                    // `collect` and `data` inputs, since this may be an inter-
+                    // device copy.
+                    let src_device = self.node_colors.0[&data];
+                    let dst_device = self.node_colors.0[&collect];
+                    write!(
+                        block,
+                        "::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {});",
+                        src_device.name(),
+                        dst_device.name(),
+                        self.get_value(collect, bb),
+                        offset,
+                        self.get_value(data, bb),
+                        data_size,
+                    )?;
+                }
+                write!(
+                    block,
+                    "{} = {};",
+                    self.get_value(id, bb),
+                    self.get_value(collect, bb)
+                )?;
             }
             _ => panic!(
                 "PANIC: Can't lower {:?} in {}.",
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index a7f2da44..3f71568f 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -119,15 +119,7 @@ pub fn gcm(
     }
 
     let func_id = editor.func_id();
-    let Some(node_colors) = color_nodes(
-        editor,
-        def_use,
-        reverse_postorder,
-        typing,
-        &objects,
-        &devices,
-        node_colors,
-    ) else {
+    let Some(node_colors) = color_nodes(editor, typing, &objects, &devices, node_colors) else {
         return None;
     };
 
@@ -1185,8 +1177,6 @@ fn unify(
  */
 fn color_nodes(
     editor: &mut FunctionEditor,
-    def_use: &ImmutableDefUseMap,
-    reverse_postorder: &Vec<NodeID>,
     typing: &Vec<TypeID>,
     objects: &CollectionObjects,
     devices: &Vec<Device>,
diff --git a/juno_samples/multi_device/src/multi_device.sch b/juno_samples/multi_device/src/multi_device.sch
index 12ca53e4..e7f9cae7 100644
--- a/juno_samples/multi_device/src/multi_device.sch
+++ b/juno_samples/multi_device/src/multi_device.sch
@@ -22,12 +22,13 @@ let l1 = outline(multi_device_1@loop1);
 let l2 = outline(multi_device_1@loop2);
 gpu(l1);
 cpu(l2);
+unforkify(l2);
 ip-sroa(*);
 sroa(*);
+ccp(*);
+gvn(*);
 dce(*);
 
 infer-schedules(*);
-xdot[true](*);
 
 gcm(*);
-xdot[true](*);
-- 
GitLab


From 9ff5f47ddc6a2dcf1945bde4f38674d73fbb4d36 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 20:16:04 -0600
Subject: [PATCH 09/13] fix

---
 hercules_cg/src/rt.rs | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 711496bf..23227482 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -968,8 +968,8 @@ impl<'a> RTContext<'a> {
         // they are collections and whether they should be immutable or mutable
         // references.
         let func = self.get_func();
-        let mut param_devices = &self.node_colors.1;
-        let mut return_device = self.node_colors.2;
+        let param_devices = &self.node_colors.1;
+        let return_device = self.node_colors.2;
         let mut param_muts = vec![false; func.param_types.len()];
         let mut return_mut = true;
         let objects = &self.collection_objects[&self.func_id];
@@ -1021,9 +1021,9 @@ impl<'a> RTContext<'a> {
             if self.module.types[func.param_types[idx].idx()].is_primitive() {
                 write!(w, ", p{}: {}", idx, self.get_type(func.param_types[idx]))?;
             } else {
-                let device = match param_devices[idx].unwrap() {
-                    Device::LLVM => "CPU",
-                    Device::CUDA => "CUDA",
+                let device = match param_devices[idx] {
+                    Some(Device::LLVM) | None => "CPU",
+                    Some(Device::CUDA) => "CUDA",
                     _ => panic!(),
                 };
                 let mutability = if param_muts[idx] { "Mut" } else { "" };
@@ -1037,9 +1037,9 @@ impl<'a> RTContext<'a> {
         if self.module.types[func.return_type.idx()].is_primitive() {
             write!(w, ") -> {} {{", self.get_type(func.return_type))?;
         } else {
-            let device = match return_device.unwrap() {
-                Device::LLVM => "CPU",
-                Device::CUDA => "CUDA",
+            let device = match return_device {
+                Some(Device::LLVM) | None => "CPU",
+                Some(Device::CUDA) => "CUDA",
                 _ => panic!(),
             };
             let mutability = if return_mut { "Mut" } else { "" };
@@ -1102,9 +1102,9 @@ impl<'a> RTContext<'a> {
         if self.module.types[func.return_type.idx()].is_primitive() {
             write!(w, "        ret")?;
         } else {
-            let device = match return_device.unwrap() {
-                Device::LLVM => "CPU",
-                Device::CUDA => "CUDA",
+            let device = match return_device {
+                Some(Device::LLVM) | None => "CPU",
+                Some(Device::CUDA) => "CUDA",
                 _ => panic!(),
             };
             let mutability = if return_mut { "Mut" } else { "" };
-- 
GitLab


From 780a07678ebb0244c449345dd0ce2cf6f24e80c8 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 20:27:19 -0600
Subject: [PATCH 10/13] fix write constraints

---
 hercules_opt/src/gcm.rs | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 3f71568f..55888e6a 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -118,7 +118,6 @@ pub fn gcm(
         return None;
     }
 
-    let func_id = editor.func_id();
     let Some(node_colors) = color_nodes(editor, typing, &objects, &devices, node_colors) else {
         return None;
     };
@@ -1256,15 +1255,22 @@ fn color_nodes(
                 data,
                 indices: _,
             } => {
-                if editor.get_type(typing[data.idx()]).is_primitive() {
-                    // If this writes a primitive, then the collection needs to
-                    // be on the device of this function. Since we can do inter-
-                    // device copies, no constraint is needed with respect to
-                    // the device of `data`.
+                if func_device != Device::AsyncRust
+                    || editor.get_type(typing[data.idx()]).is_primitive()
+                {
+                    // If this writes a primitive or this is in a device
+                    // function, then the collection needs to be on the backing
+                    // device of this function.
                     equations.push((
                         UTerm::Node(collect),
                         UTerm::Device(backing_device(func_device)),
                     ));
+
+                    if func_device != Device::AsyncRust {
+                        // We can only do inter-device copies in AsyncRust
+                        // functions.
+                        equations.push((UTerm::Node(collect), UTerm::Node(data)));
+                    }
                 }
                 equations.push((UTerm::Node(collect), UTerm::Node(id)));
             }
-- 
GitLab


From 5c44ea7d3549793248d69f3741f821e4b95f654a Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 20:29:09 -0600
Subject: [PATCH 11/13] Slots can be noresetconstant

---
 hercules_opt/src/gcm.rs | 1 +
 1 file changed, 1 insertion(+)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 55888e6a..16d300ba 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -862,6 +862,7 @@ fn spill_clones(
             // ed throughout the entire function.
             let cons_id = edit.add_zero_constant(typing[obj.idx()]);
             let slot_id = edit.add_node(Node::Constant { id: cons_id });
+            edit = edit.add_schedule(slot_id, Schedule::NoResetConstant)?;
 
             // Allocate IDs for phis that move the spill slot throughout the
             // function without implicit clones. These are dummy phis, since
-- 
GitLab


From b7f3dc8b782277a48a84c1eb5e599c8610533596 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 20:40:31 -0600
Subject: [PATCH 12/13] .

---
 .gitlab-ci.yml | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 4a2a883a..af867eb8 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -1,9 +1,9 @@
 test-cpu:
   stage: test
   script:
-    - cargo test --features=opencv
+    - cargo test --features=opencv -vv
 
 test-gpu:
   stage: test
   script:
-    - cargo test --features=cuda,opencv
+    - cargo test --features=cuda,opencv -vv
-- 
GitLab


From 532f5392721f3557cc9352866f1c300646711c7e Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Sat, 8 Feb 2025 20:57:46 -0600
Subject: [PATCH 13/13] fix

---
 hercules_opt/src/gcm.rs | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 16d300ba..79bd2851 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -1267,7 +1267,9 @@ fn color_nodes(
                         UTerm::Device(backing_device(func_device)),
                     ));
 
-                    if func_device != Device::AsyncRust {
+                    if func_device != Device::AsyncRust
+                        && !editor.get_type(typing[data.idx()]).is_primitive()
+                    {
                         // We can only do inter-device copies in AsyncRust
                         // functions.
                         equations.push((UTerm::Node(collect), UTerm::Node(data)));
-- 
GitLab