From 78d39cf9046f75a85b31fd6b9a175e9781ef23b3 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Thu, 9 Jan 2025 15:49:19 -0600 Subject: [PATCH] Clone elimination pass --- hercules_opt/src/clone_elim.rs | 33 +++++++++++++++++++++++++++++++ hercules_opt/src/lib.rs | 2 ++ hercules_opt/src/pass.rs | 28 ++++++++++++++++++++++++++ juno_frontend/src/lib.rs | 2 ++ juno_samples/matmul/src/matmul.jn | 13 +++++++++--- 5 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 hercules_opt/src/clone_elim.rs diff --git a/hercules_opt/src/clone_elim.rs b/hercules_opt/src/clone_elim.rs new file mode 100644 index 00000000..59507bac --- /dev/null +++ b/hercules_opt/src/clone_elim.rs @@ -0,0 +1,33 @@ +extern crate hercules_ir; + +use std::collections::BTreeSet; + +use self::hercules_ir::ir::*; + +use crate::*; + +/* + * Top level function to run clone elimination. + */ +pub fn clone_elim(editor: &mut FunctionEditor) { + // Create workset (starts as all nodes). + let mut workset: BTreeSet<NodeID> = (0..editor.func().nodes.len()).map(NodeID::new).collect(); + + while let Some(work) = workset.pop_first() { + // Look for Write nodes with identical `collect` and `data` inputs. + let nodes = &editor.func().nodes; + if let Node::Write { + collect, + data, + ref indices, + } = nodes[work.idx()] + && nodes[collect.idx()] == nodes[data.idx()] + { + assert!(indices.is_empty()); + editor.edit(|edit| edit.replace_all_uses(work, collect)?.delete_node(work)); + + // Removing this write may affect downstream writes. + workset.extend(editor.get_users(work)); + } + } +} diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index a69ca539..ed658b22 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -1,6 +1,7 @@ #![feature(let_chains)] pub mod ccp; +pub mod clone_elim; pub mod dce; pub mod delete_uncalled; pub mod editor; @@ -21,6 +22,7 @@ pub mod unforkify; pub mod utils; pub use crate::ccp::*; +pub use crate::clone_elim::*; pub use crate::dce::*; pub use crate::delete_uncalled::*; pub use crate::editor::*; diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index baeaae87..b136fef9 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -40,6 +40,7 @@ pub enum Pass { Unforkify, InferSchedules, LegalizeReferenceSemantics, + CloneElim, Verify, // Parameterized over whether analyses that aid visualization are necessary. // Useful to set to false if displaying a potentially broken module. @@ -807,6 +808,33 @@ impl PassManager { break; } }, + Pass::CloneElim => { + self.make_def_uses(); + let def_uses = self.def_uses.as_ref().unwrap(); + for idx in 0..self.module.functions.len() { + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( + &mut self.module.functions[idx], + FunctionID::new(idx), + &constants_ref, + &dynamic_constants_ref, + &types_ref, + &def_uses[idx], + ); + clone_elim(&mut editor); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + + self.module.functions[idx].delete_gravestones(); + } + self.clear_analyses(); + } Pass::InferSchedules => { self.make_def_uses(); self.make_fork_join_maps(); diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index 075c7a45..ae6a7421 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -192,6 +192,8 @@ pub fn compile_ir( add_pass!(pm, verify, GVN); add_verified_pass!(pm, verify, DCE); add_pass!(pm, verify, LegalizeReferenceSemantics); + add_pass!(pm, verify, CloneElim); + add_pass!(pm, verify, DCE); add_pass!(pm, verify, Outline); add_pass!(pm, verify, InterproceduralSROA); add_pass!(pm, verify, SROA); diff --git a/juno_samples/matmul/src/matmul.jn b/juno_samples/matmul/src/matmul.jn index f59cc215..ca9be73a 100644 --- a/juno_samples/matmul/src/matmul.jn +++ b/juno_samples/matmul/src/matmul.jn @@ -17,12 +17,19 @@ fn matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[ #[entry] fn tiled_64_matmul<n : usize, m : usize, l : usize>(a : i32[n, m], b : i32[m, l]) -> i32[n, l] { let res : i32[n, l]; + let atile : i32[64, 64]; + let btile : i32[64, 64]; + let ctile : i32[64, 64]; for bi = 0 to n / 64 { for bk = 0 to l / 64 { - let atile : i32[64, 64]; - let btile : i32[64, 64]; - let ctile : i32[64, 64]; + for ti = 0 to 64 { + for tk = 0 to 64 { + atile[ti, tk] = 0; + btile[ti, tk] = 0; + ctile[ti, tk] = 0; + } + } for tile_idx = 0 to m / 64 { for ti = 0 to 64 { -- GitLab