From 8f806736e67dd0661578464c681ed76aa60a4d99 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Mon, 17 Feb 2025 21:14:47 -0600
Subject: [PATCH] Add all the stuff related to collections for LibraryCall
 nodes

---
 hercules_ir/src/collections.rs  |  32 ++++++++-
 hercules_ir/src/device.rs       | 115 --------------------------------
 hercules_opt/src/gcm.rs         |  37 ++++++++++
 juno_samples/matmul/src/gpu.sch |   2 +-
 4 files changed, 69 insertions(+), 117 deletions(-)

diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs
index f3474ae0..6b631519 100644
--- a/hercules_ir/src/collections.rs
+++ b/hercules_ir/src/collections.rs
@@ -218,6 +218,8 @@ pub fn collection_objects(
         // - Constant: may originate an object.
         // - Call: may originate an object and may return an object passed in as
         //   a parameter.
+        // - LibraryCall: may return an object passed in as a parameter, but may
+        //   not originate an object.
         // - Read: may extract a smaller object from the input - this is
         //   considered to be the same object as the input, as no copy takes
         //   place.
@@ -288,6 +290,14 @@ pub fn collection_objects(
                     }
                     CollectionObjectLattice { objs }
                 }
+                Node::LibraryCall {
+                    library_function,
+                    args: _,
+                    ty: _,
+                    device: _,
+                } => match library_function {
+                    LibraryFunction::GEMM => inputs[0].clone(),
+                },
                 Node::Undef { ty: _ } => {
                     let obj = origins
                         .iter()
@@ -332,7 +342,13 @@ pub fn collection_objects(
                 for object in objects_per_node[idx].iter() {
                     mutated[object.idx()].push(NodeID::new(idx));
                 }
-            } else if let Some((_, callee, _, args)) = node.try_call() {
+            } else if let Node::Call {
+                control: _,
+                function: callee,
+                dynamic_constants: _,
+                args,
+            } = node
+            {
                 let fco = &collection_objects[&callee];
                 for (param_idx, arg) in args.into_iter().enumerate() {
                     // If this parameter corresponds to an object and it's
@@ -347,6 +363,20 @@ pub fn collection_objects(
                         }
                     }
                 }
+            } else if let Node::LibraryCall {
+                library_function,
+                args,
+                ty: _,
+                device: _,
+            } = node
+            {
+                match library_function {
+                    LibraryFunction::GEMM => {
+                        for object in objects_per_node[args[0].idx()].iter() {
+                            mutated[object.idx()].push(NodeID::new(idx));
+                        }
+                    }
+                }
             }
         }
 
diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs
index cbf8d634..c4a5454b 100644
--- a/hercules_ir/src/device.rs
+++ b/hercules_ir/src/device.rs
@@ -23,118 +23,3 @@ pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec
 
     devices
 }
