From 59854269018ebac041bd16221271e627c4726d13 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 26 Sep 2024 10:02:11 -0500 Subject: [PATCH] Dynamic Constant Math --- hercules_cg/src/sched_gen.rs | 26 ++++++- hercules_ir/src/build.rs | 24 +++++- hercules_ir/src/ir.rs | 22 ++++++ hercules_ir/src/parse.rs | 20 ++++- hercules_ir/src/typecheck.rs | 8 +- hercules_samples/flatten.hir | 16 ++++ juno_frontend/examples/cava.jn | 6 +- juno_frontend/examples/test1.jn | 8 ++ juno_frontend/src/dynconst.rs | 126 +++++++++++++++++++++----------- juno_frontend/src/main.rs | 1 + 10 files changed, 206 insertions(+), 51 deletions(-) create mode 100644 hercules_samples/flatten.hir create mode 100644 juno_frontend/examples/test1.jn diff --git a/hercules_cg/src/sched_gen.rs b/hercules_cg/src/sched_gen.rs index 98d2a202..166923a6 100644 --- a/hercules_cg/src/sched_gen.rs +++ b/hercules_cg/src/sched_gen.rs @@ -1216,7 +1216,7 @@ impl<'a> FunctionContext<'a> { fn compile_dynamic_constant( &self, dc: DynamicConstantID, - _block: &mut SBlock, + block: &mut SBlock, partition_idx: usize, manifest: &Manifest, ) -> SValue { @@ -1231,6 +1231,30 @@ impl<'a> FunctionContext<'a> { .position(|(_, kind)| *kind == ParameterKind::DynamicConstant(idx)) .unwrap(), ), + + DynamicConstant::Add(left, right) + | DynamicConstant::Sub(left, right) + | DynamicConstant::Mul(left, right) + | DynamicConstant::Div(left, right) => { + let left = self.compile_dynamic_constant(left, block, partition_idx, manifest); + let right = self.compile_dynamic_constant(right, block, partition_idx, manifest); + let output_virt_reg = self.make_virt_reg(partition_idx); + block.insts.push(SInst::Binary { + left, + right, + op: match self.dynamic_constants[dc.idx()] { + DynamicConstant::Add(_, _) => SBinaryOperator::Add, + DynamicConstant::Sub(_, _) => SBinaryOperator::Sub, + DynamicConstant::Mul(_, _) => SBinaryOperator::Mul, + DynamicConstant::Div(_, _) => SBinaryOperator::Div, + _ => panic!(), + }, + }); + block + .virt_regs + .push((output_virt_reg, SType::UnsignedInteger64)); + SValue::VirtualRegister(output_virt_reg) + } } } diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 6a73da89..59b4a9f9 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -376,8 +376,8 @@ impl<'a> Builder<'a> { Type::Float64 => self.create_constant_f64(0.0), Type::Product(fs) => { let mut cs = vec![]; - for t in fs.clone() { - cs.push(self.create_constant_zero(t)); + for t in fs.clone().iter() { + cs.push(self.create_constant_zero(*t)); } self.create_constant_prod(cs.into()) } @@ -399,6 +399,26 @@ impl<'a> Builder<'a> { self.intern_dynamic_constant(DynamicConstant::Parameter(val)) } + pub fn create_dynamic_constant_add(&mut self, x : DynamicConstantID, + y : DynamicConstantID) -> DynamicConstantID { + self.intern_dynamic_constant(DynamicConstant::Add(x, y)) + } + + pub fn create_dynamic_constant_sub(&mut self, x : DynamicConstantID, + y : DynamicConstantID) -> DynamicConstantID { + self.intern_dynamic_constant(DynamicConstant::Sub(x, y)) + } + + pub fn create_dynamic_constant_mul(&mut self, x : DynamicConstantID, + y : DynamicConstantID) -> DynamicConstantID { + self.intern_dynamic_constant(DynamicConstant::Mul(x, y)) + } + + pub fn create_dynamic_constant_div(&mut self, x : DynamicConstantID, + y : DynamicConstantID) -> DynamicConstantID { + self.intern_dynamic_constant(DynamicConstant::Div(x, y)) + } + pub fn create_field_index(&self, idx: usize) -> Index { Index::Field(idx) } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 1ea2950b..8f568282 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -113,6 +113,11 @@ pub enum DynamicConstant { // The usize here is an index (which dynamic constant parameter of a // function is this). Parameter(usize), + + Add(DynamicConstantID, DynamicConstantID), + Sub(DynamicConstantID, DynamicConstantID), + Mul(DynamicConstantID, DynamicConstantID), + Div(DynamicConstantID, DynamicConstantID), } /* @@ -393,6 +398,23 @@ impl Module { match &self.dynamic_constants[dc_id.idx()] { DynamicConstant::Constant(cons) => write!(w, "{}", cons), DynamicConstant::Parameter(param) => write!(w, "#{}", param), + DynamicConstant::Add(x, y) + | DynamicConstant::Sub(x, y) + | DynamicConstant::Mul(x, y) + | DynamicConstant::Div(x, y) => { + match &self.dynamic_constants[dc_id.idx()] { + DynamicConstant::Add(_, _) => write!(w, "+")?, + DynamicConstant::Sub(_, _) => write!(w, "-")?, + DynamicConstant::Mul(_, _) => write!(w, "*")?, + DynamicConstant::Div(_, _) => write!(w, "/")?, + _ => (), + } + write!(w, "(")?; + self.write_dynamic_constant(*x, w)?; + write!(w, ", ")?; + self.write_dynamic_constant(*y, w)?; + write!(w, ")") + }, }?; Ok(()) diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 5b9ff694..3d758eb5 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -883,14 +883,17 @@ fn parse_dynamic_constant_id<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, DynamicConstantID> { - let (ir_text, dynamic_constant) = parse_dynamic_constant(ir_text)?; + let (ir_text, dynamic_constant) = parse_dynamic_constant(ir_text, context)?; let id = context .borrow_mut() .get_dynamic_constant_id(dynamic_constant); Ok((ir_text, id)) } -fn parse_dynamic_constant<'a>(ir_text: &'a str) -> nom::IResult<&'a str, DynamicConstant> { +fn parse_dynamic_constant<'a>( + ir_text: &'a str, + context : &RefCell<Context<'a>>, +) -> nom::IResult<&'a str, DynamicConstant> { let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, dc) = nom::branch::alt(( nom::combinator::map( @@ -905,6 +908,19 @@ fn parse_dynamic_constant<'a>(ir_text: &'a str) -> nom::IResult<&'a str, Dynamic })), |(_, x)| DynamicConstant::Parameter(x), ), + // Dynamic constant math is written using a prefix function + nom::combinator::map( + nom::sequence::tuple(( + nom::character::complete::one_of("+-*/"), + parse_tuple2(|x| parse_dynamic_constant_id(x, context), + |x| parse_dynamic_constant_id(x, context)))), + |(op, (x, y))| + match op { '+' => DynamicConstant::Add(x, y), + '-' => DynamicConstant::Sub(x, y), + '*' => DynamicConstant::Mul(x, y), + '/' => DynamicConstant::Div(x, y), + _ => panic!("Invalid parse") } + ), ))(ir_text)?; Ok((ir_text, dc)) } diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 6137982b..f9c2e07f 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -179,8 +179,14 @@ fn typeflow( num_parameters: u32, ) -> bool { match dynamic_constants[root.idx()] { + DynamicConstant::Constant(_) => true, DynamicConstant::Parameter(idx) => idx < num_parameters as usize, - _ => true, + DynamicConstant::Add(x, y) + | DynamicConstant::Sub(x, y) + | DynamicConstant::Mul(x, y) + | DynamicConstant::Div(x, y) + => check_dynamic_constants(x, dynamic_constants, num_parameters) + && check_dynamic_constants(y, dynamic_constants, num_parameters), } } diff --git a/hercules_samples/flatten.hir b/hercules_samples/flatten.hir new file mode 100644 index 00000000..453bebdf --- /dev/null +++ b/hercules_samples/flatten.hir @@ -0,0 +1,16 @@ +fn flatten<2>(x : array(i32, #0, #1))-> array(i32, *(#0, #1)) + c = constant(array(i32, *(#0, #1)), []) + i_ctrl = fork(start, #0) + i_idx = thread_id(i_ctrl) + j_ctrl = fork(i_ctrl, #1) + j_idx = thread_id(j_ctrl) + j_join_ctrl = join(j_ctrl) + i_join_ctrl = join(j_join_ctrl) + r = return(i_join_ctrl, update_i_c) + read = read(x, position(i_idx, j_idx)) + cols = dynamic_constant(#1) + row_idx = mul(i_idx, cols) + out_idx = add(row_idx, j_idx) + update_c = write(update_j_c, read, position(out_idx)) + update_j_c = reduce(j_join_ctrl, update_i_c, update_c) + update_i_c = reduce(i_join_ctrl, c, update_j_c) diff --git a/juno_frontend/examples/cava.jn b/juno_frontend/examples/cava.jn index 6c5bdfe9..ee46e476 100644 --- a/juno_frontend/examples/cava.jn +++ b/juno_frontend/examples/cava.jn @@ -120,9 +120,9 @@ fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, r for c = 0 to col { if r >= 1 && r < row - 1 && c >= 1 && c < col - 1 { let filter : f32[3][3]; // same as [3, 3] - for i = -1 to 2 by 1 { - for j = -1 to 2 by 1 { - filter[i+1, j+1] = input[chan, r + i, c + j]; + for i = 0 to 3 by 1 { + for j = 0 to 3 by 1 { + filter[i, j] = input[chan, r + i - 1, c + j - 1]; } } res[chan, r, c] = medianMatrix::<f32, 3, 3>(filter); diff --git a/juno_frontend/examples/test1.jn b/juno_frontend/examples/test1.jn new file mode 100644 index 00000000..b0d26c02 --- /dev/null +++ b/juno_frontend/examples/test1.jn @@ -0,0 +1,8 @@ +fn test<x, y : usize>(a : i32[x, y]) -> i32[x, y] { + return a; +} + +fn main<x, y, z : usize>() -> i32[y, z] { + let n : i32[y, z]; + return test::<y, z>(n); +} diff --git a/juno_frontend/src/dynconst.rs b/juno_frontend/src/dynconst.rs index 2fcd637f..ba726299 100644 --- a/juno_frontend/src/dynconst.rs +++ b/juno_frontend/src/dynconst.rs @@ -1,6 +1,6 @@ /* A data structure for normalizing and performing computation over dynamic constant expressions */ use std::collections::HashMap; -use std::fmt; +use std::{fmt, iter}; use hercules_ir::{Builder, DynamicConstantID}; use num_rational::Ratio; @@ -251,50 +251,92 @@ impl DynConst { // Builds a dynamic constant in the IR pub fn build(&self, builder : &mut Builder) -> DynamicConstantID { - // Identify the terms with non-zero coefficients - let non_zero_coeff = self.terms.iter().filter(|(_, c)| !c.is_zero()) - .collect::<Vec<_>>(); - if non_zero_coeff.len() != 1 { - // Once the IR supports dynamic constant expressions, we'll need to - // sort the terms with non-zero coefficients and generate in a - // standardized manner to ensure that equivalent expressions always - // generate the same dynamic constant expression in the IR - todo!("Dynamic constant expression generation: {:?}", self) - } else { - let (dterm, dcoeff) = non_zero_coeff[0]; - // If the term with non-zero coefficient has a coefficient that is - // not 1, then this (currently) must just be a constant expression - // which must be a non-negative integer - if !dcoeff.is_one() { - if dterm.iter().all(|p| *p == 0) { - if !dcoeff.is_integer() { - panic!("Dynamic constant is a non-integer constant") - } else { - let val : i64 = dcoeff.to_integer(); - if val < 0 { - panic!("Dynamic constant is a negative constant") - } else { - builder.create_dynamic_constant_constant( - dcoeff.to_integer() as usize) - } - } + // Identify the terms with non-zero coefficients, based on the powers + let mut non_zero_coeff = self.terms.iter().filter(|(_, c)| !c.is_zero()) + .collect::<Vec<_>>(); + non_zero_coeff.sort_by(|(d1, _), (d2, _)| d1.cmp(d2)); + + let (pos, neg) : (Vec<_>, Vec<_>) = + non_zero_coeff.iter() + .map(|(d, c)| self.build_mono(builder, d, c)) + .partition(|(_, neg)| ! *neg); + + let pos_sum = pos.into_iter().map(|(t, _)| t) + .reduce(|x, y| builder.create_dynamic_constant_add(x, y)) + .unwrap_or_else(|| builder.create_dynamic_constant_constant(0)); + + let neg_sum = neg.into_iter().map(|(t, _)| t) + .reduce(|x, y| builder.create_dynamic_constant_add(x, y)); + + match neg_sum { + None => pos_sum, + Some(neg) => builder.create_dynamic_constant_sub(pos_sum, neg) + } + } + + // Build's a monomial, with a given list of powers (term) and coefficients + // Returns the dynamic constant id of the positive value and a boolean + // indicating whether the value should actually be negative + fn build_mono(&self, builder : &mut Builder, term : &Vec<i64>, + coeff : &Ratio<i64>) -> (DynamicConstantID, bool) { + let term_id = term.iter().enumerate() + .filter(|(_, p)| **p != 0) + .map(|(v, p)| self.build_power(builder, v, *p)) + .collect::<Vec<_>>().into_iter() + .reduce(|x, y| builder.create_dynamic_constant_add(x, y)); + + match term_id { + None => { // This means all powers of the term are 0, so we just + // output the coefficient + if !coeff.is_integer() { + panic!("Dynamic constant is a non-integer constant") } else { - todo!("Dynamic constant expression generation: {:?}", self) + let val : i64 = coeff.to_integer(); + (builder.create_dynamic_constant_constant(val.abs() as usize), + val < 0) } - } else { - if dterm.iter().all(|p| *p == 0) { - // Constant value 1 - builder.create_dynamic_constant_constant(1) - } else { - let present = dterm.iter().enumerate().filter(|(_, p)| **p != 0) - .collect::<Vec<_>>(); - if present.len() != 1 || *present[0].1 != 1 { - todo!("Dynamic constant expression generation: {:?}", self) - } else { - builder.create_dynamic_constant_parameter(present[0].0) - } + }, + Some(term) => { + if coeff.is_one() { (term, false) } + else { + let numer : i64 = coeff.numer().abs(); + let denom : i64 = *coeff.denom(); // > 0 + + let with_numer = + if numer == 1 { term } + else { + let numer_id = builder.create_dynamic_constant_constant(numer as usize); + builder.create_dynamic_constant_mul(numer_id, term) + }; + let with_denom = + if denom == 1 { with_numer } + else { + let denom_id = builder.create_dynamic_constant_constant(denom as usize); + builder.create_dynamic_constant_div(with_numer, denom_id) + }; + + (with_denom, numer < 0) } - } + }, + } + } + + // Build's a dynamic constant that is a certain power of a specific variable + fn build_power(&self, builder : &mut Builder, v : usize, power : i64) -> DynamicConstantID { + assert!(power != 0); + let power_pos = power.abs() as usize; + + + let var_id = builder.create_dynamic_constant_parameter(v); + let power_id = iter::repeat(var_id).take(power_pos) + .map(|_| var_id) + .reduce(|x, y| builder.create_dynamic_constant_mul(x, y)) + .expect("Power is non-zero"); + + if power > 0 { power_id } + else { + let one_id = builder.create_dynamic_constant_constant(1); + builder.create_dynamic_constant_div(one_id, power_id) } } } diff --git a/juno_frontend/src/main.rs b/juno_frontend/src/main.rs index 886c6365..67c89255 100644 --- a/juno_frontend/src/main.rs +++ b/juno_frontend/src/main.rs @@ -62,6 +62,7 @@ fn main() { pm.add_pass(hercules_opt::pass::Pass::Verify); } add_verified_pass!(pm, args, PhiElim); + pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); add_pass!(pm, args, CCP); add_pass!(pm, args, DCE); add_pass!(pm, args, GVN); -- GitLab