From 9a6f6023c239e6687dde56bc136bc494cdd5eae2 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Fri, 17 Jan 2025 19:07:40 -0600 Subject: [PATCH] Fixes for dynamic constant substitution in typechecking, interprocedural SROA, and fix inlining --- Cargo.lock | 10 +++++ Cargo.toml | 3 +- hercules_ir/src/typecheck.rs | 57 +++++++++++++----------- hercules_opt/src/inline.rs | 3 -- hercules_opt/src/interprocedural_sroa.rs | 34 ++++++++++++-- hercules_opt/src/pass.rs | 2 +- juno_samples/concat/Cargo.toml | 18 ++++++++ juno_samples/concat/build.rs | 9 ++++ juno_samples/concat/src/concat.jn | 32 +++++++++++++ juno_samples/concat/src/main.rs | 16 +++++++ 10 files changed, 150 insertions(+), 34 deletions(-) create mode 100644 juno_samples/concat/Cargo.toml create mode 100644 juno_samples/concat/build.rs create mode 100644 juno_samples/concat/src/concat.jn create mode 100644 juno_samples/concat/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 35238a04..a1eb77de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1028,6 +1028,16 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_concat" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_frontend" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 867d88fb..c57125f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ members = [ "juno_samples/nested_ccp", "juno_samples/antideps", "juno_samples/implicit_clone", + "juno_samples/concat", - "juno_samples/cava", + "juno_samples/cava", ] diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 244d8d19..79cbd403 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -516,12 +516,8 @@ fn typeflow( Constant::Float64(_) => { Concrete(get_type_id(Type::Float64, types, reverse_type_map)) } - // Product, summation, and array constants are exceptions. - // Technically, only summation constants need to explicitly - // store their type, but product and array constants also - // explicitly store their type specifically to make this code - // simpler (although their type could be derived from the - // constant itself). + // Product, summation, and array constants are exceptions. they + // all explicitly store their type. Constant::Product(id, _) => { if let Type::Product(_) = types[id.idx()] { Concrete(id) @@ -540,7 +536,6 @@ fn typeflow( )) } } - // Array typechecking also consists of validating the number of constant elements. Constant::Array(id) => { if let Type::Array(_, _) = &types[id.idx()] { Concrete(id) @@ -1135,37 +1130,49 @@ fn types_match( fn dyn_consts_match( dynamic_constants: &Vec<DynamicConstant>, dc_args: &Box<[DynamicConstantID]>, - param: DynamicConstantID, - input: DynamicConstantID, + left: DynamicConstantID, + right: DynamicConstantID, ) -> bool { // First, try evaluating the DCs and seeing if they're the same value. if let (Some(cons1), Some(cons2)) = ( - evaluate_dynamic_constant(param, dynamic_constants), - evaluate_dynamic_constant(input, dynamic_constants), + evaluate_dynamic_constant(left, dynamic_constants), + evaluate_dynamic_constant(right, dynamic_constants), ) { return cons1 == cons2; } match ( - &dynamic_constants[param.idx()], - &dynamic_constants[input.idx()], + &dynamic_constants[left.idx()], + &dynamic_constants[right.idx()], ) { (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => x == y, - (DynamicConstant::Parameter(i), _) => input == dc_args[*i], - (DynamicConstant::Add(pl, pr), DynamicConstant::Add(il, ir)) - | (DynamicConstant::Mul(pl, pr), DynamicConstant::Mul(il, ir)) - | (DynamicConstant::Min(pl, pr), DynamicConstant::Min(il, ir)) - | (DynamicConstant::Max(pl, pr), DynamicConstant::Max(il, ir)) => { + (DynamicConstant::Parameter(l), DynamicConstant::Parameter(r)) => l == r, + (DynamicConstant::Parameter(i), _) => dyn_consts_match( + dynamic_constants, + dc_args, + min(right, dc_args[*i]), + max(right, dc_args[*i]), + ), + (_, DynamicConstant::Parameter(i)) => dyn_consts_match( + dynamic_constants, + dc_args, + min(left, dc_args[*i]), + max(left, dc_args[*i]), + ), + (DynamicConstant::Add(ll, lr), DynamicConstant::Add(rl, rr)) + | (DynamicConstant::Mul(ll, lr), DynamicConstant::Mul(rl, rr)) + | (DynamicConstant::Min(ll, lr), DynamicConstant::Min(rl, rr)) + | (DynamicConstant::Max(ll, lr), DynamicConstant::Max(rl, rr)) => { // Normalize for associative ops by always looking at smaller DC ID // as left arm and larger DC ID as right arm. - dyn_consts_match(dynamic_constants, dc_args, min(*pl, *pr), min(*il, *ir)) - && dyn_consts_match(dynamic_constants, dc_args, max(*pl, *pr), max(*il, *ir)) + dyn_consts_match(dynamic_constants, dc_args, min(*ll, *lr), min(*rl, *rr)) + && dyn_consts_match(dynamic_constants, dc_args, max(*ll, *lr), max(*rl, *rr)) } - (DynamicConstant::Sub(pl, pr), DynamicConstant::Sub(il, ir)) - | (DynamicConstant::Div(pl, pr), DynamicConstant::Div(il, ir)) - | (DynamicConstant::Rem(pl, pr), DynamicConstant::Rem(il, ir)) => { - dyn_consts_match(dynamic_constants, dc_args, *pl, *il) - && dyn_consts_match(dynamic_constants, dc_args, *pr, *ir) + (DynamicConstant::Sub(ll, lr), DynamicConstant::Sub(rl, rr)) + | (DynamicConstant::Div(ll, lr), DynamicConstant::Div(rl, rr)) + | (DynamicConstant::Rem(ll, lr), DynamicConstant::Rem(rl, rr)) => { + dyn_consts_match(dynamic_constants, dc_args, *ll, *rl) + && dyn_consts_match(dynamic_constants, dc_args, *lr, *rr) } (_, _) => false, } diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index 63a05b0c..54af8582 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -43,9 +43,6 @@ pub fn inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) { // Step 4: run inlining on each function individually. Iterate the functions // in topological order. for to_inline_id in topo { - if editors[to_inline_id.idx()].func().entry { - continue; - } // Since Rust cannot analyze the accesses into an array of mutable // references, we need to do some weirdness here to simultaneously get: // 1. A mutable reference to the function we're modifying. diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index 9edb4d02..49fbcbbd 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -319,7 +319,8 @@ fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_edi let old_dcs = dc_param_idx_to_dc_id[..new_dcs.len()].to_vec().clone(); let mut substituted = old_return_type_ids[function_id.idx()]; - let first_dc = edit.num_dynamic_constants() + 1; + assert_eq!(old_dcs.len(), new_dcs.len()); + let first_dc = edit.num_dynamic_constants() + 100; for (dc_a, dc_n) in zip(old_dcs, first_dc..) { substituted = substitute_dynamic_constants_in_type( dc_a, @@ -416,12 +417,37 @@ fn remove_return_singletons(editors: &mut Vec<FunctionEditor>, all_callsites_edi .collect(); for call_node_id in call_node_ids { - let (_, function, _, _) = editor.func().nodes[call_node_id.idx()].try_call().unwrap(); + let (_, function, dc_args, _) = + editor.func().nodes[call_node_id.idx()].try_call().unwrap(); + let dc_args = dc_args.clone(); if singleton_removed[function.idx()] { let edit_successful = editor.edit(|mut edit| { - let empty_constant_id = - edit.add_zero_constant(old_return_type_ids[function.idx()]); + let mut substituted = old_return_type_ids[function.idx()]; + let first_dc = edit.num_dynamic_constants() + 100; + let dc_params: Vec<_> = (0..dc_args.len()) + .map(|param_idx| { + edit.add_dynamic_constant(DynamicConstant::Parameter(param_idx)) + }) + .collect(); + for (dc_a, dc_n) in zip(dc_params, first_dc..) { + substituted = substitute_dynamic_constants_in_type( + dc_a, + DynamicConstantID::new(dc_n), + substituted, + &mut edit, + ); + } + + for (dc_n, dc_b) in zip(first_dc.., dc_args.iter()) { + substituted = substitute_dynamic_constants_in_type( + DynamicConstantID::new(dc_n), + *dc_b, + substituted, + &mut edit, + ); + } + let empty_constant_id = edit.add_zero_constant(substituted); let empty_node_id = edit.add_node(Node::Constant { id: empty_constant_id, }); diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 1e7104ce..4f44d1d1 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -1067,7 +1067,7 @@ impl PassManager { .expect("PANIC: Unable to write output module file contents."); } } - println!("Ran pass: {:?}", pass); + eprintln!("Ran pass: {:?}", pass); } } diff --git a/juno_samples/concat/Cargo.toml b/juno_samples/concat/Cargo.toml new file mode 100644 index 00000000..24ba1acf --- /dev/null +++ b/juno_samples/concat/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "juno_concat" +version = "0.1.0" +authors = ["Russel Arbore <rarbore2@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_concat" +path = "src/main.rs" + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" diff --git a/juno_samples/concat/build.rs b/juno_samples/concat/build.rs new file mode 100644 index 00000000..f7784b99 --- /dev/null +++ b/juno_samples/concat/build.rs @@ -0,0 +1,9 @@ +use juno_build::JunoCompiler; + +fn main() { + JunoCompiler::new() + .file_in_src("concat.jn") + .unwrap() + .build() + .unwrap(); +} diff --git a/juno_samples/concat/src/concat.jn b/juno_samples/concat/src/concat.jn new file mode 100644 index 00000000..1cf36b9a --- /dev/null +++ b/juno_samples/concat/src/concat.jn @@ -0,0 +1,32 @@ +fn concat<t : number, a : usize, b : usize>(arr_a : t[a], arr_b : t[b]) -> t[a + b] { + let res : t[a + b]; + for i = 0 to a { + res[i] = arr_a[i]; + } + for i = 0 to b { + res[i + a] = arr_b[i]; + } + return res; +} + +fn sum<t : number, c : usize>(arr : t[c]) -> t { + let res : t; + for i = 0 to c { + res += arr[i]; + } + return res; +} + +#[entry] +fn concat_entry(a : i32) -> i32 { + let arr1 : i32[3]; + let arr2 : i32[6]; + arr1[0] = a; + arr1[1] = a; + arr2[0] = a; + arr2[1] = a; + arr2[4] = a; + arr2[5] = a; + let arr3 = concat::<i32, 3, 6>(arr1, arr2); + return sum::<i32, 3 + 6>(arr3); +} diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs new file mode 100644 index 00000000..17a0ab96 --- /dev/null +++ b/juno_samples/concat/src/main.rs @@ -0,0 +1,16 @@ +#![feature(future_join, box_as_ptr)] + +juno_build::juno!("concat"); + +fn main() { + async_std::task::block_on(async { + let output = concat_entry(7).await; + println!("{}", output); + assert_eq!(output, 42); + }); +} + +#[test] +fn concat_test() { + main(); +} -- GitLab