Skip to content
Snippets Groups Projects
Commit 12b5e54f authored by prathi3's avatar prathi3
Browse files

clean extra dim

parent 7849656e
No related branches found
No related tags found
1 merge request!115GPU backend
Pipeline #201299 passed
......@@ -288,7 +288,8 @@ impl GPUContext<'_> {
self.get_root_forks_and_num_blocks(self.fork_tree);
let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel);
let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(self.fork_tree, thread_root_root_fork);
// TODO: Uncomment and adjust once we know logic of extra dim
// TODO: Uncomment and adjust once we know logic of extra dim. This will affect constant
// collections, reads, and writes.
// let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
let extra_dim_collects = HashSet::new();
......@@ -749,30 +750,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
fork_control_map: &HashMap<NodeID, HashSet<NodeID>>,
fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
) -> HashSet<TypeID> {
// Get all constant collection creations
let collect_consts: HashSet<NodeID> = (0..self.function.nodes.len())
.filter(|idx| self.function.nodes[*idx].is_constant() && !self.types[self.typing[*idx].idx()].is_primitive())
.map(|idx| NodeID::new(idx))
.collect();
// Reverse fork_control_map
let control_fork_map: HashMap<NodeID, NodeID> = fork_control_map.iter()
.flat_map(|(fork, controls)| {
controls.iter().map(move |control| (*control, *fork))
})
.collect();
// Get all uses of each collection, map each use to basic block, then map each basic block to fork
let collect_fork_users: HashMap<NodeID, HashSet<NodeID>> = collect_consts.iter()
.map(|collect_const| {
(*collect_const, self.def_use_map.get_users(*collect_const))
})
.map(|(collect_const, users)| {
(collect_const, users.iter().map(|user| control_fork_map[&self.bbs.0[user.idx()]]).collect())
})
.collect();
collect_fork_users.iter()
.filter(|(_, fork_users)| !fork_thread_quota_map.contains_key(fork_users.iter().next().unwrap()))
.map(|(collect_const, _)| self.typing[collect_const.idx()])
.collect()
// Determine which fork each collection is used in, and check if it's
// parallelized via the fork_thread_quota_map.
todo!()
}
fn codegen_data_control_no_forks(
......@@ -1237,7 +1217,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
}
// Read of primitive requires load after pointer math.
Node::Read { collect, indices } => {
let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()]));
let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects);
let data_type_id = self.typing[id.idx()];
if self.types[data_type_id.idx()].is_primitive() {
let type_name = self.get_type(data_type_id, true);
......@@ -1253,7 +1233,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
data,
indices,
} => {
let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()]));
let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects);
let data_variable = self.get_value(*data, false, false);
let data_type_id = self.typing[data.idx()];
let cg_tile = match state {
......@@ -1452,27 +1432,31 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
* This function emits collection name + pointer math for the provided indices.
* All collection types use char pointers.
*/
fn codegen_collect(&self, collect: NodeID, indices: &[Index], has_extra_dim: bool) -> String {
fn codegen_collect(&self, collect: NodeID, indices: &[Index], extra_dim_collects: &HashSet<TypeID>) -> String {
let mut index_ptr = "0".to_string();
let type_id = self.typing[collect.idx()];
for index in indices {
match index {
Index::Field(field) => {
self.get_size(type_id, Some(*field), None);
self.get_size(type_id, Some(*field), Some(extra_dim_collects));
}
// Variants of summations have zero offset
Index::Variant(_) => {}
// Convert multi-d array index to 1-d index, and optionally
// convert to single-byte index by multiplying by element size
Index::Position(array_indices) => {
let has_extra_dim = extra_dim_collects.contains(&self.typing[collect.idx()]);
if has_extra_dim {
continue;
}
let Type::Array(element_type, extents) =
&self.types[self.typing[collect.idx()].idx()]
else {
panic!("Expected array type")
};
let mut cumulative_offset = multiply_dcs(&extents[array_indices.len()..]);
let max_left_array_index = array_indices.len() - 1 - if has_extra_dim { 1 } else { 0 };
for (i, index) in array_indices.iter().skip(if has_extra_dim { 1 } else { 0 }).rev().enumerate() {
let max_left_array_index = array_indices.len() - 1;
for (i, index) in array_indices.iter().rev().enumerate() {
cumulative_offset = format!(
"{} * ({}{}",
cumulative_offset,
......@@ -1487,9 +1471,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
index_ptr.push_str(&format!(
" + {}{}",
cumulative_offset,
")".repeat(array_indices.len() - if has_extra_dim { 1 } else { 0 })
")".repeat(array_indices.len())
));
let element_size = self.get_size(*element_type, None, None);
let element_size = self.get_size(*element_type, None, Some(extra_dim_collects));
index_ptr.push_str(&format!(" * {}", element_size));
}
}
......@@ -1556,7 +1540,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
format!("*reinterpret_cast<{}>({}+{})", field_type, name, offset),
constant_fields[i],
false,
extra_dim_collects,
None,
dynamic_shared_offset,
w,
num_tabs,
......@@ -1619,7 +1603,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
fn get_size(&self, type_id: TypeID, num_fields: Option<usize>, extra_dim_collects: Option<&HashSet<TypeID>>) -> String {
match &self.types[type_id.idx()] {
Type::Array(element_type, extents) => {
let array_size = multiply_dcs(if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { &extents[1..] } else { extents });
let array_size = if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { "1".to_string() } else { multiply_dcs(extents) };
format!("{} * {}", self.get_alignment(*element_type), array_size)
}
Type::Product(fields) => {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment