From 8f806736e67dd0661578464c681ed76aa60a4d99 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Mon, 17 Feb 2025 21:14:47 -0600 Subject: [PATCH] Add all the stuff related to collections for LibraryCall nodes --- hercules_ir/src/collections.rs | 32 ++++++++- hercules_ir/src/device.rs | 115 -------------------------------- hercules_opt/src/gcm.rs | 37 ++++++++++ juno_samples/matmul/src/gpu.sch | 2 +- 4 files changed, 69 insertions(+), 117 deletions(-) diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index f3474ae0..6b631519 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -218,6 +218,8 @@ pub fn collection_objects( // - Constant: may originate an object. // - Call: may originate an object and may return an object passed in as // a parameter. + // - LibraryCall: may return an object passed in as a parameter, but may + // not originate an object. // - Read: may extract a smaller object from the input - this is // considered to be the same object as the input, as no copy takes // place. @@ -288,6 +290,14 @@ pub fn collection_objects( } CollectionObjectLattice { objs } } + Node::LibraryCall { + library_function, + args: _, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => inputs[0].clone(), + }, Node::Undef { ty: _ } => { let obj = origins .iter() @@ -332,7 +342,13 @@ pub fn collection_objects( for object in objects_per_node[idx].iter() { mutated[object.idx()].push(NodeID::new(idx)); } - } else if let Some((_, callee, _, args)) = node.try_call() { + } else if let Node::Call { + control: _, + function: callee, + dynamic_constants: _, + args, + } = node + { let fco = &collection_objects[&callee]; for (param_idx, arg) in args.into_iter().enumerate() { // If this parameter corresponds to an object and it's @@ -347,6 +363,20 @@ pub fn collection_objects( } } } + } else if let Node::LibraryCall { + library_function, + args, + ty: _, + device: _, + } = node + { + match library_function { + LibraryFunction::GEMM => { + for object in objects_per_node[args[0].idx()].iter() { + mutated[object.idx()].push(NodeID::new(idx)); + } + } + } } } diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs index cbf8d634..c4a5454b 100644 --- a/hercules_ir/src/device.rs +++ b/hercules_ir/src/device.rs @@ -23,118 +23,3 @@ pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec devices } - -pub type FunctionObjectDeviceDemands = Vec<BTreeSet<Device>>; -pub type ObjectDeviceDemands = Vec<FunctionObjectDeviceDemands>; - -/* - * This analysis figures out which device each collection object may be on. At - * first, an object may need to be on different devices at different times. This - * is fine during optimization. - */ -pub fn object_device_demands( - functions: &Vec<Function>, - types: &Vec<Type>, - typing: &ModuleTyping, - callgraph: &CallGraph, - objects: &CollectionObjects, - devices: &Vec<Device>, -) -> ObjectDeviceDemands { - // An object is "demanded" on a device when: - // 1. The object is used by a primitive read node or write node in a device - // function. This includes objects on the `data` input to write nodes. - // Non-primitive reads don't demand an object on a device since they are - // lowered to pointer math and no actual memory transfers. - // 2. The object is a constant / undef defined in a device function. - // 3. The object is passed as input to a call node where the corresponding - // object in the callee is demanded on a device. - // 4. The object is returned from a call node where the corresponding object - // in the callee is demanded on a device. - // Note that reads and writes in a RT function don't induce a device demand. - // This is because RT functions can call device functions as necessary to - // arbitrarily move data onto / off of devices (though this may be slow). - // Traverse the functions in a module in reverse topological order, since - // the analysis of a function depends on all functions it calls. - let mut demands: ObjectDeviceDemands = vec![vec![]; functions.len()]; - let topo = callgraph.topo(); - - for func_id in topo { - let function = &functions[func_id.idx()]; - let typing = &typing[func_id.idx()]; - let device = devices[func_id.idx()]; - - demands[func_id.idx()].resize(objects[&func_id].num_objects(), BTreeSet::new()); - match device { - Device::LLVM | Device::CUDA => { - for (idx, node) in function.nodes.iter().enumerate() { - match node { - // Condition #1. - Node::Read { - collect, - indices: _, - } if types[typing[idx].idx()].is_primitive() => { - for object in objects[&func_id].objects(*collect) { - demands[func_id.idx()][object.idx()].insert(device); - } - } - Node::Write { - collect, - data, - indices: _, - } => { - for object in objects[&func_id] - .objects(*collect) - .into_iter() - .chain(objects[&func_id].objects(*data).into_iter()) - { - demands[func_id.idx()][object.idx()].insert(device); - } - } - // Condition #2. - Node::Constant { id: _ } | Node::Undef { ty: _ } => { - for object in objects[&func_id].objects(NodeID::new(idx)) { - demands[func_id.idx()][object.idx()].insert(device); - } - } - _ => {} - } - } - } - Device::AsyncRust => { - for (idx, node) in function.nodes.iter().enumerate() { - if let Node::Call { - control: _, - function: callee, - dynamic_constants: _, - args, - } = node - { - // Condition #3. - for (param_idx, arg) in args.into_iter().enumerate() { - if let Some(callee_obj) = objects[callee].param_to_object(param_idx) { - let callee_demands = - take(&mut demands[callee.idx()][callee_obj.idx()]); - for object in objects[&func_id].objects(*arg) { - demands[func_id.idx()][object.idx()] - .extend(callee_demands.iter()); - } - demands[callee.idx()][callee_obj.idx()] = callee_demands; - } - } - - // Condition #4. - for callee_obj in objects[callee].returned_objects() { - let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]); - for object in objects[&func_id].objects(NodeID::new(idx)) { - demands[func_id.idx()][object.idx()].extend(callee_demands.iter()); - } - demands[callee.idx()][callee_obj.idx()] = callee_demands; - } - } - } - } - } - } - - demands -} diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 821d02ea..446b3184 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -655,6 +655,14 @@ fn terminating_reads<'a>( None } })), + Node::LibraryCall { + library_function, + ref args, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))), + }, _ => Box::new(empty()), } } @@ -728,6 +736,16 @@ fn mutating_objects<'a>( }) .flatten(), ), + Node::LibraryCall { + library_function, + ref args, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => { + Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id)) + } + }, _ => Box::new(empty()), } } @@ -757,6 +775,14 @@ fn mutating_writes<'a>( None } })), + Node::LibraryCall { + library_function, + ref args, + ty: _, + device: _, + } => match library_function { + LibraryFunction::GEMM => Box::new(once(args[0])), + }, _ => Box::new(empty()), } } @@ -1311,6 +1337,17 @@ fn color_nodes( } } } + Node::LibraryCall { + library_function: _, + ref args, + ty: _, + device, + } => { + for arg in args { + equations.push((UTerm::Node(*arg), UTerm::Device(device))); + } + equations.push((UTerm::Node(id), UTerm::Device(device))); + } _ => {} } } diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch index 35ed1e84..76159ef7 100644 --- a/juno_samples/matmul/src/gpu.sch +++ b/juno_samples/matmul/src/gpu.sch @@ -21,7 +21,7 @@ fixpoint { ip-sroa(*); sroa(*); dce(*); -xdot[true](*); float-collections(*); gcm(*); +xdot[true](*); -- GitLab