From ed716a965b3edc51dd355c3c50fb42dc70428040 Mon Sep 17 00:00:00 2001
From: Russel Arbore <russel.jma@gmail.com>
Date: Thu, 31 Oct 2024 16:46:20 -0500
Subject: [PATCH] fix gcm

---
 hercules_ir/src/gcm.rs | 162 ++++++++++++++++++++++++-----------------
 1 file changed, 97 insertions(+), 65 deletions(-)

diff --git a/hercules_ir/src/gcm.rs b/hercules_ir/src/gcm.rs
index 27939931..d6eb9b87 100644
--- a/hercules_ir/src/gcm.rs
+++ b/hercules_ir/src/gcm.rs
@@ -1,7 +1,7 @@
 extern crate bitvec;
 
-use std::collections::{HashMap, HashSet};
-use std::iter::zip;
+use std::collections::{HashMap, HashSet, VecDeque};
+use std::iter::{zip, FromIterator};
 
 use crate::*;
 
@@ -47,37 +47,50 @@ pub fn gcm(
         }
     }
 
-    // Step 2: schedule early. Place nodes in reverse postorder in the earliest
-    // position they could go.
+    // Step 2: schedule early. Place nodes in the earliest position they could
+    // go - use worklist to iterate nodes.
     let mut schedule_early = bbs.clone();
     let mut antideps_uses = HashMap::<NodeID, Vec<NodeID>>::new();
     for (read, write) in antideps {
         antideps_uses.entry(*write).or_default().push(*read);
     }
