From 78d39cf9046f75a85b31fd6b9a175e9781ef23b3 Mon Sep 17 00:00:00 2001
From: rarbore2 <rarbore2@illinois.edu>
Date: Thu, 9 Jan 2025 15:49:19 -0600
Subject: [PATCH] Clone elimination pass

---
 hercules_opt/src/clone_elim.rs    | 33 +++++++++++++++++++++++++++++++
 hercules_opt/src/lib.rs           |  2 ++
 hercules_opt/src/pass.rs          | 28 ++++++++++++++++++++++++++
 juno_frontend/src/lib.rs          |  2 ++
 juno_samples/matmul/src/matmul.jn | 13 +++++++++---
 5 files changed, 75 insertions(+), 3 deletions(-)
 create mode 100644 hercules_opt/src/clone_elim.rs

diff --git a/hercules_opt/src/clone_elim.rs b/hercules_opt/src/clone_elim.rs
new file mode 100644
index 00000000..59507bac
--- /dev/null
+++ b/hercules_opt/src/clone_elim.rs
@@ -0,0 +1,33 @@
+extern crate hercules_ir;
+
+use std::collections::BTreeSet;
+
+use self::hercules_ir::ir::*;
+
+use crate::*;
+
+/*
+ * Top level function to run clone elimination.
+ */
+pub fn clone_elim(editor: &mut FunctionEditor) {
+    // Create workset (starts as all nodes).
+    let mut workset: BTreeSet<NodeID> = (0..editor.func().nodes.len()).map(NodeID::new).collect();
+
+    while let Some(work) = workset.pop_first() {
+        // Look for Write nodes with identical `collect` and `data` inputs.
+        let nodes = &editor.func().nodes;
+        if let Node::Write {
+            collect,
+            data,
+            ref indices,
+        } = nodes[work.idx()]
+            && nodes[collect.idx()] == nodes[data.idx()]
+        {
+            assert!(indices.is_empty());
+            editor.edit(|edit| edit.replace_all_uses(work, collect)?.delete_node(work));
+
+            // Removing this write may affect downstream writes.
+            workset.extend(editor.get_users(work));
+        }
+    }
+}
diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs
index a69ca539..ed658b22 100644
--- a/hercules_opt/src/lib.rs
+++ b/hercules_opt/src/lib.rs
@@ -1,6 +1,7 @@
 #![feature(let_chains)]
 
 pub mod ccp;
+pub mod clone_elim;
 pub mod dce;
 pub mod delete_uncalled;
 pub mod editor;
@@ -21,6 +22,7 @@ pub mod unforkify;
 pub mod utils;
 
 pub use crate::ccp::*;
+pub use crate::clone_elim::*;
 pub use crate::dce::*;
 pub use crate::delete_uncalled::*;
 pub use crate::editor::*;
diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs
index baeaae87..b136fef9 100644
--- a/hercules_opt/src/pass.rs
+++ b/hercules_opt/src/pass.rs
@@ -40,6 +40,7 @@ pub enum Pass {
     Unforkify,
     InferSchedules,
     LegalizeReferenceSemantics,
+    CloneElim,
     Verify,
     // Parameterized over whether analyses that aid visualization are necessary.
     // Useful to set to false if displaying a potentially broken module.
@@ -807,6 +808,33 @@ impl PassManager {
                         break;
                     }
                 },
+                Pass::CloneElim => {
+                    self.make_def_uses();
+                    let def_uses = self.def_uses.as_ref().unwrap();
+                    for idx in 0..self.module.functions.len() {
+                        let constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.constants));
+                        let dynamic_constants_ref =
+                            RefCell::new(std::mem::take(&mut self.module.dynamic_constants));
+                        let types_ref = RefCell::new(std::mem::take(&mut self.module.types));
+                        let mut editor = FunctionEditor::new(
+                            &mut self.module.functions[idx],
+                            FunctionID::new(idx),
+                            &constants_ref,
+                            &dynamic_constants_ref,
+                            &types_ref,
+                            &def_uses[idx],
+                        );
+                        clone_elim(&mut editor);
+
+                        self.module.constants = constants_ref.take();
+                        self.module.dynamic_constants = dynamic_constants_ref.take();
+                        self.module.types = types_ref.take();
+
+                        self.module.functions[idx].delete_gravestones();
+                    }
+                    self.clear_analyses();
+                }
                 Pass::InferSchedules => {
                     self.make_def_uses();
                     self.make_fork_join_maps();
diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs
index 075c7a45..ae6a7421 100644
--- a/juno_frontend/src/lib.rs
+++ b/juno_frontend/src/lib.rs
@@ -192,6 +192,8 @@ pub fn compile_ir(
     add_pass!(pm, verify, GVN);
     add_verified_pass!(pm, verify, DCE);
     add_pass!(pm, verify, LegalizeReferenceSemantics);
+    add_pass!(pm, verify, CloneElim);
+    add_pass!(pm, verify, DCE);
     add_pass!(pm, verify, Outline);
     add_pass!(pm, verify, InterproceduralSROA);
     add_pass!(pm, verify, SROA);
diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn
index f59cc215..ca9be73a 100644
--- a/juno_samples/matmul/src/matmul.jn
+++ b/juno_samples/matmul/src/matmul.jn
@@ -17,12 +17,19 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[
 #[entry]
 fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] {
   let res : i32[n, l];
+  let atile : i32[64, 64];
+  let btile : i32[64, 64];
+  let ctile : i32[64, 64];
   
   for bi = 0 to n / 64 {
     for bk = 0 to l / 64 {
-      let atile : i32[64, 64];
-      let btile : i32[64, 64];
-      let ctile : i32[64, 64];
+      for ti = 0 to 64 {
+        for tk = 0 to 64 {
+	  atile[ti, tk] = 0;
+	  btile[ti, tk] = 0;
+	  ctile[ti, tk] = 0;
+	}
+      }
 
       for tile_idx = 0 to m / 64 {
         for ti = 0 to 64 {
-- 
GitLab