From e67655e6c6d80c0fdee0c00fe49ba4f1aaf7b0ca Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Mon, 14 Oct 2024 20:31:30 -0500 Subject: [PATCH] Correct type-checking of calls --- hercules_ir/src/typecheck.rs | 211 ++++++++++++++++++++++++++++++++++- 1 file changed, 207 insertions(+), 4 deletions(-) diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index d2b97e53..505b6d24 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -92,13 +92,19 @@ pub fn typecheck( ref functions, ref mut types, ref constants, - ref dynamic_constants, + ref mut dynamic_constants, } = module; let mut reverse_type_map: HashMap<Type, TypeID> = types .iter() .enumerate() .map(|(idx, ty)| (ty.clone(), TypeID::new(idx))) .collect(); + let mut reverse_dynamic_constant_map: HashMap<DynamicConstant, DynamicConstantID> + = dynamic_constants + .iter() + .enumerate() + .map(|(idx, ty)| (ty.clone(), DynamicConstantID::new(idx))) + .collect(); // Step 2: run dataflow. This is an occurrence of dataflow where the flow // function performs a non-associative operation on the predecessor "out" @@ -115,6 +121,7 @@ pub fn typecheck( constants, dynamic_constants, &mut reverse_type_map, + &mut reverse_dynamic_constant_map, ) }) }) @@ -150,8 +157,9 @@ fn typeflow( functions: &Vec<Function>, types: &mut Vec<Type>, constants: &Vec<Constant>, - dynamic_constants: &Vec<DynamicConstant>, + dynamic_constants: &mut Vec<DynamicConstant>, reverse_type_map: &mut HashMap<Type, TypeID>, + reverse_dynamic_constant_map : &mut HashMap<DynamicConstant, DynamicConstantID>, ) -> TypeSemilattice { // Whenever we want to reference a specific type (for example, for the // start node), we need to get its type ID. This helper function gets the @@ -738,7 +746,7 @@ fn typeflow( // Check argument types. for (input, param_ty) in zip(inputs.iter(), callee.param_types.iter()) { if let Concrete(input_id) = input { - if input_id != param_ty { + if !types_match(types, dynamic_constants, dc_args, *param_ty, *input_id) { return Error(String::from( "Call node mismatches argument types with callee function.", )); @@ -749,7 +757,8 @@ fn typeflow( } } - Concrete(callee.return_type) + Concrete(type_subst(types, dynamic_constants, reverse_type_map, + reverse_dynamic_constant_map, dc_args, callee.return_type)) } Node::IntrinsicCall { intrinsic, args: _ } => { let num_params = match intrinsic { @@ -1104,3 +1113,197 @@ pub fn cast_compatible(src_ty: &Type, dst_ty: &Type) -> bool { // not from a floating point type to a boolean type. src_ty.is_primitive() && dst_ty.is_primitive() && !(src_ty.is_float() && dst_ty.is_bool()) } + +/* + * Determine if the given type matches the parameter type when the provided + * dynamic constants are substituted in for the dynamic constants used in the + * parameter type. + */ +fn types_match(types: &Vec<Type>, dynamic_constants: &Vec<DynamicConstant>, + dc_args : &Box<[DynamicConstantID]>, param : TypeID, input : TypeID) -> bool { + // Note that we can't just check whether the type ids are equal since them + // being equal does not mean they match when we properly substitute in the + // dynamic constant arguments + + match (&types[param.idx()], &types[input.idx()]) { + (Type::Control, Type::Control) | (Type::Boolean, Type::Boolean) + | (Type::Integer8, Type::Integer8) | (Type::Integer16, Type::Integer16) + | (Type::Integer32, Type::Integer32) | (Type::Integer64, Type::Integer64) + | (Type::UnsignedInteger8, Type::UnsignedInteger8) + | (Type::UnsignedInteger16, Type::UnsignedInteger16) + | (Type::UnsignedInteger32, Type::UnsignedInteger32) + | (Type::UnsignedInteger64, Type::UnsignedInteger64) + | (Type::Float32, Type::Float32) | (Type::Float64, Type::Float64) + => true, + (Type::Product(ps), Type::Product(is)) + | (Type::Summation(ps), Type::Summation(is)) => { + ps.len() == is.len() + && ps.iter().zip(is.iter()) + .all(|(p, i)| types_match(types, dynamic_constants, dc_args, *p, *i)) + }, + (Type::Array(p, pds), Type::Array(i, ids)) => { + types_match(types, dynamic_constants, dc_args, *p, *i) + && pds.len() == ids.len() + && pds.iter().zip(ids.iter()) + .all(|(pd, id)| dyn_consts_match(dynamic_constants, dc_args, *pd, *id)) + }, + (_, _) => false, + } +} + +/* + * Determine if the given dynamic constant matches the parameter's dynamic + * constants when the provided dynamic constants are substituted in for the + * dynamic constants used in the parameter's dynamic constant + */ +fn dyn_consts_match(dynamic_constants: &Vec<DynamicConstant>, + dc_args: &Box<[DynamicConstantID]>, param: DynamicConstantID, + input: DynamicConstantID) -> bool { + match (&dynamic_constants[param.idx()], &dynamic_constants[input.idx()]) { + (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) + => x == y, + (DynamicConstant::Parameter(i), _) + => input == dc_args[*i], + (DynamicConstant::Add(pl, pr), DynamicConstant::Add(il, ir)) + | (DynamicConstant::Sub(pl, pr), DynamicConstant::Sub(il, ir)) + | (DynamicConstant::Mul(pl, pr), DynamicConstant::Mul(il, ir)) + | (DynamicConstant::Div(pl, pr), DynamicConstant::Div(il, ir)) + | (DynamicConstant::Rem(pl, pr), DynamicConstant::Rem(il, ir)) + => dyn_consts_match(dynamic_constants, dc_args, *pl, *il) + && dyn_consts_match(dynamic_constants, dc_args, *pr, *ir), + (_, _) => false, + } +} + +/* + * Substitutes the given dynamic constant arguments into the provided type and + * returns the appropriate typeID (potentially creating new types and dynamic + * constants in the process) + */ +fn type_subst(types: &mut Vec<Type>, dynamic_constants: &mut Vec<DynamicConstant>, + reverse_type_map: &mut HashMap<Type, TypeID>, + reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, + dc_args: &Box<[DynamicConstantID]>, typ : TypeID) -> TypeID { + + fn intern_type(ty : Type, types: &mut Vec<Type>, reverse_type_map: &mut HashMap<Type, TypeID>) + -> TypeID { + if let Some(id) = reverse_type_map.get(&ty) { + *id + } else { + let id = TypeID::new(types.len()); + reverse_type_map.insert(ty.clone(), id); + types.push(ty); + id + } + } + + match &types[typ.idx()] { + Type::Control | Type::Boolean | Type::Integer8 | Type::Integer16 + | Type::Integer32 | Type::Integer64 | Type::UnsignedInteger8 + | Type::UnsignedInteger16 | Type::UnsignedInteger32 + | Type::UnsignedInteger64 | Type::Float32 | Type::Float64 + => typ, + Type::Product(ts) => { + let mut new_ts = vec![]; + for t in ts.clone().iter() { + new_ts.push(type_subst(types, dynamic_constants, reverse_type_map, + reverse_dynamic_constant_map, dc_args, *t)); + } + intern_type(Type::Product(new_ts.into()), types, reverse_type_map) + }, + Type::Summation(ts) => { + let mut new_ts = vec![]; + for t in ts.clone().iter() { + new_ts.push(type_subst(types, dynamic_constants, reverse_type_map, + reverse_dynamic_constant_map, dc_args, *t)); + } + intern_type(Type::Summation(new_ts.into()), types, reverse_type_map) + }, + Type::Array(elem, dims) => { + let ds = dims.clone(); + let new_elem = type_subst(types, dynamic_constants, reverse_type_map, + reverse_dynamic_constant_map, dc_args, *elem); + let mut new_dims = vec![]; + for d in ds.iter() { + new_dims.push(dyn_const_subst(dynamic_constants, + reverse_dynamic_constant_map, + dc_args, *d)); + } + intern_type(Type::Array(new_elem, new_dims.into()), types, reverse_type_map) + }, + } +} + +fn dyn_const_subst(dynamic_constants: &mut Vec<DynamicConstant>, + reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, + dc_args: &Box<[DynamicConstantID]>, dyn_const : DynamicConstantID) + -> DynamicConstantID { + + fn intern_dyn_const(dc: DynamicConstant, dynamic_constants: &mut Vec<DynamicConstant>, + reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>) + -> DynamicConstantID { + if let Some(id) = reverse_dynamic_constant_map.get(&dc) { + *id + } else { + let id = DynamicConstantID::new(dynamic_constants.len()); + reverse_dynamic_constant_map.insert(dc.clone(), id); + dynamic_constants.push(dc); + id + } + } + + match &dynamic_constants[dyn_const.idx()] { + DynamicConstant::Constant(_) => dyn_const, + DynamicConstant::Parameter(i) => dc_args[*i], + DynamicConstant::Add(l, r) => { + let x = *l; + let y = *r; + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, y); + intern_dyn_const(DynamicConstant::Add(sx, sy), dynamic_constants, + reverse_dynamic_constant_map) + }, + DynamicConstant::Sub(l, r) => { + let x = *l; + let y = *r; + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, y); + intern_dyn_const(DynamicConstant::Sub(sx, sy), dynamic_constants, + reverse_dynamic_constant_map) + }, + DynamicConstant::Mul(l, r) => { + let x = *l; + let y = *r; + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, y); + intern_dyn_const(DynamicConstant::Mul(sx, sy), dynamic_constants, + reverse_dynamic_constant_map) + }, + DynamicConstant::Div(l, r) => { + let x = *l; + let y = *r; + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, y); + intern_dyn_const(DynamicConstant::Div(sx, sy), dynamic_constants, + reverse_dynamic_constant_map) + }, + DynamicConstant::Rem(l, r) => { + let x = *l; + let y = *r; + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, + dc_args, y); + intern_dyn_const(DynamicConstant::Rem(sx, sy), dynamic_constants, + reverse_dynamic_constant_map) + }, + } +} -- GitLab