From 423235ed094c4530228faf3e1c4985188ea5c030 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Wed, 1 Jan 2025 19:45:14 -0600
Subject: [PATCH] Incorporate anti-dependencies into GCM

---
 hercules_ir/src/antideps.rs             | 54 +++++++++++--------
 hercules_ir/src/gcm.rs                  | 38 +++++++++++--
 hercules_ir/src/subgraph.rs             | 71 ++++++++++++++++++++++++-
 hercules_opt/src/materialize_clones.rs  |  6 ++-
 hercules_opt/src/pass.rs                | 24 +++++++--
 juno_samples/antideps/src/antideps.jn   | 58 ++++++++++++++++++--
 juno_samples/antideps/src/main.rs       |  6 ++-
 juno_samples/implicit_clone/src/main.rs |  2 +-
 juno_samples/matmul/src/main.rs         | 19 ++++++-
 juno_samples/matmul/src/matmul.jn       |  2 -
 10 files changed, 241 insertions(+), 39 deletions(-)

diff --git a/hercules_ir/src/antideps.rs b/hercules_ir/src/antideps.rs
index 8aa9f7d9..af949708 100644
--- a/hercules_ir/src/antideps.rs
+++ b/hercules_ir/src/antideps.rs
@@ -43,24 +43,14 @@ pub fn antideps(
     objects: &FunctionCollectionObjects,
 ) -> Vec<(NodeID, NodeID)> {
     // First, we analyze "generations" of collections as they are mutated.
-    // Originating nodes and mutating nodes start a new generation of a
+    // Originating, mutating, phi, and reduce nodes start a new generation of a
     // collection. Generations are not ordered due to loops, but are rather just
-    // node IDs of the originating or mutating node (parameter, constant, call,
-    // undef, write). Other nodes operating on collections mean reads / writes
-    // can operate on potentially different generations of multiple collections
-    // (phi, reduce, select).
+    // node IDs of node (parameter, constant, call, undef, write, phi, reduce).
+    // Other nodes operating on collections mean reads / writes can operate on
+    // potentially different generations of multiple collections (select).
     let lattice = forward_dataflow(function, reverse_postorder, |inputs, id| {
         match function.nodes[id.idx()] {
-            Node::Phi {
-                control: _,
-                data: _,
-            }
-            | Node::Reduce {
-                control: _,
-                init: _,
-                reduct: _,
-            }
-            | Node::Ternary {
+            Node::Ternary {
                 op: TernaryOperator::Select,
                 first: _,
                 second: _,
@@ -121,12 +111,21 @@ pub fn antideps(
                 collect: _,
                 indices: _,
             } => inputs[0].clone(),
-            Node::Write {
+            Node::Phi {
+                control: _,
+                data: _,
+            }
+            | Node::Reduce {
+                control: _,
+                init: _,
+                reduct: _,
+            }
+            | Node::Write {
                 collect: _,
                 data: _,
                 indices: _,
             } => {
-                // Writes update the generation to the write.
+                // Phis, reduces, and writes update the generation to the write.
                 let objs = inputs[0].objs.iter().map(|(obj, _)| (*obj, id)).collect();
                 GenerationLattice { objs }
             }
@@ -169,7 +168,7 @@ pub fn antideps(
             }
             Node::Write {
                 collect,
-                data: _,
+                data,
                 indices: _,
             } => {
                 for pair in lattice[collect.idx()].objs.iter() {
@@ -179,6 +178,19 @@ pub fn antideps(
                         .1
                         .push(id);
                 }
+
+                // When a write takes a collection on its `data` input, it
+                // memcpys that collection into the mutated collection. This is
+                // a read.
+                if !objects.objects(*data).is_empty() {
+                    for pair in lattice[collect.idx()].objs.iter() {
+                        reads_writes_calls_mut_calls_per_pair
+                            .entry(*pair)
+                            .or_default()
+                            .0
+                            .push(id);
+                    }
+                }
             }
             Node::Call {
                 control: _,
@@ -219,7 +231,7 @@ pub fn antideps(
         // Case 1:
         for read in reads.iter() {
             for write in writes.iter() {
-                if *write != gen {
+                if *write != gen && *read != *write {
                     antideps.push((*read, *write));
                 }
             }
@@ -228,7 +240,7 @@ pub fn antideps(
         // Case 2:
         for read in reads.iter() {
             for mut_call in mut_calls.iter() {
-                if *mut_call != gen {
+                if *mut_call != gen && *read != *mut_call {
                     antideps.push((*read, *mut_call));
                 }
             }
@@ -237,7 +249,7 @@ pub fn antideps(
         // Case 3:
         for call in calls.iter().chain(mut_calls.iter()) {
             for write in writes.iter() {
-                if *write != gen {
+                if *write != gen && *call != *write {
                     antideps.push((*call, *write));
                 }
             }
diff --git a/hercules_ir/src/gcm.rs b/hercules_ir/src/gcm.rs
index a8fd0a21..3718df9b 100644
--- a/hercules_ir/src/gcm.rs
+++ b/hercules_ir/src/gcm.rs
@@ -30,12 +30,18 @@ pub fn gcm(
     function: &Function,
     def_use: &ImmutableDefUseMap,
     reverse_postorder: &Vec<NodeID>,
+    control_subgraph: &Subgraph,
     dom: &DomTree,
     antideps: &Vec<(NodeID, NodeID)>,
     loops: &LoopTree,
     fork_join_map: &HashMap<NodeID, NodeID>,
 ) -> BasicBlocks {
     let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];
+    let back_edges = control_subgraph.back_edges(NodeID::new(0));
+    let no_loop_reachability =
+        control_subgraph.pairwise_reachability(|src, dst| !back_edges.contains(&(src, dst)));
+    let antideps_users = map_antideps(antideps);
+    let antideps_uses = flip_antideps(antideps);
 
     // Step 1: assign the basic block locations of all nodes that must be in a
     // specific block. This includes control nodes as well as some special data
@@ -103,7 +109,7 @@ pub fn gcm(
         .map(|(fork, join)| (*join, *fork))
         .collect();
     let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev());
-    while let Some(id) = worklist.pop_front() {
+    'worklist: while let Some(id) = worklist.pop_front() {
         if bbs[id.idx()].is_some() {
             continue;
         }
@@ -169,13 +175,38 @@ pub fn gcm(
             continue;
         };
 
+        // Check if all anti-dependency users have been placed. If one of them
+        // hasn't, then add this node back on to the worklist. We need to know
+        // where the anti-dependency users are, so that we can place this
+        // read "above" mutators that anti-depend on it. The condition for a
+        // potential placement location is that in the CFG *without loop back-
+        // edges* the mutator cannot reach the read. Ask Russel about why this
+        // works, hopefully I'll have a convincing argument by then ;).
+        let mut antidep_user_locations = vec![];
+        for antidep_user in antideps_users.get(&id).unwrap_or(&vec![]) {
+            if let Some(location) = bbs[antidep_user.idx()] {
+                antidep_user_locations.push(location);
+            } else {
+                worklist.push_back(id);
+                continue 'worklist;
+            }
+        }
+
         // Look between the LCA and the schedule early location to place the
         // node.
         let schedule_early = schedule_early[id.idx()].unwrap();
         let mut chain = dom
             // If the node has no users, then it doesn't really matter where we
             // place it - just place it at the early placement.
-            .chain(lca.unwrap_or(schedule_early), schedule_early);
+            .chain(lca.unwrap_or(schedule_early), schedule_early)
+            // Only allow locations that don't violate the anti-depence property
+            // listed above.
+            .filter(|location| {
+                !antidep_user_locations.iter().any(|antidep_user_location| {
+                    antidep_user_location != location
+                        && no_loop_reachability[antidep_user_location.idx()][location.idx()]
+                })
+            });
         let mut location = chain.next().unwrap();
         while let Some(control_node) = chain.next() {
             // If the next node further up the dominator tree is in a shallower
@@ -211,7 +242,6 @@ pub fn gcm(
             .filter(|id| !function.nodes[id.idx()].is_control()),
     );
     let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()];
-    let antideps = flip_antideps(&antideps);
     while let Some(id) = worklist.pop_front() {
         let node = &function.nodes[id.idx()];
         if node.is_phi()
@@ -219,7 +249,7 @@ pub fn gcm(
             || get_uses(node)
                 .as_ref()
                 .into_iter()
-                .chain(antideps.get(&id).into_iter().flatten())
+                .chain(antideps_uses.get(&id).into_iter().flatten())
                 .all(|u| {
                     function.nodes[u.idx()].is_control()
                         || bbs[u.idx()] != bbs[id.idx()]
diff --git a/hercules_ir/src/subgraph.rs b/hercules_ir/src/subgraph.rs
index de342579..89e8bcc6 100644
--- a/hercules_ir/src/subgraph.rs
+++ b/hercules_ir/src/subgraph.rs
@@ -1,6 +1,11 @@
-use crate::*;
+extern crate bitvec;
+
+use std::collections::{HashMap, HashSet};
+use std::mem::take;
+
+use self::bitvec::prelude::*;
 
-use std::collections::HashMap;
+use crate::*;
 
 /*
  * In various parts of the compiler, we want to consider a subset of a complete
@@ -166,6 +171,68 @@ impl Subgraph {
             original_num_nodes,
         }
     }
+
+    pub fn back_edges(&self, root: NodeID) -> HashSet<(NodeID, NodeID)> {
+        let mut edges = HashSet::new();
+        let mut stack = vec![];
+        let mut visited = bitvec![u8, Lsb0; 0; self.original_num_nodes as usize];
+        let mut in_stack = bitvec![u8, Lsb0; 0; self.original_num_nodes as usize];
+
+        stack.push(root);
+        visited.set(root.idx(), true);
+        in_stack.set(root.idx(), true);
+
+        while let Some(pop) = stack.pop() {
+            if self.succs(pop).any(|succ| !visited[succ.idx()]) {
+                stack.push(pop);
+            } else {
+                in_stack.set(pop.idx(), false);
+            }
+            for succ in self.succs(pop) {
+                if in_stack[succ.idx()] {
+                    edges.insert((pop, succ));
+                } else if !visited[succ.idx()] {
+                    visited.set(succ.idx(), true);
+                    stack.push(succ);
+                    in_stack.set(succ.idx(), true);
+                    break;
+                }
+            }
+        }
+
+        edges
+    }
+
+    pub fn pairwise_reachability<P>(&self, p: P) -> Vec<BitVec<u8, Lsb0>>
+    where
+        P: Fn(NodeID, NodeID) -> bool,
+    {
+        let mut reachability = vec![
+            bitvec![u8, Lsb0; 0; self.original_num_nodes as usize];
+            self.original_num_nodes as usize
+        ];
+
+        // Every control node can reach itself.
+        for node in self.iter() {
+            reachability[node.idx()].set(node.idx(), true);
+        }
+
+        // Propagate N - 1 times, as that is the length of the longest path.
+        for _ in 0..self.original_num_nodes {
+            for node in self.iter() {
+                for succ in self.succs(*node) {
+                    assert_ne!(*node, succ);
+                    if p(*node, succ) {
+                        let mut node_reachability = take(&mut reachability[node.idx()]);
+                        node_reachability |= &reachability[succ.idx()];
+                        reachability[node.idx()] = node_reachability;
+                    }
+                }
+            }
+        }
+
+        reachability
+    }
 }
 
 /*
diff --git a/hercules_opt/src/materialize_clones.rs b/hercules_opt/src/materialize_clones.rs
index 79002d15..687ac10c 100644
--- a/hercules_opt/src/materialize_clones.rs
+++ b/hercules_opt/src/materialize_clones.rs
@@ -12,6 +12,10 @@ use crate::*;
  * in the backend that collections have reference, rather than value, semantics.
  * The pass calling this function is mandatory for correctness.
  */
-pub fn materialize_clones(editor: &mut FunctionEditor, objects: &FunctionCollectionObjects) {
+pub fn materialize_clones(
+    editor: &mut FunctionEditor,
+    objects: &FunctionCollectionObjects,
+    bbs: &BasicBlocks,
+) {
     todo!()
 }
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index c02e967e..bcaba374 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -284,12 +284,14 @@ impl PassManager {
         if self.bbs.is_none() {
             self.make_def_uses();
             self.make_reverse_postorders();
+            self.make_control_subgraphs();
             self.make_doms();
             self.make_antideps();
             self.make_loops();
             self.make_fork_join_maps();
             let def_uses = self.def_uses.as_ref().unwrap().iter();
             let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter();
+            let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter();
             let doms = self.doms.as_ref().unwrap().iter();
             let antideps = self.antideps.as_ref().unwrap().iter();
             let loops = self.loops.as_ref().unwrap().iter();
@@ -301,19 +303,29 @@ impl PassManager {
                         def_uses,
                         zip(
                             reverse_postorders,
-                            zip(doms, zip(antideps, zip(loops, fork_join_maps))),
+                            zip(
+                                control_subgraphs,
+                                zip(doms, zip(antideps, zip(loops, fork_join_maps))),
+                            ),
                         ),
                     ),
                 )
                 .map(
                     |(
                         function,
-                        (def_use, (reverse_postorder, (dom, (antideps, (loops, fork_join_map))))),
+                        (
+                            def_use,
+                            (
+                                reverse_postorder,
+                                (control_subgraph, (dom, (antideps, (loops, fork_join_map)))),
+                            ),
+                        ),
                     )| {
                         gcm(
                             function,
                             def_use,
                             reverse_postorder,
+                            control_subgraph,
                             dom,
                             antideps,
                             loops,
@@ -802,8 +814,10 @@ impl PassManager {
                 Pass::MaterializeClones => {
                     self.make_def_uses();
                     self.make_collection_objects();
+                    self.make_bbs();
                     let def_uses = self.def_uses.as_ref().unwrap();
                     let collection_objects = self.collection_objects.as_ref().unwrap();
+                    let bbs = self.bbs.as_ref().unwrap();
                     for idx in 0..self.module.functions.len() {
                         let constants_ref =
                             RefCell::new(std::mem::take(&mut self.module.constants));
@@ -817,7 +831,11 @@ impl PassManager {
                             &types_ref,
                             &def_uses[idx],
                         );
-                        materialize_clones(&mut editor, &collection_objects[&FunctionID::new(idx)]);
+                        materialize_clones(
+                            &mut editor,
+                            &collection_objects[&FunctionID::new(idx)],
+                            &bbs[idx],
+                        );
 
                         self.module.constants = constants_ref.take();
                         self.module.dynamic_constants = dynamic_constants_ref.take();
diff --git a/juno_samples/antideps/src/antideps.jn b/juno_samples/antideps/src/antideps.jn
index f54b8dcf..5949c91a 100644
--- a/juno_samples/antideps/src/antideps.jn
+++ b/juno_samples/antideps/src/antideps.jn
@@ -9,16 +9,68 @@ fn simple_antideps(a : usize, b : usize) -> i32 {
 #[entry]
 fn complex_antideps(x : i32) -> i32 {
   let arr : i32[4];
-  arr[1] = 7;
-  let r = 0;
+  let arr2 : i32[12];
+  arr[1] = 7 + arr2[0];
+  let r = arr2[1];
+  arr2[1] = 99;
   while x > 6 {
+    arr2[10] = 98;
     if x > 5 {
-      r = arr[1];
+      r = arr[1] + arr2[(x + 1) as usize];
     } else {
+      arr2[11] = 97;
       arr[1] = 8;
       r = arr[1];
     }
+    arr2[x as usize] = 13;
     x -= 1;
   }
   return r;
 }
+
+#[entry]
+fn very_complex_antideps(x: usize) -> usize {
+  let arr1 : usize[203];
+  let arr2 : usize[202];
+  let arr3 : usize[201];
+  let arr4 : usize[200];
+  let y = x;
+  while y < 10 {
+    arr1[y] = y;
+    arr2[y] = arr1[x];
+    y += 4;
+    continue;
+  }
+  let z = arr1[x] + arr1[x + 4] + arr2[x + 4];
+  if z == 13 {
+    z = 12;
+  }
+  while true {
+    if arr1[x] == 3 {
+      z = 11;
+      break;
+    } else {
+      z = 10;
+      continue;
+    }
+  }
+  if z == 11 {
+    arr3[z - 10] = 99;
+    arr4[arr3[z - 10] - 98] = 4;
+    arr4[arr4[1]] = 42;
+  } else {
+    arr1[203] = 99999999;
+    arr1[204] = 99999999;
+    arr1[1000] = 99999999;
+  }
+  let w = arr4[4] + arr2[7];
+  while z > 0 {
+    arr2[w] = arr3[1];
+    arr4[w] = arr2[w];
+    z += z;
+    if z > 100 {
+      break;
+    }
+  }
+  return arr4[w] + w;
+}
diff --git a/juno_samples/antideps/src/main.rs b/juno_samples/antideps/src/main.rs
index e613f5ff..b0a99163 100644
--- a/juno_samples/antideps/src/main.rs
+++ b/juno_samples/antideps/src/main.rs
@@ -13,7 +13,11 @@ fn main() {
 
         let output = complex_antideps(9).await;
         println!("{}", output);
-        assert_eq!(output, 7);
+        assert_eq!(output, 20);
+
+        let output = very_complex_antideps(3).await;
+        println!("{}", output);
+        assert_eq!(output, 144);
     });
 }
 
diff --git a/juno_samples/implicit_clone/src/main.rs b/juno_samples/implicit_clone/src/main.rs
index 5ff952a3..ca7ddeb1 100644
--- a/juno_samples/implicit_clone/src/main.rs
+++ b/juno_samples/implicit_clone/src/main.rs
@@ -9,7 +9,7 @@ fn main() {
     async_std::task::block_on(async {
         let output = implicit_clone(3).await;
         println!("{}", output);
-        assert_eq!(output, 9);
+        assert_eq!(output, 11);
     });
 }
 
diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs
index 865beaf5..948459dd 100644
--- a/juno_samples/matmul/src/main.rs
+++ b/juno_samples/matmul/src/main.rs
@@ -31,7 +31,14 @@ fn main() {
                 J * K * 4,
             );
         };
-        let c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
+        let c_bytes = matmul(
+            I as u64,
+            J as u64,
+            K as u64,
+            a_bytes.clone(),
+            b_bytes.clone(),
+        )
+        .await;
         let mut c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
         unsafe {
             copy_nonoverlapping(
@@ -40,6 +47,15 @@ fn main() {
                 I * K * 4,
             );
         };
+        let tiled_c_bytes = matmul(I as u64, J as u64, K as u64, a_bytes, b_bytes).await;
+        let mut tiled_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
+        unsafe {
+            copy_nonoverlapping(
+                Box::as_ptr(&tiled_c_bytes) as *const u8,
+                Box::as_mut_ptr(&mut tiled_c) as *mut u8,
+                I * K * 4,
+            );
+        };
         let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
         for i in 0..I {
             for k in 0..K {
@@ -49,6 +65,7 @@ fn main() {
             }
         }
         assert_eq!(c, correct_c);
+        assert_eq!(tiled_c, correct_c);
     });
 }
 
diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn
index bcfa7afb..775bb382 100644
--- a/juno_samples/matmul/src/matmul.jn
+++ b/juno_samples/matmul/src/matmul.jn
@@ -14,7 +14,6 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[
   return res;
 }
 
-/*
 #[entry]
 fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
   let res : i32[n, l];
@@ -60,4 +59,3 @@ fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]
 
   return res;
 }
-*/
\ No newline at end of file
-- 
GitLab