From e67655e6c6d80c0fdee0c00fe49ba4f1aaf7b0ca Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Mon, 14 Oct 2024 20:31:30 -0500
Subject: [PATCH] Correct type-checking of calls

---
 hercules_ir/src/typecheck.rs | 211 ++++++++++++++++++++++++++++++++++-
 1 file changed, 207 insertions(+), 4 deletions(-)

diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs
index d2b97e53..505b6d24 100644
--- a/hercules_ir/src/typecheck.rs
+++ b/hercules_ir/src/typecheck.rs
@@ -92,13 +92,19 @@ pub fn typecheck(
         ref functions,
         ref mut types,
         ref constants,
-        ref dynamic_constants,
+        ref mut dynamic_constants,
     } = module;
     let mut reverse_type_map: HashMap<Type, TypeID> = types
         .iter()
         .enumerate()
         .map(|(idx, ty)| (ty.clone(), TypeID::new(idx)))
         .collect();
+    let mut reverse_dynamic_constant_map: HashMap<DynamicConstant, DynamicConstantID>
+        = dynamic_constants
+          .iter()
+          .enumerate()
+          .map(|(idx, ty)| (ty.clone(), DynamicConstantID::new(idx)))
+          .collect();
 
     // Step 2: run dataflow. This is an occurrence of dataflow where the flow
     // function performs a non-associative operation on the predecessor "out"
@@ -115,6 +121,7 @@ pub fn typecheck(
                     constants,
                     dynamic_constants,
                     &mut reverse_type_map,
+                    &mut reverse_dynamic_constant_map,
                 )
             })
         })
