diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index e6b540aed6218f56e0a1fa33b0e4b1e1db09e9b0..f3b38878b9739eb7cb2a3a14fbbb6606a731877e 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -1761,23 +1761,39 @@ extern \"C\" {} {}(",
         extra_dim_collects: &HashSet<TypeID>,
     ) -> String {
         let mut index_ptr = "0".to_string();
-        let type_id = self.typing[collect.idx()];
+        let mut type_id = self.typing[collect.idx()];
         for index in indices {
             match index {
                 Index::Field(field) => {
-                    self.get_size(type_id, Some(*field), Some(extra_dim_collects));
+                    index_ptr.push_str(&format!(
+                        " + ({})",
+                        self.get_size(type_id, Some(*field), Some(extra_dim_collects))));
+                    type_id =
+                        if let Type::Product(fields) =
+                            &self.types[type_id.idx()] {
+                            fields[*field]
+                        } else {
+                            panic!("Expected product type")
+                        };
                 }
                 // Variants of summations have zero offset
-                Index::Variant(_) => {}
+                Index::Variant(index) => {
+                    type_id =
+                        if let Type::Summation(variants) = &self.types[type_id.idx()] {
+                            variants[*index]
+                        } else {
+                            panic!("Expected summation type")
+                        };
+                }
                 // Convert multi-d array index to 1-d index, and optionally
                 // convert to single-byte index by multiplying by element size
                 Index::Position(array_indices) => {
-                    let has_extra_dim = extra_dim_collects.contains(&self.typing[collect.idx()]);
+                    let has_extra_dim = extra_dim_collects.contains(&type_id);
                     if has_extra_dim {
                         continue;
                     }
                     let Type::Array(element_type, extents) =
-                        &self.types[self.typing[collect.idx()].idx()]
+                        &self.types[type_id.idx()]
                     else {
                         panic!("Expected array type")
                     };
@@ -1801,7 +1817,8 @@ extern \"C\" {} {}(",
                         ")".repeat(array_indices.len())
                     ));
                     let element_size = self.get_size(*element_type, None, Some(extra_dim_collects));
-                    index_ptr.push_str(&format!(" * {}", element_size));
+                    index_ptr.push_str(&format!(" * ({})", element_size));
+                    type_id = *element_type;
                 }
             }
         }
@@ -1987,6 +2004,7 @@ extern \"C\" {} {}(",
     ) -> String {
         match &self.types[type_id.idx()] {
             Type::Array(element_type, extents) => {
+                assert!(num_fields.is_none());
                 let array_size = if extra_dim_collects.is_some()
                     && extra_dim_collects.unwrap().contains(&type_id)
                 {
@@ -1997,12 +2015,11 @@ extern \"C\" {} {}(",
                 format!("{} * {}", self.get_alignment(*element_type), array_size)
             }
             Type::Product(fields) => {
-                let num_fields = &num_fields.unwrap_or(fields.len());
-                let with_field = fields
+                let num_fields = num_fields.unwrap_or(fields.len());
+                fields
                     .iter()
-                    .enumerate()
-                    .filter(|(i, _)| i < num_fields)
-                    .map(|(_, id)| {
+                    .take(num_fields)
+                    .map(|id| {
                         (
                             self.get_size(*id, None, extra_dim_collects),
                             self.get_alignment(*id),
@@ -2017,18 +2034,10 @@ extern \"C\" {} {}(",
                                 acc, align, align, align, size
                             )
                         }
-                    });
-                if num_fields < &fields.len() {
-                    format!(
-                        "{} - {}",
-                        with_field,
-                        self.get_size(fields[*num_fields], None, extra_dim_collects)
-                    )
-                } else {
-                    with_field
-                }
+                    })
             }
             Type::Summation(variants) => {
+                assert!(num_fields.is_none());
                 // The argmax variant by size is not guaranteed to be same as
                 // argmax variant by alignment, eg product of 3 4-byte primitives
                 // vs 1 8-byte primitive, so we need to calculate both.
@@ -2052,7 +2061,10 @@ extern \"C\" {} {}(",
                     max_size, max_alignment, max_alignment, max_alignment
                 )
             }
-            _ => format!("{}", self.get_alignment(type_id)),
+            _ => {
+                assert!(num_fields.is_none());
+                format!("{}", self.get_alignment(type_id))
+            }
         }
     }