From 14f46b85ec739cd1f9b954da6672dbc3548e31f7 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Tue, 28 Jan 2025 13:47:07 -0600 Subject: [PATCH] Allocate all memories up-front --- hercules_cg/src/device.rs | 22 - hercules_cg/src/lib.rs | 23 +- hercules_cg/src/rt.rs | 618 ++++++++++-------- hercules_ir/src/device.rs | 133 ++++ hercules_ir/src/dom.rs | 54 ++ hercules_ir/src/dot.rs | 19 +- hercules_ir/src/fork_join_analysis.rs | 54 -- hercules_ir/src/ir.rs | 12 +- hercules_ir/src/lib.rs | 2 + hercules_opt/src/device_placement.rs | 3 + hercules_opt/src/editor.rs | 8 + hercules_opt/src/gcm.rs | 302 ++++++++- hercules_opt/src/lib.rs | 2 + hercules_rt/src/lib.rs | 423 +++++------- hercules_rt/src/rtdefs.cu | 23 +- hercules_samples/call/src/main.rs | 10 +- hercules_samples/ccp/src/main.rs | 7 +- hercules_samples/dot/src/main.rs | 11 +- hercules_samples/fac/src/main.rs | 7 +- hercules_samples/matmul/src/cpu.sch | 3 + hercules_samples/matmul/src/gpu.sch | 3 + hercules_samples/matmul/src/main.rs | 11 +- juno_samples/antideps/src/main.rs | 25 +- juno_samples/casts_and_intrinsics/src/main.rs | 7 +- juno_samples/cava/src/main.rs | 21 +- juno_samples/concat/src/main.rs | 7 +- juno_samples/implicit_clone/src/main.rs | 28 +- juno_samples/matmul/src/main.rs | 20 +- juno_samples/nested_ccp/src/main.rs | 19 +- juno_samples/schedule_test/src/main.rs | 15 +- juno_samples/simple3/src/main.rs | 15 +- juno_scheduler/src/pm.rs | 156 +++-- 32 files changed, 1266 insertions(+), 797 deletions(-) delete mode 100644 hercules_cg/src/device.rs create mode 100644 hercules_ir/src/device.rs create mode 100644 hercules_opt/src/device_placement.rs diff --git a/hercules_cg/src/device.rs b/hercules_cg/src/device.rs deleted file mode 100644 index 866fa6ad..00000000 --- a/hercules_cg/src/device.rs +++ /dev/null @@ -1,22 +0,0 @@ -use hercules_ir::*; - -/* - * Top level function to definitively place functions onto devices. A function - * may store a device placement, but only optionally - this function assigns - * devices to the rest of the functions. - */ -pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec<Device> { - let mut devices = vec![]; - - for (idx, function) in functions.into_iter().enumerate() { - if let Some(device) = function.device { - devices.push(device); - } else if function.entry || callgraph.num_callees(FunctionID::new(idx)) != 0 { - devices.push(Device::AsyncRust); - } else { - devices.push(Device::LLVM); - } - } - - devices -} diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index 8aaab214..6a12901f 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -1,15 +1,17 @@ #![feature(if_let_guard, let_chains)] pub mod cpu; -pub mod device; pub mod rt; pub use crate::cpu::*; -pub use crate::device::*; pub use crate::rt::*; +use std::collections::BTreeMap; + use hercules_ir::*; +pub const LARGEST_ALIGNMENT: usize = 8; + /* * The alignment of a type does not depend on dynamic constants. */ @@ -28,3 +30,20 @@ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize { Type::Array(elem, _) => get_type_alignment(types, elem), } } + +/* + * Nodes producing collection values are "colored" with what device their + * underlying memory lives on. + */ +pub type FunctionNodeColors = BTreeMap<NodeID, Device>; +pub type NodeColors = Vec<FunctionNodeColors>; + +/* + * The allocation information of each function is a size of the backing memory + * needed and offsets into that backing memory per constant object and call node + * in the function. + */ +pub type FunctionBackingAllocation = + BTreeMap<Device, (DynamicConstantID, BTreeMap<NodeID, DynamicConstantID>)>; +pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>; +pub const BACKED_DEVICES: [Device; 2] = [Device::LLVM, Device::CUDA]; diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index d093b2b0..445647ef 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -10,16 +10,80 @@ use crate::*; * Entry Hercules functions are lowered to async Rust code to achieve easy task * level parallelism. This Rust is generated textually, and is included via a * procedural macro in the user's Rust code. + * + * Generating Rust that properly handles memory across devices is tricky. In + * particular, the following elements are challenges: + * + * 1. RT functions may return objects that were not first passed in as + * parameters, via object constants. + * 2. We want to allocate as much memory upfront as possible. Our goal is for + * each call to a RT function from host Rust code, there is at most one + * memory allocation per device. + * 3. We want to statically determine when cross-device communication is + * necessary - this should be a separate concern from allocation. + * 4. At the boundary between host Rust / RT functions, we want to encode + * lifetime rules so that the Hercules API is memory safe. + * 5. We want to support efficient composition of Hercules RT functions in both + * synchronous and asynchronous contexts. + * + * Challenges #1 and #2 require that RT functions themselves do not allocate + * memory. Instead, for every entry point function, a "runner" type will be + * generated. The host Rust code must instantiate a runner object to invoke an + * entry point function. This runner object contains a "backing" memory that is + * the single allocation of memories for this function. The runner object can be + * used to call the same entry point multiple times, and to the extent possible + * the backing memory will be re-used. The size of the backing memory depends on + * the dynamic constants passed in to the entry point, so it's lazily allocated + * on calls to the entry point to the needed size. To address challenge #4, any + * returned objects will be lifetime-bound to the runner object instance, + * ensuring that the reference cannot be used after the runner object has de- + * allocated the backing memory. This also ensures that the runner can't be run + * again while a returned object from a previous iteration is still live, since + * the entry point method requires an exclusive reference to the runner. + * + * Addressing challenge #3 requires we determine what objects are live on what + * devices at what times. This can be done fairly easily by coloring nodes by + * what device they produce their result on and inserting inter-device transfers + * along edges connecting nodes of different colors. Nodes can only have a + * single color - this is enforced by the GCM pass. + * + * Addressing challenge #5 requires runner objects for entry points accept and + * return objects that are not in their own backing memory and potentially on + * any device. For this reason, parameter and return nodes are not necessarily + * CPU colored. Instead, runners take and return Hercules reference objects that + * refer to memory on some device which have unknown origin. Hercules reference + * objects have a lifetime parameter, and when a runner may return a Hercules + * reference that refers to its backing memory, the lifetime of the Hercules + * reference is the same as the lifetime of the mutable reference of the runner + * used in the entry point signature. In other words, the RT backend infers the + * proper lifetime bounds on parameter and returned Hercules reference objects + * in relation to the runner's self reference using the collection objects + * analysis. There are the following kinds of Hercules reference objects: + * + * - HerculesCPURef + * - HerculesCPURefMut + * - HerculesCUDARef + * - HerculesCUDARefMut + * + * Essentially, there are types for each device, one for immutable references + * and one for exclusive references. Mutable references can decay into immutable + * references, and immutable references can be cloned. The CPU reference types + * can be created from normal Rust references. The CUDA reference types can't be + * created from normal Rust references - for that purpose, an additional type is + * given, CUDABox, which essentially allows the user to manually allocate and + * set some CUDA memory - the user can then take a CUDA reference to that box. */ pub fn rt_codegen<W: Write>( func_id: FunctionID, module: &Module, typing: &Vec<TypeID>, control_subgraph: &Subgraph, - bbs: &BasicBlocks, collection_objects: &CollectionObjects, callgraph: &CallGraph, devices: &Vec<Device>, + bbs: &BasicBlocks, + node_colors: &FunctionNodeColors, + backing_allocation: &FunctionBackingAllocation, w: &mut W, ) -> Result<(), Error> { let ctx = RTContext { @@ -27,10 +91,12 @@ pub fn rt_codegen<W: Write>( module, typing, control_subgraph, - bbs, collection_objects, callgraph, devices, + bbs, + node_colors, + backing_allocation, }; ctx.codegen_function(w) } @@ -40,24 +106,40 @@ struct RTContext<'a> { module: &'a Module, typing: &'a Vec<TypeID>, control_subgraph: &'a Subgraph, - bbs: &'a BasicBlocks, collection_objects: &'a CollectionObjects, callgraph: &'a CallGraph, devices: &'a Vec<Device>, + bbs: &'a BasicBlocks, + node_colors: &'a FunctionNodeColors, + backing_allocation: &'a FunctionBackingAllocation, } impl<'a> RTContext<'a> { fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> { + // If this is an entry function, generate a corresponding runner object + // type definition. let func = &self.get_func(); + if func.entry { + self.codegen_runner_object(w)?; + } // Dump the function signature. write!( w, - "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe)]\nasync fn {}<'a>(", + "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]\nasync unsafe fn {}(", func.name )?; let mut first_param = true; - // The first set of parameters are dynamic constants. + // The first set of parameters are pointers to backing memories. + for (device, _) in self.backing_allocation { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!(w, "backing_{}: *mut u8", device.name())?; + } + // The second set of parameters are dynamic constants. for idx in 0..func.num_dynamic_constants { if first_param { first_param = false; @@ -66,35 +148,17 @@ impl<'a> RTContext<'a> { } write!(w, "dc_p{}: u64", idx)?; } - // The second set of parameters are normal parameters. + // The third set of parameters are normal parameters. for idx in 0..func.param_types.len() { if first_param { first_param = false; } else { write!(w, ", ")?; } - if !self.module.types[func.param_types[idx].idx()].is_primitive() { - write!(w, "mut ")?; - } write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?; } write!(w, ") -> {} {{\n", self.get_type(func.return_type))?; - // Allocate collection constants. - for object in self.collection_objects[&self.func_id].iter_objects() { - if let CollectionObjectOrigin::Constant(id) = - self.collection_objects[&self.func_id].origin(object) - { - let size = self.codegen_type_size(self.typing[id.idx()]); - write!( - w, - " let mut obj{}: ::hercules_rt::HerculesBox = unsafe {{ ::hercules_rt::HerculesBox::__zeros({}) }};\n", - object.idx(), - size - )? - } - } - // Dump signatures for called device functions. write!(w, " extern \"C\" {{\n")?; for callee in self.callgraph.get_callees(self.func_id) { @@ -118,9 +182,9 @@ impl<'a> RTContext<'a> { } else { write!(w, ", ")?; } - write!(w, "p{}: {}", idx, self.device_get_type(*ty))?; + write!(w, "p{}: {}", idx, self.get_type(*ty))?; } - write!(w, ") -> {};\n", self.device_get_type(callee.return_type))?; + write!(w, ") -> {};\n", self.get_type(callee.return_type))?; } write!(w, " }}\n")?; @@ -139,7 +203,7 @@ impl<'a> RTContext<'a> { } else if self.module.types[self.typing[idx].idx()].is_float() { "0.0" } else { - "unsafe { ::hercules_rt::HerculesBox::__null() }" + "::core::ptr::null_mut()" } )?; } @@ -149,7 +213,7 @@ impl<'a> RTContext<'a> { // blocks to drive execution. write!( w, - " let mut control_token: i8 = 0;\n let return_value = loop {{\n match control_token {{\n", + " let mut control_token: i8 = 0;\n loop {{\n match control_token {{\n", )?; let mut blocks: BTreeMap<_, _> = (0..func.nodes.len()) @@ -183,39 +247,7 @@ impl<'a> RTContext<'a> { } // Close the match and loop. - write!(w, " _ => panic!()\n }}\n }};\n")?; - - // Emit the epilogue of the function. - write!(w, " unsafe {{\n")?; - for idx in 0..func.param_types.len() { - if !self.module.types[func.param_types[idx].idx()].is_primitive() { - write!(w, " p{}.__forget();\n", idx)?; - } - } - if !self.module.types[func.return_type.idx()].is_primitive() { - for object in self.collection_objects[&self.func_id].iter_objects() { - if let CollectionObjectOrigin::Constant(_) = - self.collection_objects[&self.func_id].origin(object) - { - write!( - w, - " if obj{}.__cmp_ids(&return_value) {{\n", - object.idx() - )?; - write!(w, " obj{}.__forget();\n", object.idx())?; - write!(w, " }}\n")?; - } - } - } - for idx in 0..func.nodes.len() { - if !func.nodes[idx].is_control() - && !self.module.types[self.typing[idx].idx()].is_primitive() - { - write!(w, " node_{}.__forget();\n", idx)?; - } - } - write!(w, " }}\n")?; - write!(w, " return_value\n")?; + write!(w, " _ => panic!()\n }}\n }}\n")?; write!(w, "}}\n")?; Ok(()) } @@ -263,15 +295,7 @@ impl<'a> RTContext<'a> { } Node::Return { control: _, data } => { let block = &mut blocks.get_mut(&id).unwrap(); - if self.module.types[self.typing[data.idx()].idx()].is_primitive() { - write!(block, " break {};\n", self.get_value(data))? - } else { - write!( - block, - " break unsafe {{ {}.__clone() }};\n", - self.get_value(data) - )? - } + write!(block, " return {};\n", self.get_value(data))? } _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), } @@ -290,25 +314,17 @@ impl<'a> RTContext<'a> { match func.nodes[id.idx()] { Node::Parameter { index } => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); - if self.module.types[self.typing[id.idx()].idx()].is_primitive() { - write!( - block, - " {} = p{};\n", - self.get_value(id), - index - )? - } else { - write!( - block, - " {} = unsafe {{ p{}.__clone() }};\n", - self.get_value(id), - index - )? - } + write!( + block, + " {} = p{};\n", + self.get_value(id), + index + )? } Node::Constant { id: cons_id } => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); write!(block, " {} = ", self.get_value(id))?; + let mut size = None; match self.module.constants[cons_id.idx()] { Constant::Boolean(val) => write!(block, "{}bool", val)?, Constant::Integer8(val) => write!(block, "{}i8", val)?, @@ -321,14 +337,32 @@ impl<'a> RTContext<'a> { Constant::UnsignedInteger64(val) => write!(block, "{}u64", val)?, Constant::Float32(val) => write!(block, "{}f32", val)?, Constant::Float64(val) => write!(block, "{}f64", val)?, - Constant::Product(_, _) | Constant::Summation(_, _, _) | Constant::Array(_) => { - let objects = self.collection_objects[&self.func_id].objects(id); - assert_eq!(objects.len(), 1); - let object = objects[0]; - write!(block, "unsafe {{ obj{}.__clone() }}", object.idx())? + Constant::Product(ty, _) + | Constant::Summation(ty, _, _) + | Constant::Array(ty) => { + let (device, offset) = self + .backing_allocation + .into_iter() + .filter_map(|(device, (_, offsets))| { + offsets.get(&id).map(|id| (*device, *id)) + }) + .next() + .unwrap(); + write!(block, "backing_{}.byte_add(", device.name())?; + self.codegen_dynamic_constant(offset, block)?; + write!(block, " as usize)")?; + size = Some(self.codegen_type_size(ty)); } } - write!(block, ";\n")? + write!(block, ";\n")?; + if let Some(size) = size { + write!( + block, + " ::core::ptr::write_bytes({}, 0, {} as usize);\n", + self.get_value(id), + size + )?; + } } Node::Call { control: _, @@ -336,123 +370,36 @@ impl<'a> RTContext<'a> { ref dynamic_constants, ref args, } => { + // The device backends ensure that device functions have the + // same interface as AsyncRust functions. + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + write!( + block, + " {} = {}(", + self.get_value(id), + self.module.functions[callee_id.idx()].name + )?; + for (device, offset) in self + .backing_allocation + .into_iter() + .filter_map(|(device, (_, offsets))| offsets.get(&id).map(|id| (*device, *id))) + { + write!(block, "backing_{}.byte_add(", device.name())?; + self.codegen_dynamic_constant(offset, block)?; + write!(block, ")")? + } + for dc in dynamic_constants { + self.codegen_dynamic_constant(*dc, block)?; + write!(block, ", ")?; + } + for arg in args { + write!(block, "{}, ", self.get_value(*arg))?; + } let device = self.devices[callee_id.idx()]; - match device { - // The device backends ensure that device functions have the - // same C interface. - Device::LLVM | Device::CUDA => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); - - let device = match device { - Device::LLVM => "cpu", - Device::CUDA => "cuda", - _ => panic!(), - }; - - // First, get the raw pointers to collections that the - // device function takes as input. - let callee_objs = &self.collection_objects[&callee_id]; - for (idx, arg) in args.into_iter().enumerate() { - if let Some(obj) = callee_objs.param_to_object(idx) { - // Extract a raw pointer from the HerculesBox. - if callee_objs.is_mutated(obj) { - write!( - block, - " let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n", - idx, - self.get_value(*arg), - device - )?; - } else { - write!( - block, - " let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n", - idx, - self.get_value(*arg), - device - )?; - } - } else { - write!( - block, - " let arg_tmp{} = {};\n", - idx, - self.get_value(*arg) - )?; - } - } - - // Emit the call. - write!( - block, - " let call_tmp = unsafe {{ {}(", - self.module.functions[callee_id.idx()].name - )?; - for dc in dynamic_constants { - self.codegen_dynamic_constant(*dc, block)?; - write!(block, ", ")?; - } - for idx in 0..args.len() { - write!(block, "arg_tmp{}, ", idx)?; - } - write!(block, ") }};\n")?; - - // When a device function is called that returns a - // collection object, that object must have come from - // one of its parameters. Dynamically figure out which - // one it came from, so that we can move it to the slot - // of the output object. - let caller_objects = self.collection_objects[&self.func_id].objects(id); - if !caller_objects.is_empty() { - for (idx, arg) in args.into_iter().enumerate() { - if idx != 0 { - write!(block, " else\n")?; - } - write!( - block, - " if call_tmp == arg_tmp{} {{\n", - idx - )?; - write!( - block, - " {} = unsafe {{ {}.__clone() }};\n", - self.get_value(id), - self.get_value(*arg) - )?; - write!(block, " }}")?; - } - write!(block, " else {{\n")?; - write!(block, " panic!(\"HERCULES PANIC: Pointer returned from device function doesn't match an argument pointer.\");\n")?; - write!(block, " }}\n")?; - } else { - write!( - block, - " {} = call_tmp;\n", - self.get_value(id) - )?; - } - } - Device::AsyncRust => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); - write!( - block, - " {} = {}(", - self.get_value(id), - self.module.functions[callee_id.idx()].name - )?; - for dc in dynamic_constants { - self.codegen_dynamic_constant(*dc, block)?; - write!(block, ", ")?; - } - for arg in args { - if self.module.types[self.typing[arg.idx()].idx()].is_primitive() { - write!(block, "{}, ", self.get_value(*arg))?; - } else { - write!(block, "unsafe {{ {}.__clone() }}, ", self.get_value(*arg))?; - } - } - write!(block, ").await;\n")?; - } + if device == Device::AsyncRust { + write!(block, ").await;\n")?; + } else { + write!(block, ");\n")?; } } Node::Read { @@ -463,33 +410,7 @@ impl<'a> RTContext<'a> { let collect_ty = self.typing[collect.idx()]; let out_size = self.codegen_type_size(self.typing[id.idx()]); let offset = self.codegen_index_math(collect_ty, indices)?; - write!( - block, - " let mut read_offset_obj = unsafe {{ {}.__clone() }};\n", - self.get_value(collect) - )?; - write!( - block, - " unsafe {{ read_offset_obj.__offset({}, {}) }};\n", - offset, out_size, - )?; - if self.module.types[self.typing[id.idx()].idx()].is_primitive() { - write!( - block, - " {} = unsafe {{ *(read_offset_obj.__cpu_ptr() as *const _) }};\n", - self.get_value(id) - )?; - write!( - block, - " unsafe {{ read_offset_obj.__forget() }};\n", - )?; - } else { - write!( - block, - " {} = read_offset_obj;\n", - self.get_value(id) - )?; - } + todo!(); } Node::Write { collect, @@ -500,31 +421,7 @@ impl<'a> RTContext<'a> { let collect_ty = self.typing[collect.idx()]; let data_size = self.codegen_type_size(self.typing[data.idx()]); let offset = self.codegen_index_math(collect_ty, indices)?; - write!( - block, - " let mut write_offset_obj = unsafe {{ {}.__clone() }};\n", - self.get_value(collect) - )?; - write!(block, " let write_offset_ptr = unsafe {{ write_offset_obj.__cpu_ptr_mut().byte_add({}) }};\n", offset)?; - if self.module.types[self.typing[data.idx()].idx()].is_primitive() { - write!( - block, - " unsafe {{ *(write_offset_ptr as *mut _) = {} }};\n", - self.get_value(data) - )?; - } else { - write!( - block, - " unsafe {{ ::core::ptr::copy_nonoverlapping({}.__cpu_ptr(), write_offset_ptr as *mut _, {} as usize) }};\n", - self.get_value(data), - data_size, - )?; - } - write!( - block, - " {} = write_offset_obj;\n", - self.get_value(id), - )?; + todo!(); } _ => panic!( "PANIC: Can't lower {:?} in {}.", @@ -721,7 +618,7 @@ impl<'a> RTContext<'a> { acc_size = format!("::core::cmp::max({}, {})", acc_size, variant); } - // No alignment is necessary for the 1 byte discriminant. + // No alignment is necessary before the 1 byte discriminant. let total_align = get_type_alignment(&self.module.types, ty); format!( "(({} + 1 + {}) & !{})", @@ -743,6 +640,183 @@ impl<'a> RTContext<'a> { } } + /* + * Generate a runner object for this function. + */ + fn codegen_runner_object<W: Write>(&self, w: &mut W) -> Result<(), Error> { + // Figure out the devices for the parameters and the return value if + // they are collections and whether they should be immutable or mutable + // references. + let func = self.get_func(); + let mut param_devices = vec![None; func.param_types.len()]; + let mut return_device = None; + for idx in 0..func.nodes.len() { + match func.nodes[idx] { + Node::Parameter { index } => { + let device = self.node_colors.get(&NodeID::new(idx)); + assert!(param_devices[index].is_none() || param_devices[index] == device); + param_devices[index] = device; + } + Node::Return { control: _, data } => { + let device = self.node_colors.get(&data); + assert!(return_device.is_none() || return_device == device); + return_device = device; + } + _ => {} + } + } + let mut param_muts = vec![false; func.param_types.len()]; + let mut return_mut = true; + let objects = &self.collection_objects[&self.func_id]; + for idx in 0..func.param_types.len() { + if let Some(object) = objects.param_to_object(idx) + && objects.is_mutated(object) + { + param_muts[idx] = true; + } + } + for object in objects.returned_objects() { + if let Some(idx) = objects.origin(*object).try_parameter() + && !param_muts[idx] + { + return_mut = false; + } + } + + // Emit the type definition. A runner object owns its backing memory. + write!( + w, + "#[allow(non_camel_case_types)]\nstruct HerculesRunner_{} {{\n", + func.name + )?; + for (device, _) in self.backing_allocation { + write!(w, " backing_ptr_{}: *mut u8,\n", device.name(),)?; + write!(w, " backing_size_{}: usize,\n", device.name(),)?; + } + write!(w, "}}\n")?; + write!( + w, + "impl HerculesRunner_{} {{\n fn new() -> Self {{\n Self {{\n", + func.name + )?; + for (device, _) in self.backing_allocation { + write!( + w, + " backing_ptr_{}: ::core::ptr::null_mut(),\n backing_size_{}: 0,\n", + device.name(), + device.name() + )?; + } + write!(w, " }}\n }}\n")?; + write!(w, " async fn run<'a>(&'a mut self")?; + for idx in 0..func.num_dynamic_constants { + write!(w, ", dc_p{}: u64", idx)?; + } + for idx in 0..func.param_types.len() { + if self.module.types[func.param_types[idx].idx()].is_primitive() { + write!(w, ", p{}: {}", idx, self.get_type(func.param_types[idx]))?; + } else { + let device = match param_devices[idx] { + Some(Device::LLVM) => "CPU", + Some(Device::CUDA) => "CUDA", + // For parameters that are unused, it doesn't really matter + // what device is required, so just pick CPU for now. + None => "CPU", + _ => panic!(), + }; + let mutability = if param_muts[idx] { "Mut" } else { "" }; + write!( + w, + ", p{}: ::hercules_rt::Hercules{}Ref{}<'a>", + idx, device, mutability + )?; + } + } + if self.module.types[func.return_type.idx()].is_primitive() { + write!(w, ") -> {} {{\n", self.get_type(func.return_type))?; + } else { + let device = match return_device { + Some(Device::LLVM) => "CPU", + Some(Device::CUDA) => "CUDA", + _ => panic!(), + }; + let mutability = if return_mut { "Mut" } else { "" }; + write!( + w, + ") -> ::hercules_rt::Hercules{}Ref{}<'a> {{\n", + device, mutability + )?; + } + write!(w, " unsafe {{\n")?; + for (device, (total, _)) in self.backing_allocation { + write!(w, " let size = ")?; + self.codegen_dynamic_constant(*total, w)?; + write!( + w, + " as usize;\n if self.backing_size_{} < size {{\n", + device.name() + )?; + write!(w, " ::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});\n", device.name(), device.name(), device.name())?; + write!( + w, + " self.backing_size_{} = size;\n", + device.name() + )?; + write!(w, " self.backing_ptr_{} = ::hercules_rt::__{}_alloc(self.backing_size_{});\n", device.name(), device.name(), device.name())?; + write!(w, " }}\n")?; + } + for idx in 0..func.param_types.len() { + if !self.module.types[func.param_types[idx].idx()].is_primitive() { + write!(w, " let p{} = p{}.__ptr();\n", idx, idx)?; + } + } + write!(w, " let ret = {}(", func.name)?; + for (device, _) in self.backing_allocation { + write!(w, "self.backing_ptr_{}, ", device.name())?; + } + for idx in 0..func.num_dynamic_constants { + write!(w, "dc_p{}, ", idx)?; + } + for idx in 0..func.param_types.len() { + write!(w, "p{}, ", idx)?; + } + write!(w, ").await;\n")?; + if self.module.types[func.return_type.idx()].is_primitive() { + write!(w, " ret\n")?; + } else { + let device = match return_device { + Some(Device::LLVM) => "CPU", + Some(Device::CUDA) => "CUDA", + _ => panic!(), + }; + let mutability = if return_mut { "Mut" } else { "" }; + write!( + w, + " ::hercules_rt::Hercules{}Ref{}::__from_parts(ret, {} as usize)\n", + device, + mutability, + self.codegen_type_size(func.return_type) + )?; + } + write!(w, " }}\n }}\n")?; + write!( + w, + "}}\nimpl Drop for HerculesRunner_{} {{\n #[allow(unused_unsafe)]\n fn drop(&mut self) {{\n unsafe {{\n", + func.name + )?; + for (device, _) in self.backing_allocation { + write!( + w, + " ::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});\n", + device.name(), + device.name(), + device.name() + )?; + } + write!(w, " }}\n }}\n}}\n")?; + Ok(()) + } + fn get_func(&self) -> &Function { &self.module.functions[self.func_id.idx()] } @@ -754,33 +828,9 @@ impl<'a> RTContext<'a> { fn get_type(&self, id: TypeID) -> &'static str { convert_type(&self.module.types[id.idx()]) } - - fn device_get_type(&self, id: TypeID) -> &'static str { - device_convert_type(&self.module.types[id.idx()]) - } } fn convert_type(ty: &Type) -> &'static str { - match ty { - Type::Boolean => "bool", - Type::Integer8 => "i8", - Type::Integer16 => "i16", - Type::Integer32 => "i32", - Type::Integer64 => "i64", - Type::UnsignedInteger8 => "u8", - Type::UnsignedInteger16 => "u16", - Type::UnsignedInteger32 => "u32", - Type::UnsignedInteger64 => "u64", - Type::Float32 => "f32", - Type::Float64 => "f64", - Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { - "::hercules_rt::HerculesBox<'a>" - } - _ => panic!(), - } -} - -fn device_convert_type(ty: &Type) -> &'static str { match ty { Type::Boolean => "bool", Type::Integer8 => "i8", diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs new file mode 100644 index 00000000..46a1af02 --- /dev/null +++ b/hercules_ir/src/device.rs @@ -0,0 +1,133 @@ +use std::collections::BTreeSet; +use std::mem::take; + +use crate::*; + +/* + * Top level function to definitively place functions onto devices. A function + * may store a device placement, but only optionally - this function assigns + * devices to the rest of the functions. + */ +pub fn device_placement(functions: &Vec<Function>, callgraph: &CallGraph) -> Vec<Device> { + let mut devices = vec![]; + + for (idx, function) in functions.into_iter().enumerate() { + if let Some(device) = function.device { + devices.push(device); + } else if function.entry || callgraph.num_callees(FunctionID::new(idx)) != 0 { + devices.push(Device::AsyncRust); + } else { + devices.push(Device::LLVM); + } + } + + devices +} + +pub type FunctionObjectDeviceDemands = Vec<BTreeSet<Device>>; +pub type ObjectDeviceDemands = Vec<FunctionObjectDeviceDemands>; + +/* + * This analysis figures out which device each collection object may be on. At + * first, an object may need to be on different devices at different times. This + * is fine during optimization. + */ +pub fn object_device_demands( + functions: &Vec<Function>, + types: &Vec<Type>, + typing: &ModuleTyping, + callgraph: &CallGraph, + objects: &CollectionObjects, + devices: &Vec<Device>, +) -> ObjectDeviceDemands { + // An object is "demanded" on a device when: + // 1. The object is used by a primitive read node or write node in a device + // function. This includes objects on the `data` input to write nodes. + // Non-primitive reads don't demand an object on a device since they are + // lowered to pointer math and no actual memory transfers. + // 2. The object is passed as input to a call node where the corresponding + // object in the callee is demanded on a device. + // 3. The object is returned from a call node where the corresponding object + // in the callee is demanded on a device. + // Note that reads and writes in a RT function don't induce a device demand. + // This is because RT functions can call device functions as necessary to + // arbitrarily move data onto / off of devices (though this may be slow). + // Traverse the functions in a module in reverse topological order, since + // the analysis of a function depends on all functions it calls. + let mut demands: ObjectDeviceDemands = vec![vec![]; functions.len()]; + let topo = callgraph.topo(); + + for func_id in topo { + let function = &functions[func_id.idx()]; + let typing = &typing[func_id.idx()]; + let device = devices[func_id.idx()]; + + demands[func_id.idx()].resize(objects[&func_id].num_objects(), BTreeSet::new()); + match device { + Device::LLVM | Device::CUDA => { + for (idx, node) in function.nodes.iter().enumerate() { + // Condition #1. + match node { + Node::Read { + collect, + indices: _, + } if types[typing[idx].idx()].is_primitive() => { + for object in objects[&func_id].objects(*collect) { + demands[func_id.idx()][object.idx()].insert(device); + } + } + Node::Write { + collect, + data, + indices: _, + } => { + for object in objects[&func_id] + .objects(*collect) + .into_iter() + .chain(objects[&func_id].objects(*data).into_iter()) + { + demands[func_id.idx()][object.idx()].insert(device); + } + } + _ => {} + } + } + } + Device::AsyncRust => { + for (idx, node) in function.nodes.iter().enumerate() { + if let Node::Call { + control: _, + function: callee, + dynamic_constants: _, + args, + } = node + { + // Condition #2. + for (param_idx, arg) in args.into_iter().enumerate() { + if let Some(callee_obj) = objects[callee].param_to_object(param_idx) { + let callee_demands = + take(&mut demands[callee.idx()][callee_obj.idx()]); + for object in objects[&func_id].objects(*arg) { + demands[func_id.idx()][object.idx()] + .extend(callee_demands.iter()); + } + demands[callee.idx()][callee_obj.idx()] = callee_demands; + } + } + + // Condition #3. + for callee_obj in objects[callee].returned_objects() { + let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]); + for object in objects[&func_id].objects(NodeID::new(idx)) { + demands[func_id.idx()][object.idx()].extend(callee_demands.iter()); + } + demands[callee.idx()][callee_obj.idx()] = callee_demands; + } + } + } + } + } + } + + demands +} diff --git a/hercules_ir/src/dom.rs b/hercules_ir/src/dom.rs index 2c0f085d..2cfd3b09 100644 --- a/hercules_ir/src/dom.rs +++ b/hercules_ir/src/dom.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; +use bitvec::prelude::*; + use crate::*; /* @@ -304,3 +306,55 @@ pub fn postdominator(subgraph: Subgraph, fake_root: NodeID) -> DomTree { // root as the root of the dominator analysis. dominator(&reversed_subgraph, fake_root) } + +/* + * Check if a data node dominates a control node. This involves checking all + * immediate control uses to see if they dominate the queried control node. + */ +pub fn does_data_dom_control( + function: &Function, + data: NodeID, + control: NodeID, + dom: &DomTree, +) -> bool { + let mut stack = vec![data]; + let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; + visited.set(data.idx(), true); + + while let Some(pop) = stack.pop() { + let node = &function.nodes[pop.idx()]; + + let imm_control = match node { + Node::Phi { control, data: _ } + | Node::Reduce { + control, + init: _, + reduct: _, + } + | Node::Call { + control, + function: _, + dynamic_constants: _, + args: _, + } => Some(*control), + _ if node.is_control() => Some(pop), + _ => { + for u in get_uses(node).as_ref() { + if !visited[u.idx()] { + visited.set(u.idx(), true); + stack.push(*u); + } + } + None + } + }; + + if let Some(imm_control) = imm_control + && !dom.does_dom(imm_control, control) + { + return false; + } + } + + true +} diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 22cd0beb..8efabd7a 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -18,6 +18,7 @@ pub fn xdot_module( reverse_postorders: &Vec<Vec<NodeID>>, doms: Option<&Vec<DomTree>>, fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>, + devices: Option<&Vec<Device>>, bbs: Option<&Vec<BasicBlocks>>, ) { let mut tmp_path = temp_dir(); @@ -31,6 +32,7 @@ pub fn xdot_module( &reverse_postorders, doms, fork_join_maps, + devices, bbs, &mut contents, ) @@ -53,6 +55,7 @@ pub fn write_dot<W: Write>( reverse_postorders: &Vec<Vec<NodeID>>, doms: Option<&Vec<DomTree>>, fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>, + devices: Option<&Vec<Device>>, bbs: Option<&Vec<BasicBlocks>>, w: &mut W, ) -> std::fmt::Result { @@ -65,7 +68,12 @@ pub fn write_dot<W: Write>( for (idx, id) in reverse_postorder.iter().enumerate() { reverse_postorder_node_numbers[id.idx()] = idx; } - write_subgraph_header(function_id, module, w)?; + write_subgraph_header( + function_id, + module, + devices.map(|devices| devices[function_id.idx()]), + w, + )?; // Step 1: draw IR graph itself. This includes all IR nodes and all edges // between IR nodes. @@ -168,7 +176,7 @@ pub fn write_dot<W: Write>( } } - // Step 4: draw basic block edges in indigo. + // Step 4: draw basic block edges in blue. if let Some(bbs) = bbs { let bbs = &bbs[function_id.idx()].0; for (idx, bb) in bbs.into_iter().enumerate() { @@ -179,7 +187,7 @@ pub fn write_dot<W: Write>( *bb, function_id, true, - "indigo", + "lightslateblue", "dotted", &module, w, @@ -204,6 +212,7 @@ fn write_digraph_header<W: Write>(w: &mut W) -> std::fmt::Result { fn write_subgraph_header<W: Write>( function_id: FunctionID, module: &Module, + device: Option<Device>, w: &mut W, ) -> std::fmt::Result { let function = &module.functions[function_id.idx()]; @@ -219,8 +228,8 @@ fn write_subgraph_header<W: Write>( } else { write!(w, "label=\"{}\"\n", function.name)?; } - let color = match function.device { - Some(Device::LLVM) => "paleturquoise1", + let color = match device.or(function.device) { + Some(Device::LLVM) => "slategray1", Some(Device::CUDA) => "darkseagreen1", Some(Device::AsyncRust) => "peachpuff1", None => "ivory2", diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index 130bc2ed..1d089d76 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -1,7 +1,5 @@ use std::collections::{HashMap, HashSet}; -use bitvec::prelude::*; - use crate::*; /* @@ -75,55 +73,3 @@ pub fn compute_fork_join_nesting( }) .collect() } - -/* - * Check if a data node dominates a control node. This involves checking all - * immediate control uses to see if they dominate the queried control node. - */ -pub fn does_data_dom_control( - function: &Function, - data: NodeID, - control: NodeID, - dom: &DomTree, -) -> bool { - let mut stack = vec![data]; - let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; - visited.set(data.idx(), true); - - while let Some(pop) = stack.pop() { - let node = &function.nodes[pop.idx()]; - - let imm_control = match node { - Node::Phi { control, data: _ } - | Node::Reduce { - control, - init: _, - reduct: _, - } - | Node::Call { - control, - function: _, - dynamic_constants: _, - args: _, - } => Some(*control), - _ if node.is_control() => Some(pop), - _ => { - for u in get_uses(node).as_ref() { - if !visited[u.idx()] { - visited.set(u.idx(), true); - stack.push(*u); - } - } - None - } - }; - - if let Some(imm_control) = imm_control - && !dom.does_dom(imm_control, control) - { - return false; - } - } - - true -} diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 46d35f25..d8a124e2 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -332,7 +332,7 @@ pub enum Schedule { * The authoritative enumeration of supported backends. Multiple backends may * correspond to the same kind of hardware. */ -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum Device { LLVM, CUDA, @@ -1710,6 +1710,16 @@ impl Intrinsic { } } +impl Device { + pub fn name(&self) -> &'static str { + match self { + Device::LLVM => "cpu", + Device::CUDA => "cuda", + Device::AsyncRust => "rt", + } + } +} + /* * Rust things to make newtyped IDs usable. */ diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index 32bbf631..85dc277f 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -11,6 +11,7 @@ pub mod callgraph; pub mod collections; pub mod dataflow; pub mod def_use; +pub mod device; pub mod dom; pub mod dot; pub mod fork_join_analysis; @@ -26,6 +27,7 @@ pub use crate::callgraph::*; pub use crate::collections::*; pub use crate::dataflow::*; pub use crate::def_use::*; +pub use crate::device::*; pub use crate::dom::*; pub use crate::dot::*; pub use crate::fork_join_analysis::*; diff --git a/hercules_opt/src/device_placement.rs b/hercules_opt/src/device_placement.rs new file mode 100644 index 00000000..2badd69d --- /dev/null +++ b/hercules_opt/src/device_placement.rs @@ -0,0 +1,3 @@ +use hercules_ir::ir::*; + +use crate::*; diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 60745f21..8b90710e 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -335,6 +335,14 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { self.function_id } + pub fn get_types(&self) -> Ref<'_, Vec<Type>> { + self.types.borrow() + } + + pub fn get_constants(&self) -> Ref<'_, Vec<Constant>> { + self.constants.borrow() + } + pub fn get_dynamic_constants(&self) -> Ref<'_, Vec<DynamicConstant>> { self.dynamic_constants.borrow() } diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 1323d5a0..9929f6d6 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1,3 +1,4 @@ +use std::cell::Ref; use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque}; use std::iter::{empty, once, zip, FromIterator}; @@ -5,6 +6,7 @@ use bitvec::prelude::*; use either::Either; use union_find::{QuickFindUf, UnionBySize, UnionFind}; +use hercules_cg::*; use hercules_ir::*; use crate::*; @@ -35,8 +37,35 @@ use crate::*; * liveness analysis result, so every spill restarts the process of checking for * spills. Once no more spills are found, the process terminates. When a spill * is found, the basic block assignments, and all the other analyses, are not - * necessarily valid anymore, so this function is called in a loop in pass.rs - * until no more spills are found. + * necessarily valid anymore, so this function is called in a loop in the pass + * manager until no more spills are found. + * + * GCM is additionally complicated by the need to generate code that references + * objects across multiple devices. In particular, GCM makes sure that every + * object lives on exactly one device, so that references to that object always + * live on a single device. Additionally, GCM makes sure that the objects that a + * node may produce are all on the same device, so that a pointer produced by, + * for example, a select node can only refer to memory on a single device. Extra + * collection constants and potentially inter-device copies are inserted as + * necessary to make sure this is true - an inter-device copy is represented by + * a write where the `collect` and `data` inputs are on different devices. This + * is only valid in RT functions - it is asserted that this isn't necessary in + * device functions. This process "colors" the nodes in the function. + * + * GCM has one final responsibility - object allocation. Each Hercules function + * receives a pointer to a "backing" memory where collection constants live. The + * backing memory a function receives is for the constants in that function and + * the constants of every called function. Concretely, a function will pass a + * sub-regions of its backing memory to a callee, which during the call is that + * function's backing memory. Object allocation consists of finding the required + * sizes of all collection constants and functions in terms of dynamic constants + * (dynamic constant math is expressive enough to represent sizes of types, + * which is very convenient) and determining the concrete offsets into the + * backing memory where constants and callee sub-regions live. When two users of + * backing memory are never live at once, they may share backing memory. This is + * done after nodes are given a single device color, since we need to know what + * values are on what devices before we can allocate them to backing memory, + * since there are separate backing memories per-device. */ pub fn gcm( editor: &mut FunctionEditor, @@ -48,7 +77,10 @@ pub fn gcm( fork_join_map: &HashMap<NodeID, NodeID>, loops: &LoopTree, objects: &CollectionObjects, -) -> Option<BasicBlocks> { + devices: &Vec<Device>, + object_device_demands: &FunctionObjectDeviceDemands, + backing_allocations: &BackingAllocations, +) -> Option<(BasicBlocks, FunctionNodeColors, FunctionBackingAllocation)> { let bbs = basic_blocks( editor.func(), editor.func_id(), @@ -59,11 +91,69 @@ pub fn gcm( fork_join_map, objects, ); - if spill_clones(editor, typing, control_subgraph, objects, &bbs) { - None - } else { - Some(bbs) + + let liveness = liveness_dataflow( + editor.func(), + editor.func_id(), + control_subgraph, + objects, + &bbs, + ); + + if spill_clones(editor, typing, control_subgraph, objects, &bbs, &liveness) { + return None; } + + let func_id = editor.func_id(); + let Some(node_colors) = color_nodes( + editor, + reverse_postorder, + &objects[&func_id], + &object_device_demands, + ) else { + return None; + }; + + let device = devices[func_id.idx()]; + match device { + Device::LLVM | Device::CUDA => { + // Check that every object that has a demand in this function are + // only demanded on this device. + for demands in object_device_demands { + assert!(demands.is_empty() || (demands.len() == 1 && demands.contains(&device))) + } + } + Device::AsyncRust => { + // Check that every object that has a demand in this function only + // has a demand from one device. + for demands in object_device_demands { + assert!(demands.len() <= 1); + } + } + } + + let mut alignments = vec![]; + Ref::map(editor.get_types(), |types| { + for idx in 0..types.len() { + if types[idx].is_control() { + alignments.push(0); + } else { + alignments.push(get_type_alignment(types, TypeID::new(idx))); + } + } + &() + }); + + let backing_allocation = object_allocation( + editor, + typing, + &node_colors, + &alignments, + &liveness, + backing_allocations, + ); + + Some((bbs, node_colors, backing_allocation)) } /* @@ -580,8 +670,6 @@ fn mutating_writes<'a>( } } -type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>; - /* * Top level function to find implicit clones that need to be spilled. Returns * whether a clone was spilled, in which case the whole scheduling process must @@ -593,19 +681,9 @@ fn spill_clones( control_subgraph: &Subgraph, objects: &CollectionObjects, bbs: &BasicBlocks, + liveness: &Liveness, ) -> bool { - // Step 1: compute a liveness analysis of collection values in the IR. This - // requires a dataflow analysis over the scheduled IR, which is not a common - // need in Hercules, so just hardcode the analysis. - let liveness = liveness_dataflow( - editor.func(), - editor.func_id(), - control_subgraph, - objects, - bbs, - ); - - // Step 2: compute an interference graph from the liveness result. This + // Step 1: compute an interference graph from the liveness result. This // graph contains a vertex per node ID producing a collection value and an // edge per pair of node IDs that interfere. Nodes A and B interfere if node // A is defined right above a point where node B is live and A != B. Extra @@ -652,7 +730,7 @@ fn spill_clones( } } - // Step 3: filter edges (A, B) to just see edges where A uses B and A + // Step 2: filter edges (A, B) to just see edges where A uses B and A // mutates B. These are the edges that may require a spill. let mut spill_edges = edges.into_iter().filter(|(a, b)| { mutating_writes(editor.func(), *a, objects).any(|id| id == *b) @@ -664,7 +742,7 @@ fn spill_clones( || editor.func().nodes[a.idx()].is_reduce())) }); - // Step 4: if there is a spill edge, spill it and return true. Otherwise, + // Step 3: if there is a spill edge, spill it and return true. Otherwise, // return false. if let Some((user, obj)) = spill_edges.next() { // Figure out the most immediate dominating region for every basic @@ -818,6 +896,8 @@ fn spill_clones( } } +type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>; + /* * Liveness dataflow analysis on scheduled Hercules IR. Just look at nodes that * involve collections. @@ -938,3 +1018,179 @@ fn liveness_dataflow( } } } + +/* + * Determine what device each node produces a collection onto. Insert inter- + * device clones when a single node may potentially be on different devices. + */ +fn color_nodes( + editor: &mut FunctionEditor, + reverse_postorder: &Vec<NodeID>, + objects: &FunctionCollectionObjects, + object_device_demands: &FunctionObjectDeviceDemands, +) -> Option<FunctionNodeColors> { + // First, try to give each node a single color. + let mut colors = BTreeMap::new(); + let mut bad_node = None; + 'nodes: for id in reverse_postorder { + let mut device = None; + for object in objects.objects(*id) { + for demand in object_device_demands[object.idx()].iter() { + if let Some(device) = device + && device != *demand + { + bad_node = Some(id); + break 'nodes; + } + device = Some(*demand); + } + } + if let Some(device) = device { + colors.insert(*id, device); + } else { + assert!(objects.objects(*id).is_empty(), "PANIC: Found an object with no device demands. This is technically possible and is easily supported by just picking an arbitrary device for this object. This assert exists because I'm curious to see where this will be needed first, and if that use is frivolous or not."); + } + } + if bad_node.is_some() { + todo!("Deal with inter-device demands.") + } + + Some(colors) +} + +fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> DynamicConstantID { + assert_ne!(align, 0); + if align != 1 { + let align_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align)); + let align_m1_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align - 1)); + acc = edit.add_dynamic_constant(DynamicConstant::Add(acc, align_m1_dc)); + acc = edit.add_dynamic_constant(DynamicConstant::Div(acc, align_dc)); + acc = edit.add_dynamic_constant(DynamicConstant::Mul(acc, align_dc)); + } + acc +} + +/* + * Determine the size of a type in terms of dynamic constants. + */ +fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> DynamicConstantID { + let ty = edit.get_type(ty_id).clone(); + let size = match ty { + Type::Control => panic!(), + Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => { + edit.add_dynamic_constant(DynamicConstant::Constant(1)) + } + Type::Integer16 | Type::UnsignedInteger16 => { + edit.add_dynamic_constant(DynamicConstant::Constant(2)) + } + Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => { + edit.add_dynamic_constant(DynamicConstant::Constant(4)) + } + Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => { + edit.add_dynamic_constant(DynamicConstant::Constant(8)) + } + Type::Product(fields) => { + // The layout of product types is like the C-style layout. + let mut acc_size = edit.add_dynamic_constant(DynamicConstant::Constant(0)); + for field in fields { + // Round up to the alignment of the field, then add the size of + // the field. + let field_size = type_size(edit, field, alignments); + acc_size = align(edit, acc_size, alignments[field.idx()]); + acc_size = edit.add_dynamic_constant(DynamicConstant::Add(acc_size, field_size)); + } + // Finally, round up to the alignment of the whole product, since + // the size needs to be a multiple of the alignment. + acc_size = align(edit, acc_size, alignments[ty_id.idx()]); + acc_size + } + Type::Summation(variants) => { + // A summation holds every variant in the same memory. + let mut acc_size = edit.add_dynamic_constant(DynamicConstant::Constant(0)); + for variant in variants { + // Pick the size of the largest variant, since that's the most + // memory we would need. + let variant_size = type_size(edit, variant, alignments); + acc_size = edit.add_dynamic_constant(DynamicConstant::Max(acc_size, variant_size)); + } + // Add one byte for the discriminant and align the whole summation. + let one = edit.add_dynamic_constant(DynamicConstant::Constant(1)); + acc_size = edit.add_dynamic_constant(DynamicConstant::Add(acc_size, one)); + acc_size = align(edit, acc_size, alignments[ty_id.idx()]); + acc_size + } + Type::Array(elem, bounds) => { + // The layout of an array is row-major linear in memory. + let mut acc_size = type_size(edit, elem, alignments); + for bound in bounds { + acc_size = edit.add_dynamic_constant(DynamicConstant::Mul(acc_size, bound)); + } + acc_size + } + }; + size +} + +/* + * Allocate objects in a function. Relies on the allocations of all called + * functions. + */ +fn object_allocation( + editor: &mut FunctionEditor, + typing: &Vec<TypeID>, + node_colors: &FunctionNodeColors, + alignments: &Vec<usize>, + liveness: &Liveness, + backing_allocations: &BackingAllocations, +) -> FunctionBackingAllocation { + let mut fba = BTreeMap::new(); + + let node_ids = editor.node_ids(); + editor.edit(|mut edit| { + // For now, just allocate each object to its own slot. + let zero = edit.add_dynamic_constant(DynamicConstant::Constant(0)); + for id in node_ids { + match *edit.get_node(id) { + Node::Constant { id: _ } => { + if !edit.get_type(typing[id.idx()]).is_primitive() { + let device = node_colors[&id]; + let (total, offsets) = + fba.entry(device).or_insert_with(|| (zero, BTreeMap::new())); + *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]); + offsets.insert(id, *total); + let type_size = type_size(&mut edit, typing[id.idx()], alignments); + *total = edit.add_dynamic_constant(DynamicConstant::Add(*total, type_size)); + } + } + Node::Call { + control: _, + function: callee, + dynamic_constants: _, + args: _, + } => { + for device in BACKED_DEVICES { + if let Some(callee_backing_size) = backing_allocations[&callee] + .get(&device) + .map(|(callee_total, _)| *callee_total) + { + let (total, offsets) = + fba.entry(device).or_insert_with(|| (zero, BTreeMap::new())); + // We don't know the alignment requirement of the memory + // in the callee, so just assume the largest alignment. + *total = align(&mut edit, *total, LARGEST_ALIGNMENT); + offsets.insert(id, *total); + *total = edit.add_dynamic_constant(DynamicConstant::Add( + *total, + callee_backing_size, + )); + } + } + } + _ => {} + } + } + Ok(edit) + }); + + fba +} diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 0b10bdae..4a90f698 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -4,6 +4,7 @@ pub mod ccp; pub mod crc; pub mod dce; pub mod delete_uncalled; +pub mod device_placement; pub mod editor; pub mod float_collections; pub mod fork_concat_split; @@ -27,6 +28,7 @@ pub use crate::ccp::*; pub use crate::crc::*; pub use crate::dce::*; pub use crate::delete_uncalled::*; +pub use crate::device_placement::*; pub use crate::editor::*; pub use crate::float_collections::*; pub use crate::fork_concat_split::*; diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 60d3470e..db2dee77 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -1,337 +1,252 @@ -use std::alloc::{alloc, alloc_zeroed, dealloc, Layout}; +use std::alloc::{alloc, dealloc, Layout}; use std::marker::PhantomData; use std::ptr::{copy_nonoverlapping, NonNull}; -use std::slice::from_raw_parts; -use std::sync::atomic::{AtomicUsize, Ordering}; - -#[cfg(feature = "cuda")] -extern "C" { - fn cuda_alloc(size: usize) -> *mut u8; - fn cuda_alloc_zeroed(size: usize) -> *mut u8; - fn cuda_dealloc(ptr: *mut u8); - fn copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); - fn copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); - fn copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); -} +use std::slice::{from_raw_parts, from_raw_parts_mut}; /* - * Each object needs to get assigned a unique ID. + * Define supporting types, functions, and macros for Hercules RT functions. For + * a more in-depth discussion of the design of these utilities, see hercules_cg/ + * src/rt.rs (the RT backend). */ -static NUM_OBJECTS: AtomicUsize = AtomicUsize::new(1); -/* - * An in-memory collection object that can be used by functions compiled by the - * Hercules compiler. Memory objects can be in these states: - * - * 1. Shared CPU - the object has a shared reference to some CPU memory, usually - * from the programmer using the Hercules RT API. - * 2. Exclusive CPU - the object has an exclusive reference to some CPU memory, - * usually from the programmer using the Hercules RT API. - * 3. Owned CPU - the object owns some allocated CPU memory. - * 4. Owned GPU - the object owns some allocated GPU memory. - * - * A single object can be in some combination of these objects at the same time. - * Only some combinations are valid, because only some combinations are - * reachable. Under this assumption, we can model an object's placement as a - * state machine, where states are combinations of the aforementioned states, - * and actions are requests on the CPU or GPU, immutably or mutably. Here's the - * state transition table: - * - * Shared CPU = CS - * Exclusive CPU = CE - * Owned CPU = CO - * Owned GPU = GO - * - * CPU Mut CPU GPU Mut GPU - * *--------------------------------------- - * CS | CS CO CS,GO GO - * CE | CE CE CE,GO GO - * CO | CO CO CO,GO GO - * GO | CO CO GO GO - * CS,GO | CS,GO CO CS,GO GO - * CE,GO | CE,GO CE CE,GO GO - * CO,GO | CO,GO CO CO,GO GO - * | - * - * A HerculesBox cannot be cloned, because it may have be a mutable reference to - * some CPU memory. - */ -#[derive(Debug)] -pub struct HerculesBox<'a> { - cpu_shared: Option<NonOwned<'a>>, - cpu_exclusive: Option<NonOwned<'a>>, - cpu_owned: Option<Owned>, +pub unsafe fn __cpu_alloc(size: usize) -> *mut u8 { + alloc(Layout::from_size_align(size, 16).unwrap()) +} - #[cfg(feature = "cuda")] - cuda_owned: Option<Owned>, +pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) { + dealloc(ptr, Layout::from_size_align(size, 16).unwrap()) +} - size: usize, - id: usize, +pub unsafe fn __copy_cpu_to_cpu(dst: *mut u8, src: *mut u8, size: usize) { + copy_nonoverlapping(src, dst, size); +} + +#[cfg(feature = "cuda")] +extern "C" { + pub fn __cuda_alloc(size: usize) -> *mut u8; + pub fn __cuda_dealloc(ptr: *mut u8, size: usize); + pub fn __copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + pub fn __copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); + pub fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); } #[derive(Clone, Debug)] -struct NonOwned<'a> { +pub struct HerculesCPURef<'a> { ptr: NonNull<u8>, - offset: usize, + size: usize, _phantom: PhantomData<&'a u8>, } +#[derive(Debug)] +pub struct HerculesCPURefMut<'a> { + ptr: NonNull<u8>, + size: usize, + _phantom: PhantomData<&'a u8>, +} + +#[cfg(feature = "cuda")] #[derive(Clone, Debug)] -struct Owned { +pub struct HerculesCUDARef<'a> { ptr: NonNull<u8>, - alloc_size: usize, - offset: usize, + size: usize, + _phantom: PhantomData<&'a u8>, } -impl<'b, 'a: 'b> HerculesBox<'a> { - pub fn from_slice<T>(slice: &'a [T]) -> Self { - let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }; - let size = slice.len() * size_of::<T>(); - let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed); - HerculesBox { - cpu_shared: Some(NonOwned { - ptr, - offset: 0, - _phantom: PhantomData, - }), - cpu_exclusive: None, - cpu_owned: None, - - #[cfg(feature = "cuda")] - cuda_owned: None, +#[cfg(feature = "cuda")] +#[derive(Debug)] +pub struct HerculesCUDARefMut<'a> { + ptr: NonNull<u8>, + size: usize, + _phantom: PhantomData<&'a u8>, +} - size, - id, - } - } +#[cfg(feature = "cuda")] +#[derive(Debug)] +pub struct CUDABox { + ptr: NonNull<u8>, + size: usize, +} - pub fn from_slice_mut<T>(slice: &'a mut [T]) -> Self { +impl<'a> HerculesCPURef<'a> { + pub fn from_slice<T>(slice: &'a [T]) -> Self { let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }; let size = slice.len() * size_of::<T>(); - let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed); - HerculesBox { - cpu_shared: None, - cpu_exclusive: Some(NonOwned { - ptr, - offset: 0, - _phantom: PhantomData, - }), - cpu_owned: None, - - #[cfg(feature = "cuda")] - cuda_owned: None, - + Self { + ptr, size, - id, + _phantom: PhantomData, } } - pub fn as_slice<T>(&'b mut self) -> &'b [T] { + pub fn as_slice<T>(self) -> &'a [T] { + let ptr = self.ptr.as_ptr() as *const T; assert_eq!(self.size % size_of::<T>(), 0); - unsafe { from_raw_parts(self.__cpu_ptr() as *const T, self.size / size_of::<T>()) } + assert!(ptr.is_aligned()); + unsafe { from_raw_parts(ptr, self.size / size_of::<T>()) } } - unsafe fn get_cpu_ptr(&self) -> Option<NonNull<u8>> { - self.cpu_owned - .as_ref() - .map(|obj| obj.ptr.byte_add(obj.offset)) - .or(self - .cpu_exclusive - .as_ref() - .map(|obj| obj.ptr.byte_add(obj.offset))) - .or(self - .cpu_shared - .as_ref() - .map(|obj| obj.ptr.byte_add(obj.offset))) + pub unsafe fn __ptr(&self) -> *mut u8 { + self.ptr.as_ptr() as *mut u8 } - #[cfg(feature = "cuda")] - unsafe fn get_cuda_ptr(&self) -> Option<NonNull<u8>> { - self.cuda_owned - .as_ref() - .map(|obj| obj.ptr.byte_add(obj.offset)) + pub unsafe fn __size(&self) -> usize { + self.size } - unsafe fn allocate_cpu(&mut self) -> NonNull<u8> { - if let Some(obj) = self.cpu_owned.as_ref() { - obj.ptr - } else { - let ptr = - NonNull::new(alloc(Layout::from_size_align_unchecked(self.size, 16))).unwrap(); - self.cpu_owned = Some(Owned { - ptr, - alloc_size: self.size, - offset: 0, - }); - ptr + pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self { + Self { + ptr: NonNull::new(ptr).unwrap(), + size, + _phantom: PhantomData, } } +} - #[cfg(feature = "cuda")] - unsafe fn allocate_cuda(&mut self) -> NonNull<u8> { - if let Some(obj) = self.cuda_owned.as_ref() { - obj.ptr - } else { - let ptr = NonNull::new(cuda_alloc(self.size)).unwrap(); - self.cuda_owned = Some(Owned { - ptr, - alloc_size: self.size, - offset: 0, - }); - ptr +impl<'a> HerculesCPURefMut<'a> { + pub fn from_slice<T>(slice: &'a mut [T]) -> Self { + let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }; + let size = slice.len() * size_of::<T>(); + Self { + ptr, + size, + _phantom: PhantomData, } } - unsafe fn deallocate_cpu(&mut self) { - if let Some(obj) = self.cpu_owned.take() { - dealloc( - obj.ptr.as_ptr(), - Layout::from_size_align_unchecked(obj.alloc_size, 16), - ); - } + pub fn as_slice<T>(self) -> &'a mut [T] { + let ptr = self.ptr.as_ptr() as *mut T; + assert_eq!(self.size % size_of::<T>(), 0); + assert!(ptr.is_aligned()); + unsafe { from_raw_parts_mut(ptr, self.size / size_of::<T>()) } } - #[cfg(feature = "cuda")] - unsafe fn deallocate_cuda(&mut self) { - if let Some(obj) = self.cuda_owned.take() { - cuda_dealloc(obj.ptr.as_ptr()); + pub fn as_ref(self) -> HerculesCPURef<'a> { + HerculesCPURef { + ptr: self.ptr, + size: self.size, + _phantom: PhantomData, } } - pub unsafe fn __zeros(size: u64) -> Self { - let size = size as usize; - let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed); - HerculesBox { - cpu_shared: None, - cpu_exclusive: None, - cpu_owned: Some(Owned { - ptr: NonNull::new(alloc_zeroed(Layout::from_size_align_unchecked(size, 16))) - .unwrap(), - alloc_size: size, - offset: 0, - }), - - #[cfg(feature = "cuda")] - cuda_owned: None, + pub unsafe fn __ptr(&self) -> *mut u8 { + self.ptr.as_ptr() + } + + pub unsafe fn __size(&self) -> usize { + self.size + } + pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self { + Self { + ptr: NonNull::new(ptr).unwrap(), size, - id, + _phantom: PhantomData, } } +} - pub unsafe fn __null() -> Self { - HerculesBox { - cpu_shared: None, - cpu_exclusive: None, - cpu_owned: None, - - #[cfg(feature = "cuda")] - cuda_owned: None, +#[cfg(feature = "cuda")] +impl<'a> HerculesCUDARef<'a> { + pub unsafe fn __ptr(&self) -> *mut u8 { + self.ptr.as_ptr() + } - size: 0, - id: 0, - } + pub unsafe fn __size(&self) -> usize { + self.size } - pub unsafe fn __cpu_ptr(&mut self) -> *mut u8 { - if let Some(ptr) = self.get_cpu_ptr() { - return ptr.as_ptr(); - } - #[cfg(feature = "cuda")] - { - let cuda_ptr = self.get_cuda_ptr().unwrap(); - let cpu_ptr = self.allocate_cpu(); - copy_cuda_to_cpu(cpu_ptr.as_ptr(), cuda_ptr.as_ptr(), self.size); - return cpu_ptr.as_ptr(); + pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self { + Self { + ptr: NonNull::new(ptr).unwrap(), + size, + _phantom: PhantomData, } - panic!() } +} - pub unsafe fn __cpu_ptr_mut(&mut self) -> *mut u8 { - let cpu_ptr = self.__cpu_ptr(); - if Some(cpu_ptr) == self.cpu_shared.as_ref().map(|obj| obj.ptr.as_ptr()) { - self.allocate_cpu(); - copy_nonoverlapping( - cpu_ptr, - self.cpu_owned.as_ref().unwrap().ptr.as_ptr(), - self.size, - ); +#[cfg(feature = "cuda")] +impl<'a> HerculesCUDARefMut<'a> { + pub fn as_ref(self) -> HerculesCUDARef<'a> { + HerculesCUDARef { + ptr: self.ptr, + size: self.size, + _phantom: PhantomData, } - self.cpu_shared = None; - #[cfg(feature = "cuda")] - self.deallocate_cuda(); - cpu_ptr } - #[cfg(feature = "cuda")] - pub unsafe fn __cuda_ptr(&mut self) -> *mut u8 { - if let Some(ptr) = self.get_cuda_ptr() { - ptr.as_ptr() - } else { - let cpu_ptr = self.get_cpu_ptr().unwrap(); - let cuda_ptr = self.allocate_cuda(); - copy_cpu_to_cuda(cuda_ptr.as_ptr(), cpu_ptr.as_ptr(), self.size); - cuda_ptr.as_ptr() - } + pub unsafe fn __ptr(&self) -> *mut u8 { + self.ptr.as_ptr() } - #[cfg(feature = "cuda")] - pub unsafe fn __cuda_ptr_mut(&mut self) -> *mut u8 { - let cuda_ptr = self.__cuda_ptr(); - self.cpu_shared = None; - self.cpu_exclusive = None; - self.deallocate_cpu(); - cuda_ptr + pub unsafe fn __size(&self) -> usize { + self.size } - pub unsafe fn __clone(&self) -> Self { + pub unsafe fn __from_parts(ptr: *mut u8, size: usize) -> Self { Self { - cpu_shared: self.cpu_shared.clone(), - cpu_exclusive: self.cpu_exclusive.clone(), - cpu_owned: self.cpu_owned.clone(), - #[cfg(feature = "cuda")] - cuda_owned: self.cuda_owned.clone(), - size: self.size, - id: self.id, + ptr: NonNull::new(ptr).unwrap(), + size, + _phantom: PhantomData, } } +} - pub unsafe fn __forget(&mut self) { - self.cpu_owned = None; - #[cfg(feature = "cuda")] - { - self.cuda_owned = None; +#[cfg(feature = "cuda")] +impl CUDABox { + pub fn from_cpu_ref(cpu_ref: HerculesCPURef) -> Self { + unsafe { + let size = cpu_ref.size; + let ptr = NonNull::new(__cuda_alloc(size)).unwrap(); + __copy_cpu_to_cuda(ptr.as_ptr(), cpu_ref.ptr.as_ptr(), size); + Self { ptr, size } } } - pub unsafe fn __offset(&mut self, offset: u64, size: u64) { - if let Some(obj) = self.cpu_shared.as_mut() { - obj.offset += offset as usize; - } - if let Some(obj) = self.cpu_exclusive.as_mut() { - obj.offset += offset as usize; + pub fn from_cuda_ref(cuda_ref: HerculesCUDARef) -> Self { + unsafe { + let size = cuda_ref.size; + let ptr = NonNull::new(__cuda_alloc(size)).unwrap(); + __copy_cuda_to_cuda(ptr.as_ptr(), cuda_ref.ptr.as_ptr(), size); + Self { ptr, size } } - if let Some(obj) = self.cpu_owned.as_mut() { - obj.offset += offset as usize; + } + + pub fn get_ref<'a>(&'a self) -> HerculesCUDARef<'a> { + HerculesCUDARef { + ptr: self.ptr, + size: self.size, + _phantom: PhantomData, } - #[cfg(feature = "cuda")] - if let Some(obj) = self.cuda_owned.as_mut() { - obj.offset += offset as usize; + } + + pub fn get_ref_mut<'a>(&'a mut self) -> HerculesCUDARefMut<'a> { + HerculesCUDARefMut { + ptr: self.ptr, + size: self.size, + _phantom: PhantomData, } - self.size = size as usize; } +} - pub unsafe fn __cmp_ids(&self, other: &HerculesBox<'_>) -> bool { - self.id == other.id +#[cfg(feature = "cuda")] +impl Clone for CUDABox { + fn clone(&self) -> Self { + Self::from_cuda_ref(self.get_ref()) } } -impl<'a> Drop for HerculesBox<'a> { +#[cfg(feature = "cuda")] +impl Drop for CUDABox { fn drop(&mut self) { unsafe { - self.deallocate_cpu(); - #[cfg(feature = "cuda")] - self.deallocate_cuda(); + __cuda_dealloc(self.ptr.as_ptr(), self.size); } } } + +#[macro_export] +macro_rules! runner { + ($x: ident) => { + <concat_idents!(HerculesRunner_, $x)>::new() + }; +} diff --git a/hercules_rt/src/rtdefs.cu b/hercules_rt/src/rtdefs.cu index b7378d81..ab67ec98 100644 --- a/hercules_rt/src/rtdefs.cu +++ b/hercules_rt/src/rtdefs.cu @@ -1,5 +1,5 @@ extern "C" { - void *cuda_alloc(size_t size) { + void *__cuda_alloc(size_t size) { void *ptr = NULL; cudaError_t res = cudaMalloc(&ptr, size); if (res != cudaSuccess) { @@ -8,31 +8,20 @@ extern "C" { return ptr; } - void *cuda_alloc_zeroed(size_t size) { - void *ptr = cuda_alloc(size); - if (!ptr) { - return NULL; - } - cudaError_t res = cudaMemset(ptr, 0, size); - if (res != cudaSuccess) { - return NULL; - } - return ptr; - } - - void cuda_dealloc(void *ptr) { + void __cuda_dealloc(void *ptr, size_t size) { + (void) size; cudaFree(ptr); } - void copy_cpu_to_cuda(void *dst, void *src, size_t size) { + void __copy_cpu_to_cuda(void *dst, void *src, size_t size) { cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); } - void copy_cuda_to_cpu(void *dst, void *src, size_t size) { + void __copy_cuda_to_cpu(void *dst, void *src, size_t size) { cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); } - void copy_cuda_to_cuda(void *dst, void *src, size_t size) { + void __copy_cuda_to_cuda(void *dst, void *src, size_t size) { cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); } } diff --git a/hercules_samples/call/src/main.rs b/hercules_samples/call/src/main.rs index 0b657dd8..ff4b6f4a 100644 --- a/hercules_samples/call/src/main.rs +++ b/hercules_samples/call/src/main.rs @@ -1,11 +1,15 @@ -#![feature(box_as_ptr, let_chains)] +#![feature(concat_idents)] + +use hercules_rt::runner; juno_build::juno!("call"); fn main() { async_std::task::block_on(async { - let x = myfunc(7).await; - let y = add(10, 2, 18).await; + let mut r = runner!(myfunc); + let x = r.run(7).await; + let mut r = runner!(add); + let y = r.run(10, 2, 18).await; assert_eq!(x, y); }); } diff --git a/hercules_samples/ccp/src/main.rs b/hercules_samples/ccp/src/main.rs index 7f6459a0..ecf37973 100644 --- a/hercules_samples/ccp/src/main.rs +++ b/hercules_samples/ccp/src/main.rs @@ -1,10 +1,13 @@ -#![feature(box_as_ptr, let_chains)] +#![feature(concat_idents)] + +use hercules_rt::runner; juno_build::juno!("ccp"); fn main() { async_std::task::block_on(async { - let x = tricky(7).await; + let mut r = runner!(tricky); + let x = r.run(7).await; assert_eq!(x, 1); }); } diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs index 0b5c6a93..335e8909 100644 --- a/hercules_samples/dot/src/main.rs +++ b/hercules_samples/dot/src/main.rs @@ -1,6 +1,6 @@ -#![feature(box_as_ptr, let_chains)] +#![feature(concat_idents)] -use hercules_rt::HerculesBox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("dot"); @@ -8,9 +8,10 @@ fn main() { async_std::task::block_on(async { let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; - let a = HerculesBox::from_slice(&a); - let b = HerculesBox::from_slice(&b); - let c = dot(8, a, b).await; + let a = HerculesCPURef::from_slice(&a); + let b = HerculesCPURef::from_slice(&b); + let mut r = runner!(dot); + let c = r.run(8, a, b).await; println!("{}", c); assert_eq!(c, 70.0); }); diff --git a/hercules_samples/fac/src/main.rs b/hercules_samples/fac/src/main.rs index b6e0257b..40180d44 100644 --- a/hercules_samples/fac/src/main.rs +++ b/hercules_samples/fac/src/main.rs @@ -1,8 +1,13 @@ +#![feature(concat_idents)] + +use hercules_rt::runner; + juno_build::juno!("fac"); fn main() { async_std::task::block_on(async { - let f = fac(8).await; + let mut r = runner!(fac); + let f = r.run(8).await; println!("{}", f); assert_eq!(f, 40320); }); diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch index 42dda6e3..f7891b9b 100644 --- a/hercules_samples/matmul/src/cpu.sch +++ b/hercules_samples/matmul/src/cpu.sch @@ -10,5 +10,8 @@ fork-split(*); unforkify(*); dce(*); float-collections(*); +gvn(*); +phi-elim(*); +dce(*); gcm(*); diff --git a/hercules_samples/matmul/src/gpu.sch b/hercules_samples/matmul/src/gpu.sch index 9067a190..99ac21a6 100644 --- a/hercules_samples/matmul/src/gpu.sch +++ b/hercules_samples/matmul/src/gpu.sch @@ -10,6 +10,9 @@ ip-sroa(*); sroa(*); dce(*); float-collections(*); +gvn(*); +phi-elim(*); +dce(*); gcm(*); xdot[true](*); diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 767fda07..8757a0fd 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -1,8 +1,8 @@ -#![feature(box_as_ptr, let_chains)] +#![feature(concat_idents)] use rand::random; -use hercules_rt::HerculesBox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("matmul"); @@ -21,9 +21,10 @@ fn main() { } } } - let a = HerculesBox::from_slice_mut(&mut a); - let b = HerculesBox::from_slice_mut(&mut b); - let mut c = matmul(I as u64, J as u64, K as u64, a, b).await; + let a = HerculesCPURef::from_slice(&mut a); + let b = HerculesCPURef::from_slice(&mut b); + let mut r = runner!(matmul); + let c = r.run(I as u64, J as u64, K as u64, a, b).await; assert_eq!(c.as_slice::<i32>(), &*correct_c); }); } diff --git a/juno_samples/antideps/src/main.rs b/juno_samples/antideps/src/main.rs index 2f1e8efc..9c37bd01 100644 --- a/juno_samples/antideps/src/main.rs +++ b/juno_samples/antideps/src/main.rs @@ -1,34 +1,43 @@ -#![feature(future_join, box_as_ptr)] +#![feature(concat_idents)] + +use hercules_rt::runner; juno_build::juno!("antideps"); fn main() { async_std::task::block_on(async { - let output = simple_antideps(1, 1).await; + let mut r = runner!(simple_antideps); + let output = r.run(1, 1).await; println!("{}", output); assert_eq!(output, 5); - let output = loop_antideps(11).await; + let mut r = runner!(loop_antideps); + let output = r.run(11).await; println!("{}", output); assert_eq!(output, 5); - let output = complex_antideps1(9).await; + let mut r = runner!(complex_antideps1); + let output = r.run(9).await; println!("{}", output); assert_eq!(output, 20); - let output = complex_antideps2(44).await; + let mut r = runner!(complex_antideps2); + let output = r.run(44).await; println!("{}", output); assert_eq!(output, 226); - let output = very_complex_antideps(3).await; + let mut r = runner!(very_complex_antideps); + let output = r.run(3).await; println!("{}", output); assert_eq!(output, 144); - let output = read_chains(2).await; + let mut r = runner!(read_chains); + let output = r.run(2).await; println!("{}", output); assert_eq!(output, 14); - let output = array_of_structs(2).await; + let mut r = runner!(array_of_structs); + let output = r.run(2).await; println!("{}", output); assert_eq!(output, 14); }); diff --git a/juno_samples/casts_and_intrinsics/src/main.rs b/juno_samples/casts_and_intrinsics/src/main.rs index 8ee509bf..6b27c60c 100644 --- a/juno_samples/casts_and_intrinsics/src/main.rs +++ b/juno_samples/casts_and_intrinsics/src/main.rs @@ -1,10 +1,13 @@ -#![feature(future_join)] +#![feature(concat_idents)] + +use hercules_rt::runner; juno_build::juno!("casts_and_intrinsics"); fn main() { async_std::task::block_on(async { - let output = casts_and_intrinsics(16.0).await; + let mut r = runner!(casts_and_intrinsics); + let output = r.run(16.0).await; println!("{}", output); assert_eq!(output, 4); }); diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index 9c2f99a8..73a75a94 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -1,4 +1,4 @@ -#![feature(future_join, box_as_ptr, let_chains)] +#![feature(concat_idents)] mod camera_model; mod cava_rust; @@ -8,7 +8,7 @@ use self::camera_model::*; use self::cava_rust::CHAN; use self::image_proc::*; -use hercules_rt::HerculesBox; +use hercules_rt::{runner, HerculesCPURef}; use image::ImageError; @@ -28,25 +28,26 @@ fn run_cava( tonemap: &[f32], ) -> Box<[u8]> { assert_eq!(image.len(), CHAN * rows * cols); - let image = HerculesBox::from_slice(image); + let image = HerculesCPURef::from_slice(image); assert_eq!(tstw.len(), CHAN * CHAN); - let tstw = HerculesBox::from_slice(tstw); + let tstw = HerculesCPURef::from_slice(tstw); assert_eq!(ctrl_pts.len(), num_ctrl_pts * CHAN); - let ctrl_pts = HerculesBox::from_slice(ctrl_pts); + let ctrl_pts = HerculesCPURef::from_slice(ctrl_pts); assert_eq!(weights.len(), num_ctrl_pts * CHAN); - let weights = HerculesBox::from_slice(weights); + let weights = HerculesCPURef::from_slice(weights); assert_eq!(coefs.len(), 4 * CHAN); - let coefs = HerculesBox::from_slice(coefs); + let coefs = HerculesCPURef::from_slice(coefs); assert_eq!(tonemap.len(), 256 * CHAN); - let tonemap = HerculesBox::from_slice(tonemap); + let tonemap = HerculesCPURef::from_slice(tonemap); + let mut r = runner!(cava); async_std::task::block_on(async { - cava( + r.run( rows as u64, cols as u64, num_ctrl_pts as u64, @@ -58,7 +59,7 @@ fn run_cava( tonemap, ) .await - }).as_slice::<u8>().into() + }).as_slice::<u8>().to_vec().into_boxed_slice() } enum Error { diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs index 17a0ab96..db3f37fd 100644 --- a/juno_samples/concat/src/main.rs +++ b/juno_samples/concat/src/main.rs @@ -1,10 +1,13 @@ -#![feature(future_join, box_as_ptr)] +#![feature(concat_idents)] + +use hercules_rt::runner; juno_build::juno!("concat"); fn main() { async_std::task::block_on(async { - let output = concat_entry(7).await; + let mut r = runner!(concat_entry); + let output = r.run(7).await; println!("{}", output); assert_eq!(output, 42); }); diff --git a/juno_samples/implicit_clone/src/main.rs b/juno_samples/implicit_clone/src/main.rs index bc687ed3..1e94ff89 100644 --- a/juno_samples/implicit_clone/src/main.rs +++ b/juno_samples/implicit_clone/src/main.rs @@ -1,38 +1,48 @@ -#![feature(future_join, box_as_ptr)] +#![feature(concat_idents)] + +use hercules_rt::runner; juno_build::juno!("implicit_clone"); fn main() { async_std::task::block_on(async { - let output = simple_implicit_clone(3).await; + let mut r = runner!(simple_implicit_clone); + let output = r.run(3).await; println!("{}", output); assert_eq!(output, 11); - let output = loop_implicit_clone(100).await; + let mut r = runner!(loop_implicit_clone); + let output = r.run(100).await; println!("{}", output); assert_eq!(output, 7); - let output = double_loop_implicit_clone(3).await; + let mut r = runner!(double_loop_implicit_clone); + let output = r.run(3).await; println!("{}", output); assert_eq!(output, 42); - let output = tricky_loop_implicit_clone(2, 2).await; + let mut r = runner!(tricky_loop_implicit_clone); + let output = r.run(2, 2).await; println!("{}", output); assert_eq!(output, 130); - let output = tricky2_loop_implicit_clone(2, 3).await; + let mut r = runner!(tricky2_loop_implicit_clone); + let output = r.run(2, 3).await; println!("{}", output); assert_eq!(output, 39); - let output = tricky3_loop_implicit_clone(5, 7).await; + let mut r = runner!(tricky3_loop_implicit_clone); + let output = r.run(5, 7).await; println!("{}", output); assert_eq!(output, 7); - let output = no_implicit_clone(4).await; + let mut r = runner!(no_implicit_clone); + let output = r.run(4).await; println!("{}", output); assert_eq!(output, 13); - let output = mirage_implicit_clone(73).await; + let mut r = runner!(mirage_implicit_clone); + let output = r.run(73).await; println!("{}", output); assert_eq!(output, 843); }); diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index bace3765..fa5d1f04 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -1,8 +1,8 @@ -#![feature(box_as_ptr, let_chains)] +#![feature(concat_idents)] use rand::random; -use hercules_rt::HerculesBox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("matmul"); @@ -21,17 +21,13 @@ fn main() { } } } - let mut c = { - let a = HerculesBox::from_slice(&a); - let b = HerculesBox::from_slice(&b); - matmul(I as u64, J as u64, K as u64, a, b).await - }; + let a = HerculesCPURef::from_slice(&a); + let b = HerculesCPURef::from_slice(&b); + let mut r = runner!(matmul); + let c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; assert_eq!(c.as_slice::<i32>(), &*correct_c); - let mut tiled_c = { - let a = HerculesBox::from_slice(&a); - let b = HerculesBox::from_slice(&b); - tiled_64_matmul(I as u64, J as u64, K as u64, a, b).await - }; + let mut r = runner!(tiled_64_matmul); + let tiled_c = r.run(I as u64, J as u64, K as u64, a.clone(), b.clone()).await; assert_eq!(tiled_c.as_slice::<i32>(), &*correct_c); }); } diff --git a/juno_samples/nested_ccp/src/main.rs b/juno_samples/nested_ccp/src/main.rs index f49171ce..423b66fb 100644 --- a/juno_samples/nested_ccp/src/main.rs +++ b/juno_samples/nested_ccp/src/main.rs @@ -1,18 +1,21 @@ -#![feature(box_as_ptr, let_chains)] +#![feature(concat_idents)] -use hercules_rt::HerculesBox; +use hercules_rt::{runner, HerculesCPURef, HerculesCPURefMut}; juno_build::juno!("nested_ccp"); fn main() { async_std::task::block_on(async { - let mut a: Box<[f32]> = Box::new([17.0, 18.0, 19.0]); + let a: Box<[f32]> = Box::new([17.0, 18.0, 19.0]); let mut b: Box<[i32]> = Box::new([12, 16, 4, 18, 23, 56, 93, 22, 14]); - let a = HerculesBox::from_slice_mut(&mut a); - let b = HerculesBox::from_slice_mut(&mut b); - let output_example = ccp_example(a).await; - let output_median = median_array(9, b).await; - let out_no_underflow = no_underflow().await; + let a = HerculesCPURef::from_slice(&a); + let b = HerculesCPURefMut::from_slice(&mut b); + let mut r = runner!(ccp_example); + let output_example = r.run(a).await; + let mut r = runner!(median_array); + let output_median = r.run(9, b).await; + let mut r = runner!(no_underflow); + let out_no_underflow = r.run().await; println!("{}", output_example); println!("{}", output_median); println!("{}", out_no_underflow); diff --git a/juno_samples/schedule_test/src/main.rs b/juno_samples/schedule_test/src/main.rs index a64cd16f..2e63babf 100644 --- a/juno_samples/schedule_test/src/main.rs +++ b/juno_samples/schedule_test/src/main.rs @@ -1,8 +1,8 @@ -#![feature(box_as_ptr, let_chains)] +#![feature(concat_idents)] use rand::random; -use hercules_rt::HerculesBox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("code"); @@ -26,12 +26,11 @@ fn main() { } } - let mut res = { - let a = HerculesBox::from_slice(&a); - let b = HerculesBox::from_slice(&b); - let c = HerculesBox::from_slice(&c); - test(N as u64, M as u64, K as u64, a, b, c).await - }; + let a = HerculesCPURef::from_slice(&a); + let b = HerculesCPURef::from_slice(&b); + let c = HerculesCPURef::from_slice(&c); + let mut r = runner!(test); + let res = r.run(N as u64, M as u64, K as u64, a, b, c).await; assert_eq!(res.as_slice::<i32>(), &*correct_res); }); } diff --git a/juno_samples/simple3/src/main.rs b/juno_samples/simple3/src/main.rs index 1f6e213c..4f9fe6a7 100644 --- a/juno_samples/simple3/src/main.rs +++ b/juno_samples/simple3/src/main.rs @@ -1,16 +1,17 @@ -#![feature(box_as_ptr, let_chains)] +#![feature(concat_idents)] -use hercules_rt::HerculesBox; +use hercules_rt::{runner, HerculesCPURef}; juno_build::juno!("simple3"); fn main() { async_std::task::block_on(async { - let mut a: Box<[u32]> = Box::new([1, 2, 3, 4, 5, 6, 7, 8]); - let mut b: Box<[u32]> = Box::new([8, 7, 6, 5, 4, 3, 2, 1]); - let a = HerculesBox::from_slice_mut(&mut a); - let b = HerculesBox::from_slice_mut(&mut b); - let c = simple3(8, a, b).await; + let a: Box<[u32]> = Box::new([1, 2, 3, 4, 5, 6, 7, 8]); + let b: Box<[u32]> = Box::new([8, 7, 6, 5, 4, 3, 2, 1]); + let a = HerculesCPURef::from_slice(&a); + let b = HerculesCPURef::from_slice(&b); + let mut r = runner!(simple3); + let c = r.run(8, a, b).await; println!("{}", c); assert_eq!(c, 120); }); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 452c1995..aa540064 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2,13 +2,7 @@ use crate::ir::*; use crate::labels::*; use hercules_cg::*; use hercules_ir::*; -use hercules_opt::FunctionEditor; -use hercules_opt::{ - ccp, collapse_returns, crc, dce, dumb_outline, ensure_between_control_flow, float_collections, - fork_split, gcm, gvn, infer_parallel_fork, infer_parallel_reduce, infer_tight_associative, - infer_vectorizable, inline, interprocedural_sroa, lift_dc_math, outline, phi_elim, predication, - slf, sroa, unforkify, write_predication, -}; +use hercules_opt::*; use tempfile::TempDir; @@ -16,8 +10,7 @@ use juno_utils::env::Env; use juno_utils::stringtab::StringTable; use std::cell::RefCell; -use std::collections::{BTreeSet, HashMap, HashSet}; -use std::env::temp_dir; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt; use std::fs::File; use std::io::Write; @@ -186,9 +179,13 @@ struct PassManager { pub loops: Option<Vec<LoopTree>>, pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, - pub bbs: Option<Vec<BasicBlocks>>, pub collection_objects: Option<CollectionObjects>, pub callgraph: Option<CallGraph>, + pub devices: Option<Vec<Device>>, + pub object_device_demands: Option<ObjectDeviceDemands>, + pub bbs: Option<Vec<BasicBlocks>>, + pub node_colors: Option<NodeColors>, + pub backing_allocations: Option<BackingAllocations>, } impl PassManager { @@ -217,9 +214,13 @@ impl PassManager { loops: None, reduce_cycles: None, data_nodes_in_fork_joins: None, - bbs: None, collection_objects: None, callgraph: None, + devices: None, + object_device_demands: None, + bbs: None, + node_colors: None, + backing_allocations: None, } } @@ -406,6 +407,35 @@ impl PassManager { } } + pub fn make_devices(&mut self) { + if self.devices.is_none() { + self.make_callgraph(); + let callgraph = self.callgraph.as_ref().unwrap(); + self.devices = Some(device_placement(&self.functions, callgraph)); + } + } + + pub fn make_object_device_demands(&mut self) { + if self.object_device_demands.is_none() { + self.make_typing(); + self.make_callgraph(); + self.make_collection_objects(); + self.make_devices(); + let typing = self.typing.as_ref().unwrap(); + let callgraph = self.callgraph.as_ref().unwrap(); + let collection_objects = self.collection_objects.as_ref().unwrap(); + let devices = self.devices.as_ref().unwrap(); + self.object_device_demands = Some(object_device_demands( + &self.functions, + &self.types.borrow(), + typing, + callgraph, + collection_objects, + devices, + )); + } + } + pub fn delete_gravestones(&mut self) { for func in self.functions.iter_mut() { func.delete_gravestones(); @@ -424,9 +454,13 @@ impl PassManager { self.loops = None; self.reduce_cycles = None; self.data_nodes_in_fork_joins = None; - self.bbs = None; self.collection_objects = None; self.callgraph = None; + self.devices = None; + self.object_device_demands = None; + self.bbs = None; + self.node_colors = None; + self.backing_allocations = None; } fn with_mod<B, F>(&mut self, mut f: F) -> B @@ -464,6 +498,7 @@ impl PassManager { self.make_control_subgraphs(); self.make_collection_objects(); self.make_callgraph(); + self.make_devices(); let PassManager { functions, @@ -473,15 +508,18 @@ impl PassManager { labels, typing: Some(typing), control_subgraphs: Some(control_subgraphs), - bbs: Some(bbs), collection_objects: Some(collection_objects), callgraph: Some(callgraph), + devices: Some(devices), + bbs: Some(bbs), + node_colors: Some(node_colors), + backing_allocations: Some(backing_allocations), .. } = self else { return Err(SchedulerError::PassError { pass: "codegen".to_string(), - error: "Missing basic blocks".to_string(), + error: "Missing basic blocks or backing allocations".to_string(), }); }; @@ -493,8 +531,6 @@ impl PassManager { labels: labels.into_inner(), }; - let devices = device_placement(&module.functions, &callgraph); - let mut rust_rt = String::new(); let mut llvm_ir = String::new(); for idx in 0..module.functions.len() { @@ -518,10 +554,12 @@ impl PassManager { &module, &typing[idx], &control_subgraphs[idx], - &bbs[idx], &collection_objects, &callgraph, &devices, + &bbs[idx], + &node_colors[idx], + &backing_allocations[&FunctionID::new(idx)], &mut rust_rt, ) .map_err(|e| SchedulerError::PassError { @@ -1178,10 +1216,10 @@ fn run_pass( pm.make_typing(); pm.make_callgraph(); + pm.make_devices(); let typing = pm.typing.take().unwrap(); let callgraph = pm.callgraph.take().unwrap(); - - let devices = device_placement(&pm.functions, &callgraph); + let devices = pm.devices.take().unwrap(); let mut editors = build_editors(pm); float_collections(&mut editors, &typing, &callgraph, &devices); @@ -1228,6 +1266,13 @@ fn run_pass( }); } + // Iterate functions in reverse topological order, since inter- + // device copies introduced in a callee may affect demands in a + // caller, and the object allocation of a callee affects the object + // allocation of its callers. + pm.make_callgraph(); + let callgraph = pm.callgraph.take().unwrap(); + let topo = callgraph.topo(); loop { pm.make_def_uses(); pm.make_reverse_postorders(); @@ -1237,6 +1282,8 @@ fn run_pass( pm.make_fork_join_maps(); pm.make_loops(); pm.make_collection_objects(); + pm.make_devices(); + pm.make_object_device_demands(); let def_uses = pm.def_uses.take().unwrap(); let reverse_postorders = pm.reverse_postorders.take().unwrap(); @@ -1246,47 +1293,47 @@ fn run_pass( let loops = pm.loops.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); let collection_objects = pm.collection_objects.take().unwrap(); - - let mut bbs = vec![]; - - for ( - ( - ( - ((((mut func, def_use), reverse_postorder), typing), control_subgraph), - doms, - ), - fork_join_map, - ), - loops, - ) in build_editors(pm) - .into_iter() - .zip(def_uses.iter()) - .zip(reverse_postorders.iter()) - .zip(typing.iter()) - .zip(control_subgraphs.iter()) - .zip(doms.iter()) - .zip(fork_join_maps.iter()) - .zip(loops.iter()) - { - if let Some(bb) = gcm( - &mut func, - def_use, - reverse_postorder, - typing, - control_subgraph, - doms, - fork_join_map, - loops, + let devices = pm.devices.take().unwrap(); + let object_device_demands = pm.object_device_demands.take().unwrap(); + + let mut bbs = vec![(vec![], vec![]); topo.len()]; + let mut node_colors = vec![BTreeMap::new(); topo.len()]; + let mut backing_allocations = BTreeMap::new(); + let mut editors = build_editors(pm); + let mut any_failed = false; + for id in topo.iter() { + let editor = &mut editors[id.idx()]; + if let Some((bb, function_node_colors, backing_allocation)) = gcm( + editor, + &def_uses[id.idx()], + &reverse_postorders[id.idx()], + &typing[id.idx()], + &control_subgraphs[id.idx()], + &doms[id.idx()], + &fork_join_maps[id.idx()], + &loops[id.idx()], &collection_objects, + &devices, + &object_device_demands[id.idx()], + &backing_allocations, ) { - bbs.push(bb); + bbs[id.idx()] = bb; + node_colors[id.idx()] = function_node_colors; + backing_allocations.insert(*id, backing_allocation); + } else { + any_failed = true; + } + changed |= editor.modified(); + if any_failed { + break; } - changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); - if bbs.len() == pm.functions.len() { + if !any_failed { pm.bbs = Some(bbs); + pm.node_colors = Some(node_colors); + pm.backing_allocations = Some(backing_allocations); break; } } @@ -1584,11 +1631,13 @@ fn run_pass( if force_analyses { pm.make_doms(); pm.make_fork_join_maps(); + pm.make_devices(); } let reverse_postorders = pm.reverse_postorders.take().unwrap(); let doms = pm.doms.take(); let fork_join_maps = pm.fork_join_maps.take(); + let devices = pm.devices.take(); let bbs = pm.bbs.take(); pm.with_mod(|module| { xdot_module( @@ -1596,6 +1645,7 @@ fn run_pass( &reverse_postorders, doms.as_ref(), fork_join_maps.as_ref(), + devices.as_ref(), bbs.as_ref(), ) }); -- GitLab