From 37567e75c3f7adbb96248a12084bb0ed51914f5b Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 25 Feb 2025 11:30:47 -0600 Subject: [PATCH] Add extra dimensions to arrays when they'd cause a race otherwise --- hercules_opt/src/gcm.rs | 54 +++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index ae8801f6..939f1502 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1053,19 +1053,22 @@ fn add_extra_collection_dims( devices: &Vec<Device>, bbs: &BasicBlocks, ) -> bool { - if devices[editor.func_id().idx()] == Device::AsyncRust - && editor.func().name == "_1_laplacian_estimate" - { + if devices[editor.func_id().idx()] == Device::AsyncRust { // Look for collection constant nodes inside fork-joins that are mutated // inside the fork-join, aren't involved in any of the reduces of the // fork-join, and have a user that isn't a direct read based on all of // the thread IDs. - let nodes = &editor.func().nodes; let fco = &objects[&editor.func_id()]; - for id in editor.node_ids().filter(|id| { - nodes[id.idx()].is_constant() && !editor.get_type(typing[id.idx()]).is_primitive() - }) { + let candidates: Vec<_> = editor + .node_ids() + .filter(|id| { + editor.func().nodes[id.idx()].is_constant() + && !editor.get_type(typing[id.idx()]).is_primitive() + }) + .collect(); + for id in candidates { // Check all of the above conditions. + let nodes = &editor.func().nodes; if editor.get_users(id).len() != 1 { continue; } @@ -1115,7 +1118,42 @@ fn add_extra_collection_dims( // We know that this collection needs to be replicated across the // fork-join dimensions, so do that. - todo!() + let ty = typing[id.idx()]; + let num_dims: Vec<_> = forks + .into_iter() + .rev() + .map(|id| nodes[id.idx()].try_fork().unwrap().1.len()) + .collect(); + let factors = forks + .into_iter() + .rev() + .flat_map(|id| nodes[id.idx()].try_fork().unwrap().1.into_iter()) + .map(|dc| *dc) + .collect(); + let array_ty = Type::Array(ty, factors); + let success = editor.edit(|mut edit| { + let new_ty = edit.add_type(array_ty); + let new_cons = edit.add_zero_constant(new_ty); + let new_cons = edit.add_node(Node::Constant { id: new_cons }); + let mut tids = vec![]; + for (fork, num_dims) in forks.into_iter().rev().zip(num_dims) { + for dim in 0..num_dims { + tids.push(edit.add_node(Node::ThreadID { + control: *fork, + dimension: dim, + })); + } + } + let read = edit.add_node(Node::Read { + collect: new_cons, + indices: Box::new([Index::Position(tids.into_boxed_slice())]), + }); + edit = edit.replace_all_uses(id, read)?; + edit = edit.delete_node(id)?; + Ok(edit) + }); + assert!(success); + return true; } } false -- GitLab