From 567778044140dc675dc7d2a27dc5abf6214a8652 Mon Sep 17 00:00:00 2001
From: Russel Arbore <rarbore2@illinois.edu>
Date: Tue, 5 Mar 2024 14:32:45 -0600
Subject: [PATCH] iterate loops bottom-up for forkify

---
 hercules_ir/src/loops.rs    | 24 ++++++++++++++++++++++++
 hercules_opt/src/forkify.rs | 25 +++++++++++++------------
 2 files changed, 37 insertions(+), 12 deletions(-)

diff --git a/hercules_ir/src/loops.rs b/hercules_ir/src/loops.rs
index f966c0ad..c657572f 100644
--- a/hercules_ir/src/loops.rs
+++ b/hercules_ir/src/loops.rs
@@ -2,6 +2,7 @@ extern crate bitvec;
 
 use std::collections::hash_map;
 use std::collections::HashMap;
+use std::collections::VecDeque;
 
 use self::bitvec::prelude::*;
 
@@ -34,6 +35,29 @@ impl LoopTree {
     pub fn loops(&self) -> hash_map::Iter<'_, NodeID, (BitVec<u8, Lsb0>, NodeID)> {
         self.loops.iter()
     }
+
+    /*
+     * Sometimes, we need to iterate the loop tree bottom-up. Just assemble the
+     * order upfront.
+     */
+    pub fn bottom_up_loops(&self) -> Vec<(NodeID, &BitVec<u8, Lsb0>)> {
+        let mut bottom_up = vec![];
+        let mut children_count: HashMap<NodeID, u32> = self.loops.iter().map(|(k, _)| (*k, 0)).collect();
+        children_count.insert(self.root, 0);
+        for (_, (_, parent)) in self.loops.iter() {
+            *children_count.get_mut(&parent).unwrap() += 1;
+        }
+        let mut worklist: VecDeque<_> = self.loops.iter().map(|(k, v)| (*k, &v.0)).collect();
+        while let Some(pop) = worklist.pop_front() {
+            if children_count[&pop.0] == 0 {
+                *children_count.get_mut(&self.loops[&pop.0].1).unwrap() -= 1;
+                bottom_up.push(pop);
+            } else {
+                worklist.push_back(pop);
+            }
+        }
+        bottom_up
+    }
 }
 
 /*
diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs
index 22496b20..18dfe36a 100644
--- a/hercules_opt/src/forkify.rs
+++ b/hercules_opt/src/forkify.rs
@@ -19,14 +19,15 @@ pub fn forkify(
 ) {
     // Ignore loops that are already fork-joins.
     let natural_loops = loops
-        .loops()
+        .bottom_up_loops()
+        .into_iter()
         .filter(|(k, _)| function.nodes[k.idx()].is_region());
 
     // Detect loops that have a simple loop induction variable. TODO: proper
     // affine analysis to recognize other cases of linear induction variables.
     let affine_loops: Vec<_> = natural_loops
         .into_iter()
-        .filter_map(|(header, (contents, _))| {
+        .filter_map(|(header, contents)| {
             // Get the single loop contained predecessor of the loop header.
             let header_uses = get_uses(&function.nodes[header.idx()]);
             let mut pred_loop = header_uses.as_ref().iter().filter(|id| contents[id.idx()]);
@@ -43,7 +44,7 @@ pub fn forkify(
             let (should_be_header, pred_datas) = function.nodes[phi.idx()].try_phi()?;
             let one_c_id = function.nodes[one.idx()].try_constant()?;
 
-            if should_be_header != *header || !constants[one_c_id.idx()].is_one() {
+            if should_be_header != header || !constants[one_c_id.idx()].is_one() {
                 return None;
             }
 
@@ -108,12 +109,12 @@ pub fn forkify(
         let header_uses: Vec<_> = header_uses.as_ref().into_iter().map(|x| *x).collect();
 
         // Get the control portions of the loop that need to be grafted.
-        let loop_pred = *header_uses
+        let loop_pred = header_uses
             .iter()
             .filter(|id| !contents[id.idx()])
             .next()
             .unwrap();
-        let loop_true_read = *header_uses
+        let loop_true_read = header_uses
             .iter()
             .filter(|id| contents[id.idx()])
             .next()
@@ -137,14 +138,14 @@ pub fn forkify(
 
         // Create fork and join nodes.
         let fork = Node::Fork {
-            control: loop_pred,
+            control: *loop_pred,
             factor: dc_id,
         };
         let fork_id = NodeID::new(function.nodes.len());
         function.nodes.push(fork);
 
         let join = Node::Join {
-            control: if *header == get_uses(&function.nodes[loop_end.idx()]).as_ref()[0] {
+            control: if header == get_uses(&function.nodes[loop_end.idx()]).as_ref()[0] {
                 fork_id
             } else {
                 function.nodes[loop_end.idx()].try_if().unwrap().0
@@ -158,7 +159,7 @@ pub fn forkify(
 
         // Convert reducing phi nodes to reduce nodes.
         let reduction_phis: Vec<_> = def_use
-            .get_users(*header)
+            .get_users(header)
             .iter()
             .filter(|id| **id != idx_phi && function.nodes[id.idx()].is_phi())
             .collect();
@@ -172,7 +173,7 @@ pub fn forkify(
                     .1
                     .iter(),
             )
-            .filter(|(c, _)| **c == loop_pred)
+            .filter(|(c, _)| **c == *loop_pred)
             .next()
             .unwrap()
             .1;
@@ -186,7 +187,7 @@ pub fn forkify(
                     .1
                     .iter(),
             )
-            .filter(|(c, _)| **c == loop_true_read)
+            .filter(|(c, _)| **c == *loop_true_read)
             .next()
             .unwrap()
             .1;
@@ -224,8 +225,8 @@ pub fn forkify(
         function.nodes[idx_phi.idx()] = Node::Start;
 
         // Delete old loop control nodes;
-        for user in def_use.get_users(*header) {
-            get_uses_mut(&mut function.nodes[user.idx()]).map(*header, fork_id);
+        for user in def_use.get_users(header) {
+            get_uses_mut(&mut function.nodes[user.idx()]).map(header, fork_id);
         }
         function.nodes[header.idx()] = Node::Start;
         function.nodes[loop_end.idx()] = Node::Start;
-- 
GitLab