diff --git a/Cargo.lock b/Cargo.lock index 623fc35c9260676fc9b683bd63e96ac7cbc31a2c..d6ebd2f72e582004cbc7beaf36fe21c77e264442 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -867,6 +867,7 @@ name = "hercules_ir" version = "0.1.0" dependencies = [ "bitvec", + "either", "nom", "ordered-float", "rand", diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index ea326f8a0310fa082c240b5f52000f9c79e0be57..344554b65280f5ff5dc2979034048cd20e86f3bb 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -589,12 +589,12 @@ impl<'a> CPUContext<'a> { ) -> Result<(), Error> { let body = &mut block.body; for dc in dynamic_constants_bottom_up(&self.dynamic_constants) { - match self.dynamic_constants[dc.idx()] { + match &self.dynamic_constants[dc.idx()] { DynamicConstant::Constant(val) => { write!(body, " %dc{} = bitcast i64 {} to i64\n", dc.idx(), val)? } DynamicConstant::Parameter(idx) => { - if idx < num_dc_params as usize { + if *idx < num_dc_params as usize { write!( body, " %dc{} = bitcast i64 %dc_p{} to i64\n", @@ -605,13 +605,31 @@ impl<'a> CPUContext<'a> { write!(body, " %dc{} = bitcast i64 0 to i64\n", dc.idx())? } } - DynamicConstant::Add(left, right) => write!( - body, - " %dc{} = add i64%dc{},%dc{}\n", - dc.idx(), - left.idx(), - right.idx() - )?, + DynamicConstant::Add(xs) => { + let mut xs = xs.iter().peekable(); + let mut cur_value = format!("%dc{}", xs.next().unwrap().idx()); + let mut idx = 0; + while let Some(x) = xs.next() { + let new_val = format!( + "%dc{}{}", + dc.idx(), + if xs.peek().is_some() { + format!(".{}", idx) + } else { + "".to_string() + } + ); + write!( + body, + " {} = add i64{},%dc{}\n", + new_val, + cur_value, + x.idx() + )?; + cur_value = new_val; + idx += 1; + } + } DynamicConstant::Sub(left, right) => write!( body, " %dc{} = sub i64%dc{},%dc{}\n", @@ -619,13 +637,31 @@ impl<'a> CPUContext<'a> { left.idx(), right.idx() )?, - DynamicConstant::Mul(left, right) => write!( - body, - " %dc{} = mul i64%dc{},%dc{}\n", - dc.idx(), - left.idx(), - right.idx() - )?, + DynamicConstant::Mul(xs) => { + let mut xs = xs.iter().peekable(); + let mut cur_value = format!("%dc{}", xs.next().unwrap().idx()); + let mut idx = 0; + while let Some(x) = xs.next() { + let new_val = format!( + "%dc{}{}", + dc.idx(), + if xs.peek().is_some() { + format!(".{}", idx) + } else { + "".to_string() + } + ); + write!( + body, + " {} = mul i64{},%dc{}\n", + new_val, + cur_value, + x.idx() + )?; + cur_value = new_val; + idx += 1; + } + } DynamicConstant::Div(left, right) => write!( body, " %dc{} = udiv i64%dc{},%dc{}\n", @@ -640,20 +676,56 @@ impl<'a> CPUContext<'a> { left.idx(), right.idx() )?, - DynamicConstant::Min(left, right) => write!( - body, - " %dc{} = call i64 @llvm.umin.i64(i64%dc{},i64%dc{})\n", - dc.idx(), - left.idx(), - right.idx() - )?, - DynamicConstant::Max(left, right) => write!( - body, - " %dc{} = call i64 @llvm.umax.i64(i64%dc{},i64%dc{})\n", - dc.idx(), - left.idx(), - right.idx() - )?, + DynamicConstant::Min(xs) => { + let mut xs = xs.iter().peekable(); + let mut cur_value = format!("%dc{}", xs.next().unwrap().idx()); + let mut idx = 0; + while let Some(x) = xs.next() { + let new_val = format!( + "%dc{}{}", + dc.idx(), + if xs.peek().is_some() { + format!(".{}", idx) + } else { + "".to_string() + } + ); + write!( + body, + " {} = call i64 @llvm.umin.i64(i64{},i64%dc{}))\n", + new_val, + cur_value, + x.idx() + )?; + cur_value = new_val; + idx += 1; + } + } + DynamicConstant::Max(xs) => { + let mut xs = xs.iter().peekable(); + let mut cur_value = format!("%dc{}", xs.next().unwrap().idx()); + let mut idx = 0; + while let Some(x) = xs.next() { + let new_val = format!( + "%dc{}{}", + dc.idx(), + if xs.peek().is_some() { + format!(".{}", idx) + } else { + "".to_string() + } + ); + write!( + body, + " {} = call i64 @llvm.umax.i64(i64{},i64%dc{}))\n", + new_val, + cur_value, + x.idx() + )?; + cur_value = new_val; + idx += 1; + } + } } } Ok(()) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index f97180ea24d3bf1b69810d6e79cf68fcb292e9fb..916d6520ae211aa0fe7a7f71f055f8acd0b01066 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -443,57 +443,89 @@ impl<'a> RTContext<'a> { id: DynamicConstantID, w: &mut W, ) -> Result<(), Error> { - match self.module.dynamic_constants[id.idx()] { + match &self.module.dynamic_constants[id.idx()] { DynamicConstant::Constant(val) => write!(w, "{}", val)?, DynamicConstant::Parameter(idx) => write!(w, "dc_p{}", idx)?, - DynamicConstant::Add(left, right) => { + DynamicConstant::Add(xs) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; - write!(w, "+")?; - self.codegen_dynamic_constant(right, w)?; + let mut xs = xs.iter(); + self.codegen_dynamic_constant(*xs.next().unwrap(), w)?; + for x in xs { + write!(w, "+")?; + self.codegen_dynamic_constant(*x, w)?; + } write!(w, ")")?; } DynamicConstant::Sub(left, right) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; + self.codegen_dynamic_constant(*left, w)?; write!(w, "-")?; - self.codegen_dynamic_constant(right, w)?; + self.codegen_dynamic_constant(*right, w)?; write!(w, ")")?; } - DynamicConstant::Mul(left, right) => { + DynamicConstant::Mul(xs) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; - write!(w, "*")?; - self.codegen_dynamic_constant(right, w)?; + let mut xs = xs.iter(); + self.codegen_dynamic_constant(*xs.next().unwrap(), w)?; + for x in xs { + write!(w, "*")?; + self.codegen_dynamic_constant(*x, w)?; + } write!(w, ")")?; } DynamicConstant::Div(left, right) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; + self.codegen_dynamic_constant(*left, w)?; write!(w, "/")?; - self.codegen_dynamic_constant(right, w)?; + self.codegen_dynamic_constant(*right, w)?; write!(w, ")")?; } DynamicConstant::Rem(left, right) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; + self.codegen_dynamic_constant(*left, w)?; write!(w, "%")?; - self.codegen_dynamic_constant(right, w)?; + self.codegen_dynamic_constant(*right, w)?; write!(w, ")")?; } - DynamicConstant::Min(left, right) => { - write!(w, "::core::cmp::min(")?; - self.codegen_dynamic_constant(left, w)?; - write!(w, ",")?; - self.codegen_dynamic_constant(right, w)?; - write!(w, ")")?; + DynamicConstant::Min(xs) => { + let mut xs = xs.iter().peekable(); + + // Track the number of parentheses we open that need to be closed later + let mut opens = 0; + while let Some(x) = xs.next() { + if xs.peek().is_none() { + // For the last element, we just print it + self.codegen_dynamic_constant(*x, w)?; + } else { + // Otherwise, we create a new call to min and print the element as the + // first argument + write!(w, "::core::cmp::min(")?; + self.codegen_dynamic_constant(*x, w)?; + write!(w, ",")?; + opens += 1; + } + } + for _ in 0..opens { + write!(w, ")")?; + } } - DynamicConstant::Max(left, right) => { - write!(w, "::core::cmp::max(")?; - self.codegen_dynamic_constant(left, w)?; - write!(w, ",")?; - self.codegen_dynamic_constant(right, w)?; - write!(w, ")")?; + DynamicConstant::Max(xs) => { + let mut xs = xs.iter().peekable(); + + let mut opens = 0; + while let Some(x) = xs.next() { + if xs.peek().is_none() { + self.codegen_dynamic_constant(*x, w)?; + } else { + write!(w, "::core::cmp::max(")?; + self.codegen_dynamic_constant(*x, w)?; + write!(w, ",")?; + opens += 1; + } + } + for _ in 0..opens { + write!(w, ")")?; + } } } Ok(()) diff --git a/hercules_ir/Cargo.toml b/hercules_ir/Cargo.toml index deda9cc58758f6cc834aadcc8e4ec66625fefb4b..26950d4b7700d19326e6ea61aa2488b4c5d5df59 100644 --- a/hercules_ir/Cargo.toml +++ b/hercules_ir/Cargo.toml @@ -10,3 +10,4 @@ nom = "*" ordered-float = { version = "*", features = ["serde"] } bitvec = "*" serde = { version = "*", features = ["derive"] } +either = "*" diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 1dd326c3ad1abf24e7bfa4aa1f28dfb8255af0e9..b804404524bb26e1e52c8f751bc416b7df84040d 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use crate::*; @@ -25,6 +26,23 @@ pub struct Builder<'a> { module: Module, } +impl<'a> DynamicConstantView for Builder<'a> { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ { + &self.module.dynamic_constants[id.idx()] + } + + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID { + if let Some(id) = self.interned_dynamic_constants.get(&dc) { + *id + } else { + let id = DynamicConstantID::new(self.module.dynamic_constants.len()); + self.module.dynamic_constants.push(dc.clone()); + self.interned_dynamic_constants.insert(dc, id); + id + } + } +} + /* * Since the builder doesn't provide string names for nodes, we need a different * mechanism for allowing one to allocate node IDs before actually creating the @@ -70,17 +88,6 @@ impl<'a> Builder<'a> { } } - fn intern_dynamic_constant(&mut self, dyn_cons: DynamicConstant) -> DynamicConstantID { - if let Some(id) = self.interned_dynamic_constants.get(&dyn_cons) { - *id - } else { - let id = DynamicConstantID::new(self.interned_dynamic_constants.len()); - self.interned_dynamic_constants.insert(dyn_cons.clone(), id); - self.module.dynamic_constants.push(dyn_cons); - id - } - } - pub fn add_label(&mut self, label: &String) -> LabelID { if let Some(id) = self.interned_labels.get(label) { *id @@ -406,11 +413,11 @@ impl<'a> Builder<'a> { } pub fn create_dynamic_constant_constant(&mut self, val: usize) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Constant(val)) + self.dc_const(val) } - pub fn create_dynamic_constant_parameter(&mut self, val: usize) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Parameter(val)) + pub fn create_dynamic_constant_parameter(&mut self, idx: usize) -> DynamicConstantID { + self.dc_param(idx) } pub fn create_dynamic_constant_add( @@ -418,7 +425,14 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Add(x, y)) + self.dc_add(vec![x, y]) + } + + pub fn create_dynamic_constant_add_many( + &mut self, + xs: Vec<DynamicConstantID>, + ) -> DynamicConstantID { + self.dc_add(xs) } pub fn create_dynamic_constant_sub( @@ -426,7 +440,7 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Sub(x, y)) + self.dc_sub(x, y) } pub fn create_dynamic_constant_mul( @@ -434,7 +448,14 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Mul(x, y)) + self.dc_mul(vec![x, y]) + } + + pub fn create_dynamic_constant_mul_many( + &mut self, + xs: Vec<DynamicConstantID>, + ) -> DynamicConstantID { + self.dc_mul(xs) } pub fn create_dynamic_constant_div( @@ -442,7 +463,7 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Div(x, y)) + self.dc_div(x, y) } pub fn create_dynamic_constant_rem( @@ -450,7 +471,7 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Rem(x, y)) + self.dc_rem(x, y) } pub fn create_field_index(&self, idx: usize) -> Index { diff --git a/hercules_ir/src/dc_normalization.rs b/hercules_ir/src/dc_normalization.rs new file mode 100644 index 0000000000000000000000000000000000000000..e9f8f23aa7c05d47601c4a88087775ad82257ec0 --- /dev/null +++ b/hercules_ir/src/dc_normalization.rs @@ -0,0 +1,206 @@ +use crate::*; + +use std::cmp::{max, min}; +use std::collections::BTreeSet; +use std::ops::Deref; + +use either::Either; + +pub trait DynamicConstantView { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_; + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID; + + fn dc_const(&mut self, val: usize) -> DynamicConstantID { + self.add_dynconst(DynamicConstant::Constant(val)) + } + + fn dc_param(&mut self, index: usize) -> DynamicConstantID { + self.add_dynconst(DynamicConstant::Parameter(index)) + } + + fn dc_add(&mut self, dcs: Vec<DynamicConstantID>) -> DynamicConstantID { + let mut constant_val = 0; + let mut fields = vec![]; + + for dc in dcs { + match self.get_dynconst(dc).deref() { + DynamicConstant::Constant(x) => constant_val += x, + DynamicConstant::Add(xs) => fields.extend_from_slice(xs), + _ => fields.push(dc), + } + } + + // If either there are no fields or the constant is non-zero, add it + if constant_val != 0 || fields.len() == 0 { + fields.push(self.add_dynconst(DynamicConstant::Constant(constant_val))); + } + + if fields.len() <= 1 { + // If there is only one term to add, just return it + fields[0] + } else { + fields.sort(); + self.add_dynconst(DynamicConstant::Add(fields)) + } + } + + fn dc_mul(&mut self, dcs: Vec<DynamicConstantID>) -> DynamicConstantID { + let mut constant_val = 1; + let mut fields = vec![]; + + for dc in dcs { + match self.get_dynconst(dc).deref() { + DynamicConstant::Constant(x) => constant_val *= x, + DynamicConstant::Mul(xs) => fields.extend_from_slice(xs), + _ => fields.push(dc), + } + } + + if constant_val == 0 { + return self.add_dynconst(DynamicConstant::Constant(0)); + } + + if constant_val != 1 || fields.len() == 0 { + fields.push(self.add_dynconst(DynamicConstant::Constant(constant_val))); + } + + if fields.len() <= 1 { + fields[0] + } else { + fields.sort(); + self.add_dynconst(DynamicConstant::Mul(fields)) + } + } + + fn dc_min(&mut self, dcs: Vec<DynamicConstantID>) -> DynamicConstantID { + let mut constant_val: Option<usize> = None; + // For min and max we track the fields via a set during normalization as this removes + // duplicates (and we use a BTreeSet as it can produce its elements in sorted order) + let mut fields = BTreeSet::new(); + + for dc in dcs { + match self.get_dynconst(dc).deref() { + DynamicConstant::Constant(x) => { + if let Some(cur_min) = constant_val { + constant_val = Some(min(cur_min, *x)); + } else { + constant_val = Some(*x); + } + } + DynamicConstant::Min(xs) => fields.extend(xs), + _ => { + fields.insert(dc); + } + } + } + + if let Some(const_val) = constant_val { + // Since dynamic constants are non-negative, ignore the constant if it is 0 + if const_val != 0 { + fields.insert(self.add_dynconst(DynamicConstant::Constant(const_val))); + } + } + + if fields.len() == 0 { + // The minimum of 0 dynamic constants is 0 since dynamic constants are non-negative + self.add_dynconst(DynamicConstant::Constant(0)) + } else if fields.len() <= 1 { + *fields.first().unwrap() + } else { + self.add_dynconst(DynamicConstant::Min(fields.into_iter().collect())) + } + } + + fn dc_max(&mut self, dcs: Vec<DynamicConstantID>) -> DynamicConstantID { + let mut constant_val: Option<usize> = None; + let mut fields = BTreeSet::new(); + + for dc in dcs { + match self.get_dynconst(dc).deref() { + DynamicConstant::Constant(x) => { + if let Some(cur_max) = constant_val { + constant_val = Some(max(cur_max, *x)); + } else { + constant_val = Some(*x); + } + } + DynamicConstant::Max(xs) => fields.extend(xs), + _ => { + fields.insert(dc); + } + } + } + + if let Some(const_val) = constant_val { + fields.insert(self.add_dynconst(DynamicConstant::Constant(const_val))); + } + + assert!( + fields.len() > 0, + "Max of 0 dynamic constant expressions is undefined" + ); + + if fields.len() <= 1 { + *fields.first().unwrap() + } else { + self.add_dynconst(DynamicConstant::Max(fields.into_iter().collect())) + } + } + + fn dc_sub(&mut self, x: DynamicConstantID, y: DynamicConstantID) -> DynamicConstantID { + let dc = match (self.get_dynconst(x).deref(), self.get_dynconst(y).deref()) { + (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => { + Either::Left(DynamicConstant::Constant(x - y)) + } + (_, DynamicConstant::Constant(0)) => Either::Right(x), + _ => Either::Left(DynamicConstant::Sub(x, y)), + }; + + match dc { + Either::Left(dc) => self.add_dynconst(dc), + Either::Right(id) => id, + } + } + + fn dc_div(&mut self, x: DynamicConstantID, y: DynamicConstantID) -> DynamicConstantID { + let dc = match (self.get_dynconst(x).deref(), self.get_dynconst(y).deref()) { + (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => { + Either::Left(DynamicConstant::Constant(x / y)) + } + (_, DynamicConstant::Constant(1)) => Either::Right(x), + _ => Either::Left(DynamicConstant::Div(x, y)), + }; + + match dc { + Either::Left(dc) => self.add_dynconst(dc), + Either::Right(id) => id, + } + } + + fn dc_rem(&mut self, x: DynamicConstantID, y: DynamicConstantID) -> DynamicConstantID { + let dc = match (self.get_dynconst(x).deref(), self.get_dynconst(y).deref()) { + (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => { + Either::Left(DynamicConstant::Constant(x % y)) + } + _ => Either::Left(DynamicConstant::Rem(x, y)), + }; + + match dc { + Either::Left(dc) => self.add_dynconst(dc), + Either::Right(id) => id, + } + } + + fn dc_normalize(&mut self, dc: DynamicConstant) -> DynamicConstantID { + match dc { + DynamicConstant::Add(xs) => self.dc_add(xs), + DynamicConstant::Mul(xs) => self.dc_mul(xs), + DynamicConstant::Min(xs) => self.dc_min(xs), + DynamicConstant::Max(xs) => self.dc_max(xs), + DynamicConstant::Sub(x, y) => self.dc_sub(x, y), + DynamicConstant::Div(x, y) => self.dc_div(x, y), + DynamicConstant::Rem(x, y) => self.dc_rem(x, y), + _ => self.add_dynconst(dc), + } + } +} diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index bf7806dcc371112419bbf7e21ef3b206df3a2e32..9f9188af9708c0d82ab49a8330b8127554b3d8e5 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1,4 +1,3 @@ -use std::cmp::{max, min}; use std::collections::HashSet; use std::fmt::Write; use std::ops::Coroutine; @@ -120,13 +119,14 @@ pub enum DynamicConstant { // function is this). Parameter(usize), // Supported integer operations on dynamic constants. - Add(DynamicConstantID, DynamicConstantID), + Add(Vec<DynamicConstantID>), + Mul(Vec<DynamicConstantID>), + Min(Vec<DynamicConstantID>), + Max(Vec<DynamicConstantID>), + Sub(DynamicConstantID, DynamicConstantID), - Mul(DynamicConstantID, DynamicConstantID), Div(DynamicConstantID, DynamicConstantID), Rem(DynamicConstantID, DynamicConstantID), - Min(DynamicConstantID, DynamicConstantID), - Max(DynamicConstantID, DynamicConstantID), } /* @@ -464,21 +464,31 @@ impl Module { match &self.dynamic_constants[dc_id.idx()] { DynamicConstant::Constant(cons) => write!(w, "{}", cons), DynamicConstant::Parameter(param) => write!(w, "#{}", param), - DynamicConstant::Add(x, y) - | DynamicConstant::Sub(x, y) - | DynamicConstant::Mul(x, y) + DynamicConstant::Add(xs) + | DynamicConstant::Mul(xs) + | DynamicConstant::Min(xs) + | DynamicConstant::Max(xs) => { + match &self.dynamic_constants[dc_id.idx()] { + DynamicConstant::Add(_) => write!(w, "+")?, + DynamicConstant::Mul(_) => write!(w, "*")?, + DynamicConstant::Min(_) => write!(w, "min")?, + DynamicConstant::Max(_) => write!(w, "max")?, + _ => (), + } + write!(w, "(")?; + for arg in xs { + self.write_dynamic_constant(*arg, w)?; + write!(w, ",")?; + } + write!(w, ")") + } + DynamicConstant::Sub(x, y) | DynamicConstant::Div(x, y) - | DynamicConstant::Rem(x, y) - | DynamicConstant::Min(x, y) - | DynamicConstant::Max(x, y) => { + | DynamicConstant::Rem(x, y) => { match &self.dynamic_constants[dc_id.idx()] { - DynamicConstant::Add(_, _) => write!(w, "+")?, DynamicConstant::Sub(_, _) => write!(w, "-")?, - DynamicConstant::Mul(_, _) => write!(w, "*")?, DynamicConstant::Div(_, _) => write!(w, "/")?, DynamicConstant::Rem(_, _) => write!(w, "%")?, - DynamicConstant::Min(_, _) => write!(w, "min")?, - DynamicConstant::Max(_, _) => write!(w, "max")?, _ => (), } write!(w, "(")?; @@ -639,15 +649,37 @@ pub fn dynamic_constants_bottom_up( if visited[id.idx()] { continue; } - match dynamic_constants[id.idx()] { - DynamicConstant::Add(left, right) - | DynamicConstant::Sub(left, right) - | DynamicConstant::Mul(left, right) - | DynamicConstant::Div(left, right) - | DynamicConstant::Rem(left, right) => { + match &dynamic_constants[id.idx()] { + DynamicConstant::Add(args) + | DynamicConstant::Mul(args) + | DynamicConstant::Min(args) + | DynamicConstant::Max(args) => { // We have to yield the children of this node before // this node itself. We keep track of which nodes have // yielded using visited. + if args + .iter() + .any(|i| i.idx() >= visited.len() || invalid[i.idx()]) + { + // This is an invalid dynamic constant and should be skipped + invalid.set(id.idx(), true); + continue; + } + + if args.iter().all(|i| visited[i.idx()]) { + // Since all children have been yielded, we yield ourself + visited.set(id.idx(), true); + yield id; + } else { + // Otherwise push self onto stack so that the children will get popped + // first + stack.push(id); + stack.extend(args.clone()); + } + } + DynamicConstant::Sub(left, right) + | DynamicConstant::Div(left, right) + | DynamicConstant::Rem(left, right) => { if left.idx() >= visited.len() || right.idx() >= visited.len() || invalid[left.idx()] @@ -664,8 +696,8 @@ pub fn dynamic_constants_bottom_up( // Push ourselves, then children, so that children // get popped first. stack.push(id); - stack.push(left); - stack.push(right); + stack.push(*left); + stack.push(*right); } } _ => { @@ -991,6 +1023,34 @@ impl Constant { } impl DynamicConstant { + pub fn add(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Add(vec![x, y]) + } + + pub fn sub(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Sub(x, y) + } + + pub fn mul(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Mul(vec![x, y]) + } + + pub fn div(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Div(x, y) + } + + pub fn rem(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Rem(x, y) + } + + pub fn min(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Min(vec![x, y]) + } + + pub fn max(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Max(vec![x, y]) + } + pub fn is_parameter(&self) -> bool { if let DynamicConstant::Parameter(_) = self { true @@ -1028,33 +1088,12 @@ pub fn evaluate_dynamic_constant( cons: DynamicConstantID, dcs: &Vec<DynamicConstant>, ) -> Option<usize> { - match dcs[cons.idx()] { - DynamicConstant::Constant(cons) => Some(cons), - DynamicConstant::Parameter(_) => None, - DynamicConstant::Add(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? + evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Sub(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? - evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Mul(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? * evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Div(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? / evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Rem(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? % evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Min(left, right) => Some(min( - evaluate_dynamic_constant(left, dcs)?, - evaluate_dynamic_constant(right, dcs)?, - )), - DynamicConstant::Max(left, right) => Some(max( - evaluate_dynamic_constant(left, dcs)?, - evaluate_dynamic_constant(right, dcs)?, - )), - } + // Because of normalization, if a dynamic constant can be expressed as a constant it must be a + // constant + let DynamicConstant::Constant(cons) = dcs[cons.idx()] else { + return None; + }; + Some(cons) } /* diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index 185a28886595330dc6e05e6f9286397d1f5a30d5..fc59a74c4b453d24deddc7141456bc0b21bb6e5d 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -11,6 +11,7 @@ pub mod build; pub mod callgraph; pub mod collections; pub mod dataflow; +pub mod dc_normalization; pub mod def_use; pub mod device; pub mod dom; @@ -28,6 +29,7 @@ pub use crate::build::*; pub use crate::callgraph::*; pub use crate::collections::*; pub use crate::dataflow::*; +pub use crate::dc_normalization::*; pub use crate::def_use::*; pub use crate::device::*; pub use crate::dom::*; diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index cdad54f935afd7b79eaf258b8ec0a83415ecb5ef..257dd4d998341feb2ad6e326a1e4b9e58b31c1e3 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -1,5 +1,6 @@ use std::cell::RefCell; use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::str::FromStr; use crate::*; @@ -87,16 +88,7 @@ impl<'a> Context<'a> { } fn get_dynamic_constant_id(&mut self, dynamic_constant: DynamicConstant) -> DynamicConstantID { - if let Some(id) = self.interned_dynamic_constants.get(&dynamic_constant) { - *id - } else { - let id = DynamicConstantID::new(self.interned_dynamic_constants.len()); - self.interned_dynamic_constants - .insert(dynamic_constant.clone(), id); - self.reverse_dynamic_constant_map - .insert(id, dynamic_constant); - id - } + self.dc_normalize(dynamic_constant) } fn get_label_id(&mut self, label: String) -> LabelID { @@ -110,6 +102,23 @@ impl<'a> Context<'a> { } } +impl<'a> DynamicConstantView for Context<'a> { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ { + &self.reverse_dynamic_constant_map[&id] + } + + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID { + if let Some(id) = self.interned_dynamic_constants.get(&dc) { + *id + } else { + let id = DynamicConstantID::new(self.reverse_dynamic_constant_map.len()); + self.interned_dynamic_constants.insert(dc.clone(), id); + self.reverse_dynamic_constant_map.insert(id, dc); + id + } + } +} + /* * A module is just a file with a list of functions. */ @@ -946,9 +955,9 @@ fn parse_dynamic_constant<'a>( ), )), |(op, (x, y))| match op { - '+' => DynamicConstant::Add(x, y), + '+' => DynamicConstant::Add(vec![x, y]), '-' => DynamicConstant::Sub(x, y), - '*' => DynamicConstant::Mul(x, y), + '*' => DynamicConstant::Mul(vec![x, y]), '/' => DynamicConstant::Div(x, y), '%' => DynamicConstant::Rem(x, y), _ => panic!("Invalid parse"), diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index a80dd422128bd3ba2ab6436272943ff1b2deb82f..f7ea397e49355029c7b5cbc0fa534494bc747c6d 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -1,6 +1,6 @@ -use std::cmp::{max, min}; use std::collections::HashMap; use std::iter::zip; +use std::ops::Deref; use crate::*; @@ -184,18 +184,20 @@ fn typeflow( dynamic_constants: &Vec<DynamicConstant>, num_parameters: u32, ) -> bool { - match dynamic_constants[root.idx()] { + match &dynamic_constants[root.idx()] { DynamicConstant::Constant(_) => true, - DynamicConstant::Parameter(idx) => idx < num_parameters as usize, - DynamicConstant::Add(x, y) - | DynamicConstant::Sub(x, y) - | DynamicConstant::Mul(x, y) + DynamicConstant::Parameter(idx) => *idx < num_parameters as usize, + DynamicConstant::Add(xs) + | DynamicConstant::Mul(xs) + | DynamicConstant::Min(xs) + | DynamicConstant::Max(xs) => xs + .iter() + .all(|dc| check_dynamic_constants(*dc, dynamic_constants, num_parameters)), + DynamicConstant::Sub(x, y) | DynamicConstant::Div(x, y) - | DynamicConstant::Rem(x, y) - | DynamicConstant::Min(x, y) - | DynamicConstant::Max(x, y) => { - check_dynamic_constants(x, dynamic_constants, num_parameters) - && check_dynamic_constants(y, dynamic_constants, num_parameters) + | DynamicConstant::Rem(x, y) => { + check_dynamic_constants(*x, dynamic_constants, num_parameters) + && check_dynamic_constants(*y, dynamic_constants, num_parameters) } } } @@ -733,10 +735,20 @@ fn typeflow( } } + // Construct the substitution object + let mut subst = DCSubst::new( + types, + reverse_type_map, + dynamic_constants, + reverse_dynamic_constant_map, + dc_args, + ); + // Check argument types. for (input, param_ty) in zip(inputs.iter().skip(1), callee.param_types.iter()) { + let param_ty = subst.type_subst(*param_ty); if let Concrete(input_id) = input { - if !types_match(types, dynamic_constants, dc_args, *param_ty, *input_id) { + if param_ty != *input_id { return Error(String::from( "Call node mismatches argument types with callee function.", )); @@ -747,14 +759,7 @@ fn typeflow( } } - Concrete(type_subst( - types, - dynamic_constants, - reverse_type_map, - reverse_dynamic_constant_map, - dc_args, - callee.return_type, - )) + Concrete(subst.type_subst(callee.return_type)) } Node::IntrinsicCall { intrinsic, args: _ } => { let num_params = match intrinsic { @@ -1071,307 +1076,154 @@ pub fn cast_compatible(src_ty: &Type, dst_ty: &Type) -> bool { } /* - * Determine if the given type matches the parameter type when the provided - * dynamic constants are substituted in for the dynamic constants used in the - * parameter type. + * Data structures and methods for substituting given dynamic constant arguments into the provided + * types and dynamic constants */ -fn types_match( - types: &Vec<Type>, - dynamic_constants: &Vec<DynamicConstant>, - dc_args: &Box<[DynamicConstantID]>, - param: TypeID, - input: TypeID, -) -> bool { - // Note that we can't just check whether the type ids are equal since them - // being equal does not mean they match when we properly substitute in the - // dynamic constant arguments - - match (&types[param.idx()], &types[input.idx()]) { - (Type::Control, Type::Control) - | (Type::Boolean, Type::Boolean) - | (Type::Integer8, Type::Integer8) - | (Type::Integer16, Type::Integer16) - | (Type::Integer32, Type::Integer32) - | (Type::Integer64, Type::Integer64) - | (Type::UnsignedInteger8, Type::UnsignedInteger8) - | (Type::UnsignedInteger16, Type::UnsignedInteger16) - | (Type::UnsignedInteger32, Type::UnsignedInteger32) - | (Type::UnsignedInteger64, Type::UnsignedInteger64) - | (Type::Float32, Type::Float32) - | (Type::Float64, Type::Float64) => true, - (Type::Product(ps), Type::Product(is)) | (Type::Summation(ps), Type::Summation(is)) => { - ps.len() == is.len() - && ps - .iter() - .zip(is.iter()) - .all(|(p, i)| types_match(types, dynamic_constants, dc_args, *p, *i)) - } - (Type::Array(p, pds), Type::Array(i, ids)) => { - types_match(types, dynamic_constants, dc_args, *p, *i) - && pds.len() == ids.len() - && pds - .iter() - .zip(ids.iter()) - .all(|(pd, id)| dyn_consts_match(dynamic_constants, dc_args, *pd, *id)) - } - (_, _) => false, - } +struct DCSubst<'a> { + types: &'a mut Vec<Type>, + reverse_type_map: &'a mut HashMap<Type, TypeID>, + dynamic_constants: &'a mut Vec<DynamicConstant>, + reverse_dynamic_constant_map: &'a mut HashMap<DynamicConstant, DynamicConstantID>, + dc_args: &'a [DynamicConstantID], } -/* - * Determine if the given dynamic constant matches the parameter's dynamic - * constants when the provided dynamic constants are substituted in for the - * dynamic constants used in the parameter's dynamic constant. Implement dynamic - * constant normalization here as well - i.e., 1 * 2 * 3 = 6. - */ -fn dyn_consts_match( - dynamic_constants: &Vec<DynamicConstant>, - dc_args: &Box<[DynamicConstantID]>, - left: DynamicConstantID, - right: DynamicConstantID, -) -> bool { - // First, try evaluating the DCs and seeing if they're the same value. - if let (Some(cons1), Some(cons2)) = ( - evaluate_dynamic_constant(left, dynamic_constants), - evaluate_dynamic_constant(right, dynamic_constants), - ) { - return cons1 == cons2; +impl<'a> DynamicConstantView for DCSubst<'a> { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ { + &self.dynamic_constants[id.idx()] } - match ( - &dynamic_constants[left.idx()], - &dynamic_constants[right.idx()], - ) { - (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => x == y, - (DynamicConstant::Parameter(l), DynamicConstant::Parameter(r)) => l == r, - (DynamicConstant::Parameter(i), _) => dyn_consts_match( - dynamic_constants, - dc_args, - min(right, dc_args[*i]), - max(right, dc_args[*i]), - ), - (_, DynamicConstant::Parameter(i)) => dyn_consts_match( + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID { + if let Some(id) = self.reverse_dynamic_constant_map.get(&dc) { + *id + } else { + let id = DynamicConstantID::new(self.dynamic_constants.len()); + self.reverse_dynamic_constant_map.insert(dc.clone(), id); + self.dynamic_constants.push(dc); + id + } + } +} + +impl<'a> DCSubst<'a> { + fn new( + types: &'a mut Vec<Type>, + reverse_type_map: &'a mut HashMap<Type, TypeID>, + dynamic_constants: &'a mut Vec<DynamicConstant>, + reverse_dynamic_constant_map: &'a mut HashMap<DynamicConstant, DynamicConstantID>, + dc_args: &'a [DynamicConstantID], + ) -> Self { + Self { + types, + reverse_type_map, dynamic_constants, + reverse_dynamic_constant_map, dc_args, - min(left, dc_args[*i]), - max(left, dc_args[*i]), - ), - (DynamicConstant::Add(ll, lr), DynamicConstant::Add(rl, rr)) - | (DynamicConstant::Mul(ll, lr), DynamicConstant::Mul(rl, rr)) - | (DynamicConstant::Min(ll, lr), DynamicConstant::Min(rl, rr)) - | (DynamicConstant::Max(ll, lr), DynamicConstant::Max(rl, rr)) => { - // Normalize for associative ops by always looking at smaller DC ID - // as left arm and larger DC ID as right arm. - dyn_consts_match(dynamic_constants, dc_args, min(*ll, *lr), min(*rl, *rr)) - && dyn_consts_match(dynamic_constants, dc_args, max(*ll, *lr), max(*rl, *rr)) - } - (DynamicConstant::Sub(ll, lr), DynamicConstant::Sub(rl, rr)) - | (DynamicConstant::Div(ll, lr), DynamicConstant::Div(rl, rr)) - | (DynamicConstant::Rem(ll, lr), DynamicConstant::Rem(rl, rr)) => { - dyn_consts_match(dynamic_constants, dc_args, *ll, *rl) - && dyn_consts_match(dynamic_constants, dc_args, *lr, *rr) } - (_, _) => false, } -} -/* - * Substitutes the given dynamic constant arguments into the provided type and - * returns the appropriate typeID (potentially creating new types and dynamic - * constants in the process) - */ -fn type_subst( - types: &mut Vec<Type>, - dynamic_constants: &mut Vec<DynamicConstant>, - reverse_type_map: &mut HashMap<Type, TypeID>, - reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, - dc_args: &Box<[DynamicConstantID]>, - typ: TypeID, -) -> TypeID { - fn intern_type( - ty: Type, - types: &mut Vec<Type>, - reverse_type_map: &mut HashMap<Type, TypeID>, - ) -> TypeID { - if let Some(id) = reverse_type_map.get(&ty) { + fn intern_type(&mut self, ty: Type) -> TypeID { + if let Some(id) = self.reverse_type_map.get(&ty) { *id } else { - let id = TypeID::new(types.len()); - reverse_type_map.insert(ty.clone(), id); - types.push(ty); + let id = TypeID::new(self.types.len()); + self.reverse_type_map.insert(ty.clone(), id); + self.types.push(ty); id } } - match &types[typ.idx()] { - Type::Control - | Type::Boolean - | Type::Integer8 - | Type::Integer16 - | Type::Integer32 - | Type::Integer64 - | Type::UnsignedInteger8 - | Type::UnsignedInteger16 - | Type::UnsignedInteger32 - | Type::UnsignedInteger64 - | Type::Float32 - | Type::Float64 => typ, - Type::Product(ts) => { - let mut new_ts = vec![]; - for t in ts.clone().iter() { - new_ts.push(type_subst( - types, - dynamic_constants, - reverse_type_map, - reverse_dynamic_constant_map, - dc_args, - *t, - )); + fn type_subst(&mut self, typ: TypeID) -> TypeID { + match &self.types[typ.idx()] { + Type::Control + | Type::Boolean + | Type::Integer8 + | Type::Integer16 + | Type::Integer32 + | Type::Integer64 + | Type::UnsignedInteger8 + | Type::UnsignedInteger16 + | Type::UnsignedInteger32 + | Type::UnsignedInteger64 + | Type::Float32 + | Type::Float64 => typ, + Type::Product(ts) => { + let new_ts = ts.clone().iter().map(|t| self.type_subst(*t)).collect(); + self.intern_type(Type::Product(new_ts)) } - intern_type(Type::Product(new_ts.into()), types, reverse_type_map) - } - Type::Summation(ts) => { - let mut new_ts = vec![]; - for t in ts.clone().iter() { - new_ts.push(type_subst( - types, - dynamic_constants, - reverse_type_map, - reverse_dynamic_constant_map, - dc_args, - *t, - )); + Type::Summation(ts) => { + let new_ts = ts.clone().iter().map(|t| self.type_subst(*t)).collect(); + self.intern_type(Type::Summation(new_ts)) } - intern_type(Type::Summation(new_ts.into()), types, reverse_type_map) - } - Type::Array(elem, dims) => { - let ds = dims.clone(); - let new_elem = type_subst( - types, - dynamic_constants, - reverse_type_map, - reverse_dynamic_constant_map, - dc_args, - *elem, - ); - let mut new_dims = vec![]; - for d in ds.iter() { - new_dims.push(dyn_const_subst( - dynamic_constants, - reverse_dynamic_constant_map, - dc_args, - *d, - )); + Type::Array(elem, dims) => { + let elem = *elem; + let new_dims = dims + .clone() + .iter() + .map(|d| self.dyn_const_subst(*d)) + .collect(); + let new_elem = self.type_subst(elem); + self.intern_type(Type::Array(new_elem, new_dims)) } - intern_type( - Type::Array(new_elem, new_dims.into()), - types, - reverse_type_map, - ) - } - } -} - -fn dyn_const_subst( - dynamic_constants: &mut Vec<DynamicConstant>, - reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, - dc_args: &Box<[DynamicConstantID]>, - dyn_const: DynamicConstantID, -) -> DynamicConstantID { - fn intern_dyn_const( - dc: DynamicConstant, - dynamic_constants: &mut Vec<DynamicConstant>, - reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, - ) -> DynamicConstantID { - if let Some(id) = reverse_dynamic_constant_map.get(&dc) { - *id - } else { - let id = DynamicConstantID::new(dynamic_constants.len()); - reverse_dynamic_constant_map.insert(dc.clone(), id); - dynamic_constants.push(dc); - id } } - match &dynamic_constants[dyn_const.idx()] { - DynamicConstant::Constant(_) => dyn_const, - DynamicConstant::Parameter(i) => dc_args[*i], - DynamicConstant::Add(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Add(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Sub(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Sub(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Mul(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Mul(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Div(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Div(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Rem(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Rem(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Min(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Min(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Max(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Max(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) + fn dyn_const_subst(&mut self, dyn_const: DynamicConstantID) -> DynamicConstantID { + match &self.dynamic_constants[dyn_const.idx()] { + DynamicConstant::Constant(_) => dyn_const, + DynamicConstant::Parameter(i) => self.dc_args[*i], + DynamicConstant::Add(xs) => { + let sxs = xs + .clone() + .into_iter() + .map(|dc| self.dyn_const_subst(dc)) + .collect(); + self.dc_add(sxs) + } + DynamicConstant::Sub(l, r) => { + let x = *l; + let y = *r; + let sx = self.dyn_const_subst(x); + let sy = self.dyn_const_subst(y); + self.dc_sub(sx, sy) + } + DynamicConstant::Mul(xs) => { + let sxs = xs + .clone() + .into_iter() + .map(|dc| self.dyn_const_subst(dc)) + .collect(); + self.dc_mul(sxs) + } + DynamicConstant::Div(l, r) => { + let x = *l; + let y = *r; + let sx = self.dyn_const_subst(x); + let sy = self.dyn_const_subst(y); + self.dc_div(sx, sy) + } + DynamicConstant::Rem(l, r) => { + let x = *l; + let y = *r; + let sx = self.dyn_const_subst(x); + let sy = self.dyn_const_subst(y); + self.dc_rem(sx, sy) + } + DynamicConstant::Min(xs) => { + let sxs = xs + .clone() + .into_iter() + .map(|dc| self.dyn_const_subst(dc)) + .collect(); + self.dc_min(sxs) + } + DynamicConstant::Max(xs) => { + let sxs = xs + .clone() + .into_iter() + .map(|dc| self.dyn_const_subst(dc)) + .collect(); + self.dc_max(sxs) + } } } } diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 39f1184cc947a35418641a817a86321343f101fc..1d5860574f4384887a334b78d060d2d4fd53e010 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -9,6 +9,7 @@ use either::Either; use hercules_ir::def_use::*; use hercules_ir::ir::*; +use hercules_ir::DynamicConstantView; /* * Helper object for editing Hercules functions in a trackable manner. Edits @@ -743,22 +744,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } pub fn add_dynamic_constant(&mut self, dynamic_constant: DynamicConstant) -> DynamicConstantID { - let pos = self - .editor - .dynamic_constants - .borrow() - .iter() - .chain(self.added_dynamic_constants.iter()) - .position(|c| *c == dynamic_constant); - if let Some(idx) = pos { - DynamicConstantID::new(idx) - } else { - let id = DynamicConstantID::new( - self.editor.dynamic_constants.borrow().len() + self.added_dynamic_constants.len(), - ); - self.added_dynamic_constants.push(dynamic_constant); - id - } + self.dc_normalize(dynamic_constant) } pub fn get_dynamic_constant( @@ -788,6 +774,31 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } } +impl<'a, 'b> DynamicConstantView for FunctionEdit<'a, 'b> { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ { + self.get_dynamic_constant(id) + } + + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID { + let pos = self + .editor + .dynamic_constants + .borrow() + .iter() + .chain(self.added_dynamic_constants.iter()) + .position(|c| *c == dc); + if let Some(idx) = pos { + DynamicConstantID::new(idx) + } else { + let id = DynamicConstantID::new( + self.editor.dynamic_constants.borrow().len() + self.added_dynamic_constants.len(), + ); + self.added_dynamic_constants.push(dc); + id + } + } +} + #[cfg(test)] mod editor_tests { #[allow(unused_imports)] diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index 1abb89672ae1d5c4f0f34578ca9d8eb2d69a2bc0..052fd0e493327fceb3bd1b1918d4d4aafc93bf79 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use hercules_ir::*; @@ -70,22 +71,24 @@ fn guarded_fork( }; let mut factors = factors.iter().enumerate().map(|(idx, dc)| { - let DynamicConstant::Max(l, r) = *editor.get_dynamic_constant(*dc) else { + let factor = editor.get_dynamic_constant(*dc); + let DynamicConstant::Max(xs) = factor.deref() else { return Factor::Normal(*dc); }; - // There really needs to be a better way to work w/ associativity. - let binding = [(l, r), (r, l)]; - let id = binding.iter().find_map(|(a, b)| { - let DynamicConstant::Constant(1) = *editor.get_dynamic_constant(*a) else { - return None; - }; - Some(b) - }); - - match id { - Some(v) => Factor::Max(idx, *v), - None => Factor::Normal(*dc), + // Filter out any terms which are just 1s + let non_ones = xs.iter().filter(|i| { + if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() { + false + } else { + true + } + }).collect::<Vec<_>>(); + // If we're left with just one term x, we had max { 1, x } + if non_ones.len() == 1 { + Factor::Max(idx, *non_ones[0]) + } else { + Factor::Normal(*dc) } }); diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index ce9ac1412f1253bff6589ec668db63725183ca6c..ec4e9fbcc22d9f1c8a53652173706b40c5b12e65 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -265,9 +265,8 @@ pub fn forkify_loop( let bound_dc_id = { let mut max_id = DynamicConstantID::new(0); editor.edit(|mut edit| { - // FIXME: Maybe add_dynamic_constant should intern? let one_id = edit.add_dynamic_constant(DynamicConstant::Constant(1)); - max_id = edit.add_dynamic_constant(DynamicConstant::Max(one_id, bound_dc_id)); + max_id = edit.add_dynamic_constant(DynamicConstant::max(one_id, bound_dc_id)); Ok(edit) }); max_id diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 6d36e8ac2cd6c8fdff46f4e46b25a95e5b15db51..271bfaf1da55f6c8ab342d06853532eb5ce99fff 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1064,9 +1064,9 @@ fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> D if align != 1 { let align_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align)); let align_m1_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align - 1)); - acc = edit.add_dynamic_constant(DynamicConstant::Add(acc, align_m1_dc)); - acc = edit.add_dynamic_constant(DynamicConstant::Div(acc, align_dc)); - acc = edit.add_dynamic_constant(DynamicConstant::Mul(acc, align_dc)); + acc = edit.add_dynamic_constant(DynamicConstant::add(acc, align_m1_dc)); + acc = edit.add_dynamic_constant(DynamicConstant::div(acc, align_dc)); + acc = edit.add_dynamic_constant(DynamicConstant::mul(acc, align_dc)); } acc } @@ -1098,7 +1098,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> // the field. let field_size = type_size(edit, field, alignments); acc_size = align(edit, acc_size, alignments[field.idx()]); - acc_size = edit.add_dynamic_constant(DynamicConstant::Add(acc_size, field_size)); + acc_size = edit.add_dynamic_constant(DynamicConstant::add(acc_size, field_size)); } // Finally, round up to the alignment of the whole product, since // the size needs to be a multiple of the alignment. @@ -1112,11 +1112,11 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> // Pick the size of the largest variant, since that's the most // memory we would need. let variant_size = type_size(edit, variant, alignments); - acc_size = edit.add_dynamic_constant(DynamicConstant::Max(acc_size, variant_size)); + acc_size = edit.add_dynamic_constant(DynamicConstant::max(acc_size, variant_size)); } // Add one byte for the discriminant and align the whole summation. let one = edit.add_dynamic_constant(DynamicConstant::Constant(1)); - acc_size = edit.add_dynamic_constant(DynamicConstant::Add(acc_size, one)); + acc_size = edit.add_dynamic_constant(DynamicConstant::add(acc_size, one)); acc_size = align(edit, acc_size, alignments[ty_id.idx()]); acc_size } @@ -1124,7 +1124,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> // The layout of an array is row-major linear in memory. let mut acc_size = type_size(edit, elem, alignments); for bound in bounds { - acc_size = edit.add_dynamic_constant(DynamicConstant::Mul(acc_size, bound)); + acc_size = edit.add_dynamic_constant(DynamicConstant::mul(acc_size, bound)); } acc_size } @@ -1160,7 +1160,7 @@ fn object_allocation( *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]); offsets.insert(id, *total); let type_size = type_size(&mut edit, typing[id.idx()], alignments); - *total = edit.add_dynamic_constant(DynamicConstant::Add(*total, type_size)); + *total = edit.add_dynamic_constant(DynamicConstant::add(*total, type_size)); } } Node::Call { @@ -1169,7 +1169,13 @@ fn object_allocation( ref dynamic_constants, args: _, } => { - let dynamic_constants = dynamic_constants.clone(); + let dynamic_constants = dynamic_constants.to_vec(); + let dc_args = (0..dynamic_constants.len()) + .map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i))); + let substs = dc_args + .zip(dynamic_constants.into_iter()) + .collect::<HashMap<_, _>>(); + for device in BACKED_DEVICES { if let Some(mut callee_backing_size) = backing_allocations[&callee] .get(&device) @@ -1183,26 +1189,12 @@ fn object_allocation( offsets.insert(id, *total); // Substitute the dynamic constant parameters in the // callee's backing size. - let first_dc = edit.num_dynamic_constants() + 10000; - for (p_idx, dc_n) in zip(0..dynamic_constants.len(), first_dc..) { - let dc_a = - edit.add_dynamic_constant(DynamicConstant::Parameter(p_idx)); - callee_backing_size = substitute_dynamic_constants( - dc_a, - DynamicConstantID::new(dc_n), - callee_backing_size, - &mut edit, - ); - } - for (dc_n, dc_b) in zip(first_dc.., dynamic_constants.iter()) { - callee_backing_size = substitute_dynamic_constants( - DynamicConstantID::new(dc_n), - *dc_b, - callee_backing_size, - &mut edit, - ); - } - *total = edit.add_dynamic_constant(DynamicConstant::Add( + callee_backing_size = substitute_dynamic_constants( + &substs, + callee_backing_size, + &mut edit, + ); + *total = edit.add_dynamic_constant(DynamicConstant::add( *total, callee_backing_size, )); diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index 1d2bac97d848ace910d614a42743c6ea5fe3aa9e..848d957f1e25c37b448aea7b35b0f5c5100c6d69 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -119,7 +119,12 @@ fn inline_func( // Assemble all the info we'll need to do the edit. let dcs_a = &dc_param_idx_to_dc_id[..dynamic_constants.len()]; - let dcs_b = dynamic_constants.clone(); + let dcs_b = dynamic_constants.to_vec(); + let substs = dcs_a + .iter() + .map(|i| *i) + .zip(dcs_b.into_iter()) + .collect::<HashMap<_, _>>(); let args = args.clone(); let old_num_nodes = editor.func().nodes.len(); let old_id_to_new_id = |old_id: NodeID| NodeID::new(old_id.idx() + old_num_nodes); @@ -163,39 +168,7 @@ fn inline_func( || node.is_dynamic_constant() || node.is_call() { - // We have to perform the subsitution in two steps. First, - // we map every dynamic constant A to a non-sense dynamic - // constant ID. Second, we map each non-sense dynamic - // constant ID to the appropriate dynamic constant B. Why - // not just do this in one step from A to B? We update - // dynamic constants one at a time, so imagine the following - // A -> B mappings: - // ID 0 -> ID 1 - // ID 1 -> ID 0 - // First, we apply the first mapping. This changes all - // references to dynamic constant 0 to dynamic constant 1. - // Then, we apply the second mapping. This updates all - // already present references to dynamic constant 1, as well - // as the new references we just made in the first step. We - // actually want to institute all the updates - // *simultaneously*, hence the two step maneuver. - let first_dc = edit.num_dynamic_constants() + 10000; - for (dc_a, dc_n) in zip(dcs_a, first_dc..) { - substitute_dynamic_constants_in_node( - *dc_a, - DynamicConstantID::new(dc_n), - &mut node, - &mut edit, - ); - } - for (dc_n, dc_b) in zip(first_dc.., dcs_b.iter()) { - substitute_dynamic_constants_in_node( - DynamicConstantID::new(dc_n), - *dc_b, - &mut node, - &mut edit, - ); - } + substitute_dynamic_constants_in_node(&substs, &mut node, &mut edit); } let mut uses = get_uses_mut(&mut node); for u in uses.as_mut() { diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index f597cd80347d94a7c927d6fe085d80f843e280eb..f22c1fe8410bdb008dbffe4acb90fa1679f9e44e 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -313,31 +313,22 @@ fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_edi // If this becomes a common pattern, it would be worth creating // a better abstraction around bulk replacement. - let new_dcs = (*dynamic_constants).clone(); + let new_dcs = (*dynamic_constants).to_vec(); + let old_dcs = dc_param_idx_to_dc_id[..new_dcs.len()].to_vec(); + assert_eq!(old_dcs.len(), new_dcs.len()); + let substs = old_dcs + .into_iter() + .zip(new_dcs.into_iter()) + .collect::<HashMap<_, _>>(); let edit_successful = editor.edit(|mut edit| { - let old_dcs = dc_param_idx_to_dc_id[..new_dcs.len()].to_vec().clone(); let mut substituted = old_return_type_ids[function_id.idx()]; - assert_eq!(old_dcs.len(), new_dcs.len()); - let first_dc = edit.num_dynamic_constants() + 10000; - for (dc_a, dc_n) in zip(old_dcs, first_dc..) { - substituted = substitute_dynamic_constants_in_type( - dc_a, - DynamicConstantID::new(dc_n), - substituted, - &mut edit, - ); - } - - for (dc_n, dc_b) in zip(first_dc.., new_dcs.iter()) { - substituted = substitute_dynamic_constants_in_type( - DynamicConstantID::new(dc_n), - *dc_b, - substituted, - &mut edit, - ); - } + let substituted = substitute_dynamic_constants_in_type( + &substs, + old_return_type_ids[function_id.idx()], + &mut edit, + ); let (expanded_product, readers) = uncompress_product(&mut edit, &call_node_id, &substituted); @@ -419,34 +410,26 @@ fn remove_return_singletons(editors: &mut Vec<FunctionEditor>, all_callsites_edi for call_node_id in call_node_ids { let (_, function, dc_args, _) = editor.func().nodes[call_node_id.idx()].try_call().unwrap(); - let dc_args = dc_args.clone(); + + let dc_args = dc_args.to_vec(); if singleton_removed[function.idx()] { let edit_successful = editor.edit(|mut edit| { - let mut substituted = old_return_type_ids[function.idx()]; - let first_dc = edit.num_dynamic_constants() + 10000; - let dc_params: Vec<_> = (0..dc_args.len()) + let dc_params = (0..dc_args.len()) .map(|param_idx| { edit.add_dynamic_constant(DynamicConstant::Parameter(param_idx)) }) - .collect(); - for (dc_a, dc_n) in zip(dc_params, first_dc..) { - substituted = substitute_dynamic_constants_in_type( - dc_a, - DynamicConstantID::new(dc_n), - substituted, - &mut edit, - ); - } - - for (dc_n, dc_b) in zip(first_dc.., dc_args.iter()) { - substituted = substitute_dynamic_constants_in_type( - DynamicConstantID::new(dc_n), - *dc_b, - substituted, - &mut edit, - ); - } + .collect::<Vec<_>>(); + let substs = dc_params + .into_iter() + .zip(dc_args.into_iter()) + .collect::<HashMap<_, _>>(); + + let substituted = substitute_dynamic_constants_in_type( + &substs, + old_return_type_ids[function.idx()], + &mut edit, + ); let empty_constant_id = edit.add_zero_constant(substituted); let empty_node_id = edit.add_node(Node::Constant { id: empty_constant_id, diff --git a/hercules_opt/src/lift_dc_math.rs b/hercules_opt/src/lift_dc_math.rs index afdb212064d84a0191f87ce366d67b7ea6728fa8..8256c889085a9b2902c6d4d5c8fd5a9fa2e77429 100644 --- a/hercules_opt/src/lift_dc_math.rs +++ b/hercules_opt/src/lift_dc_math.rs @@ -41,11 +41,11 @@ pub fn lift_dc_math(editor: &mut FunctionEditor) { continue; }; match op { - BinaryOperator::Add => DynamicConstant::Add(left, right), - BinaryOperator::Sub => DynamicConstant::Sub(left, right), - BinaryOperator::Mul => DynamicConstant::Mul(left, right), - BinaryOperator::Div => DynamicConstant::Div(left, right), - BinaryOperator::Rem => DynamicConstant::Rem(left, right), + BinaryOperator::Add => DynamicConstant::add(left, right), + BinaryOperator::Sub => DynamicConstant::sub(left, right), + BinaryOperator::Mul => DynamicConstant::mul(left, right), + BinaryOperator::Div => DynamicConstant::div(left, right), + BinaryOperator::Rem => DynamicConstant::rem(left, right), _ => { continue; } @@ -64,8 +64,8 @@ pub fn lift_dc_math(editor: &mut FunctionEditor) { continue; }; match intrinsic { - Intrinsic::Min => DynamicConstant::Min(left, right), - Intrinsic::Max => DynamicConstant::Max(left, right), + Intrinsic::Min => DynamicConstant::min(left, right), + Intrinsic::Max => DynamicConstant::max(left, right), _ => { continue; } diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 7ad48c1c09cc8542d1b521e3d8e12fe271ef1d39..2ab4e094a47f2ce1805b924560a51f30d12951d6 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -1,5 +1,4 @@ -use std::collections::HashMap; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::iter::zip; use hercules_ir::def_use::*; @@ -9,12 +8,11 @@ use nestify::nest; use crate::*; /* - * Substitute all uses of a dynamic constant A with dynamic constant B in a - * type. Return the substituted version of the type, once memozied. + * Substitute all uses of dynamic constants in a type that are keys in the substs map with the + * dynamic constant value for that key. Return the substituted version of the type, once memoized. */ pub(crate) fn substitute_dynamic_constants_in_type( - dc_a: DynamicConstantID, - dc_b: DynamicConstantID, + substs: &HashMap<DynamicConstantID, DynamicConstantID>, ty: TypeID, edit: &mut FunctionEdit, ) -> TypeID { @@ -24,7 +22,7 @@ pub(crate) fn substitute_dynamic_constants_in_type( Type::Product(ref fields) => { let new_fields = fields .into_iter() - .map(|field_id| substitute_dynamic_constants_in_type(dc_a, dc_b, *field_id, edit)) + .map(|field_id| substitute_dynamic_constants_in_type(substs, *field_id, edit)) .collect(); if new_fields != *fields { edit.add_type(Type::Product(new_fields)) @@ -35,9 +33,7 @@ pub(crate) fn substitute_dynamic_constants_in_type( Type::Summation(ref variants) => { let new_variants = variants .into_iter() - .map(|variant_id| { - substitute_dynamic_constants_in_type(dc_a, dc_b, *variant_id, edit) - }) + .map(|variant_id| substitute_dynamic_constants_in_type(substs, *variant_id, edit)) .collect(); if new_variants != *variants { edit.add_type(Type::Summation(new_variants)) @@ -46,10 +42,10 @@ pub(crate) fn substitute_dynamic_constants_in_type( } } Type::Array(elem_ty, ref dims) => { - let new_elem_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, elem_ty, edit); + let new_elem_ty = substitute_dynamic_constants_in_type(substs, elem_ty, edit); let new_dims = dims .into_iter() - .map(|dim_id| substitute_dynamic_constants(dc_a, dc_b, *dim_id, edit)) + .map(|dim_id| substitute_dynamic_constants(substs, *dim_id, edit)) .collect(); if new_elem_ty != elem_ty || new_dims != *dims { edit.add_type(Type::Array(new_elem_ty, new_dims)) @@ -62,107 +58,105 @@ pub(crate) fn substitute_dynamic_constants_in_type( } /* - * Substitute all uses of a dynamic constant A with dynamic constant B in a - * dynamic constant C. Return the substituted version of C, once memoized. Takes - * a mutable edit instead of an editor since this may create new dynamic - * constants, which can only be done inside an edit. + * Substitute all uses of dynamic constants in a dynamic constant dc that are keys in the + * substs map and replace them with their appropriate replacement values. Return the substituted + * version of dc, once memoized. Takes a mutable edit instead of an editor since this may create + * new dynamic constants, which can only be done inside an edit. */ pub(crate) fn substitute_dynamic_constants( - dc_a: DynamicConstantID, - dc_b: DynamicConstantID, - dc_c: DynamicConstantID, + substs: &HashMap<DynamicConstantID, DynamicConstantID>, + dc: DynamicConstantID, edit: &mut FunctionEdit, ) -> DynamicConstantID { - // If C is just A, then just replace all of C with B. - if dc_a == dc_c { - return dc_b; + // If this dynamic constant should be substituted, just return the substitution + if let Some(subst) = substs.get(&dc) { + return *subst; } - // Since we substitute non-sense dynamic constant IDs earlier, we explicitly - // check that the provided ID to replace inside of is valid. Otherwise, - // ignore. - if dc_c.idx() >= edit.num_dynamic_constants() { - return dc_c; - } - - // If C is not just A, look inside of it to possibly substitute a child DC. - let dc_clone = edit.get_dynamic_constant(dc_c).clone(); + // Look inside the dynamic constant to perform substitution in its children + let dc_clone = edit.get_dynamic_constant(dc).clone(); match dc_clone { - DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc_c, - // This is a certified Rust moment. - DynamicConstant::Add(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); - if new_left != left || new_right != right { - edit.add_dynamic_constant(DynamicConstant::Add(new_left, new_right)) + DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc, + DynamicConstant::Add(xs) => { + let new_xs = xs + .iter() + .map(|x| substitute_dynamic_constants(substs, *x, edit)) + .collect::<Vec<_>>(); + if new_xs != xs { + edit.add_dynamic_constant(DynamicConstant::Add(new_xs)) } else { - dc_c + dc } } DynamicConstant::Sub(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); + let new_left = substitute_dynamic_constants(substs, left, edit); + let new_right = substitute_dynamic_constants(substs, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Sub(new_left, new_right)) } else { - dc_c + dc } } - DynamicConstant::Mul(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); - if new_left != left || new_right != right { - edit.add_dynamic_constant(DynamicConstant::Mul(new_left, new_right)) + DynamicConstant::Mul(xs) => { + let new_xs = xs + .iter() + .map(|x| substitute_dynamic_constants(substs, *x, edit)) + .collect::<Vec<_>>(); + if new_xs != xs { + edit.add_dynamic_constant(DynamicConstant::Mul(new_xs)) } else { - dc_c + dc } } DynamicConstant::Div(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); + let new_left = substitute_dynamic_constants(substs, left, edit); + let new_right = substitute_dynamic_constants(substs, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Div(new_left, new_right)) } else { - dc_c + dc } } DynamicConstant::Rem(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); + let new_left = substitute_dynamic_constants(substs, left, edit); + let new_right = substitute_dynamic_constants(substs, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Rem(new_left, new_right)) } else { - dc_c + dc } } - DynamicConstant::Min(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); - if new_left != left || new_right != right { - edit.add_dynamic_constant(DynamicConstant::Min(new_left, new_right)) + DynamicConstant::Min(xs) => { + let new_xs = xs + .iter() + .map(|x| substitute_dynamic_constants(substs, *x, edit)) + .collect::<Vec<_>>(); + if new_xs != xs { + edit.add_dynamic_constant(DynamicConstant::Min(new_xs)) } else { - dc_c + dc } } - DynamicConstant::Max(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); - if new_left != left || new_right != right { - edit.add_dynamic_constant(DynamicConstant::Max(new_left, new_right)) + DynamicConstant::Max(xs) => { + let new_xs = xs + .iter() + .map(|x| substitute_dynamic_constants(substs, *x, edit)) + .collect::<Vec<_>>(); + if new_xs != xs { + edit.add_dynamic_constant(DynamicConstant::Max(new_xs)) } else { - dc_c + dc } } } } /* - * Substitute all uses of a dynamic constant A with dynamic constant B in a - * constant. Return the substituted version of the constant, once memozied. + * Substitute all uses of the dynamic constants specified by the subst map in a constant. Return + * the substituted version of the constant, once memozied. */ pub(crate) fn substitute_dynamic_constants_in_constant( - dc_a: DynamicConstantID, - dc_b: DynamicConstantID, + substs: &HashMap<DynamicConstantID, DynamicConstantID>, cons: ConstantID, edit: &mut FunctionEdit, ) -> ConstantID { @@ -170,12 +164,10 @@ pub(crate) fn substitute_dynamic_constants_in_constant( let cons_clone = edit.get_constant(cons).clone(); match cons_clone { Constant::Product(ty, fields) => { - let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit); + let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit); let new_fields = fields .iter() - .map(|field_id| { - substitute_dynamic_constants_in_constant(dc_a, dc_b, *field_id, edit) - }) + .map(|field_id| substitute_dynamic_constants_in_constant(substs, *field_id, edit)) .collect(); if new_ty != ty || new_fields != fields { edit.add_constant(Constant::Product(new_ty, new_fields)) @@ -184,8 +176,8 @@ pub(crate) fn substitute_dynamic_constants_in_constant( } } Constant::Summation(ty, idx, variant) => { - let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit); - let new_variant = substitute_dynamic_constants_in_constant(dc_a, dc_b, variant, edit); + let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit); + let new_variant = substitute_dynamic_constants_in_constant(substs, variant, edit); if new_ty != ty || new_variant != variant { edit.add_constant(Constant::Summation(new_ty, idx, new_variant)) } else { @@ -193,7 +185,7 @@ pub(crate) fn substitute_dynamic_constants_in_constant( } } Constant::Array(ty) => { - let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit); + let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit); if new_ty != ty { edit.add_constant(Constant::Array(new_ty)) } else { @@ -205,12 +197,10 @@ pub(crate) fn substitute_dynamic_constants_in_constant( } /* - * Substitute all uses of a dynamic constant A with dynamic constant B in a - * node. + * Substitute all uses of the dynamic constants specified by the subst map in a node. */ pub(crate) fn substitute_dynamic_constants_in_node( - dc_a: DynamicConstantID, - dc_b: DynamicConstantID, + substs: &HashMap<DynamicConstantID, DynamicConstantID>, node: &mut Node, edit: &mut FunctionEdit, ) { @@ -220,14 +210,14 @@ pub(crate) fn substitute_dynamic_constants_in_node( factors, } => { for factor in factors.into_iter() { - *factor = substitute_dynamic_constants(dc_a, dc_b, *factor, edit); + *factor = substitute_dynamic_constants(substs, *factor, edit); } } Node::Constant { id } => { - *id = substitute_dynamic_constants_in_constant(dc_a, dc_b, *id, edit); + *id = substitute_dynamic_constants_in_constant(substs, *id, edit); } Node::DynamicConstant { id } => { - *id = substitute_dynamic_constants(dc_a, dc_b, *id, edit); + *id = substitute_dynamic_constants(substs, *id, edit); } Node::Call { control: _, @@ -236,7 +226,7 @@ pub(crate) fn substitute_dynamic_constants_in_node( args: _, } => { for dc_arg in dynamic_constants.into_iter() { - *dc_arg = substitute_dynamic_constants(dc_a, dc_b, *dc_arg, edit); + *dc_arg = substitute_dynamic_constants(substs, *dc_arg, edit); } } _ => {} diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index a78330e4f08075be053593b41dba0f412687f5f1..871e304a2f8fb285cc9d8c64d4aa62ec5eef3a1d 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -69,17 +69,17 @@ pub fn dyn_const_value( match dc { DynamicConstant::Constant(v) => *v, DynamicConstant::Parameter(v) => dyn_const_params[*v], - DynamicConstant::Add(a, b) => { - dyn_const_value(a, dyn_const_values, dyn_const_params) - + dyn_const_value(b, dyn_const_values, dyn_const_params) + DynamicConstant::Add(xs) => { + xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(0, |s, v| s + v) } DynamicConstant::Sub(a, b) => { dyn_const_value(a, dyn_const_values, dyn_const_params) - dyn_const_value(b, dyn_const_values, dyn_const_params) } - DynamicConstant::Mul(a, b) => { - dyn_const_value(a, dyn_const_values, dyn_const_params) - * dyn_const_value(b, dyn_const_values, dyn_const_params) + DynamicConstant::Mul(xs) => { + xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(1, |p, v| p * v) } DynamicConstant::Div(a, b) => { dyn_const_value(a, dyn_const_values, dyn_const_params) @@ -89,14 +89,28 @@ pub fn dyn_const_value( dyn_const_value(a, dyn_const_values, dyn_const_params) % dyn_const_value(b, dyn_const_values, dyn_const_params) } - DynamicConstant::Max(a, b) => max( - dyn_const_value(a, dyn_const_values, dyn_const_params), - dyn_const_value(b, dyn_const_values, dyn_const_params), - ), - DynamicConstant::Min(a, b) => min( - dyn_const_value(a, dyn_const_values, dyn_const_params), - dyn_const_value(b, dyn_const_values, dyn_const_params), - ), + DynamicConstant::Max(xs) => { + xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(None, |m, v| { + if let Some(m) = m { + Some(max(m, v)) + } else { + Some(v) + } + }) + .unwrap() + } + DynamicConstant::Min(xs) => { + xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(None, |m, v| { + if let Some(m) = m { + Some(min(m, v)) + } else { + Some(v) + } + }) + .unwrap() + } } } diff --git a/juno_frontend/src/dynconst.rs b/juno_frontend/src/dynconst.rs index defab822d19ebb99191ad5a7d387916247dcd7df..511dfa341e4bcb8100f5f6761e068f9a803257e3 100644 --- a/juno_frontend/src/dynconst.rs +++ b/juno_frontend/src/dynconst.rs @@ -291,16 +291,16 @@ impl DynConst { .map(|(d, c)| self.build_mono(builder, d, c)) .partition(|(_, neg)| !*neg); - let pos_sum = pos - .into_iter() - .map(|(t, _)| t) - .reduce(|x, y| builder.create_dynamic_constant_add(x, y)) - .unwrap_or_else(|| builder.create_dynamic_constant_constant(0)); + let pos_sum = + builder.create_dynamic_constant_add_many(pos.into_iter().map(|(t, _)| t).collect()); - let neg_sum = neg - .into_iter() - .map(|(t, _)| t) - .reduce(|x, y| builder.create_dynamic_constant_add(x, y)); + let neg_sum = if neg.is_empty() { + None + } else { + Some( + builder.create_dynamic_constant_add_many(neg.into_iter().map(|(t, _)| t).collect()), + ) + }; match neg_sum { None => pos_sum, @@ -317,72 +317,61 @@ impl DynConst { term: &Vec<i64>, coeff: &Ratio<i64>, ) -> (DynamicConstantID, bool) { - let term_id = term + let (pos, neg): (Vec<_>, Vec<_>) = term .iter() .enumerate() .filter(|(_, p)| **p != 0) .map(|(v, p)| self.build_power(builder, v, *p)) - .collect::<Vec<_>>() - .into_iter() - .reduce(|x, y| builder.create_dynamic_constant_mul(x, y)); - - match term_id { - None => { - // This means all powers of the term are 0, so we just - // output the coefficient - if !coeff.is_integer() { - panic!("Dynamic constant is a non-integer constant") - } else { - let val: i64 = coeff.to_integer(); - ( - builder.create_dynamic_constant_constant(val.abs() as usize), - val < 0, - ) - } - } - Some(term) => { - if coeff.is_one() { - (term, false) - } else { - let numer: i64 = coeff.numer().abs(); - let denom: i64 = *coeff.denom(); // > 0 - - let with_numer = if numer == 1 { - term - } else { - let numer_id = builder.create_dynamic_constant_constant(numer as usize); - builder.create_dynamic_constant_mul(numer_id, term) - }; - let with_denom = if denom == 1 { - with_numer - } else { - let denom_id = builder.create_dynamic_constant_constant(denom as usize); - builder.create_dynamic_constant_div(with_numer, denom_id) - }; - - (with_denom, numer < 0) - } + .partition(|(_, neg)| !*neg); + let mut pos: Vec<_> = pos.into_iter().map(|(t, _)| t).collect(); + let mut neg: Vec<_> = neg.into_iter().map(|(t, _)| t).collect(); + + let numerator = { + let numer: i64 = coeff.numer().abs(); + let numer_dc = builder.create_dynamic_constant_constant(numer as usize); + pos.push(numer_dc); + builder.create_dynamic_constant_mul_many(pos) + }; + + let denominator = { + let denom: i64 = *coeff.denom(); + assert!(denom > 0); + + if neg.is_empty() && denom == 1 { + None + } else { + let denom_dc = builder.create_dynamic_constant_constant(denom as usize); + neg.push(denom_dc); + Some(builder.create_dynamic_constant_mul_many(neg)) } + }; + + if let Some(denominator) = denominator { + ( + builder.create_dynamic_constant_div(numerator, denominator), + *coeff.numer() < 0, + ) + } else { + (numerator, *coeff.numer() < 0) } } // Build's a dynamic constant that is a certain power of a specific variable - fn build_power(&self, builder: &mut Builder, v: usize, power: i64) -> DynamicConstantID { + // Returns the dynamic constant id of variable raised to the absolute value of the power and a + // boolean indicating whether the power is actually negative + fn build_power( + &self, + builder: &mut Builder, + v: usize, + power: i64, + ) -> (DynamicConstantID, bool) { assert!(power != 0); let power_pos = power.abs() as usize; let var_id = builder.create_dynamic_constant_parameter(v); - let power_id = iter::repeat(var_id) - .take(power_pos) - .map(|_| var_id) - .reduce(|x, y| builder.create_dynamic_constant_mul(x, y)) - .expect("Power is non-zero"); + let power_id = + builder.create_dynamic_constant_mul_many((0..power_pos).map(|_| var_id).collect()); - if power > 0 { - power_id - } else { - let one_id = builder.create_dynamic_constant_constant(1); - builder.create_dynamic_constant_div(one_id, power_id) - } + (power_id, power < 0) } } diff --git a/juno_samples/concat/src/concat.jn b/juno_samples/concat/src/concat.jn index 2471671e69af7e9c73e2347abfe56e2db722d1d0..70c741b6c36ed780cbc4107028850c198d943483 100644 --- a/juno_samples/concat/src/concat.jn +++ b/juno_samples/concat/src/concat.jn @@ -30,3 +30,22 @@ fn concat_entry(a : i32) -> i32 { let arr3 = concat::<i32, 3, 6>(arr1, arr2); return sum::<i32, 9>(arr3); } + +#[entry] +fn concat_switch<n: usize>(b: i32, m: i32[n]) -> i32[n + 2] { + let ex : i32[2]; + ex[0] = 0; + ex[1] = 1; + + let x = concat::<_, 2, n>(ex, m); + let y = concat::<_, n, 2>(m, ex); + + let s = 0; + + s += sum::<i32, n + 2>(x); + s += sum::<i32, 2 + n>(x); + s += sum::<i32, n + 2>(y); + s += sum::<i32, 2 + n>(y); + + return if s < b then x else y; +} diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs index db3f37fdaa6047d146ebd899cc7178e2b135d7ee..8bcd7ba54a7b61fe15c97ba74d37fc0e867aa277 100644 --- a/juno_samples/concat/src/main.rs +++ b/juno_samples/concat/src/main.rs @@ -1,6 +1,7 @@ #![feature(concat_idents)] use hercules_rt::runner; +use hercules_rt::HerculesCPURef; juno_build::juno!("concat"); @@ -10,6 +11,21 @@ fn main() { let output = r.run(7).await; println!("{}", output); assert_eq!(output, 42); + + const N: usize = 3; + let arr : Box<[i32]> = (2..=4).collect(); + let arr = HerculesCPURef::from_slice(&arr); + + let mut r = runner!(concat_switch); + let output = r.run(N as u64, 50, arr.clone()).await; + let result = output.as_slice::<i32>(); + println!("{:?}", result); + assert_eq!(result, [0, 1, 2, 3, 4]); + + let output = r.run(N as u64, 30, arr).await; + let result = output.as_slice::<i32>(); + println!("{:?}", result); + assert_eq!(result, [2, 3, 4, 0, 1]); }); }