diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 6c62ed76392dabc10c0fc48bf20d56d8577ad99e..d7a6d258358920d8a42aea1530843c08e223022a 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -288,7 +288,8 @@ impl GPUContext<'_> {
                 self.get_root_forks_and_num_blocks(self.fork_tree);
             let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel);
             let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(self.fork_tree, thread_root_root_fork);
-            // TODO: Uncomment and adjust once we know logic of extra dim
+            // TODO: Uncomment and adjust once we know logic of extra dim. This will affect constant
+            // collections, reads, and writes.
             // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
             let extra_dim_collects = HashSet::new();
 
@@ -749,30 +750,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
         fork_control_map: &HashMap<NodeID, HashSet<NodeID>>,
         fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
     ) -> HashSet<TypeID> {
-        // Get all constant collection creations
-        let collect_consts: HashSet<NodeID> = (0..self.function.nodes.len())
-            .filter(|idx| self.function.nodes[*idx].is_constant() && !self.types[self.typing[*idx].idx()].is_primitive())
-            .map(|idx| NodeID::new(idx))
-            .collect();
-        // Reverse fork_control_map
-        let control_fork_map: HashMap<NodeID, NodeID> = fork_control_map.iter()
-            .flat_map(|(fork, controls)| {
-                controls.iter().map(move |control| (*control, *fork))
-            })
-            .collect();
-        // Get all uses of each collection, map each use to basic block, then map each basic block to fork
-        let collect_fork_users: HashMap<NodeID, HashSet<NodeID>> = collect_consts.iter()
-            .map(|collect_const| {
-                (*collect_const, self.def_use_map.get_users(*collect_const))
-            })
-            .map(|(collect_const, users)| {
-                (collect_const, users.iter().map(|user| control_fork_map[&self.bbs.0[user.idx()]]).collect())
-            })
-            .collect();
-        collect_fork_users.iter()
-            .filter(|(_, fork_users)| !fork_thread_quota_map.contains_key(fork_users.iter().next().unwrap()))
-            .map(|(collect_const, _)| self.typing[collect_const.idx()])
-            .collect()
+        // Determine which fork each collection is used in, and check if it's
+        // parallelized via the fork_thread_quota_map.
+        todo!()
     }
 
     fn codegen_data_control_no_forks(
@@ -1237,7 +1217,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
             }
             // Read of primitive requires load after pointer math.
             Node::Read { collect, indices } => {
-                let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()]));
+                let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects);
                 let data_type_id = self.typing[id.idx()];
                 if self.types[data_type_id.idx()].is_primitive() {
                     let type_name = self.get_type(data_type_id, true);
@@ -1253,7 +1233,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                 data,
                 indices,
             } => {
-                let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects.contains(&self.typing[collect.idx()]));
+                let collect_with_indices = self.codegen_collect(*collect, indices, extra_dim_collects);
                 let data_variable = self.get_value(*data, false, false);
                 let data_type_id = self.typing[data.idx()];
                 let cg_tile = match state {
@@ -1452,27 +1432,31 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
      * This function emits collection name + pointer math for the provided indices.
      * All collection types use char pointers.
      */
-    fn codegen_collect(&self, collect: NodeID, indices: &[Index], has_extra_dim: bool) -> String {
+    fn codegen_collect(&self, collect: NodeID, indices: &[Index], extra_dim_collects: &HashSet<TypeID>) -> String {
         let mut index_ptr = "0".to_string();
         let type_id = self.typing[collect.idx()];
         for index in indices {
             match index {
                 Index::Field(field) => {
-                    self.get_size(type_id, Some(*field), None);
+                    self.get_size(type_id, Some(*field), Some(extra_dim_collects));
                 }
                 // Variants of summations have zero offset
                 Index::Variant(_) => {}
                 // Convert multi-d array index to 1-d index, and optionally
                 // convert to single-byte index by multiplying by element size
                 Index::Position(array_indices) => {
+                    let has_extra_dim = extra_dim_collects.contains(&self.typing[collect.idx()]);
+                    if has_extra_dim {
+                        continue;
+                    }
                     let Type::Array(element_type, extents) =
                         &self.types[self.typing[collect.idx()].idx()]
                     else {
                         panic!("Expected array type")
                     };
                     let mut cumulative_offset = multiply_dcs(&extents[array_indices.len()..]);
-                    let max_left_array_index = array_indices.len() - 1 - if has_extra_dim { 1 } else { 0 };
-                    for (i, index) in array_indices.iter().skip(if has_extra_dim { 1 } else { 0 }).rev().enumerate() {
+                    let max_left_array_index = array_indices.len() - 1;
+                    for (i, index) in array_indices.iter().rev().enumerate() {
                         cumulative_offset = format!(
                             "{} * ({}{}",
                             cumulative_offset,
@@ -1487,9 +1471,9 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                     index_ptr.push_str(&format!(
                         " + {}{}",
                         cumulative_offset,
-                        ")".repeat(array_indices.len() - if has_extra_dim { 1 } else { 0 })
+                        ")".repeat(array_indices.len())
                     ));
-                    let element_size = self.get_size(*element_type, None, None);
+                    let element_size = self.get_size(*element_type, None, Some(extra_dim_collects));
                     index_ptr.push_str(&format!(" * {}", element_size));
                 }
             }
@@ -1556,7 +1540,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
                             format!("*reinterpret_cast<{}>({}+{})", field_type, name, offset),
                             constant_fields[i],
                             false,
-                            extra_dim_collects,
+                            None,
                             dynamic_shared_offset,
                             w,
                             num_tabs,
@@ -1619,7 +1603,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
     fn get_size(&self, type_id: TypeID, num_fields: Option<usize>, extra_dim_collects: Option<&HashSet<TypeID>>) -> String {
         match &self.types[type_id.idx()] {
             Type::Array(element_type, extents) => {
-                let array_size = multiply_dcs(if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { &extents[1..] } else { extents });
+                let array_size = if extra_dim_collects.is_some() && extra_dim_collects.unwrap().contains(&type_id) { "1".to_string() } else { multiply_dcs(extents) };
                 format!("{} * {}", self.get_alignment(*element_type), array_size)
             }
             Type::Product(fields) => {