diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 5374385288c4aa4940eafd293c58b0beeabbc5e3..ea326f8a0310fa082c240b5f52000f9c79e0be57 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -22,6 +22,7 @@ pub fn cpu_codegen<W: Write>( typing: &Vec<TypeID>, control_subgraph: &Subgraph, bbs: &BasicBlocks, + backing_allocation: &FunctionBackingAllocation, w: &mut W, ) -> Result<(), Error> { let ctx = CPUContext { @@ -32,6 +33,7 @@ pub fn cpu_codegen<W: Write>( typing, control_subgraph, bbs, + backing_allocation, }; ctx.codegen_function(w) } @@ -44,6 +46,7 @@ struct CPUContext<'a> { typing: &'a Vec<TypeID>, control_subgraph: &'a Subgraph, bbs: &'a BasicBlocks, + backing_allocation: &'a FunctionBackingAllocation, } #[derive(Default, Debug)] @@ -72,7 +75,13 @@ impl<'a> CPUContext<'a> { )?; } let mut first_param = true; - // The first set of parameters are dynamic constants. + // The first parameter is a pointer to CPU backing memory, if it's + // needed. + if self.backing_allocation.contains_key(&Device::LLVM) { + first_param = false; + write!(w, "ptr %backing")?; + } + // The second set of parameters are dynamic constants. for idx in 0..self.function.num_dynamic_constants { if first_param { first_param = false; @@ -81,7 +90,7 @@ impl<'a> CPUContext<'a> { } write!(w, "i64 %dc_p{}", idx)?; } - // The second set of parameters are normal parameters. + // The third set of parameters are normal parameters. for (idx, ty) in self.function.param_types.iter().enumerate() { if first_param { first_param = false; @@ -242,32 +251,50 @@ impl<'a> CPUContext<'a> { } Node::Constant { id: cons_id } => { let body = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().body; - write!(body, " {} = bitcast ", self.get_value(id, false))?; - match self.constants[cons_id.idx()] { - Constant::Boolean(val) => write!(body, "i1 {} to i1\n", val)?, - Constant::Integer8(val) => write!(body, "i8 {} to i8\n", val)?, - Constant::Integer16(val) => write!(body, "i16 {} to i16\n", val)?, - Constant::Integer32(val) => write!(body, "i32 {} to i32\n", val)?, - Constant::Integer64(val) => write!(body, "i64 {} to i64\n", val)?, - Constant::UnsignedInteger8(val) => write!(body, "i8 {} to i8\n", val)?, - Constant::UnsignedInteger16(val) => write!(body, "i16 {} to i16\n", val)?, - Constant::UnsignedInteger32(val) => write!(body, "i32 {} to i32\n", val)?, - Constant::UnsignedInteger64(val) => write!(body, "i64 {} to i64\n", val)?, - Constant::Float32(val) => { - if val.fract() == 0.0 { - write!(body, "float {}.0 to float\n", val)? - } else { - write!(body, "float {} to float\n", val)? + if self.constants[cons_id.idx()].is_scalar() { + write!(body, " {} = bitcast ", self.get_value(id, false))?; + match self.constants[cons_id.idx()] { + Constant::Boolean(val) => write!(body, "i1 {} to i1\n", val)?, + Constant::Integer8(val) => write!(body, "i8 {} to i8\n", val)?, + Constant::Integer16(val) => write!(body, "i16 {} to i16\n", val)?, + Constant::Integer32(val) => write!(body, "i32 {} to i32\n", val)?, + Constant::Integer64(val) => write!(body, "i64 {} to i64\n", val)?, + Constant::UnsignedInteger8(val) => write!(body, "i8 {} to i8\n", val)?, + Constant::UnsignedInteger16(val) => write!(body, "i16 {} to i16\n", val)?, + Constant::UnsignedInteger32(val) => write!(body, "i32 {} to i32\n", val)?, + Constant::UnsignedInteger64(val) => write!(body, "i64 {} to i64\n", val)?, + Constant::Float32(val) => { + if val.fract() == 0.0 { + write!(body, "float {}.0 to float\n", val)? + } else { + write!(body, "float {} to float\n", val)? + } } - } - Constant::Float64(val) => { - if val.fract() == 0.0 { - write!(body, "double {}.0 to double", val)? - } else { - write!(body, "double {} to double", val)? + Constant::Float64(val) => { + if val.fract() == 0.0 { + write!(body, "double {}.0 to double\n", val)? + } else { + write!(body, "double {} to double\n", val)? + } } + _ => unreachable!(), } - _ => panic!("PANIC: Can't dynamically allocate memory for an aggregate type within a CPU function ({:?} in {}).", id, self.function.name), + } else { + let (_, offsets) = &self.backing_allocation[&Device::LLVM]; + let offset = offsets[&id]; + write!( + body, + " {} = getelementptr i8, ptr %backing, i64 %dc{}\n", + self.get_value(id, false), + offset.idx() + )?; + let data_size = self.codegen_type_size(self.typing[id.idx()], body)?; + write!( + body, + " call void @llvm.memset.p0.i64({}, i8 0, i64 {}, i1 false)\n", + self.get_value(id, true), + data_size, + )?; } } Node::DynamicConstant { id: dc_id } => { diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 1bfc01fef2885dbb2b6b6c5fdbda97f594e32010..f97180ea24d3bf1b69810d6e79cf68fcb292e9fb 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -83,7 +83,7 @@ pub fn rt_codegen<W: Write>( devices: &Vec<Device>, bbs: &BasicBlocks, node_colors: &FunctionNodeColors, - backing_allocation: &FunctionBackingAllocation, + backing_allocations: &BackingAllocations, w: &mut W, ) -> Result<(), Error> { let ctx = RTContext { @@ -96,7 +96,7 @@ pub fn rt_codegen<W: Write>( devices, bbs, node_colors, - backing_allocation, + backing_allocations, }; ctx.codegen_function(w) } @@ -111,7 +111,7 @@ struct RTContext<'a> { devices: &'a Vec<Device>, bbs: &'a BasicBlocks, node_colors: &'a FunctionNodeColors, - backing_allocation: &'a FunctionBackingAllocation, + backing_allocations: &'a BackingAllocations, } impl<'a> RTContext<'a> { @@ -131,7 +131,7 @@ impl<'a> RTContext<'a> { )?; let mut first_param = true; // The first set of parameters are pointers to backing memories. - for (device, _) in self.backing_allocation { + for (device, _) in self.backing_allocations[&self.func_id].iter() { if first_param { first_param = false; } else { @@ -161,13 +161,17 @@ impl<'a> RTContext<'a> { // Dump signatures for called device functions. write!(w, " extern \"C\" {{\n")?; - for callee in self.callgraph.get_callees(self.func_id) { - if self.devices[callee.idx()] == Device::AsyncRust { + for callee_id in self.callgraph.get_callees(self.func_id) { + if self.devices[callee_id.idx()] == Device::AsyncRust { continue; } - let callee = &self.module.functions[callee.idx()]; + let callee = &self.module.functions[callee_id.idx()]; write!(w, " fn {}(", callee.name)?; let mut first_param = true; + if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { + first_param = false; + write!(w, "backing: *mut u8")?; + } for idx in 0..callee.num_dynamic_constants { if first_param { first_param = false; @@ -340,9 +344,8 @@ impl<'a> RTContext<'a> { Constant::Product(ty, _) | Constant::Summation(ty, _, _) | Constant::Array(ty) => { - let (device, offset) = self - .backing_allocation - .into_iter() + let (device, offset) = self.backing_allocations[&self.func_id] + .iter() .filter_map(|(device, (_, offsets))| { offsets.get(&id).map(|id| (*device, *id)) }) @@ -380,14 +383,13 @@ impl<'a> RTContext<'a> { self.get_value(id), self.module.functions[callee_id.idx()].name )?; - for (device, offset) in self - .backing_allocation - .into_iter() + for (device, offset) in self.backing_allocations[&self.func_id] + .iter() .filter_map(|(device, (_, offsets))| offsets.get(&id).map(|id| (*device, *id))) { write!(block, "backing_{}.byte_add(", device.name())?; self.codegen_dynamic_constant(offset, block)?; - write!(block, ")")? + write!(block, "), ")? } for dc in dynamic_constants { self.codegen_dynamic_constant(*dc, block)?; @@ -690,7 +692,7 @@ impl<'a> RTContext<'a> { "#[allow(non_camel_case_types)]\nstruct HerculesRunner_{} {{\n", func.name )?; - for (device, _) in self.backing_allocation { + for (device, _) in self.backing_allocations[&self.func_id].iter() { write!(w, " backing_ptr_{}: *mut u8,\n", device.name(),)?; write!(w, " backing_size_{}: usize,\n", device.name(),)?; } @@ -700,7 +702,7 @@ impl<'a> RTContext<'a> { "impl HerculesRunner_{} {{\n fn new() -> Self {{\n Self {{\n", func.name )?; - for (device, _) in self.backing_allocation { + for (device, _) in self.backing_allocations[&self.func_id].iter() { write!( w, " backing_ptr_{}: ::core::ptr::null_mut(),\n backing_size_{}: 0,\n", @@ -749,7 +751,7 @@ impl<'a> RTContext<'a> { )?; } write!(w, " unsafe {{\n")?; - for (device, (total, _)) in self.backing_allocation { + for (device, (total, _)) in self.backing_allocations[&self.func_id].iter() { write!(w, " let size = ")?; self.codegen_dynamic_constant(*total, w)?; write!( @@ -772,7 +774,7 @@ impl<'a> RTContext<'a> { } } write!(w, " let ret = {}(", func.name)?; - for (device, _) in self.backing_allocation { + for (device, _) in self.backing_allocations[&self.func_id].iter() { write!(w, "self.backing_ptr_{}, ", device.name())?; } for idx in 0..func.num_dynamic_constants { @@ -805,7 +807,7 @@ impl<'a> RTContext<'a> { "}}\nimpl Drop for HerculesRunner_{} {{\n #[allow(unused_unsafe)]\n fn drop(&mut self) {{\n unsafe {{\n", func.name )?; - for (device, _) in self.backing_allocation { + for (device, _) in self.backing_allocations[&self.func_id].iter() { write!( w, " ::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});\n", diff --git a/hercules_ir/src/device.rs b/hercules_ir/src/device.rs index 46a1af0256286c02cec74eb09e282bdb3a845b4b..cbf8d6349857b6110cc4a9e03f574b358646b29a 100644 --- a/hercules_ir/src/device.rs +++ b/hercules_ir/src/device.rs @@ -45,9 +45,10 @@ pub fn object_device_demands( // function. This includes objects on the `data` input to write nodes. // Non-primitive reads don't demand an object on a device since they are // lowered to pointer math and no actual memory transfers. - // 2. The object is passed as input to a call node where the corresponding + // 2. The object is a constant / undef defined in a device function. + // 3. The object is passed as input to a call node where the corresponding // object in the callee is demanded on a device. - // 3. The object is returned from a call node where the corresponding object + // 4. The object is returned from a call node where the corresponding object // in the callee is demanded on a device. // Note that reads and writes in a RT function don't induce a device demand. // This is because RT functions can call device functions as necessary to @@ -66,8 +67,8 @@ pub fn object_device_demands( match device { Device::LLVM | Device::CUDA => { for (idx, node) in function.nodes.iter().enumerate() { - // Condition #1. match node { + // Condition #1. Node::Read { collect, indices: _, @@ -89,6 +90,12 @@ pub fn object_device_demands( demands[func_id.idx()][object.idx()].insert(device); } } + // Condition #2. + Node::Constant { id: _ } | Node::Undef { ty: _ } => { + for object in objects[&func_id].objects(NodeID::new(idx)) { + demands[func_id.idx()][object.idx()].insert(device); + } + } _ => {} } } @@ -102,7 +109,7 @@ pub fn object_device_demands( args, } = node { - // Condition #2. + // Condition #3. for (param_idx, arg) in args.into_iter().enumerate() { if let Some(callee_obj) = objects[callee].param_to_object(param_idx) { let callee_demands = @@ -115,7 +122,7 @@ pub fn object_device_demands( } } - // Condition #3. + // Condition #4. for callee_obj in objects[callee].returned_objects() { let callee_demands = take(&mut demands[callee.idx()][callee_obj.idx()]); for object in objects[&func_id].objects(NodeID::new(idx)) { diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 462d10871565bfed9b351d2627c44af8ca778ffc..6d36e8ac2cd6c8fdff46f4e46b25a95e5b15db51 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -421,6 +421,9 @@ fn basic_blocks( // If the next node further up the dominator tree is in a shallower // loop nest or if we can get out of a reduce loop when we don't // need to be in one, place this data node in a higher-up location. + // Only do this is the node isn't a constant or undef. + let is_constant_or_undef = + function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef(); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); @@ -441,7 +444,7 @@ fn basic_blocks( // loop use the reduce node forming the loop, so the dominator chain // will consist of one block, and this loop won't ever iterate. let currently_at_join = function.nodes[location.idx()].is_join(); - if shallower_nest || currently_at_join { + if !is_constant_or_undef && (shallower_nest || currently_at_join) { location = control_node; } } diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs index fd45a3713dfa48ac425a0c453707e133198e1d79..d1f139db58c944733d6cbbf43d0c52f651d72f23 100644 --- a/juno_scheduler/src/default.rs +++ b/juno_scheduler/src/default.rs @@ -81,8 +81,5 @@ pub fn default_schedule() -> ScheduleStmt { InferSchedules, DCE, GCM, - DCE, - FloatCollections, - GCM, ] } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 81c466568d80b152d39c8b1899580df272810f01..584867e5ebc4dcf1771295717315e45af52ff141 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -622,6 +622,7 @@ impl PassManager { &typing[idx], &control_subgraphs[idx], &bbs[idx], + &backing_allocations[&FunctionID::new(idx)], &mut llvm_ir, ) .map_err(|e| SchedulerError::PassError { @@ -638,7 +639,7 @@ impl PassManager { &devices, &bbs[idx], &node_colors[idx], - &backing_allocations[&FunctionID::new(idx)], + &backing_allocations, &mut rust_rt, ) .map_err(|e| SchedulerError::PassError {