From af3175773528d77fa06b280e027029943e2539da Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 25 Feb 2025 12:34:37 -0600 Subject: [PATCH] Use different allocations in parallel calls --- hercules_cg/src/cpu.rs | 2 +- hercules_cg/src/gpu.rs | 2 +- hercules_cg/src/lib.rs | 11 +++++++--- hercules_cg/src/rt.rs | 29 +++++++++++++++++++------ hercules_opt/src/gcm.rs | 6 ++--- juno_samples/edge_detection/src/cpu.sch | 4 ++-- juno_scheduler/src/pm.rs | 4 +++- 7 files changed, 40 insertions(+), 18 deletions(-) diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 6ad38fc0..552dc3a3 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -334,7 +334,7 @@ impl<'a> CPUContext<'a> { } } else { let (_, offsets) = &self.backing_allocation[&Device::LLVM]; - let offset = offsets[&id]; + let offset = offsets[&id].0; write!( body, " {} = getelementptr i8, ptr %backing, i64 %dc{}\n", diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 76aba7e0..5f2feedd 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -1290,7 +1290,7 @@ namespace cg = cooperative_groups; if !is_primitive && state == KernelState::OutBlock { assert!(self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant), "PANIC: The CUDA backend cannot lower a global memory constant that has to be reset to zero. This is because we cannot efficiently implement a memset to the underlying memory of the constant due to the need for a grid level sync. Consider floating this collection outside the CUDA function and into an AsyncRust function, or attaching the NoResetConstant schedule to indicate that no memset is semantically necessary."); let (_, offsets) = &self.backing_allocation[&Device::CUDA]; - let offset = offsets[&id]; + let offset = offsets[&id].0; write!( w, "{}{} = backing + dc{};\n", diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index 98f91e1f..9866400c 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -53,10 +53,15 @@ pub type NodeColors = BTreeMap<FunctionID, FunctionNodeColors>; /* * The allocation information of each function is a size of the backing memory * needed and offsets into that backing memory per constant object and call node - * in the function. + * in the function (as well as their individual sizes). */ -pub type FunctionBackingAllocation = - BTreeMap<Device, (DynamicConstantID, BTreeMap<NodeID, DynamicConstantID>)>; +pub type FunctionBackingAllocation = BTreeMap< + Device, + ( + DynamicConstantID, + BTreeMap<NodeID, (DynamicConstantID, DynamicConstantID)>, + ), +>; pub type BackingAllocations = BTreeMap<FunctionID, FunctionBackingAllocation>; pub const BACKED_DEVICES: [Device; 2] = [Device::LLVM, Device::CUDA]; diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index ddfa9503..3db0f16f 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -80,7 +80,7 @@ pub fn rt_codegen<W: Write>( typing: &Vec<TypeID>, control_subgraph: &Subgraph, fork_join_map: &HashMap<NodeID, NodeID>, - fork_control_map: &HashMap<NodeID, HashSet<NodeID>>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, fork_tree: &HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, collection_objects: &CollectionObjects, @@ -103,7 +103,7 @@ pub fn rt_codegen<W: Write>( control_subgraph, fork_join_map, join_fork_map: &join_fork_map, - fork_control_map, + fork_join_nest, fork_tree, nodes_in_fork_joins, collection_objects, @@ -124,7 +124,7 @@ struct RTContext<'a> { control_subgraph: &'a Subgraph, fork_join_map: &'a HashMap<NodeID, NodeID>, join_fork_map: &'a HashMap<NodeID, NodeID>, - fork_control_map: &'a HashMap<NodeID, HashSet<NodeID>>, + fork_join_nest: &'a HashMap<NodeID, Vec<NodeID>>, fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, collection_objects: &'a CollectionObjects, @@ -559,7 +559,7 @@ impl<'a> RTContext<'a> { Constant::Product(ty, _) | Constant::Summation(ty, _, _) | Constant::Array(ty) => { - let (device, offset) = self.backing_allocations[&self.func_id] + let (device, (offset, _)) = self.backing_allocations[&self.func_id] .iter() .filter_map(|(device, (_, offsets))| { offsets.get(&id).map(|id| (*device, *id)) @@ -676,13 +676,28 @@ impl<'a> RTContext<'a> { prefix, self.module.functions[callee_id.idx()].name )?; - for (device, offset) in self.backing_allocations[&self.func_id] + for (device, (offset, size)) in self.backing_allocations[&self.func_id] .iter() .filter_map(|(device, (_, offsets))| offsets.get(&id).map(|id| (*device, *id))) { - write!(block, "backing_{}.byte_add(", device.name())?; + write!(block, "backing_{}.byte_add(((", device.name())?; self.codegen_dynamic_constant(offset, block)?; - write!(block, " as usize), ")? + let forks = &self.fork_join_nest[&bb]; + if !forks.is_empty() { + write!(block, ") + ")?; + let mut linear_thread = "0".to_string(); + for fork in forks { + let factors = func.nodes[fork.idx()].try_fork().unwrap().1; + for (factor_idx, factor) in factors.into_iter().enumerate() { + linear_thread = format!("({} *", linear_thread); + self.codegen_dynamic_constant(*factor, &mut linear_thread)?; + write!(linear_thread, " + tid_{}_{})", fork.idx(), factor_idx)?; + } + } + write!(block, "{} * (", linear_thread)?; + self.codegen_dynamic_constant(size, block)?; + } + write!(block, ")) as usize), ")? } for dc in dynamic_constants { self.codegen_dynamic_constant(*dc, block)?; diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index b415371f..c612acac 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1647,7 +1647,7 @@ fn object_allocation( _liveness: &Liveness, backing_allocations: &BackingAllocations, ) -> FunctionBackingAllocation { - let mut fba = BTreeMap::new(); + let mut fba = FunctionBackingAllocation::new(); let node_ids = editor.node_ids(); editor.edit(|mut edit| { @@ -1661,8 +1661,8 @@ fn object_allocation( let (total, offsets) = fba.entry(device).or_insert_with(|| (zero, BTreeMap::new())); *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]); - offsets.insert(id, *total); let type_size = type_size(&mut edit, typing[id.idx()], alignments); + offsets.insert(id, (*total, type_size)); *total = edit.add_dynamic_constant(DynamicConstant::add(*total, type_size)); } } @@ -1689,7 +1689,6 @@ fn object_allocation( // We don't know the alignment requirement of the memory // in the callee, so just assume the largest alignment. *total = align(&mut edit, *total, LARGEST_ALIGNMENT); - offsets.insert(id, *total); // Substitute the dynamic constant parameters in the // callee's backing size. callee_backing_size = substitute_dynamic_constants( @@ -1697,6 +1696,7 @@ fn object_allocation( callee_backing_size, &mut edit, ); + offsets.insert(id, (*total, callee_backing_size)); // Multiply the backing allocation size of the // callee by the number of parallel threads that // will call the function. diff --git a/juno_samples/edge_detection/src/cpu.sch b/juno_samples/edge_detection/src/cpu.sch index 6f1ee14b..4bd3254b 100644 --- a/juno_samples/edge_detection/src/cpu.sch +++ b/juno_samples/edge_detection/src/cpu.sch @@ -107,8 +107,8 @@ simpl!(reject_zero_crossings); async-call(edge_detection@le, edge_detection@zc); -fork-split(gaussian_smoothing_body, laplacian_estimate, laplacian_estimate_body, zero_crossings, zero_crossings_body, gradient, reject_zero_crossings); -unforkify(gaussian_smoothing_body, laplacian_estimate, laplacian_estimate_body, zero_crossings, zero_crossings_body, gradient, reject_zero_crossings); +fork-split(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings); +unforkify(gaussian_smoothing_body, laplacian_estimate_body, zero_crossings_body, gradient, reject_zero_crossings); simpl!(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 77437a61..5f2fa4cc 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -900,6 +900,7 @@ impl PassManager { self.make_typing(); self.make_control_subgraphs(); self.make_fork_join_maps(); + self.make_fork_join_nests(); self.make_fork_control_maps(); self.make_fork_trees(); self.make_nodes_in_fork_joins(); @@ -917,6 +918,7 @@ impl PassManager { typing: Some(typing), control_subgraphs: Some(control_subgraphs), fork_join_maps: Some(fork_join_maps), + fork_join_nests: Some(fork_join_nests), fork_control_maps: Some(fork_control_maps), fork_trees: Some(fork_trees), nodes_in_fork_joins: Some(nodes_in_fork_joins), @@ -990,7 +992,7 @@ impl PassManager { &typing[idx], &control_subgraphs[idx], &fork_join_maps[idx], - &fork_control_maps[idx], + &fork_join_nests[idx], &fork_trees[idx], &nodes_in_fork_joins[idx], &collection_objects, -- GitLab