-
-pub type FunctionObjectDeviceDemands = Vec<BTreeSet<Device>>;
-pub type ObjectDeviceDemands = Vec<FunctionObjectDeviceDemands>;
-
-/*
- * This analysis figures out which device each collection object may be on. At
- * first, an object may need to be on different devices at different times. This
- * is fine during optimization.
- */
-pub fn object_device_demands(
-    functions: &Vec<Function>,
-    types: &Vec<Type>,
-    typing: &ModuleTyping,
-    callgraph: &CallGraph,
-    objects: &CollectionObjects,
-    devices: &Vec<Device>,
-) -> ObjectDeviceDemands {
-    // An object is "demanded" on a device when:
-    // 1. The object is used by a primitive read node or write node in a device
-    //    function. This includes objects on the `data` input to write nodes.
-    //    Non-primitive reads don't demand an object on a device since they are
-    //    lowered to pointer math and no actual memory transfers.
-    // 2. The object is a constant / undef defined in a device function.
-    // 3. The object is passed as input to a call node where the corresponding
-    //    object in the callee is demanded on a device.
-    // 4. The object is returned from a call node where the corresponding object
-    //    in the callee is demanded on a device.
-    // Note that reads and writes in a RT function don't induce a device demand.
-    // This is because RT functions can  call device functions as necessary to
-    // arbitrarily move data onto / off of devices (though this may be slow).
-    // Traverse the functions in a module in reverse topological order, since
-    // the analysis of a function depends on all functions it calls.
-    let mut demands: ObjectDeviceDemands = vec![vec![]; functions.len()];
-    let topo = callgraph.topo();
-
-    for func_id in topo {
-        let function = &functions[func_id.idx()];
-        let typing = &typing[func_id.idx()];
-        let device = devices[func_id.idx()];
-
-        demands[func_id.idx()].resize(objects[&func_id].num_objects(), BTreeSet::new());
-        match device {
-            Device::LLVM | Device::CUDA => {
-                for (idx, node) in function.nodes.iter().enumerate() {
-                    match node {
-                        // Condition #1.
-                        Node::Read {
-                            collect,
-                            indices: _,
-                        } if types[typing[idx].idx()].is_primitive() => {
-                            for object in objects[&func_id].objects(*collect) {
-                                demands[func_id.idx()][object.idx()].insert(device);
-                            }
-                        }
-                        Node::Write {
-                            collect,
-                            data,
-                            indices: _,
-                        } => {
-                            for object in objects[&func_id]
-                                .objects(*collect)
-                                .into_iter()
-                                .chain(objects[&func_id].objects(*data).into_iter())
-                            {
-                                demands[func_id.idx()][object.idx()].insert(device);
-                            }
-                        }
-                        // Condition #2.
-                        Node::Constant { id: _ } | Node::Undef { ty: _ } => {
-                            for object in objects[&func_id].objects(NodeID::new(idx)) {
-                                demands[func_id.idx()][object.idx()].insert(device);
-                            }
-                        }
-                        _ => {}
-                    }
-                }
-            }
-            Device::AsyncRust => {
-                for (idx, node) in function.nodes.iter().enumerate() {
-                    if let Node::Call {
-                        control: _,
-                        function: callee,
-                        dynamic_constants: _,
-                        args,
-                    } = node
-                    {
-                        // Condition #3.
-                        for (param_idx, arg) in args.into_iter().enumerate() {
-                            if let Some(callee_obj) = objects[callee].param_to_object(param_idx) {
-                                let callee_demands =
-                                    take(&mut demands[callee.idx()][callee_obj.idx()]);
-                                for object in objects[&func_id].objects(*arg) {
-                                    demands[func_id.idx()][object.idx()]
-                                        .extend(callee_demands.iter());
-                                }
-                                demands[callee.idx()][callee_obj.idx()] = callee_demands;
-                            }
-                        }
-
-                        // Condition #4.
-                        for callee_obj in objects[callee].returned_objects() {
-                            let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]);
-                            for object in objects[&func_id].objects(NodeID::new(idx)) {
-                                demands[func_id.idx()][object.idx()].extend(callee_demands.iter());
-                            }
-                            demands[callee.idx()][callee_obj.idx()] = callee_demands;
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    demands
-}
diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs
index 821d02ea..446b3184 100644
--- a/hercules_opt/src/gcm.rs
+++ b/hercules_opt/src/gcm.rs
@@ -655,6 +655,14 @@ fn terminating_reads<'a>(
                 None
             }
         })),
+        Node::LibraryCall {
+            library_function,
+            ref args,
+            ty: _,
+            device: _,
+        } => match library_function {
+            LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))),
+        },
         _ => Box::new(empty()),
     }
 }
@@ -728,6 +736,16 @@ fn mutating_objects<'a>(
                 })
                 .flatten(),
         ),
+        Node::LibraryCall {
+            library_function,
+            ref args,
+            ty: _,
+            device: _,
+        } => match library_function {
+            LibraryFunction::GEMM => {
+                Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id))
+            }
+        },
         _ => Box::new(empty()),
     }
 }
@@ -757,6 +775,14 @@ fn mutating_writes<'a>(
                 None
             }
         })),
+        Node::LibraryCall {
+            library_function,
+            ref args,
+            ty: _,
+            device: _,
+        } => match library_function {
+            LibraryFunction::GEMM => Box::new(once(args[0])),
+        },
         _ => Box::new(empty()),
     }
 }
@@ -1311,6 +1337,17 @@ fn color_nodes(
                     }
                 }
             }
+            Node::LibraryCall {
+                library_function: _,
+                ref args,
+                ty: _,
+                device,
+            } => {
+                for arg in args {
+                    equations.push((UTerm::Node(*arg), UTerm::Device(device)));
+                }
+                equations.push((UTerm::Node(id), UTerm::Device(device)));
+            }
             _ => {}
         }
     }
diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch
index 35ed1e84..76159ef7 100644
--- a/juno_samples/matmul/src/gpu.sch
+++ b/juno_samples/matmul/src/gpu.sch
@@ -21,7 +21,7 @@ fixpoint {
 ip-sroa(*);
 sroa(*);
 dce(*);
-xdot[true](*);
 
 float-collections(*);
 gcm(*);
+xdot[true](*);
-- 
GitLab