diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 06dfd18d173eb85a355f806f5c2111bba2ca5268..3750c4f6abbac3a774269c729eaded8afcc204c3 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -638,7 +638,7 @@ impl<'a> CPUContext<'a> { fn codegen_index_math( &self, collect_name: &str, - collect_ty: TypeID, + mut collect_ty: TypeID, indices: &[Index], body: &mut String, ) -> Result<String, Error> { @@ -666,10 +666,15 @@ impl<'a> CPUContext<'a> { body, )?; acc_ptr = Self::gep(&acc_ptr, &acc_offset, body)?; + collect_ty = fields[*idx]; } - Index::Variant(_) => { + Index::Variant(idx) => { // The tag of a summation is at the end of the summation, so // the variant pointer is just the base pointer. Do nothing. + let Type::Summation(ref variants) = self.types[collect_ty.idx()] else { + panic!() + }; + collect_ty = variants[*idx]; } Index::Position(ref pos) => { let Type::Array(elem, ref dims) = self.types[collect_ty.idx()] else { @@ -691,6 +696,7 @@ impl<'a> CPUContext<'a> { // Convert offset in # elements -> # bytes. acc_offset = Self::multiply(&acc_offset, &elem_size, body)?; acc_ptr = Self::gep(&acc_ptr, &acc_offset, body)?; + collect_ty = elem; } } } diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index ac0d0fd41ffb92c1b96fb917b32dcfece9c72a21..d093b2b0f937d796e8d43c707693ee976e673081 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -602,7 +602,11 @@ impl<'a> RTContext<'a> { /* * Emit logic to index into an collection. */ - fn codegen_index_math(&self, collect_ty: TypeID, indices: &[Index]) -> Result<String, Error> { + fn codegen_index_math( + &self, + mut collect_ty: TypeID, + indices: &[Index], + ) -> Result<String, Error> { let mut acc_offset = "0".to_string(); for index in indices { match index { @@ -632,10 +636,15 @@ impl<'a> RTContext<'a> { last_align - 1, last_align - 1 ); + collect_ty = fields[*idx]; } - Index::Variant(_) => { + Index::Variant(idx) => { // The tag of a summation is at the end of the summation, so // the variant pointer is just the base pointer. Do nothing. + let Type::Summation(ref variants) = self.module.types[collect_ty.idx()] else { + panic!() + }; + collect_ty = variants[*idx]; } Index::Position(ref pos) => { let Type::Array(elem, ref dims) = self.module.types[collect_ty.idx()] else { @@ -655,6 +664,7 @@ impl<'a> RTContext<'a> { // Convert offset in # elements -> # bytes. acc_offset = format!("({} * {})", acc_offset, elem_size); + collect_ty = elem; } } } diff --git a/hercules_opt/src/crc.rs b/hercules_opt/src/crc.rs new file mode 100644 index 0000000000000000000000000000000000000000..be80e61bea8d6522a40b35542bb6b3c39edeb4db --- /dev/null +++ b/hercules_opt/src/crc.rs @@ -0,0 +1,39 @@ +use hercules_ir::*; + +use crate::*; + +/* + * Top level function to collapse read chains in a function. + */ +pub fn crc(editor: &mut FunctionEditor) { + let mut changed = true; + while changed { + changed = false; + for id in editor.node_ids() { + if let Node::Read { + collect: lower_collect, + indices: ref lower_indices, + } = editor.func().nodes[id.idx()] + && let Node::Read { + collect: upper_collect, + indices: ref upper_indices, + } = editor.func().nodes[lower_collect.idx()] + { + let collapsed_read = Node::Read { + collect: upper_collect, + indices: upper_indices + .iter() + .chain(lower_indices.iter()) + .map(|idx| idx.clone()) + .collect(), + }; + let success = editor.edit(|mut edit| { + let new_id = edit.add_node(collapsed_read); + let edit = edit.replace_all_uses(id, new_id)?; + edit.delete_node(id) + }); + changed = changed || success; + } + } + } +} diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 9935703e839b5a4dd705af03f40a456008fe0f12..e351deba58c9061fe19bba4932509177d83af626 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -1,6 +1,7 @@ #![feature(let_chains)] pub mod ccp; +pub mod crc; pub mod dce; pub mod delete_uncalled; pub mod editor; @@ -23,6 +24,7 @@ pub mod unforkify; pub mod utils; pub use crate::ccp::*; +pub use crate::crc::*; pub use crate::dce::*; pub use crate::delete_uncalled::*; pub use crate::editor::*; diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 4f44d1d1a833b808a23ef3e56338097a1b874a02..e528b35de7a2e94fd94a8014008073c15548cbd7 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -25,6 +25,7 @@ pub enum Pass { PhiElim, Forkify, ForkGuardElim, + CRC, SLF, WritePredication, Predication, @@ -471,6 +472,33 @@ impl PassManager { } self.clear_analyses(); } + Pass::CRC => { + self.make_def_uses(); + let def_uses = self.def_uses.as_ref().unwrap(); + for idx in 0..self.module.functions.len() { + let constants_ref = + RefCell::new(std::mem::take(&mut self.module.constants)); + let dynamic_constants_ref = + RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); + let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); + let mut editor = FunctionEditor::new( + &mut self.module.functions[idx], + FunctionID::new(idx), + &constants_ref, + &dynamic_constants_ref, + &types_ref, + &def_uses[idx], + ); + crc(&mut editor); + + self.module.constants = constants_ref.take(); + self.module.dynamic_constants = dynamic_constants_ref.take(); + self.module.types = types_ref.take(); + + self.module.functions[idx].delete_gravestones(); + } + self.clear_analyses(); + } Pass::SLF => { self.make_def_uses(); self.make_reverse_postorders(); @@ -498,7 +526,6 @@ impl PassManager { self.module.dynamic_constants = dynamic_constants_ref.take(); self.module.types = types_ref.take(); - println!("{}", self.module.functions[idx].name); self.module.functions[idx].delete_gravestones(); } self.clear_analyses(); diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index eeb09eb6859fef91c10f2026bbb8553274195cae..60d3470edc8cf42dd105cb4a2835a2a86d00fe3f 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -199,7 +199,6 @@ impl<'b, 'a: 'b> HerculesBox<'a> { } pub unsafe fn __zeros(size: u64) -> Self { - assert_ne!(size, 0); let size = size as usize; let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed); HerculesBox { diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index 46d3489156d34f5152388d08e6e66a600600b4f8..d9a59a38567f24dda3d94e2bcc742655615110fc 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -149,8 +149,13 @@ pub fn compile_ir( pm.add_pass(hercules_opt::pass::Pass::Verify); } add_verified_pass!(pm, verify, GVN); + add_pass!(pm, verify, DCE); add_verified_pass!(pm, verify, PhiElim); add_pass!(pm, verify, DCE); + add_pass!(pm, verify, CRC); + add_pass!(pm, verify, DCE); + add_pass!(pm, verify, SLF); + add_pass!(pm, verify, DCE); if x_dot { pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); } @@ -182,6 +187,8 @@ pub fn compile_ir( add_pass!(pm, verify, WritePredication); add_pass!(pm, verify, PhiElim); add_pass!(pm, verify, DCE); + add_pass!(pm, verify, CRC); + add_pass!(pm, verify, DCE); add_pass!(pm, verify, SLF); add_pass!(pm, verify, DCE); add_pass!(pm, verify, Predication); diff --git a/juno_samples/antideps/src/antideps.jn b/juno_samples/antideps/src/antideps.jn index 6886741b6de9cde99d0372b5b9e885e08f9b95e0..738ee6da775baaed88d99b7b7901bcd81f2515e4 100644 --- a/juno_samples/antideps/src/antideps.jn +++ b/juno_samples/antideps/src/antideps.jn @@ -110,3 +110,15 @@ fn very_complex_antideps(x: usize) -> usize { } return arr4[w] + w; } + +#[entry] +fn read_chains(input : i32) -> i32 { + let arrs : (i32[2], i32[2]); + let sub = arrs.0; + sub[1] = input + 7; + arrs.0[1] = input + 3; + let result = sub[1] + arrs.0[1]; + sub[1] = 99; + arrs.0[1] = 99; + return result + sub[1] - arrs.0[1]; +} \ No newline at end of file diff --git a/juno_samples/antideps/src/main.rs b/juno_samples/antideps/src/main.rs index 0b065cbaa6e6cffcaf9ff7b3fbc5a2c882dc7248..6e5ed7a3108261d64c29f92341bc72e3fd3ae786 100644 --- a/juno_samples/antideps/src/main.rs +++ b/juno_samples/antideps/src/main.rs @@ -23,6 +23,10 @@ fn main() { let output = very_complex_antideps(3).await; println!("{}", output); assert_eq!(output, 144); + + let output = read_chains(2).await; + println!("{}", output); + assert_eq!(output, 14); }); }