Skip to content
Snippets Groups Projects

Fix GPU addressing bug

Merged Aaron Councilman requested to merge gpu-addressing-bug into main
All threads resolved!
8 files
+ 139
25
Compare changes
  • Side-by-side
  • Inline
Files
8
+ 37
25
@@ -1761,24 +1761,36 @@ extern \"C\" {} {}(",
extra_dim_collects: &HashSet<TypeID>,
) -> String {
let mut index_ptr = "0".to_string();
let type_id = self.typing[collect.idx()];
let mut type_id = self.typing[collect.idx()];
for index in indices {
match index {
Index::Field(field) => {
self.get_size(type_id, Some(*field), Some(extra_dim_collects));
index_ptr.push_str(&format!(
" + ({})",
self.get_size(type_id, Some(*field), Some(extra_dim_collects))
));
type_id = if let Type::Product(fields) = &self.types[type_id.idx()] {
fields[*field]
} else {
panic!("Expected product type")
};
}
// Variants of summations have zero offset
Index::Variant(_) => {}
Index::Variant(index) => {
type_id = if let Type::Summation(variants) = &self.types[type_id.idx()] {
variants[*index]
} else {
panic!("Expected summation type")
};
}
// Convert multi-d array index to 1-d index, and optionally
// convert to single-byte index by multiplying by element size
Index::Position(array_indices) => {
let has_extra_dim = extra_dim_collects.contains(&self.typing[collect.idx()]);
let has_extra_dim = extra_dim_collects.contains(&type_id);
if has_extra_dim {
continue;
}
let Type::Array(element_type, extents) =
&self.types[self.typing[collect.idx()].idx()]
else {
let Type::Array(element_type, extents) = &self.types[type_id.idx()] else {
panic!("Expected array type")
};
let mut cumulative_offset = multiply_dcs(&extents[array_indices.len()..]);
@@ -1801,7 +1813,8 @@ extern \"C\" {} {}(",
")".repeat(array_indices.len())
));
let element_size = self.get_size(*element_type, None, Some(extra_dim_collects));
index_ptr.push_str(&format!(" * {}", element_size));
index_ptr.push_str(&format!(" * ({})", element_size));
type_id = *element_type;
}
}
}
@@ -1987,6 +2000,7 @@ extern \"C\" {} {}(",
) -> String {
match &self.types[type_id.idx()] {
Type::Array(element_type, extents) => {
assert!(num_fields.is_none());
let array_size = if extra_dim_collects.is_some()
&& extra_dim_collects.unwrap().contains(&type_id)
{
@@ -1994,15 +2008,18 @@ extern \"C\" {} {}(",
} else {
multiply_dcs(extents)
};
format!("{} * {}", self.get_alignment(*element_type), array_size)
format!(
"{} * {}",
self.get_size(*element_type, None, extra_dim_collects),
array_size
)
}
Type::Product(fields) => {
let num_fields = &num_fields.unwrap_or(fields.len());
let with_field = fields
let num_fields = num_fields.unwrap_or(fields.len());
fields
.iter()
.enumerate()
.filter(|(i, _)| i < num_fields)
.map(|(_, id)| {
.take(num_fields)
.map(|id| {
(
self.get_size(*id, None, extra_dim_collects),
self.get_alignment(*id),
@@ -2017,18 +2034,10 @@ extern \"C\" {} {}(",
acc, align, align, align, size
)
}
});
if num_fields < &fields.len() {
format!(
"{} - {}",
with_field,
self.get_size(fields[*num_fields], None, extra_dim_collects)
)
} else {
with_field
}
})
}
Type::Summation(variants) => {
assert!(num_fields.is_none());
// The argmax variant by size is not guaranteed to be same as
// argmax variant by alignment, eg product of 3 4-byte primitives
// vs 1 8-byte primitive, so we need to calculate both.
@@ -2052,7 +2061,10 @@ extern \"C\" {} {}(",
max_size, max_alignment, max_alignment, max_alignment
)
}
_ => format!("{}", self.get_alignment(type_id)),
_ => {
assert!(num_fields.is_none());
format!("{}", self.get_alignment(type_id))
}
}
}
Loading