diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index e6b540aed6218f56e0a1fa33b0e4b1e1db09e9b0..f3b38878b9739eb7cb2a3a14fbbb6606a731877e 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -1761,23 +1761,39 @@ extern \"C\" {} {}(", extra_dim_collects: &HashSet<TypeID>, ) -> String { let mut index_ptr = "0".to_string(); - let type_id = self.typing[collect.idx()]; + let mut type_id = self.typing[collect.idx()]; for index in indices { match index { Index::Field(field) => { - self.get_size(type_id, Some(*field), Some(extra_dim_collects)); + index_ptr.push_str(&format!( + " + ({})", + self.get_size(type_id, Some(*field), Some(extra_dim_collects)))); + type_id = + if let Type::Product(fields) = + &self.types[type_id.idx()] { + fields[*field] + } else { + panic!("Expected product type") + }; } // Variants of summations have zero offset - Index::Variant(_) => {} + Index::Variant(index) => { + type_id = + if let Type::Summation(variants) = &self.types[type_id.idx()] { + variants[*index] + } else { + panic!("Expected summation type") + }; + } // Convert multi-d array index to 1-d index, and optionally // convert to single-byte index by multiplying by element size Index::Position(array_indices) => { - let has_extra_dim = extra_dim_collects.contains(&self.typing[collect.idx()]); + let has_extra_dim = extra_dim_collects.contains(&type_id); if has_extra_dim { continue; } let Type::Array(element_type, extents) = - &self.types[self.typing[collect.idx()].idx()] + &self.types[type_id.idx()] else { panic!("Expected array type") }; @@ -1801,7 +1817,8 @@ extern \"C\" {} {}(", ")".repeat(array_indices.len()) )); let element_size = self.get_size(*element_type, None, Some(extra_dim_collects)); - index_ptr.push_str(&format!(" * {}", element_size)); + index_ptr.push_str(&format!(" * ({})", element_size)); + type_id = *element_type; } } } @@ -1987,6 +2004,7 @@ extern \"C\" {} {}(", ) -> String { match &self.types[type_id.idx()] { Type::Array(element_type, extents) => { + assert!(num_fields.is_none()); let array_size = if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { @@ -1997,12 +2015,11 @@ extern \"C\" {} {}(", format!("{} * {}", self.get_alignment(*element_type), array_size) } Type::Product(fields) => { - let num_fields = &num_fields.unwrap_or(fields.len()); - let with_field = fields + let num_fields = num_fields.unwrap_or(fields.len()); + fields .iter() - .enumerate() - .filter(|(i, _)| i < num_fields) - .map(|(_, id)| { + .take(num_fields) + .map(|id| { ( self.get_size(*id, None, extra_dim_collects), self.get_alignment(*id), @@ -2017,18 +2034,10 @@ extern \"C\" {} {}(", acc, align, align, align, size ) } - }); - if num_fields < &fields.len() { - format!( - "{} - {}", - with_field, - self.get_size(fields[*num_fields], None, extra_dim_collects) - ) - } else { - with_field - } + }) } Type::Summation(variants) => { + assert!(num_fields.is_none()); // The argmax variant by size is not guaranteed to be same as // argmax variant by alignment, eg product of 3 4-byte primitives // vs 1 8-byte primitive, so we need to calculate both. @@ -2052,7 +2061,10 @@ extern \"C\" {} {}(", max_size, max_alignment, max_alignment, max_alignment ) } - _ => format!("{}", self.get_alignment(type_id)), + _ => { + assert!(num_fields.is_none()); + format!("{}", self.get_alignment(type_id)) + } } }