From ed716a965b3edc51dd355c3c50fb42dc70428040 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 31 Oct 2024 16:46:20 -0500 Subject: [PATCH] fix gcm --- hercules_ir/src/gcm.rs | 162 ++++++++++++++++++++++++----------------- 1 file changed, 97 insertions(+), 65 deletions(-) diff --git a/hercules_ir/src/gcm.rs b/hercules_ir/src/gcm.rs index 27939931..d6eb9b87 100644 --- a/hercules_ir/src/gcm.rs +++ b/hercules_ir/src/gcm.rs @@ -1,7 +1,7 @@ extern crate bitvec; -use std::collections::{HashMap, HashSet}; -use std::iter::zip; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::iter::{zip, FromIterator}; use crate::*; @@ -47,37 +47,50 @@ pub fn gcm( } } - // Step 2: schedule early. Place nodes in reverse postorder in the earliest - // position they could go. + // Step 2: schedule early. Place nodes in the earliest position they could + // go - use worklist to iterate nodes. let mut schedule_early = bbs.clone(); let mut antideps_uses = HashMap::<NodeID, Vec<NodeID>>::new(); for (read, write) in antideps { antideps_uses.entry(*write).or_default().push(*read); } - for id in reverse_postorder { + let mut worklist = VecDeque::from(reverse_postorder.clone()); + while let Some(id) = worklist.pop_front() { if schedule_early[id.idx()].is_some() { continue; } // For every use, check what block is its "schedule early" block. This // node goes in the lowest block amongst those blocks. - let lowest = dom.lowest_amongst( - get_uses(&function.nodes[id.idx()]) - .as_ref() - .into_iter() - .map(|id| *id) - // Include "uses" from anti-dependencies. - .chain(antideps_uses.remove(&id).unwrap_or_default().into_iter()) - .map(|id| schedule_early[id.idx()].unwrap()), - ); - schedule_early[id.idx()] = Some(lowest); + let use_places: Option<Vec<NodeID>> = get_uses(&function.nodes[id.idx()]) + .as_ref() + .into_iter() + .map(|id| *id) + // Include "uses" from anti-dependencies. + .chain( + antideps_uses + .get(&id) + .unwrap_or(&vec![]) + .into_iter() + .map(|id| *id), + ) + .map(|id| schedule_early[id.idx()]) + .collect(); + if let Some(use_places) = use_places { + // If every use has been placed, we can place this node as the + // lowest place in the domtree that dominates all of the use places. + let lowest = dom.lowest_amongst(use_places.into_iter()); + schedule_early[id.idx()] = Some(lowest); + } else { + // If not, then just push this node back on the worklist. + worklist.push_back(id); + } } // Step 3: schedule late and pick each nodes final position. Since the late // schedule of each node depends on the final positions of its users, these - // two steps must be fused. Place nodes in postorder. Compute their latest - // position, then use the control dependent + shallow loop heuristic to - // actually place them. + // two steps must be fused. Compute their latest position, then use the + // control dependent + shallow loop heuristic to actually place them. let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map .into_iter() .map(|(fork, join)| (*join, *fork)) @@ -86,66 +99,85 @@ pub fn gcm( for (read, write) in antideps { antideps_users.entry(*read).or_default().push(*write); } - for id in reverse_postorder.into_iter().rev() { + let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev()); + while let Some(id) = worklist.pop_front() { if bbs[id.idx()].is_some() { continue; } // Calculate the least common ancestor of user blocks, a.k.a. the "late" // schedule. - let mut lca = None; - // Helper to incrementally update the LCA. - let mut update_lca = |a| { - if let Some(acc) = lca { - lca = Some(dom.least_common_ancestor(acc, a)); - } else { - lca = Some(a); - } - }; + let calculate_lca = || -> Option<_> { + let mut lca = None; + // Helper to incrementally update the LCA. + let mut update_lca = |a| { + if let Some(acc) = lca { + lca = Some(dom.least_common_ancestor(acc, a)); + } else { + lca = Some(a); + } + }; - // For every user, consider where we need to be to directly dominate the - // user. - for user in def_use - .get_users(*id) - .as_ref() - .into_iter() - .map(|id| *id) - // Include "users" from anti-dependencies. - .chain(antideps_users.remove(&id).unwrap_or_default().into_iter()) - { - if let Node::Phi { control, data } = &function.nodes[user.idx()] { - // For phis, we need to dominate the block jumping to the phi in - // the slot that corresponds to our use. - for (control, data) in zip(get_uses(&function.nodes[control.idx()]).as_ref(), data) + // For every user, consider where we need to be to directly dominate the + // user. + for user in def_use + .get_users(id) + .as_ref() + .into_iter() + .map(|id| *id) + // Include "users" from anti-dependencies. + .chain( + antideps_users + .get(&id) + .unwrap_or(&vec![]) + .into_iter() + .map(|id| *id), + ) + { + if let Node::Phi { control, data } = &function.nodes[user.idx()] { + // For phis, we need to dominate the block jumping to the phi in + // the slot that corresponds to our use. + for (control, data) in + zip(get_uses(&function.nodes[control.idx()]).as_ref(), data) + { + if id == *data { + update_lca(*control); + } + } + } else if let Node::Reduce { + control, + init, + reduct, + } = &function.nodes[user.idx()] { - if *id == *data { + // For reduces, we need to either dominate the block right + // before the fork if we're the init input, or we need to + // dominate the join if we're the reduct input. + if id == *init { + let before_fork = function.nodes[join_fork_map[control].idx()] + .try_fork() + .unwrap() + .0; + update_lca(before_fork); + } else { + assert_eq!(id, *reduct); update_lca(*control); } - } - } else if let Node::Reduce { - control, - init, - reduct, - } = &function.nodes[user.idx()] - { - // For reduces, we need to either dominate the block right - // before the fork if we're the init input, or we need to - // dominate the join if we're the reduct input. - if *id == *init { - let before_fork = function.nodes[join_fork_map[control].idx()] - .try_fork() - .unwrap() - .0; - update_lca(before_fork); } else { - assert_eq!(*id, *reduct); - update_lca(*control); + // For everything else, we just need to dominate the user. + update_lca(bbs[user.idx()]?); } - } else { - // For everything else, we just need to dominate the user. - update_lca(bbs[user.idx()].unwrap()); } - } + + Some(lca) + }; + + // Check if all users have been placed. If one of them hasn't, then add + // this node back on to the worklist. + let Some(lca) = calculate_lca() else { + worklist.push_back(id); + continue; + }; // Look between the LCA and the schedule early location to place the // node. -- GitLab