diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 6c62ed76392dabc10c0fc48bf20d56d8577ad99e..d7a6d258358920d8a42aea1530843c08e223022a 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -288,7 +288,8 @@ impl GPUContext<'_> { self.get_root_forks_and_num_blocks(self.fork_tree); let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel); let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(self.fork_tree, thread_root_root_fork); - // TODO: Uncomment and adjust once we know logic of extra dim + // TODO: Uncomment and adjust once we know logic of extra dim. This will affect constant + // collections, reads, and writes. // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map); let extra_dim_collects = HashSet::new(); @@ -749,30 +750,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; fork_control_map: &HashMap<NodeID, HashSet<NodeID>>, fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>, ) -> HashSet<TypeID> { - // Get all constant collection creations - let collect_consts: HashSet<NodeID> = (0..self.function.nodes.len()) - .filter(|idx| self.function.nodes[*idx].is_constant() && !self.types[self.typing[*idx].idx()].is_primitive()) - .map(|idx| NodeID::new(idx)) - .collect(); - // Reverse fork_control_map - let control_fork_map: HashMap<NodeID, NodeID> = fork_control_map.iter() - .flat_map(|(fork, controls)| { - controls.iter().map(move |control| (*control, *fork)) - }) - .collect(); - // Get all uses of each collection, map each use to basic block, then map each basic block to fork - let collect_fork_users: HashMap<NodeID, HashSet<NodeID>> = collect_consts.iter() - .map(|collect_const| { - (*collect_const, self.def_use_map.get_users(*collect_const)) - }) - .map(|(collect_const, users)| { - (collect_const, users.iter().map(|user| control_fork_map[&self.bbs.0[user.idx()]]).collect()) - }) - .collect(); - collect_fork_users.iter() - .filter(|(_, fork_users)| !fork_thread_quota_map.contains_key(fork_users.iter().next().unwrap())) - .map(|(collect_const, _)| self.typing[collect_const.idx()]) - .collect() + // Determine which fork each collection is used in, and check if it's + // parallelized via the fork_thread_quota_map. + todo!() } fn codegen_data_control_no_forks( @@ -1237,7 +1217,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; } // Read of primitive requires load after pointer math. Node::Read { collect, indices } => { - let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()])); + let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects); let data_type_id = self.typing[id.idx()]; if self.types[data_type_id.idx()].is_primitive() { let type_name = self.get_type(data_type_id, true); @@ -1253,7 +1233,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; data, indices, } => { - let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()])); + let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects); let data_variable = self.get_value(*data, false, false); let data_type_id = self.typing[data.idx()]; let cg_tile = match state { @@ -1452,27 +1432,31 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; * This function emits collection name + pointer math for the provided indices. * All collection types use char pointers. */ - fn codegen_collect(&self, collect: NodeID, indices: &[Index], has_extra_dim: bool) -> String { + fn codegen_collect(&self, collect: NodeID, indices: &[Index], extra_dim_collects: &HashSet<TypeID>) -> String { let mut index_ptr = "0".to_string(); let type_id = self.typing[collect.idx()]; for index in indices { match index { Index::Field(field) => { - self.get_size(type_id, Some(*field), None); + self.get_size(type_id, Some(*field), Some(extra_dim_collects)); } // Variants of summations have zero offset Index::Variant(_) => {} // Convert multi-d array index to 1-d index, and optionally // convert to single-byte index by multiplying by element size Index::Position(array_indices) => { + let has_extra_dim = extra_dim_collects.contains(&self.typing[collect.idx()]); + if has_extra_dim { + continue; + } let Type::Array(element_type, extents) = &self.types[self.typing[collect.idx()].idx()] else { panic!("Expected array type") }; let mut cumulative_offset = multiply_dcs(&extents[array_indices.len()..]); - let max_left_array_index = array_indices.len() - 1 - if has_extra_dim { 1 } else { 0 }; - for (i, index) in array_indices.iter().skip(if has_extra_dim { 1 } else { 0 }).rev().enumerate() { + let max_left_array_index = array_indices.len() - 1; + for (i, index) in array_indices.iter().rev().enumerate() { cumulative_offset = format!( "{} * ({}{}", cumulative_offset, @@ -1487,9 +1471,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; index_ptr.push_str(&format!( " + {}{}", cumulative_offset, - ")".repeat(array_indices.len() - if has_extra_dim { 1 } else { 0 }) + ")".repeat(array_indices.len()) )); - let element_size = self.get_size(*element_type, None, None); + let element_size = self.get_size(*element_type, None, Some(extra_dim_collects)); index_ptr.push_str(&format!(" * {}", element_size)); } } @@ -1556,7 +1540,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; format!("*reinterpret_cast<{}>({}+{})", field_type, name, offset), constant_fields[i], false, - extra_dim_collects, + None, dynamic_shared_offset, w, num_tabs, @@ -1619,7 +1603,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; fn get_size(&self, type_id: TypeID, num_fields: Option<usize>, extra_dim_collects: Option<&HashSet<TypeID>>) -> String { match &self.types[type_id.idx()] { Type::Array(element_type, extents) => { - let array_size = multiply_dcs(if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { &extents[1..] } else { extents }); + let array_size = if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { "1".to_string() } else { multiply_dcs(extents) }; format!("{} * {}", self.get_alignment(*element_type), array_size) } Type::Product(fields) => {