-    for id in reverse_postorder {
+    let mut worklist = VecDeque::from(reverse_postorder.clone());
+    while let Some(id) = worklist.pop_front() {
         if schedule_early[id.idx()].is_some() {
             continue;
         }
 
         // For every use, check what block is its "schedule early" block. This
         // node goes in the lowest block amongst those blocks.
-        let lowest = dom.lowest_amongst(
-            get_uses(&function.nodes[id.idx()])
-                .as_ref()
-                .into_iter()
-                .map(|id| *id)
-                // Include "uses" from anti-dependencies.
-                .chain(antideps_uses.remove(&id).unwrap_or_default().into_iter())
-                .map(|id| schedule_early[id.idx()].unwrap()),
-        );
-        schedule_early[id.idx()] = Some(lowest);
+        let use_places: Option<Vec<NodeID>> = get_uses(&function.nodes[id.idx()])
+            .as_ref()
+            .into_iter()
+            .map(|id| *id)
+            // Include "uses" from anti-dependencies.
+            .chain(
+                antideps_uses
+                    .get(&id)
+                    .unwrap_or(&vec![])
+                    .into_iter()
+                    .map(|id| *id),
+            )
+            .map(|id| schedule_early[id.idx()])
+            .collect();
+        if let Some(use_places) = use_places {
+            // If every use has been placed, we can place this node as the
+            // lowest place in the domtree that dominates all of the use places.
+            let lowest = dom.lowest_amongst(use_places.into_iter());
+            schedule_early[id.idx()] = Some(lowest);
+        } else {
+            // If not, then just push this node back on the worklist.
+            worklist.push_back(id);
+        }
     }
 
     // Step 3: schedule late and pick each nodes final position. Since the late
     // schedule of each node depends on the final positions of its users, these
-    // two steps must be fused. Place nodes in postorder. Compute their latest
-    // position, then use the control dependent + shallow loop heuristic to
-    // actually place them.
+    // two steps must be fused. Compute their latest position, then use the
+    // control dependent + shallow loop heuristic to actually place them.
     let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
         .into_iter()
         .map(|(fork, join)| (*join, *fork))
@@ -86,66 +99,85 @@ pub fn gcm(
     for (read, write) in antideps {
         antideps_users.entry(*read).or_default().push(*write);
     }
-    for id in reverse_postorder.into_iter().rev() {
+    let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev());
+    while let Some(id) = worklist.pop_front() {
         if bbs[id.idx()].is_some() {
             continue;
         }
 
         // Calculate the least common ancestor of user blocks, a.k.a. the "late"
         // schedule.
-        let mut lca = None;
-        // Helper to incrementally update the LCA.
-        let mut update_lca = |a| {
-            if let Some(acc) = lca {
-                lca = Some(dom.least_common_ancestor(acc, a));
-            } else {
-                lca = Some(a);
-            }
-        };
+        let calculate_lca = || -> Option<_> {
+            let mut lca = None;
+            // Helper to incrementally update the LCA.
+            let mut update_lca = |a| {
+                if let Some(acc) = lca {
+                    lca = Some(dom.least_common_ancestor(acc, a));
+                } else {
+                    lca = Some(a);
+                }
+            };
 
-        // For every user, consider where we need to be to directly dominate the
-        // user.
-        for user in def_use
-            .get_users(*id)
-            .as_ref()
-            .into_iter()
-            .map(|id| *id)
-            // Include "users" from anti-dependencies.
-            .chain(antideps_users.remove(&id).unwrap_or_default().into_iter())
-        {
-            if let Node::Phi { control, data } = &function.nodes[user.idx()] {
-                // For phis, we need to dominate the block jumping to the phi in
-                // the slot that corresponds to our use.
-                for (control, data) in zip(get_uses(&function.nodes[control.idx()]).as_ref(), data)
+            // For every user, consider where we need to be to directly dominate the
+            // user.
+            for user in def_use
+                .get_users(id)
+                .as_ref()
+                .into_iter()
+                .map(|id| *id)
+                // Include "users" from anti-dependencies.
+                .chain(
+                    antideps_users
+                        .get(&id)
+                        .unwrap_or(&vec![])
+                        .into_iter()
+                        .map(|id| *id),
+                )
+            {
+                if let Node::Phi { control, data } = &function.nodes[user.idx()] {
+                    // For phis, we need to dominate the block jumping to the phi in
+                    // the slot that corresponds to our use.
+                    for (control, data) in
+                        zip(get_uses(&function.nodes[control.idx()]).as_ref(), data)
+                    {
+                        if id == *data {
+                            update_lca(*control);
+                        }
+                    }
+                } else if let Node::Reduce {
+                    control,
+                    init,
+                    reduct,
+                } = &function.nodes[user.idx()]
                 {
-                    if *id == *data {
+                    // For reduces, we need to either dominate the block right
+                    // before the fork if we're the init input, or we need to
+                    // dominate the join if we're the reduct input.
+                    if id == *init {
+                        let before_fork = function.nodes[join_fork_map[control].idx()]
+                            .try_fork()
+                            .unwrap()
+                            .0;
+                        update_lca(before_fork);
+                    } else {
+                        assert_eq!(id, *reduct);
                         update_lca(*control);
                     }
-                }
-            } else if let Node::Reduce {
-                control,
-                init,
-                reduct,
-            } = &function.nodes[user.idx()]
-            {
-                // For reduces, we need to either dominate the block right
-                // before the fork if we're the init input, or we need to
-                // dominate the join if we're the reduct input.
-                if *id == *init {
-                    let before_fork = function.nodes[join_fork_map[control].idx()]
-                        .try_fork()
-                        .unwrap()
-                        .0;
-                    update_lca(before_fork);
                 } else {
-                    assert_eq!(*id, *reduct);
-                    update_lca(*control);
+                    // For everything else, we just need to dominate the user.
+                    update_lca(bbs[user.idx()]?);
                 }
-            } else {
-                // For everything else, we just need to dominate the user.
-                update_lca(bbs[user.idx()].unwrap());
             }
-        }
+
+            Some(lca)
+        };
+
+        // Check if all users have been placed. If one of them hasn't, then add
+        // this node back on to the worklist.
+        let Some(lca) = calculate_lca() else {
+            worklist.push_back(id);
+            continue;
+        };
 
         // Look between the LCA and the schedule early location to place the
         // node.
-- 
GitLab