From a42eb07651ac9062fc9cfa288e519a70e6afa82e Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Thu, 13 Feb 2025 12:08:16 -0600
Subject: [PATCH] Fix inout codegen

---
 juno_frontend/src/semant.rs | 195 +++++++++++++++++++++++-------------
 1 file changed, 128 insertions(+), 67 deletions(-)

diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs
index 86979128..b8d04035 100644
--- a/juno_frontend/src/semant.rs
+++ b/juno_frontend/src/semant.rs
@@ -723,8 +723,9 @@ fn analyze_program(
                 // We collect the list of the argument types, whether they are inout, and their
                 // unique variable number
                 let mut arg_info: Vec<(Type, bool, usize)> = vec![];
-                // We collect the list of the types and variable numbers of the inout arguments
-                let mut inouts: Vec<(Type, usize)> = vec![];
+                // We collect the list of expressions that should be returned for the inout
+                // arguments
+                let mut inouts: Vec<Expr> = vec![];
 
                 // A collection of errors we encounter processing the arguments
                 let mut errors = LinkedList::new();
@@ -747,9 +748,6 @@ fn analyze_program(
                         Ok(ty) => {
                             let var = env.uniq();
 
-                            if inout.is_some() {
-                                inouts.push((ty, var));
-                            }
                             arg_info.push((ty, inout.is_some(), var));
 
                             match process_irrefutable_pattern(
@@ -761,9 +759,13 @@ fn analyze_program(
                                 &mut stringtab,
                                 &mut env,
                                 &mut types,
+                                inout.is_some(),
                             ) {
-                                Ok(prep) => {
+                                Ok((prep, expr)) => {
                                     stmts.extend(prep);
+                                    if inout.is_some() {
+                                        inouts.push(expr.unwrap());
+                                    }
                                 }
                                 Err(mut errs) => {
                                     errors.append(&mut errs);
@@ -804,13 +806,11 @@ fn analyze_program(
                 }
 
                 // Compute the proper type accounting for the inouts (which become returns)
-                let mut inout_types = inouts.iter().map(|(t, _)| *t).collect::<Vec<_>>();
+                let mut inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>();
 
-                let inout_tuple = types.new_tuple(inout_types.clone());
+                let inout_tuple = types.new_tuple(inout_types);
                 let pure_return_type = types.new_tuple(vec![return_type, inout_tuple]);
 
-                let inout_variables = inouts.iter().map(|(_, v)| *v).collect::<Vec<_>>();
-
                 // Finally, we have a properly built environment and we can
                 // start processing the body
                 let (mut body, end_reachable) = process_stmt(
@@ -822,8 +822,7 @@ fn analyze_program(
                     &mut types,
                     false,
                     return_type,
-                    &inout_variables,
-                    &inout_types,
+                    &inouts,
                     &mut labels,
                 )?;
 
@@ -841,8 +840,7 @@ fn analyze_program(
                                         vals: vec![],
                                         typ: types.new_primitive(types::Primitive::Unit),
                                     },
-                                    &inout_variables,
-                                    &inout_types,
+                                    &inouts,
                                     &mut types,
                                 ),
                             ],
@@ -1607,8 +1605,7 @@ fn process_stmt(
     types: &mut TypeSolver,
     in_loop: bool,
     return_type: Type,
-    inout_vars: &Vec<usize>,
-    inout_types: &Vec<Type>,
+    inouts: &Vec<Expr>,
     labels: &mut StringTable,
 ) -> Result<(Stmt, bool), ErrorMessages> {
     match stmt {
@@ -1673,9 +1670,12 @@ fn process_stmt(
 
             let mut res = vec![];
             res.push(Stmt::AssignStmt { var, val });
-            res.extend(process_irrefutable_pattern(
-                pattern, false, var, typ, lexer, stringtab, env, types,
-            )?);
+            res.extend(
+                process_irrefutable_pattern(
+                    pattern, false, var, typ, lexer, stringtab, env, types, false,
+                )?
+                .0,
+            );
 
             Ok((Stmt::BlockStmt { body: res }, true))
         }
@@ -1740,9 +1740,12 @@ fn process_stmt(
 
             let mut res = vec![];
             res.push(Stmt::AssignStmt { var, val });
-            res.extend(process_irrefutable_pattern(
-                pattern, false, var, typ, lexer, stringtab, env, types,
-            )?);
+            res.extend(
+                process_irrefutable_pattern(
+                    pattern, false, var, typ, lexer, stringtab, env, types, false,
+                )?
+                .0,
+            );
 
             Ok((Stmt::BlockStmt { body: res }, true))
         }
@@ -1927,8 +1930,7 @@ fn process_stmt(
                 types,
                 in_loop,
                 return_type,
-                inout_vars,
-                inout_types,
+                inouts,
                 labels,
             );
             env.close_scope();
@@ -1945,8 +1947,7 @@ fn process_stmt(
                     types,
                     in_loop,
                     return_type,
-                    inout_vars,
-                    inout_types,
+                    inouts,
                     labels,
                 )
                 .map(|(s, b)| (Some(s), b)),
@@ -2103,8 +2104,7 @@ fn process_stmt(
                 types,
                 true,
                 return_type,
-                inout_vars,
-                inout_types,
+                inouts,
                 labels,
             )?;
 
@@ -2199,8 +2199,7 @@ fn process_stmt(
                 types,
                 true,
                 return_type,
-                inout_vars,
-                inout_types,
+                inouts,
                 labels,
             );
             env.close_scope();
@@ -2255,10 +2254,7 @@ fn process_stmt(
 
             // We return a tuple of the return value and of the inout variables
             // Statements after a return are never reachable
-            Ok((
-                generate_return(return_val, inout_vars, inout_types, types),
-                false,
-            ))
+            Ok((generate_return(return_val, inouts, types), false))
         }
         parser::Stmt::BreakStmt { span } => {
             if !in_loop {
@@ -2307,8 +2303,7 @@ fn process_stmt(
                     types,
                     in_loop,
                     return_type,
-                    inout_vars,
-                    inout_types,
+                    inouts,
                     labels,
                 ) {
                     Err(mut errs) => {
@@ -2374,8 +2369,7 @@ fn process_stmt(
                 types,
                 in_loop,
                 return_type,
-                inout_vars,
-                inout_types,
+                inouts,
                 labels,
             )?;
             Ok((
@@ -5014,24 +5008,12 @@ fn process_expr(
     }
 }
 
-fn generate_return(
-    expr: Expr,
-    vars: &Vec<usize>,
-    var_types: &Vec<Type>,
-    types: &mut TypeSolver,
-) -> Stmt {
-    let var_exprs = vars
-        .iter()
-        .zip(var_types.iter())
-        .map(|(var, typ)| Expr::Variable {
-            var: *var,
-            typ: *typ,
-        })
-        .collect::<Vec<_>>();
-
-    let inout_type = types.new_tuple(var_types.clone());
+fn generate_return(expr: Expr, inouts: &Vec<Expr>, types: &mut TypeSolver) -> Stmt {
+    let inout_types = inouts.iter().map(|e| e.get_type()).collect();
+    let inout_type = types.new_tuple(inout_types);
+
     let inout_vals = Expr::Tuple {
-        vals: var_exprs,
+        vals: inouts.clone(),
         typ: inout_type,
     };
 
@@ -5068,6 +5050,9 @@ fn convert_primitive(prim: parser::Primitive) -> types::Primitive {
 // Processes an irrefutable pattern by extracting the pieces from the given variable which has the
 // given type. Adds any variables in that pattern to the environment and returns a list of
 // statements that handle the pattern
+// If the return_expr variable is set to true, it also returns an expression that reconstructs the
+// value of the variable from the fields created by the pattern, this is used for inout variables
+// where a pattern in its binding creates variables whose final values need to be returned
 fn process_irrefutable_pattern(
     pat: parser::Pattern,
     is_const: bool,
@@ -5077,9 +5062,17 @@ fn process_irrefutable_pattern(
     stringtab: &mut StringTable,
     env: &mut Env<usize, Entity>,
     types: &mut TypeSolver,
-) -> Result<Vec<Stmt>, ErrorMessages> {
+    return_expr: bool,
+) -> Result<(Vec<Stmt>, Option<Expr>), ErrorMessages> {
     match pat {
-        Pattern::Wildcard { .. } => Ok(vec![]),
+        Pattern::Wildcard { .. } => Ok((
+            vec![],
+            if return_expr {
+                Some(Expr::Variable { var, typ })
+            } else {
+                None
+            },
+        )),
         Pattern::Variable { span, name } => {
             if name.len() != 1 {
                 return Err(singleton_error(ErrorMessage::SemanticError(
@@ -5099,10 +5092,17 @@ fn process_irrefutable_pattern(
                 },
             );
 
-            Ok(vec![Stmt::AssignStmt {
-                var: variable,
-                val: Expr::Variable { var, typ },
-            }])
+            Ok((
+                vec![Stmt::AssignStmt {
+                    var: variable,
+                    val: Expr::Variable { var, typ },
+                }],
+                if return_expr {
+                    Some(Expr::Variable { var: variable, typ })
+                } else {
+                    None
+                },
+            ))
         }
         Pattern::TuplePattern { span, pats } => {
             let Some(fields) = types.get_fields(typ) else {
@@ -5128,6 +5128,7 @@ fn process_irrefutable_pattern(
             }
 
             let mut res = vec![];
+            let mut exprs = vec![];
             let mut errors = LinkedList::new();
 
             for (idx, (pat, field)) in pats.into_iter().zip(fields.into_iter()).enumerate() {
@@ -5143,15 +5144,35 @@ fn process_irrefutable_pattern(
                 });
 
                 match process_irrefutable_pattern(
-                    pat, is_const, variable, field, lexer, stringtab, env, types,
+                    pat,
+                    is_const,
+                    variable,
+                    field,
+                    lexer,
+                    stringtab,
+                    env,
+                    types,
+                    return_expr,
                 ) {
-                    Ok(stmts) => res.extend(stmts),
+                    Ok((stmts, expr)) => {
+                        res.extend(stmts);
+                        if return_expr {
+                            exprs.push(expr.unwrap());
+                        }
+                    }
                     Err(errs) => errors.extend(errs),
                 }
             }
 
             if errors.is_empty() {
-                Ok(res)
+                Ok((
+                    res,
+                    if return_expr {
+                        Some(Expr::Tuple { vals: exprs, typ })
+                    } else {
+                        None
+                    },
+                ))
             } else {
                 Err(errors)
             }
@@ -5227,6 +5248,7 @@ fn process_irrefutable_pattern(
                         .into_iter()
                         .collect::<HashSet<_>>();
                     let mut res = vec![];
+                    let mut exprs = vec![None; unused_fields.len()];
                     let mut errors = LinkedList::new();
 
                     for (field_name, pat) in pats {
@@ -5263,10 +5285,22 @@ fn process_irrefutable_pattern(
                                         },
                                     });
                                     match process_irrefutable_pattern(
-                                        pat, is_const, variable, field_typ, lexer, stringtab, env,
+                                        pat,
+                                        is_const,
+                                        variable,
+                                        field_typ,
+                                        lexer,
+                                        stringtab,
+                                        env,
                                         types,
+                                        return_expr,
                                     ) {
-                                        Ok(stmts) => res.extend(stmts),
+                                        Ok((stmts, expr)) => {
+                                            res.extend(stmts);
+                                            if return_expr {
+                                                exprs[idx] = expr;
+                                            }
+                                        }
                                         Err(errs) => errors.extend(errs),
                                     }
                                 }
@@ -5286,12 +5320,39 @@ fn process_irrefutable_pattern(
                                     .join(", ")
                             ),
                         ));
+                    } else if return_expr {
+                        for field in unused_fields {
+                            let (idx, field_typ) = types.get_field(struct_typ, field).unwrap();
+                            let variable = env.uniq();
+                            res.push(Stmt::AssignStmt {
+                                var: variable,
+                                val: Expr::Read {
+                                    index: vec![Index::Field(idx)],
+                                    val: Box::new(Expr::Variable { var, typ }),
+                                    typ: field_typ,
+                                },
+                            });
+                            exprs[idx] = Some(Expr::Variable {
+                                var: variable,
+                                typ: field_typ,
+                            });
+                        }
                     }
 
                     if !errors.is_empty() {
                         Err(errors)
                     } else {
-                        Ok(res)
+                        Ok((
+                            res,
+                            if return_expr {
+                                Some(Expr::Tuple {
+                                    vals: exprs.into_iter().map(|v| v.unwrap()).collect(),
+                                    typ,
+                                })
+                            } else {
+                                None
+                            },
+                        ))
                     }
                 }
             }
-- 
GitLab