Skip to content
Snippets Groups Projects
Commit 32ad1e82 authored by Praneet Rathi's avatar Praneet Rathi
Browse files

untested

parent 6af2e9ec
No related branches found
No related tags found
1 merge request!115GPU backend
Pipeline #201084 passed
...@@ -80,17 +80,17 @@ pub fn gpu_codegen<W: Write>( ...@@ -80,17 +80,17 @@ pub fn gpu_codegen<W: Write>(
.collect(); .collect();
let fork_join_map = &fork_join_map(function, control_subgraph); let fork_join_map = fork_join_map(function, control_subgraph);
let join_fork_map: &HashMap<NodeID, NodeID> = &fork_join_map let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
.into_iter() .iter()
.map(|(fork, join)| (*join, *fork)) .map(|(fork, join)| (*join, *fork))
.collect(); .collect();
// Fork Reduce map should have all reduces contained in some key // Fork Reduce map should have all reduces contained in some key
let fork_reduce_map: &mut HashMap<NodeID, Vec<NodeID>> = &mut HashMap::new(); let mut fork_reduce_map: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
// Reduct Reduce map should have all non-parallel and non-associative reduces // Reduct Reduce map should have all non-parallel and non-associative reduces
// contained in some key. Unlike Fork, Reduct is not involved in any assertions. // contained in some key. Unlike Fork, Reduct is not involved in any assertions.
// It's placed here for convenience but can be moved. // It's placed here for convenience but can be moved.
let reduct_reduce_map: &mut HashMap<NodeID, Vec<NodeID>> = &mut HashMap::new(); let mut reduct_reduce_map: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
for reduce_node in &reduce_nodes { for reduce_node in &reduce_nodes {
if let Node::Reduce { if let Node::Reduce {
control, control,
...@@ -124,11 +124,13 @@ pub fn gpu_codegen<W: Write>( ...@@ -124,11 +124,13 @@ pub fn gpu_codegen<W: Write>(
} }
} }
for idx in 0..function.nodes.len() { for idx in 0..function.nodes.len() {
if function.nodes[idx].is_fork() if function.nodes[idx].is_fork() {
&& fork_reduce_map assert!(fork_reduce_map
.get(&NodeID::new(idx)).is_none_or(|reduces| reduces.is_empty()) .get(&NodeID::new(idx))
{ .is_none_or(|reduces| reduces.is_empty()),
panic!("Fork node {} has no reduce nodes", idx); "Fork node {} has no reduce nodes",
idx
);
} }
} }
...@@ -155,7 +157,7 @@ pub fn gpu_codegen<W: Write>( ...@@ -155,7 +157,7 @@ pub fn gpu_codegen<W: Write>(
(NodeID::new(pos), *data) (NodeID::new(pos), *data)
}; };
let return_type_id = &typing[data_node_id.idx()]; let return_type_id = typing[data_node_id.idx()];
let return_type = &types[return_type_id.idx()]; let return_type = &types[return_type_id.idx()];
let return_param_idx = if !return_type.is_primitive() { let return_param_idx = if !return_type.is_primitive() {
let objects = &collection_objects.objects(data_node_id); let objects = &collection_objects.objects(data_node_id);
...@@ -186,7 +188,7 @@ pub fn gpu_codegen<W: Write>( ...@@ -186,7 +188,7 @@ pub fn gpu_codegen<W: Write>(
// Map from control to pairs of data to update phi // Map from control to pairs of data to update phi
// For each phi, we go to its region and get region's controls // For each phi, we go to its region and get region's controls
let control_data_phi_map: &mut HashMap<NodeID, Vec<(NodeID, NodeID)>> = &mut HashMap::new(); let mut control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>> = HashMap::new();
for (idx, node) in function.nodes.iter().enumerate() { for (idx, node) in function.nodes.iter().enumerate() {
if let Node::Phi { control, data } = node { if let Node::Phi { control, data } = node {
let Node::Region { preds } = &function.nodes[control.idx()] else { let Node::Region { preds } = &function.nodes[control.idx()] else {
...@@ -237,12 +239,12 @@ struct GPUContext<'a> { ...@@ -237,12 +239,12 @@ struct GPUContext<'a> {
bbs: &'a BasicBlocks, bbs: &'a BasicBlocks,
kernel_params: &'a GPUKernelParams, kernel_params: &'a GPUKernelParams,
def_use_map: &'a ImmutableDefUseMap, def_use_map: &'a ImmutableDefUseMap,
fork_join_map: &'a HashMap<NodeID, NodeID>, fork_join_map: HashMap<NodeID, NodeID>,
join_fork_map: &'a HashMap<NodeID, NodeID>, join_fork_map: HashMap<NodeID, NodeID>,
fork_reduce_map: &'a HashMap<NodeID, Vec<NodeID>>, fork_reduce_map: HashMap<NodeID, Vec<NodeID>>,
reduct_reduce_map: &'a HashMap<NodeID, Vec<NodeID>>, reduct_reduce_map: HashMap<NodeID, Vec<NodeID>>,
control_data_phi_map: &'a HashMap<NodeID, Vec<(NodeID, NodeID)>>, control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
return_type_id: &'a TypeID, return_type_id: TypeID,
return_param_idx: Option<usize>, return_param_idx: Option<usize>,
} }
...@@ -318,7 +320,7 @@ impl GPUContext<'_> { ...@@ -318,7 +320,7 @@ impl GPUContext<'_> {
(1, 1) (1, 1)
} else { } else {
// Create structures and determine block and thread parallelization strategy // Create structures and determine block and thread parallelization strategy
let (fork_tree, fork_control_map) = self.make_fork_structures(self.fork_join_map); let (fork_tree, fork_control_map) = self.make_fork_structures(&self.fork_join_map);
let (root_forks, num_blocks) = let (root_forks, num_blocks) =
self.get_root_forks_and_num_blocks(&fork_tree, self.kernel_params.max_num_blocks); self.get_root_forks_and_num_blocks(&fork_tree, self.kernel_params.max_num_blocks);
let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &fork_tree, num_blocks); let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &fork_tree, num_blocks);
...@@ -422,7 +424,7 @@ namespace cg = cooperative_groups; ...@@ -422,7 +424,7 @@ namespace cg = cooperative_groups;
write!( write!(
w, w,
"{} __restrict__ ret", "{} __restrict__ ret",
self.get_type(*self.return_type_id, true) self.get_type(self.return_type_id, true)
)?; )?;
} }
...@@ -536,7 +538,7 @@ namespace cg = cooperative_groups; ...@@ -536,7 +538,7 @@ namespace cg = cooperative_groups;
// need to pass arguments to kernel, so we keep track of the arguments here. // need to pass arguments to kernel, so we keep track of the arguments here.
let mut pass_args = String::new(); let mut pass_args = String::new();
let ret_primitive = self.types[self.return_type_id.idx()].is_primitive(); let ret_primitive = self.types[self.return_type_id.idx()].is_primitive();
let ret_type = self.get_type(*self.return_type_id, false); let ret_type = self.get_type(self.return_type_id, false);
write!(w, " write!(w, "
extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
// The first set of parameters are dynamic constants. // The first set of parameters are dynamic constants.
...@@ -566,7 +568,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -566,7 +568,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
write!(w, ") {{\n")?; write!(w, ") {{\n")?;
// Pull primitive return as pointer parameter for kernel // Pull primitive return as pointer parameter for kernel
if ret_primitive { if ret_primitive {
let ret_type_pnt = self.get_type(*self.return_type_id, true); let ret_type_pnt = self.get_type(self.return_type_id, true);
write!(w, "\t{} ret;\n", ret_type_pnt)?; write!(w, "\t{} ret;\n", ret_type_pnt)?;
write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?;
if !first_param { if !first_param {
...@@ -1267,16 +1269,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1267,16 +1269,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
// If we read collection, distribute elements among threads with cg // If we read collection, distribute elements among threads with cg
// sync after. If we read primitive, copy read on all threads. // sync after. If we read primitive, copy read on all threads.
Node::Read { collect, indices } => { Node::Read { collect, indices } => {
let is_char = self.is_char(self.typing[collect.idx()]); let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()]));
let collect_with_indices = self.codegen_collect(*collect, indices, is_char, extra_dim_collects.contains(&self.typing[collect.idx()]));
let data_type_id = self.typing[id.idx()]; let data_type_id = self.typing[id.idx()];
if self.types[data_type_id.idx()].is_primitive() { if self.types[data_type_id.idx()].is_primitive() {
if is_char { let type_name = self.get_type(data_type_id, true);
let type_name = self.get_type(data_type_id, true); write!(w, "{}{} = *reinterpret_cast<{}>({});\n", tabs, define_variable, type_name, collect_with_indices)?;
write!(w, "{}{} = *reinterpret_cast<{}>({});\n", tabs, define_variable, type_name, collect_with_indices)?;
} else {
write!(w, "{}{} = *({});\n", tabs, define_variable, collect_with_indices)?;
}
} else { } else {
if KernelState::OutBlock == state && num_blocks.unwrap() > 1 { if KernelState::OutBlock == state && num_blocks.unwrap() > 1 {
panic!("GPU can't guarantee correctness for multi-block collection reads"); panic!("GPU can't guarantee correctness for multi-block collection reads");
...@@ -1287,13 +1284,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1287,13 +1284,12 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
}; };
// Divide up "elements", which are collection size divided // Divide up "elements", which are collection size divided
// by element size, among threads. // by element size, among threads.
let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects), Some(true)); let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects));
let num_elements = format!("({})", data_size); write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?;
write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, define_variable, collect_with_indices)?; write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, define_variable, collect_with_indices)?;
write!(w, "{}}}\n", tabs)?; write!(w, "{}}}\n", tabs)?;
write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?; write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, define_variable, cg_tile, num_elements, cg_tile, cg_tile, collect_with_indices, cg_tile, num_elements, cg_tile, cg_tile)?; write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, define_variable, cg_tile, data_size, cg_tile, cg_tile, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile)?;
write!(w, "{}}}\n", tabs)?; write!(w, "{}}}\n", tabs)?;
write!(w, "{}{}.sync();\n", tabs, cg_tile)?; write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
} }
...@@ -1305,8 +1301,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1305,8 +1301,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
data, data,
indices, indices,
} => { } => {
let is_char = self.is_char(self.typing[collect.idx()]); let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()]));
let collect_with_indices = self.codegen_collect(*collect, indices, is_char, extra_dim_collects.contains(&self.typing[collect.idx()]));
let data_variable = self.get_value(*data, false, false); let data_variable = self.get_value(*data, false, false);
let data_type_id = self.typing[data.idx()]; let data_type_id = self.typing[data.idx()];
if KernelState::OutBlock == state && num_blocks.unwrap() > 1 { if KernelState::OutBlock == state && num_blocks.unwrap() > 1 {
...@@ -1318,21 +1313,16 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1318,21 +1313,16 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
}; };
if self.types[data_type_id.idx()].is_primitive() { if self.types[data_type_id.idx()].is_primitive() {
write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?; write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?;
if is_char { let type_name = self.get_type(data_type_id, true);
let type_name = self.get_type(data_type_id, true); write!(w, "{}\t*reinterpret_cast<{}>({}) = {};\n", tabs, type_name, collect_with_indices, data_variable)?;
write!(w, "{}\t*reinterpret_cast<{}>({}) = {};\n", tabs, type_name, collect_with_indices, data_variable)?;
} else {
write!(w, "{}\t*({}) = {};\n", tabs, collect_with_indices, data_variable)?;
}
write!(w, "{}}}\n", tabs)?; write!(w, "{}}}\n", tabs)?;
} else { } else {
let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects), Some(true)); let data_size = self.get_size(data_type_id, None, Some(extra_dim_collects));
let num_elements = format!("({})", data_size); write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
write!(w, "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?;
write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, collect_with_indices, data_variable)?; write!(w, "{}\t*({} + i) = *({} + i);\n", tabs, collect_with_indices, data_variable)?;
write!(w, "{}}}\n", tabs)?; write!(w, "{}}}\n", tabs)?;
write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, num_elements, cg_tile)?; write!(w, "{}if ({}.thread_rank() < {} % {}.size()) {{\n", tabs, cg_tile, data_size, cg_tile)?;
write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, num_elements, cg_tile, cg_tile, data_variable, cg_tile, num_elements, cg_tile, cg_tile)?; write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile, data_variable, cg_tile, data_size, cg_tile, cg_tile)?;
write!(w, "{}}}\n", tabs)?; write!(w, "{}}}\n", tabs)?;
} }
write!(w, "{}{}.sync();\n", tabs, cg_tile)?; write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
...@@ -1508,18 +1498,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1508,18 +1498,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
/* /*
* This function emits collection name + pointer math for the provided indices. * This function emits collection name + pointer math for the provided indices.
* One nuance is whether the collection is represented as char pointer or * All collection types use char pointers.
* the original primitive pointer. For Field, it's always char, for Variant,
* it doesn't matter here, and for Array, it depends- so we may need to tack
* on the element size to the index math.
*/ */
fn codegen_collect(&self, collect: NodeID, indices: &[Index], is_char: bool, has_extra_dim: bool) -> String { fn codegen_collect(&self, collect: NodeID, indices: &[Index], has_extra_dim: bool) -> String {
let mut index_ptr = "0".to_string(); let mut index_ptr = "0".to_string();
let type_id = self.typing[collect.idx()]; let type_id = self.typing[collect.idx()];
for index in indices { for index in indices {
match index { match index {
Index::Field(field) => { Index::Field(field) => {
self.get_size(type_id, Some(*field), None, None); self.get_size(type_id, Some(*field), None);
} }
// Variants of summations have zero offset // Variants of summations have zero offset
Index::Variant(_) => {} Index::Variant(_) => {}
...@@ -1550,10 +1537,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1550,10 +1537,8 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
cumulative_offset, cumulative_offset,
")".repeat(array_indices.len() - if has_extra_dim { 1 } else { 0 }) ")".repeat(array_indices.len() - if has_extra_dim { 1 } else { 0 })
)); ));
if is_char { let element_size = self.get_size(*element_type, None, None);
let element_size = self.get_size(*element_type, None, None, None); index_ptr.push_str(&format!(" * {}", element_size));
index_ptr.push_str(&format!(" * {}", element_size));
}
} }
} }
} }
...@@ -1600,7 +1585,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1600,7 +1585,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
Constant::Product(type_id, constant_fields) => { Constant::Product(type_id, constant_fields) => {
if allow_allocate { if allow_allocate {
let alignment = self.get_alignment(*type_id); let alignment = self.get_alignment(*type_id);
let size = self.get_size(*type_id, None, extra_dim_collects, None); let size = self.get_size(*type_id, None, extra_dim_collects);
*dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment); *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment);
write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?; write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?;
write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?; write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?;
...@@ -1612,7 +1597,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1612,7 +1597,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
for i in 0..constant_fields.len() { for i in 0..constant_fields.len() {
// For each field update offset and issue recursive call // For each field update offset and issue recursive call
let field_type = self.get_type(type_fields[i], true); let field_type = self.get_type(type_fields[i], true);
let offset = self.get_size(type_fields[i], Some(i), extra_dim_collects, None); let offset = self.get_size(type_fields[i], Some(i), extra_dim_collects);
let field_constant = &self.constants[constant_fields[i].idx()]; let field_constant = &self.constants[constant_fields[i].idx()];
if field_constant.is_scalar() { if field_constant.is_scalar() {
self.codegen_constant( self.codegen_constant(
...@@ -1632,7 +1617,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1632,7 +1617,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
Constant::Summation(type_id, variant, field) => { Constant::Summation(type_id, variant, field) => {
if allow_allocate { if allow_allocate {
let alignment = self.get_alignment(*type_id); let alignment = self.get_alignment(*type_id);
let size = self.get_size(*type_id, None, extra_dim_collects, None); let size = self.get_size(*type_id, None, extra_dim_collects);
*dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment); *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment);
write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?; write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?;
write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?; write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?;
...@@ -1660,18 +1645,14 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1660,18 +1645,14 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
}; };
} }
Constant::Array(type_id) => { Constant::Array(type_id) => {
let Type::Array(element_type, _) = &self.types[type_id.idx()] else {
panic!("Expected array type")
};
if !allow_allocate { if !allow_allocate {
panic!("Nested array constant should not be re-allocated"); panic!("Nested array constant should not be re-allocated");
} }
let alignment = self.get_alignment(*type_id); let alignment = self.get_alignment(*type_id);
let size = self.get_size(*type_id, None, extra_dim_collects, None); let size = self.get_size(*type_id, None, extra_dim_collects);
let element_type = self.get_type(*element_type, true);
*dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment); *dynamic_shared_offset = format!("(({} + {} - 1) / {}) * {}", dynamic_shared_offset, alignment, alignment, alignment);
write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?; write!(w, "{}dynamic_shared_offset = {};\n", tabs, dynamic_shared_offset)?;
write!(w, "{}{} = reinterpret_cast<{}>(dynamic_shared + dynamic_shared_offset);\n", tabs, name, element_type)?; write!(w, "{}{} = dynamic_shared + dynamic_shared_offset;\n", tabs, name)?;
*dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size); *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size);
} }
} }
...@@ -1684,15 +1665,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1684,15 +1665,11 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
* and offset to 2nd field. This is useful for constant initialization and read/write * and offset to 2nd field. This is useful for constant initialization and read/write
* index math. * index math.
*/ */
fn get_size(&self, type_id: TypeID, num_fields: Option<usize>, extra_dim_collects: Option<&HashSet<TypeID>>, exclude_element_size: Option<bool>) -> String { fn get_size(&self, type_id: TypeID, num_fields: Option<usize>, extra_dim_collects: Option<&HashSet<TypeID>>) -> String {
match &self.types[type_id.idx()] { match &self.types[type_id.idx()] {
Type::Array(element_type, extents) => { Type::Array(element_type, extents) => {
let array_size = multiply_dcs(if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { &extents[1..] } else { extents }); let array_size = multiply_dcs(if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { &extents[1..] } else { extents });
if exclude_element_size.unwrap_or(false) { format!("{} * {}", self.get_alignment(*element_type), array_size)
array_size
} else {
format!("{} * {}", self.get_alignment(*element_type), array_size)
}
} }
Type::Product(fields) => { Type::Product(fields) => {
let num_fields = &num_fields.unwrap_or(fields.len()); let num_fields = &num_fields.unwrap_or(fields.len());
...@@ -1700,7 +1677,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1700,7 +1677,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
.iter() .iter()
.enumerate() .enumerate()
.filter(|(i, _)| i < num_fields) .filter(|(i, _)| i < num_fields)
.map(|(_, id)| (self.get_size(*id, None, extra_dim_collects, None), self.get_alignment(*id))) .map(|(_, id)| (self.get_size(*id, None, extra_dim_collects), self.get_alignment(*id)))
.fold(String::from("0"), |acc, (size, align)| { .fold(String::from("0"), |acc, (size, align)| {
if acc == "0" { if acc == "0" {
size size
...@@ -1715,7 +1692,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1715,7 +1692,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
format!( format!(
"{} - {}", "{} - {}",
with_field, with_field,
self.get_size(fields[*num_fields], None, extra_dim_collects, None) self.get_size(fields[*num_fields], None, extra_dim_collects)
) )
} else { } else {
with_field with_field
...@@ -1725,7 +1702,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1725,7 +1702,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
// The argmax variant by size is not guaranteed to be same as // The argmax variant by size is not guaranteed to be same as
// argmax variant by alignment, eg product of 3 4-byte primitives // argmax variant by alignment, eg product of 3 4-byte primitives
// vs 1 8-byte primitive, so we need to calculate both. // vs 1 8-byte primitive, so we need to calculate both.
let max_size = variants.iter().map(|id| self.get_size(*id, None, extra_dim_collects, None)).fold( let max_size = variants.iter().map(|id| self.get_size(*id, None, extra_dim_collects)).fold(
String::from("0"), String::from("0"),
|acc, x| { |acc, x| {
if acc == "0" { if acc == "0" {
...@@ -1880,16 +1857,6 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1880,16 +1857,6 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
func_name.to_string() func_name.to_string()
} }
// Check if a type should be represented as char*. Must be a product,
// summation, or array of product/summation types.
fn is_char(&self, type_id: TypeID) -> bool {
match &self.types[type_id.idx()] {
Type::Product(_) | Type::Summation(_) => true,
Type::Array(element_type, _) => self.is_char(*element_type),
_ => false,
}
}
fn get_cg_tile(&self, fork: NodeID, cg_type: CGType) -> String { fn get_cg_tile(&self, fork: NodeID, cg_type: CGType) -> String {
format!("cg_{}{}", self.get_value(fork, false, false), if cg_type == CGType::Use { "_use" } else if cg_type == CGType::Available { "_available" } else { "" }) format!("cg_{}{}", self.get_value(fork, false, false), if cg_type == CGType::Use { "_use" } else if cg_type == CGType::Available { "_available" } else { "" })
} }
...@@ -1938,12 +1905,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -1938,12 +1905,10 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
} }
fn get_type(&self, id: TypeID, make_pointer: bool) -> String { fn get_type(&self, id: TypeID, make_pointer: bool) -> String {
match &self.types[id.idx()] { if self.types[id.idx()].is_primitive() {
// Product and summation collections are char* for 1 byte-addressability convert_type(&self.types[id.idx()], make_pointer)
// since we can have variable type fields } else {
Type::Product(_) | Type::Summation(_) => "char*".to_string(), "char*".to_string()
Type::Array(element_type, _) => self.get_type(*element_type, true),
_ => convert_type(&self.types[id.idx()], make_pointer),
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment