From 790b50d7cf518ea5698f7b00e173210a1d4fb700 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Tue, 25 Feb 2025 14:37:30 -0600
Subject: [PATCH] Properly align array elements

---
 hercules_cg/src/cpu.rs |  6 +++++-
 hercules_cg/src/gpu.rs | 16 ++++++++++++++--
 hercules_cg/src/rt.rs  | 16 +++++++++++++++-
 3 files changed, 34 insertions(+), 4 deletions(-)

diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs
index 552dc3a3..b15cf301 100644
--- a/hercules_cg/src/cpu.rs
+++ b/hercules_cg/src/cpu.rs
@@ -839,6 +839,8 @@ impl<'a> CPUContext<'a> {
                     //
                     //     ((0 * s1 + p1) * s2 + p2) * s3 + p3 ...
                     let elem_size = self.codegen_type_size(elem, body)?;
+                    let elem_align = get_type_alignment(&self.types, elem);
+                    let aligned_elem_size = Self::round_up_to(&elem_size, elem_align, body)?;
                     let mut acc_offset = "0".to_string();
                     for (p, s) in zip(pos, dims) {
                         let p = self.get_value(*p, false);
@@ -848,7 +850,7 @@ impl<'a> CPUContext<'a> {
                     }
 
                     // Convert offset in # elements -> # bytes.
-                    acc_offset = Self::multiply(&acc_offset, &elem_size, body)?;
+                    acc_offset = Self::multiply(&acc_offset, &aligned_elem_size, body)?;
                     acc_ptr = Self::gep(&acc_ptr, &acc_offset, body)?;
                     collect_ty = elem;
                 }
@@ -910,6 +912,8 @@ impl<'a> CPUContext<'a> {
                 // The size of an array is the size of the element multipled by
                 // the dynamic constant bounds.
                 let mut acc_size = self.codegen_type_size(elem, body)?;
+                acc_size =
+                    Self::round_up_to(&acc_size, get_type_alignment(&self.types, elem), body)?;
                 for dc in bounds {
                     acc_size = Self::multiply(&acc_size, &format!("%dc{}", dc.idx()), body)?;
                 }
diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs
index 14341756..25bbf1be 100644
--- a/hercules_cg/src/gpu.rs
+++ b/hercules_cg/src/gpu.rs
@@ -1877,7 +1877,11 @@ namespace cg = cooperative_groups;
                         ")".repeat(array_indices.len())
                     ));
                     let element_size = self.get_size(*element_type, None);
-                    index_ptr.push_str(&format!(" * ({})", element_size));
+                    let element_align = self.get_alignment(*element_type);
+                    index_ptr.push_str(&format!(
+                        " * (({} + {} - 1 / {} * {}))",
+                        element_size, element_align, element_align, element_align
+                    ));
                     type_id = *element_type;
                 }
             }
@@ -2049,7 +2053,15 @@ namespace cg = cooperative_groups;
             Type::Array(element_type, extents) => {
                 assert!(num_fields.is_none());
                 let array_size = multiply_dcs(extents);
-                format!("{} * {}", self.get_size(*element_type, None), array_size)
+                let elem_align = self.get_alignment(type_id);
+                format!(
+                    "(({} + {} - 1) / {} * {}) * {}",
+                    self.get_size(*element_type, None),
+                    elem_align,
+                    elem_align,
+                    elem_align,
+                    array_size
+                )
             }
             Type::Product(fields) => {
                 let num_fields = num_fields.unwrap_or(fields.len());
diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs
index 3db0f16f..8fa0c09e 100644
--- a/hercules_cg/src/rt.rs
+++ b/hercules_cg/src/rt.rs
@@ -1111,6 +1111,13 @@ impl<'a> RTContext<'a> {
                     //
                     //     ((0 * s1 + p1) * s2 + p2) * s3 + p3 ...
                     let elem_size = self.codegen_type_size(elem);
+                    let elem_align = get_type_alignment(&self.module.types, elem);
+                    let aligned_elem_size = format!(
+                        "(({} + {}) & !{})",
+                        elem_size,
+                        elem_align - 1,
+                        elem_align - 1
+                    );
                     for (p, s) in zip(pos, dims) {
                         let p = self.get_value(*p, bb, false);
                         acc_offset = format!("{} * ", acc_offset);
@@ -1119,7 +1126,7 @@ impl<'a> RTContext<'a> {
                     }
 
                     // Convert offset in # elements -> # bytes.
-                    acc_offset = format!("({} * {})", acc_offset, elem_size);
+                    acc_offset = format!("({} * {})", acc_offset, aligned_elem_size);
                     collect_ty = elem;
                 }
             }
@@ -1192,6 +1199,13 @@ impl<'a> RTContext<'a> {
                 // The size of an array is the size of the element multipled by
                 // the dynamic constant bounds.
                 let mut acc_size = self.codegen_type_size(elem);
+                let elem_align = get_type_alignment(&self.module.types, elem);
+                acc_size = format!(
+                    "(({} + {}) & !{})",
+                    acc_size,
+                    elem_align - 1,
+                    elem_align - 1
+                );
                 for dc in bounds {
                     acc_size = format!("{} * ", acc_size);
                     self.codegen_dynamic_constant(*dc, &mut acc_size).unwrap();
-- 
GitLab