@@ -150,8 +157,9 @@ fn typeflow(
     functions: &Vec<Function>,
     types: &mut Vec<Type>,
     constants: &Vec<Constant>,
-    dynamic_constants: &Vec<DynamicConstant>,
+    dynamic_constants: &mut Vec<DynamicConstant>,
     reverse_type_map: &mut HashMap<Type, TypeID>,
+    reverse_dynamic_constant_map : &mut HashMap<DynamicConstant, DynamicConstantID>,
 ) -> TypeSemilattice {
     // Whenever we want to reference a specific type (for example, for the
     // start node), we need to get its type ID. This helper function gets the
@@ -738,7 +746,7 @@ fn typeflow(
             // Check argument types.
             for (input, param_ty) in zip(inputs.iter(), callee.param_types.iter()) {
                 if let Concrete(input_id) = input {
-                    if input_id != param_ty {
+                    if !types_match(types, dynamic_constants, dc_args, *param_ty, *input_id) {
                         return Error(String::from(
                             "Call node mismatches argument types with callee function.",
                         ));
@@ -749,7 +757,8 @@ fn typeflow(
                 }
             }
 
-            Concrete(callee.return_type)
+            Concrete(type_subst(types, dynamic_constants, reverse_type_map,
+                                reverse_dynamic_constant_map, dc_args, callee.return_type))
         }
         Node::IntrinsicCall { intrinsic, args: _ } => {
             let num_params = match intrinsic {
@@ -1104,3 +1113,197 @@ pub fn cast_compatible(src_ty: &Type, dst_ty: &Type) -> bool {
     // not from a floating point type to a boolean type.
     src_ty.is_primitive() && dst_ty.is_primitive() && !(src_ty.is_float() && dst_ty.is_bool())
 }
+
+/*
+ * Determine if the given type matches the parameter type when the provided
+ * dynamic constants are substituted in for the dynamic constants used in the
+ * parameter type.
+ */
+fn types_match(types: &Vec<Type>, dynamic_constants: &Vec<DynamicConstant>,
+                   dc_args : &Box<[DynamicConstantID]>, param : TypeID, input : TypeID) -> bool {
+    // Note that we can't just check whether the type ids are equal since them
+    // being equal does not mean they match when we properly substitute in the
+    // dynamic constant arguments
+
+    match (&types[param.idx()], &types[input.idx()]) {
+        (Type::Control, Type::Control) | (Type::Boolean, Type::Boolean)
+        | (Type::Integer8, Type::Integer8) | (Type::Integer16, Type::Integer16)
+        | (Type::Integer32, Type::Integer32) | (Type::Integer64, Type::Integer64)
+        | (Type::UnsignedInteger8, Type::UnsignedInteger8)
+        | (Type::UnsignedInteger16, Type::UnsignedInteger16)
+        | (Type::UnsignedInteger32, Type::UnsignedInteger32)
+        | (Type::UnsignedInteger64, Type::UnsignedInteger64)
+        | (Type::Float32, Type::Float32) | (Type::Float64, Type::Float64)
+            => true,
+        (Type::Product(ps), Type::Product(is))
+        | (Type::Summation(ps), Type::Summation(is)) => {
+            ps.len() == is.len()
+            && ps.iter().zip(is.iter())
+                 .all(|(p, i)| types_match(types, dynamic_constants, dc_args, *p, *i))
+        },
+        (Type::Array(p, pds), Type::Array(i, ids)) => {
+            types_match(types, dynamic_constants, dc_args, *p, *i)
+            && pds.len() == ids.len()
+            && pds.iter().zip(ids.iter())
+                  .all(|(pd, id)| dyn_consts_match(dynamic_constants, dc_args, *pd, *id))
+        },
+        (_, _) => false,
+    }
+}
+
+/*
+ * Determine if the given dynamic constant matches the parameter's dynamic
+ * constants when the provided dynamic constants are substituted in for the
+ * dynamic constants used in the parameter's dynamic constant
+ */
+fn dyn_consts_match(dynamic_constants: &Vec<DynamicConstant>,
+                    dc_args: &Box<[DynamicConstantID]>, param: DynamicConstantID,
+                    input: DynamicConstantID) -> bool {
+    match (&dynamic_constants[param.idx()], &dynamic_constants[input.idx()]) {
+        (DynamicConstant::Constant(x), DynamicConstant::Constant(y))
+            => x == y,
+        (DynamicConstant::Parameter(i), _)
+            => input == dc_args[*i],
+        (DynamicConstant::Add(pl, pr), DynamicConstant::Add(il, ir))
+        | (DynamicConstant::Sub(pl, pr), DynamicConstant::Sub(il, ir))
+        | (DynamicConstant::Mul(pl, pr), DynamicConstant::Mul(il, ir))
+        | (DynamicConstant::Div(pl, pr), DynamicConstant::Div(il, ir))
+        | (DynamicConstant::Rem(pl, pr), DynamicConstant::Rem(il, ir))
+            => dyn_consts_match(dynamic_constants, dc_args, *pl, *il)
+            && dyn_consts_match(dynamic_constants, dc_args, *pr, *ir),
+        (_, _) => false,
+    }
+}
+
+/*
+ * Substitutes the given dynamic constant arguments into the provided type and
+ * returns the appropriate typeID (potentially creating new types and dynamic
+ * constants in the process)
+ */
+fn type_subst(types: &mut Vec<Type>, dynamic_constants: &mut Vec<DynamicConstant>,
+              reverse_type_map: &mut HashMap<Type, TypeID>,
+              reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>,
+              dc_args: &Box<[DynamicConstantID]>, typ : TypeID) -> TypeID {
+
+    fn intern_type(ty : Type, types: &mut Vec<Type>, reverse_type_map: &mut HashMap<Type, TypeID>)
+        -> TypeID {
+        if let Some(id) = reverse_type_map.get(&ty) {
+            *id
+        } else {
+            let id = TypeID::new(types.len());
+            reverse_type_map.insert(ty.clone(), id);
+            types.push(ty);
+            id
+        }
+    }
+
+    match &types[typ.idx()] {
+        Type::Control | Type::Boolean | Type::Integer8 | Type::Integer16
+            | Type::Integer32 | Type::Integer64 | Type::UnsignedInteger8
+            | Type::UnsignedInteger16 | Type::UnsignedInteger32
+            | Type::UnsignedInteger64 | Type::Float32 | Type::Float64
+        => typ,
+        Type::Product(ts) => {
+            let mut new_ts = vec![];
+            for t in ts.clone().iter() {
+                new_ts.push(type_subst(types, dynamic_constants, reverse_type_map,
+                                       reverse_dynamic_constant_map, dc_args, *t));
+            }
+            intern_type(Type::Product(new_ts.into()), types, reverse_type_map)
+        },
+        Type::Summation(ts) => {
+            let mut new_ts = vec![];
+            for t in ts.clone().iter() {
+                new_ts.push(type_subst(types, dynamic_constants, reverse_type_map,
+                                       reverse_dynamic_constant_map, dc_args, *t));
+            }
+            intern_type(Type::Summation(new_ts.into()), types, reverse_type_map)
+        },
+        Type::Array(elem, dims) => {
+            let ds = dims.clone();
+            let new_elem = type_subst(types, dynamic_constants, reverse_type_map,
+                                      reverse_dynamic_constant_map, dc_args, *elem);
+            let mut new_dims = vec![];
+            for d in ds.iter() {
+                new_dims.push(dyn_const_subst(dynamic_constants,
+                                              reverse_dynamic_constant_map,
+                                              dc_args, *d));
+            }
+            intern_type(Type::Array(new_elem, new_dims.into()), types, reverse_type_map)
+        },
+    }
+}
+
+fn dyn_const_subst(dynamic_constants: &mut Vec<DynamicConstant>,
+                   reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>,
+                   dc_args: &Box<[DynamicConstantID]>, dyn_const : DynamicConstantID)
+    -> DynamicConstantID {
+
+    fn intern_dyn_const(dc: DynamicConstant, dynamic_constants: &mut Vec<DynamicConstant>,
+                reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>)
+        -> DynamicConstantID {
+        if let Some(id) = reverse_dynamic_constant_map.get(&dc) {
+            *id
+        } else {
+            let id = DynamicConstantID::new(dynamic_constants.len());
+            reverse_dynamic_constant_map.insert(dc.clone(), id);
+            dynamic_constants.push(dc);
+            id
+        }
+    }
+
+    match &dynamic_constants[dyn_const.idx()] {
+        DynamicConstant::Constant(_) => dyn_const,
+        DynamicConstant::Parameter(i) => dc_args[*i],
+        DynamicConstant::Add(l, r) => {
+            let x = *l;
+            let y = *r;
+            let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, x);
+            let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, y);
+            intern_dyn_const(DynamicConstant::Add(sx, sy), dynamic_constants,
+                             reverse_dynamic_constant_map)
+        },
+        DynamicConstant::Sub(l, r) => {
+            let x = *l;
+            let y = *r;
+            let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, x);
+            let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, y);
+            intern_dyn_const(DynamicConstant::Sub(sx, sy), dynamic_constants,
+                             reverse_dynamic_constant_map)
+        },
+        DynamicConstant::Mul(l, r) => {
+            let x = *l;
+            let y = *r;
+            let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, x);
+            let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, y);
+            intern_dyn_const(DynamicConstant::Mul(sx, sy), dynamic_constants,
+                             reverse_dynamic_constant_map)
+        },
+        DynamicConstant::Div(l, r) => {
+            let x = *l;
+            let y = *r;
+            let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, x);
+            let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, y);
+            intern_dyn_const(DynamicConstant::Div(sx, sy), dynamic_constants,
+                             reverse_dynamic_constant_map)
+        },
+        DynamicConstant::Rem(l, r) => {
+            let x = *l;
+            let y = *r;
+            let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, x);
+            let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map,
+                                     dc_args, y);
+            intern_dyn_const(DynamicConstant::Rem(sx, sy), dynamic_constants,
+                             reverse_dynamic_constant_map)
+        },
+    }
+}
-- 
GitLab