diff --git a/juno_frontend/examples/cava.jn b/juno_frontend/examples/cava.jn index 6c5bdfe9085c77585c34ae5c3fdc4a988ebd8729..ee46e47610bdc0ad324cbbf2c34d482fdbe0d25b 100644 --- a/juno_frontend/examples/cava.jn +++ b/juno_frontend/examples/cava.jn @@ -120,9 +120,9 @@ fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, r for c = 0 to col { if r >= 1 && r < row - 1 && c >= 1 && c < col - 1 { let filter : f32[3][3]; // same as [3, 3] - for i = -1 to 2 by 1 { - for j = -1 to 2 by 1 { - filter[i+1, j+1] = input[chan, r + i, c + j]; + for i = 0 to 3 by 1 { + for j = 0 to 3 by 1 { + filter[i, j] = input[chan, r + i - 1, c + j - 1]; } } res[chan, r, c] = medianMatrix::<f32, 3, 3>(filter); diff --git a/juno_frontend/examples/test1.jn b/juno_frontend/examples/test1.jn new file mode 100644 index 0000000000000000000000000000000000000000..b0d26c02962e3e133e94836af489df650119250e --- /dev/null +++ b/juno_frontend/examples/test1.jn @@ -0,0 +1,8 @@ +fn test<x, y : usize>(a : i32[x, y]) -> i32[x, y] { + return a; +} + +fn main<x, y, z : usize>() -> i32[y, z] { + let n : i32[y, z]; + return test::<y, z>(n); +} diff --git a/juno_frontend/src/dynconst.rs b/juno_frontend/src/dynconst.rs index 2fcd637fb8aee2fbb7feebe334085301d4cd8380..72a9d0ff7a9f1395552d5f0a9a374fb44484fa4b 100644 --- a/juno_frontend/src/dynconst.rs +++ b/juno_frontend/src/dynconst.rs @@ -1,6 +1,6 @@ /* A data structure for normalizing and performing computation over dynamic constant expressions */ use std::collections::HashMap; -use std::fmt; +use std::{fmt, iter}; use hercules_ir::{Builder, DynamicConstantID}; use num_rational::Ratio; @@ -251,50 +251,88 @@ impl DynConst { // Builds a dynamic constant in the IR pub fn build(&self, builder : &mut Builder) -> DynamicConstantID { - // Identify the terms with non-zero coefficients - let non_zero_coeff = self.terms.iter().filter(|(_, c)| !c.is_zero()) - .collect::<Vec<_>>(); - if non_zero_coeff.len() != 1 { - // Once the IR supports dynamic constant expressions, we'll need to - // sort the terms with non-zero coefficients and generate in a - // standardized manner to ensure that equivalent expressions always - // generate the same dynamic constant expression in the IR - todo!("Dynamic constant expression generation: {:?}", self) - } else { - let (dterm, dcoeff) = non_zero_coeff[0]; - // If the term with non-zero coefficient has a coefficient that is - // not 1, then this (currently) must just be a constant expression - // which must be a non-negative integer - if !dcoeff.is_one() { - if dterm.iter().all(|p| *p == 0) { - if !dcoeff.is_integer() { - panic!("Dynamic constant is a non-integer constant") - } else { - let val : i64 = dcoeff.to_integer(); - if val < 0 { - panic!("Dynamic constant is a negative constant") - } else { - builder.create_dynamic_constant_constant( - dcoeff.to_integer() as usize) - } - } - } else { - todo!("Dynamic constant expression generation: {:?}", self) - } - } else { - if dterm.iter().all(|p| *p == 0) { - // Constant value 1 - builder.create_dynamic_constant_constant(1) + // Identify the terms with non-zero coefficients, based on the powers + let mut non_zero_coeff = self.terms.iter().filter(|(_, c)| !c.is_zero()) + .collect::<Vec<_>>(); + non_zero_coeff.sort_by(|(d1, _), (d2, _)| d1.cmp(d2)); + + non_zero_coeff.iter().map(|(d, c)| self.build_mono(builder, d, c)) + .collect::<Vec<_>>().into_iter() + .reduce(|x, y| builder.create_dynamic_constant_add(x, y)) + // If there are no terms, this dynamic constant is 0 + .unwrap_or_else(|| builder.create_dynamic_constant_constant(0)) + } + + // Build's a monomial, with a given list of powers (term) and coefficients + fn build_mono(&self, builder : &mut Builder, term : &Vec<i64>, + coeff : &Ratio<i64>) -> DynamicConstantID { + let term_id = term.iter().enumerate() + .filter(|(_, p)| **p != 0) + .map(|(v, p)| self.build_power(builder, v, *p)) + .collect::<Vec<_>>().into_iter() + .reduce(|x, y| builder.create_dynamic_constant_add(x, y)); + + match term_id { + None => { // This means all powers of the term are 0, so we just + // output the coefficient + if !coeff.is_integer() { + panic!("Dynamic constant is a non-integer constant") } else { - let present = dterm.iter().enumerate().filter(|(_, p)| **p != 0) - .collect::<Vec<_>>(); - if present.len() != 1 || *present[0].1 != 1 { - todo!("Dynamic constant expression generation: {:?}", self) + let val : i64 = coeff.to_integer(); + if val < 0 { + panic!("Dynamic constant is a negative constant") } else { - builder.create_dynamic_constant_parameter(present[0].0) + builder.create_dynamic_constant_constant(val as usize) } } - } + }, + Some(term) => { + if coeff.is_one() { term } + else { + let numer : i64 = *coeff.numer(); + let denom : i64 = *coeff.denom(); // > 0 + let coeff_id = self.build_coeff(builder, numer, denom); + builder.create_dynamic_constant_mul(coeff_id, term) + } + }, + } + } + + // Build's a dynamic constant that is a certain power of a specific variable + fn build_power(&self, builder : &mut Builder, v : usize, power : i64) -> DynamicConstantID { + assert!(power != 0); + let power_pos = power.abs() as usize; + + + let var_id = builder.create_dynamic_constant_parameter(v); + let power_id = iter::repeat(var_id).take(power_pos) + .map(|_| var_id) + .reduce(|x, y| builder.create_dynamic_constant_mul(x, y)) + .expect("Power is non-zero"); + + if power > 0 { power_id } + else { + let one_id = builder.create_dynamic_constant_constant(1); + builder.create_dynamic_constant_div(one_id, power_id) + } + } + + // Build's a dynamic constant that is a certain rational number (n / d) + fn build_coeff(&self, builder : &mut Builder, n : i64, d : i64) -> DynamicConstantID { + assert!(n != 0); assert!(d > 0); + + let n_abs_term = builder.create_dynamic_constant_constant(n.abs() as usize); + let n_term = + if n > 0 { n_abs_term } + else { + let zero_term = builder.create_dynamic_constant_constant(0); + builder.create_dynamic_constant_sub(zero_term, n_abs_term) + }; + + if d == 1 { n_term } + else { + let d_term = builder.create_dynamic_constant_constant(d as usize); + builder.create_dynamic_constant_div(n_term, d_term) } } } diff --git a/juno_frontend/src/main.rs b/juno_frontend/src/main.rs index 886c6365601c6f45e5ffd0dd9104421fb36fa896..67c89255fb6b869bace5b0dd9ae1b6906fae736b 100644 --- a/juno_frontend/src/main.rs +++ b/juno_frontend/src/main.rs @@ -62,6 +62,7 @@ fn main() { pm.add_pass(hercules_opt::pass::Pass::Verify); } add_verified_pass!(pm, args, PhiElim); + pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); add_pass!(pm, args, CCP); add_pass!(pm, args, DCE); add_pass!(pm, args, GVN);