From ae8f17eb21ca53bbd1d9318518c4c2d62a3e1437 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 25 Feb 2025 10:53:22 -0600 Subject: [PATCH] Look for arrays that are used in parallel --- hercules_cg/src/lib.rs | 2 +- hercules_ir/src/collections.rs | 3 +- hercules_opt/src/gcm.rs | 94 ++++++++++++++++++++++++++++++++++ hercules_rt/src/lib.rs | 2 +- juno_scheduler/src/pm.rs | 3 ++ 5 files changed, 101 insertions(+), 3 deletions(-) diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index 446231de..98f91e1f 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -16,7 +16,7 @@ use std::collections::BTreeMap; use hercules_ir::*; -pub const LARGEST_ALIGNMENT: usize = 32; +pub const LARGEST_ALIGNMENT: usize = 64; /* * The alignment of a type does not depend on dynamic constants. diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 60f4fb1c..a8fc3c67 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -202,7 +202,7 @@ pub fn collection_objects( let fco = &collection_objects[&callee]; if fco.returned[*selection] .iter() - .any(|returned| fco.origins[returned.idx()].try_parameter().is_some()) + .any(|returned| fco.origins[returned.idx()].try_parameter().is_none()) { // If the callee may return a new collection object, then // this data projection node originates a single collection object. The @@ -412,6 +412,7 @@ pub fn collection_objects( } } + assert_eq!(objects_per_node.len(), func.nodes.len()); let fco = FunctionCollectionObjects { objects_per_node, mutated, diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index d3119705..ae8801f6 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -82,6 +82,7 @@ pub fn gcm( control_subgraph: &Subgraph, dom: &DomTree, fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, loops: &LoopTree, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, objects: &CollectionObjects, @@ -120,6 +121,18 @@ pub fn gcm( return None; } + if add_extra_collection_dims( + editor, + typing, + fork_join_map, + fork_join_nest, + objects, + devices, + &bbs, + ) { + return None; + } + let Some(node_colors) = color_nodes(editor, typing, &objects, &devices, node_colors) else { return None; }; @@ -1027,6 +1040,87 @@ fn spill_clones( } } +/* + * Look for mutated collections placed inside fork-joins in AsyncRust functions. + * These collections should be duplicated across the size of the fork-join. + */ +fn add_extra_collection_dims( + editor: &mut FunctionEditor, + typing: &Vec<TypeID>, + fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, + objects: &CollectionObjects, + devices: &Vec<Device>, + bbs: &BasicBlocks, +) -> bool { + if devices[editor.func_id().idx()] == Device::AsyncRust + && editor.func().name == "_1_laplacian_estimate" + { + // Look for collection constant nodes inside fork-joins that are mutated + // inside the fork-join, aren't involved in any of the reduces of the + // fork-join, and have a user that isn't a direct read based on all of + // the thread IDs. + let nodes = &editor.func().nodes; + let fco = &objects[&editor.func_id()]; + for id in editor.node_ids().filter(|id| { + nodes[id.idx()].is_constant() && !editor.get_type(typing[id.idx()]).is_primitive() + }) { + // Check all of the above conditions. + if editor.get_users(id).len() != 1 { + continue; + } + let forks = &fork_join_nest[&bbs.0[id.idx()]]; + if forks.is_empty() { + continue; + } + let object = fco.objects(id)[0]; + let mutated_inside = fco + .mutators(object) + .into_iter() + .any(|id| &fork_join_nest[&bbs.0[id.idx()]] == forks); + if !mutated_inside { + continue; + } + let in_reduce = forks.into_iter().any(|id| { + let join = fork_join_map[id]; + let mut reduces = editor + .get_users(join) + .filter(|id| nodes[id.idx()].is_reduce()); + reduces.any(|id| fco.objects(id).contains(&object)) + }); + if in_reduce { + continue; + } + if let Node::Read { + collect: _, + ref indices, + } = nodes[editor.get_users(id).next().unwrap().idx()] + && let Index::Position(ref pos) = indices[0] + && { + let tid_pos: BTreeSet<(NodeID, usize)> = pos + .into_iter() + .filter_map(|id| nodes[id.idx()].try_thread_id()) + .collect(); + let reference: BTreeSet<(NodeID, usize)> = forks + .into_iter() + .flat_map(|id| { + (0..nodes[id.idx()].try_fork().unwrap().1.len()).map(|dim| (*id, dim)) + }) + .collect(); + tid_pos == reference + } + { + continue; + } + + // We know that this collection needs to be replicated across the + // fork-join dimensions, so do that. + todo!() + } + } + false +} + type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>; /* diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index df53a0e9..a5954ca0 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -13,7 +13,7 @@ use std::sync::OnceLock; * src/rt.rs (the RT backend). */ -pub const LARGEST_ALIGNMENT: usize = 32; +pub const LARGEST_ALIGNMENT: usize = 64; pub unsafe fn __cpu_alloc(size: usize) -> *mut u8 { let ptr = alloc(Layout::from_size_align(size, LARGEST_ALIGNMENT).unwrap()); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 4656d841..77437a61 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2105,6 +2105,7 @@ fn run_pass( pm.make_control_subgraphs(); pm.make_doms(); pm.make_fork_join_maps(); + pm.make_fork_join_nests(); pm.make_loops(); pm.make_reduce_cycles(); pm.make_collection_objects(); @@ -2115,6 +2116,7 @@ fn run_pass( let typing = pm.typing.take().unwrap(); let doms = pm.doms.take().unwrap(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let fork_join_nests = pm.fork_join_nests.take().unwrap(); let loops = pm.loops.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); @@ -2136,6 +2138,7 @@ fn run_pass( &control_subgraphs[id.idx()], &doms[id.idx()], &fork_join_maps[id.idx()], + &fork_join_nests[id.idx()], &loops[id.idx()], &reduce_cycles[id.idx()], &collection_objects, -- GitLab