From bf0d479a4027ab8938179a41a8a13353da9b4488 Mon Sep 17 00:00:00 2001
From: Aaron Councilman <aaronjc4@illinois.edu>
Date: Tue, 18 Feb 2025 15:08:38 -0600
Subject: [PATCH] Front-end changes for multi-return

---
 juno_frontend/src/codegen.rs         |  57 +++--
 juno_frontend/src/labeled_builder.rs |   4 +-
 juno_frontend/src/lang.y             |  49 ++--
 juno_frontend/src/semant.rs          | 370 ++++++++++++++-------------
 juno_frontend/src/ssa.rs             |   4 +-
 juno_frontend/src/types.rs           |  63 ++---
 6 files changed, 295 insertions(+), 252 deletions(-)

diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs
index a3197041..4e6f89e8 100644
--- a/juno_frontend/src/codegen.rs
+++ b/juno_frontend/src/codegen.rs
@@ -118,15 +118,18 @@ impl CodeGenerator<'_> {
                     param_types.push(solver_inst.lower_type(&mut self.builder.builder, *ty));
                 }
 
-                let return_type =
-                    solver_inst.lower_type(&mut self.builder.builder, func.return_type);
+                let return_types =
+                    func.return_types
+                        .iter()
+                        .map(|t| solver_inst.lower_type(&mut self.builder.builder, *t))
+                        .collect::<Vec<_>>();
 
                 let (func_id, entry) = self
                     .builder
                     .create_function(
                         &name,
                         param_types,
-                        return_type,
+                        return_types,
                         func.num_dyn_consts as u32,
                         func.entry,
                     )
@@ -264,10 +267,16 @@ impl CodeGenerator<'_> {
                 // loop
                 Some(block_exit)
             }
-            Stmt::ReturnStmt { expr } => {
-                let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, cur_block);
+            Stmt::ReturnStmt { exprs } => {
+                let mut vals = vec![];
+                let mut block = cur_block;
+                for expr in exprs {
+                    let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, block);
+                    vals.push(val_ret);
+                    block = block_ret;
+                }
                 let mut return_node = self.builder.allocate_node();
-                return_node.build_return(block_ret, val_ret);
+                return_node.build_return(block, vals);
                 self.builder.add_node(return_node);
                 None
             }
@@ -482,6 +491,7 @@ impl CodeGenerator<'_> {
                 ty_args,
                 dyn_consts,
                 args,
+                num_returns, // number of non-inout returns (which are first)
                 ..
             } => {
                 // We start by lowering the type arguments to TypeIDs
@@ -541,30 +551,27 @@ impl CodeGenerator<'_> {
 
                 // Read each of the "inout values" and perform the SSA update
                 let has_inouts = !inouts.is_empty();
-                // TODO: We should omit unit returns, if we do so the + 1 below is not needed
                 for (idx, var) in inouts.into_iter().enumerate() {
-                    let index = self.builder.builder.create_field_index(idx + 1);
-                    let mut read = self.builder.allocate_node();
-                    let read_id = read.id();
-                    read.build_read(call_id, vec![index].into());
-                    self.builder.add_node(read);
+                    let index = self.builder.builder.create_field_index(num_returns + idx);
+                    let mut proj = self.builder.allocate_node();
+                    let proj_id = proj.id();
+                    proj.build_data_projection(call_id, index);
+                    self.builder.add_node(proj);
 
-                    ssa.write_variable(var, block, read_id);
+                    ssa.write_variable(var, block, proj_id);
                 }
 
-                // Read the "actual return" value and return it
-                let result = if !has_inouts {
-                    call_id
-                } else {
-                    let value_index = self.builder.builder.create_field_index(0);
-                    let mut read = self.builder.allocate_node();
-                    let read_id = read.id();
-                    read.build_read(call_id, vec![value_index].into());
-                    self.builder.add_node(read);
-                    read_id
-                };
+                (call_id, block)
+            }
+            Expr::CallExtract { call, index, .. } => {
+                let (call, block) = self.codegen_expr(call, types, ssa, cur_block);
+
+                let mut proj = self.builder.allocate_node();
+                let proj_id = proj.id();
+                proj.build_data_projection(call, index);
+                self.builder.add_node(proj);
 
-                (result, block)
+                (proj_id, block)
             }
             Expr::Intrinsic {
                 id,
diff --git a/juno_frontend/src/labeled_builder.rs b/juno_frontend/src/labeled_builder.rs
index 15bed6c2..869485e9 100644
--- a/juno_frontend/src/labeled_builder.rs
+++ b/juno_frontend/src/labeled_builder.rs
@@ -33,14 +33,14 @@ impl<'a> LabeledBuilder<'a> {
         &mut self,
         name: &str,
         param_types: Vec<TypeID>,
-        return_type: TypeID,
+        return_types: Vec<TypeID>,
         num_dynamic_constants: u32,
         entry: bool,
     ) -> Result<(FunctionID, NodeID), String> {
         let (func, entry) = self.builder.create_function(
             name,
             param_types,
-            return_type,
+            return_types,
             num_dynamic_constants,
             entry,
         )?;
diff --git a/juno_frontend/src/lang.y b/juno_frontend/src/lang.y
index be9161aa..b9efe1fa 100644
--- a/juno_frontend/src/lang.y
+++ b/juno_frontend/src/lang.y
@@ -167,17 +167,17 @@ ConstDecl -> Result<Top, ()>
 FuncDecl -> Result<Top, ()>
   : PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts
       { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?,
-                          ty_vars : $4?, args : $6?, ty : None, body : $8? }) }
+                          ty_vars : $4?, args : $6?, rets: vec![], body : $8? }) }
   | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts
       { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?),
-                          name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : None,
+                          name : span_of_tok($4)?, ty_vars : $5?, args : $7?, rets: vec![],
                           body : $9? }) }
-  | PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts
+  | PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Types Stmts
       { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?,
-                          ty_vars : $4?, args : $6?, ty : Some($9?), body : $10? }) }
-  | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts
+                          ty_vars : $4?, args : $6?, rets: $9?, body : $10? }) }
+  | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Types Stmts
       { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?),
-                          name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : Some($10?),
+                          name : span_of_tok($4)?, ty_vars : $5?, args : $7?, rets: $10?,
                           body : $11? }) }
   ;
 Arguments -> Result<Vec<(Option<Span>, VarBind)>, ()>
@@ -198,6 +198,18 @@ VarBind -> Result<VarBind, ()>
   | Pattern ':' Type { Ok(VarBind{ span : $span, pattern : $1?, typ : Some($3?) }) }
   ;
 
+LetBind -> Result<LetBind, ()>
+  : VarBind { 
+      let VarBind { span, pattern, typ } = $1?;
+      Ok(LetBind::Single { span, pattern, typ }) 
+  }
+  | PatternsCommaS ',' Pattern {
+      let mut pats = $1?;
+      pats.push($3?);
+      Ok(LetBind::Multi { span: $span, patterns: pats })
+  }
+  ;
+
 Pattern -> Result<Pattern, ()>
   : '_'                   { Ok(Pattern::Wildcard { span : $span }) }
   | IntLit                { let (span, base) = $1?;
@@ -240,9 +252,9 @@ StructPatterns -> Result<(VecDeque<(Id, Pattern)>, bool), ()>
   ;
 
 Stmt -> Result<Stmt, ()>
-  : 'let' VarBind ';'
+  : 'let' LetBind ';'
       { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : None }) }
-  | 'let' VarBind '=' Expr ';'
+  | 'let' LetBind '=' Expr ';'
       { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : Some($4?) }) }
   | 'const' VarBind ';'
       { Ok(Stmt::ConstStmt{ span : $span, var : $2?, init : None }) }
@@ -305,10 +317,8 @@ Stmt -> Result<Stmt, ()>
                           inclusive: true, step: None, body: Box::new($8?) }) }
   | 'while' NonStructExpr Stmts
       { Ok(Stmt::WhileStmt{ span : $span, cond : $2?, body : Box::new($3?) }) }
-  | 'return' ';'
-      { Ok(Stmt::ReturnStmt{ span : $span, expr : None }) }
-  | 'return' Expr ';'
-      { Ok(Stmt::ReturnStmt{ span : $span, expr : Some($2?)}) }
+  | 'return' Exprs ';'
+      { Ok(Stmt::ReturnStmt{ span : $span, vals: $2?}) }
   | 'break' ';'
       { Ok(Stmt::BreakStmt{ span : $span }) }
   | 'continue' ';'
@@ -659,6 +669,15 @@ pub struct VarBind { pub span : Span, pub pattern : Pattern, pub typ : Option<Ty
 #[derive(Debug)]
 pub struct Case { pub span : Span, pub pat : Vec<Pattern>, pub body : Stmt }
 
+// Let bindings are different from other bindings because they can be used to
+// destruct multi-return function values, and so can actually contain multiple
+// patterns
+#[derive(Debug)]
+pub enum LetBind {
+  Single { span: Span, pattern: Pattern, typ: Option<Type> },
+  Multi  { span: Span, patterns: Vec<Pattern> },
+}
+
 #[derive(Debug)]
 pub enum Top {
   Import    { span : Span, name : ImportName },
@@ -666,7 +685,7 @@ pub enum Top {
   ConstDecl { span : Span, public : bool, name : Id, ty : Option<Type>, body : Expr },
   FuncDecl  { span : Span, public : bool, attr : Option<Span>, name : Id, ty_vars : Vec<TypeVar>,
               args : Vec<(Option<Span>, VarBind)>, // option is for inout
-              ty : Option<Type>, body : Stmt },
+              rets : Vec<Type>, body : Stmt },
   ModDecl   { span : Span, public : bool, name : Id, body : Vec<Top> },
 }
 
@@ -688,7 +707,7 @@ pub enum Type {
 
 #[derive(Debug)]
 pub enum Stmt {
-  LetStmt    { span : Span, var : VarBind, init : Option<Expr> },
+  LetStmt    { span : Span, var : LetBind, init : Option<Expr> },
   ConstStmt  { span : Span, var : VarBind, init : Option<Expr> },
   AssignStmt { span : Span, lhs : LExpr, assign : AssignOp, assign_span : Span, rhs : Expr },
   IfStmt     { span : Span, cond : Expr, thn : Box<Stmt>, els : Option<Box<Stmt>> },
@@ -697,7 +716,7 @@ pub enum Stmt {
   ForStmt    { span : Span, var : VarBind, init : Expr, bound : Expr,
                inclusive: bool, step : Option<(bool, Span, IntBase)>, body : Box<Stmt> },
   WhileStmt  { span : Span, cond : Expr, body : Box<Stmt> },
-  ReturnStmt { span : Span, expr : Option<Expr> },
+  ReturnStmt { span : Span, vals : Vec<Expr> },
   BreakStmt  { span : Span },
   ContinueStmt { span : Span },
   BlockStmt    { span : Span, body : Vec<Stmt> },
diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs
index bfd4cf7f..ae696b4f 100644
--- a/juno_frontend/src/semant.rs
+++ b/juno_frontend/src/semant.rs
@@ -46,7 +46,7 @@ enum Entity {
         index: usize,
         type_args: Vec<parser::Kind>,
         args: Vec<(types::Type, bool)>,
-        return_type: types::Type,
+        return_types: Vec<types::Type>,
     },
 }
 
@@ -202,7 +202,7 @@ pub struct Function {
     pub num_dyn_consts: usize,
     pub num_type_args: usize,
     pub arguments: Vec<(usize, Type)>,
-    pub return_type: Type,
+    pub return_types: Vec<Type>,
     pub body: Stmt,
     pub entry: bool,
 }
@@ -236,7 +236,7 @@ pub enum Stmt {
         body: Box<Stmt>,
     },
     ReturnStmt {
-        expr: Expr,
+        exprs: Vec<Expr>,
     },
     BreakStmt {},
     ContinueStmt {},
@@ -328,6 +328,14 @@ pub enum Expr {
         ty_args: Vec<Type>,
         dyn_consts: Vec<DynConst>,
         args: Vec<Either<Expr, usize>>,
+        // Include the number of Juno returns (i.e. non-inouts) for codegen
+        num_returns: usize,
+        typ: Type,
+    },
+    // A projection from a call
+    CallExtract {
+        call: Box<Expr>,
+        index: usize,
         typ: Type,
     },
     Intrinsic {
@@ -415,57 +423,20 @@ fn convert_binary_op(op: parser::BinaryOp) -> BinaryOp {
 impl Expr {
     pub fn get_type(&self) -> Type {
         match self {
-            Expr::Variable { var: _, typ }
-            | Expr::DynConst { val: _, typ }
-            | Expr::Read {
-                index: _,
-                val: _,
-                typ,
-            }
-            | Expr::Write {
-                index: _,
-                val: _,
-                rep: _,
-                typ,
-            }
-            | Expr::Tuple { vals: _, typ }
-            | Expr::Union {
-                tag: _,
-                val: _,
-                typ,
-            }
-            | Expr::Constant { val: _, typ }
-            | Expr::UnaryExp {
-                op: _,
-                expr: _,
-                typ,
-            }
-            | Expr::BinaryExp {
-                op: _,
-                lhs: _,
-                rhs: _,
-                typ,
-            }
-            | Expr::CastExpr { expr: _, typ }
-            | Expr::CondExpr {
-                cond: _,
-                thn: _,
-                els: _,
-                typ,
-            }
-            | Expr::CallExpr {
-                func: _,
-                ty_args: _,
-                dyn_consts: _,
-                args: _,
-                typ,
-            }
-            | Expr::Intrinsic {
-                id: _,
-                ty_args: _,
-                args: _,
-                typ,
-            }
+            Expr::Variable { typ, .. }
+            | Expr::DynConst { typ, .. }
+            | Expr::Read { typ, .. }
+            | Expr::Write { typ, .. }
+            | Expr::Tuple { typ, .. }
+            | Expr::Union { typ, .. }
+            | Expr::Constant { typ, .. }
+            | Expr::UnaryExp { typ, .. }
+            | Expr::BinaryExp { typ, .. }
+            | Expr::CastExpr { typ, .. }
+            | Expr::CondExpr { typ, .. }
+            | Expr::CallExpr { typ, .. }
+            | Expr::CallExtract { typ, .. }
+            | Expr::Intrinsic { typ, .. }
             | Expr::Zero { typ } => *typ,
         }
     }
@@ -650,7 +621,7 @@ fn analyze_program(
                 name,
                 ty_vars,
                 args,
-                ty,
+                rets,
                 body,
             } => {
                 // TODO: Handle public
@@ -778,44 +749,36 @@ fn analyze_program(
                     }
                 }
 
-                let return_type = {
-                    // A missing return type is implicitly void
-                    let ty = ty.unwrap_or(parser::Type::PrimType {
-                        span: span,
-                        typ: parser::Primitive::Void,
-                    });
-                    match process_type(
-                        ty,
-                        num_dyn_const,
-                        lexer,
-                        &mut stringtab,
-                        &env,
-                        &mut types,
-                        true,
-                    ) {
-                        Ok(ty) => ty,
-                        Err(mut errs) => {
-                            errors.append(&mut errs);
-                            types.new_primitive(types::Primitive::Unit)
+                let return_types = rets
+                    .into_iter()
+                    .map(|ty|
+                        match process_type(
+                            ty,
+                            num_dyn_const,
+                            lexer,
+                            &mut stringtab,
+                            &env,
+                            &mut types,
+                            true,
+                        ) {
+                            Ok(ty) => ty,
+                            Err(mut errs) => {
+                                errors.append(&mut errs);
+                                // Type we return doesn't matter, error will be propagated upwards
+                                // next, but need to return something
+                                types.new_primitive(types::Primitive::Unit)
+                            }
                         }
-                    }
-                };
+                    )
+                    .collect::<Vec<_>>();
 
                 if !errors.is_empty() {
                     return Err(errors);
                 }
 
                 // Compute the proper type accounting for the inouts (which become returns)
-                let mut inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>();
-
-                let mut return_types = vec![return_type];
-                return_types.extend(inout_types);
-                // TODO: Ideally we would omit unit returns
-                let pure_return_type = if return_types.len() == 1 {
-                    return_types.pop().unwrap()
-                } else {
-                    types.new_tuple(return_types)
-                };
+                let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>();
+                let pure_return_types = return_types.clone().into_iter().chain(inout_types.into_iter()).collect::<Vec<_>>();
 
                 // Finally, we have a properly built environment and we can
                 // start processing the body
@@ -827,7 +790,7 @@ fn analyze_program(
                     &mut env,
                     &mut types,
                     false,
-                    return_type,
+                    &return_types,
                     &inouts,
                     &mut labels,
                 )?;
@@ -835,19 +798,15 @@ fn analyze_program(
                 if end_reachable {
                     // The end of a function being reachable (i.e. there is some possible path
                     // where there is no return statement) is an error unless the return type is
-                    // void
-                    if types.unify_void(return_type) {
+                    // empty
+                    if return_types.is_empty() {
                         // Insert return at the end
                         body = Stmt::BlockStmt {
                             body: vec![
                                 body,
                                 generate_return(
-                                    Expr::Tuple {
-                                        vals: vec![],
-                                        typ: types.new_primitive(types::Primitive::Unit),
-                                    },
+                                    vec![],
                                     &inouts,
-                                    &mut types,
                                 ),
                             ],
                         };
@@ -876,7 +835,7 @@ fn analyze_program(
                             .iter()
                             .map(|(ty, is, _)| (*ty, *is))
                             .collect::<Vec<_>>(),
-                        return_type: return_type,
+                        return_types,
                     },
                 );
 
@@ -889,7 +848,7 @@ fn analyze_program(
                         .iter()
                         .map(|(t, _, v)| (*v, *t))
                         .collect::<Vec<_>>(),
-                    return_type: pure_return_type,
+                    return_types: pure_return_types,
                     body: body,
                     entry: entry,
                 });
@@ -1610,21 +1569,22 @@ fn process_stmt(
     env: &mut Env<usize, Entity>,
     types: &mut TypeSolver,
     in_loop: bool,
-    return_type: Type,
+    return_types: &[Type],
     inouts: &Vec<Expr>,
     labels: &mut StringTable,
 ) -> Result<(Stmt, bool), ErrorMessages> {
     match stmt {
         parser::Stmt::LetStmt {
             span,
-            var:
-                VarBind {
-                    span: v_span,
-                    pattern,
-                    typ,
-                },
+            var,
             init,
         } => {
+            let (_, pattern, typ) =
+                match var {
+                    LetBind::Single { span, pattern, typ } => (span, Either::Left(pattern), typ),
+                    LetBind::Multi { span, patterns } => (span, Either::Right(patterns), None),
+                };
+
             if typ.is_none() && init.is_none() {
                 return Err(singleton_error(ErrorMessage::SemanticError(
                     span_to_loc(span, lexer),
@@ -1676,12 +1636,60 @@ fn process_stmt(
 
             let mut res = vec![];
             res.push(Stmt::AssignStmt { var, val });
-            res.extend(
-                process_irrefutable_pattern(
-                    pattern, false, var, typ, lexer, stringtab, env, types, false,
-                )?
-                .0,
-            );
+
+            match pattern {
+                Either::Left(pattern) => {
+                    if let Some(return_types) = types.get_return_types(typ) {
+                        return Err(singleton_error(ErrorMessage::SemanticError(
+                            span_to_loc(span, lexer),
+                            format!("Expected {} patterns, found 1 pattern", return_types.len()),
+                        )));
+                    }
+                    res.extend(
+                        process_irrefutable_pattern(
+                            pattern, false, var, typ, lexer, stringtab, env, types, false,
+                        )?
+                        .0,
+                    );
+                }
+                Either::Right(patterns) => {
+                    let Some(return_types) = types.get_return_types(typ) else {
+                        return Err(singleton_error(ErrorMessage::SemanticError(
+                            span_to_loc(span, lexer),
+                            format!("Expected 1 pattern, found {} patterns", patterns.len()),
+                        )));
+                    };
+                    if return_types.len() != patterns.len() {
+                        return Err(singleton_error(ErrorMessage::SemanticError(
+                            span_to_loc(span, lexer),
+                            format!("Expected {} pattern, found {} patterns", return_types.len(), patterns.len()),
+                        )));
+                    }
+
+                    // Process each pattern after extracting the appropriate value from the call
+                    for (index, (pat, ret_typ)) in
+                        patterns.into_iter()
+                            .zip(return_types.clone().into_iter())
+                            .enumerate() {
+                        let extract_var = env.uniq();
+                        res.push(Stmt::AssignStmt {
+                            var: extract_var,
+                            val: Expr::CallExtract {
+                                call: Box::new(Expr::Variable { var, typ }),
+                                index,
+                                typ: ret_typ,
+                            }
+                        });
+                        res.extend(
+                            process_irrefutable_pattern(
+                                pat, false, extract_var, ret_typ, lexer, stringtab, env, types, false
+                            )?
+                            .0,
+                        );
+                    }
+                }
+            }
+
 
             Ok((Stmt::BlockStmt { body: res }, true))
         }
@@ -1689,7 +1697,7 @@ fn process_stmt(
             span,
             var:
                 VarBind {
-                    span: v_span,
+                    span: _v_span,
                     pattern,
                     typ,
                 },
@@ -1935,7 +1943,7 @@ fn process_stmt(
                 env,
                 types,
                 in_loop,
-                return_type,
+                return_types,
                 inouts,
                 labels,
             );
@@ -1952,7 +1960,7 @@ fn process_stmt(
                     env,
                     types,
                     in_loop,
-                    return_type,
+                    return_types,
                     inouts,
                     labels,
                 )
@@ -2110,7 +2118,7 @@ fn process_stmt(
                 env,
                 types,
                 true,
-                return_type,
+                return_types,
                 inouts,
                 labels,
             )?;
@@ -2214,7 +2222,7 @@ fn process_stmt(
                 env,
                 types,
                 true,
-                return_type,
+                return_types,
                 inouts,
                 labels,
             );
@@ -2241,36 +2249,52 @@ fn process_stmt(
                 true,
             ))
         }
-        parser::Stmt::ReturnStmt { span, expr } => {
-            let return_val = if expr.is_none() && types.unify_void(return_type) {
-                Expr::Constant {
-                    val: (Literal::Unit, return_type),
-                    typ: return_type,
-                }
-            } else if expr.is_none() {
-                Err(singleton_error(ErrorMessage::SemanticError(
+        parser::Stmt::ReturnStmt { span, vals } => {
+            if return_types.len() != vals.len() {
+                return Err(singleton_error(ErrorMessage::SemanticError(
                     span_to_loc(span, lexer),
                     format!(
-                        "Expected return of type {} found no return value",
-                        unparse_type(types, return_type, stringtab)
-                    ),
-                )))?
-            } else {
-                let val = process_expr(expr.unwrap(), num_dyn_const, lexer, stringtab, env, types)?;
-                let typ = val.get_type();
-                if !types.unify(return_type, typ) {
-                    Err(singleton_error(ErrorMessage::TypeError(
-                        span_to_loc(span, lexer),
-                        unparse_type(types, return_type, stringtab),
-                        unparse_type(types, typ, stringtab),
-                    )))?
-                }
-                val
-            };
+                        "Expected {} return values found {}",
+                        return_types.len(),
+                        vals.len(),
+                ))));
+            }
 
-            // We return a tuple of the return value and of the inout variables
+            let return_vals = vals
+                .into_iter()
+                .zip(return_types.iter())
+                .map(|(expr, typ)| {
+                    let expr_span = expr.span();
+                    let val = process_expr(expr, num_dyn_const, lexer, stringtab, env, types)?;
+                    if types.unify(*typ, val.get_type()) {
+                        Ok(val)
+                    } else {
+                        Err(singleton_error(ErrorMessage::TypeError(
+                            span_to_loc(expr_span, lexer),
+                            unparse_type(types, *typ, stringtab),
+                            unparse_type(types, val.get_type(), stringtab),
+                        )))
+                    }
+                })
+                .fold(Ok(vec![]),
+                    |res, val| {
+                        match (res, val) {
+                            (Ok(mut res), Ok(val)) => {
+                                res.push(val);
+                                Ok(res)
+                            }
+                            (Ok(_), Err(msg)) => Err(msg),
+                            (Err(msg), Ok(_)) => Err(msg),
+                            (Err(mut msgs), Err(msg)) => {
+                                msgs.extend(msg);
+                                Err(msgs)
+                            }
+                        }
+                })?;
+
+            // We return both the actual return values and the inout arguments
             // Statements after a return are never reachable
-            Ok((generate_return(return_val, inouts, types), false))
+            Ok((generate_return(return_vals, inouts), false))
         }
         parser::Stmt::BreakStmt { span } => {
             if !in_loop {
@@ -2318,7 +2342,7 @@ fn process_stmt(
                     env,
                     types,
                     in_loop,
-                    return_type,
+                    return_types,
                     inouts,
                     labels,
                 ) {
@@ -2384,7 +2408,7 @@ fn process_stmt(
                 env,
                 types,
                 in_loop,
-                return_type,
+                return_types,
                 inouts,
                 labels,
             )?;
@@ -4723,7 +4747,7 @@ fn process_expr(
                     index: function,
                     type_args: kinds,
                     args: func_args,
-                    return_type,
+                    return_types,
                 }) => {
                     let func = *function;
 
@@ -4813,11 +4837,11 @@ fn process_expr(
                         }
                         tys
                     };
-                    let return_typ = if let Some(res) =
-                        types.instantiate(*return_type, &type_vars, &dyn_consts)
-                    {
-                        res
-                    } else {
+                    let return_types = return_types
+                        .iter()
+                        .map(|t| types.instantiate(*t, &type_vars, &dyn_consts))
+                        .collect::<Option<Vec<_>>>();
+                    let Some(return_types) = return_types else {
                         return Err(singleton_error(ErrorMessage::SemanticError(
                             span_to_loc(span, lexer),
                             "Failure in variable substitution".to_string(),
@@ -4887,13 +4911,30 @@ fn process_expr(
                     if !errors.is_empty() {
                         Err(errors)
                     } else {
-                        Ok(Expr::CallExpr {
-                            func: func,
+                        let single_type =
+                            if return_types.len() == 1 {
+                                Some(return_types[0])
+                            } else {
+                                None
+                            };
+                        let num_returns = return_types.len();
+                        let call = Expr::CallExpr {
+                            func,
                             ty_args: type_vars,
-                            dyn_consts: dyn_consts,
+                            dyn_consts,
                             args: arg_vals,
-                            typ: return_typ,
-                        })
+                            num_returns,
+                            typ: types.new_multi_return(return_types),
+                        };
+                        if let Some(return_type) = single_type {
+                            Ok(Expr::CallExtract {
+                                call: Box::new(call),
+                                index: 0,
+                                typ: return_type,
+                            })
+                        } else {
+                            Ok(call)
+                        }
                     }
                 }
             }
@@ -5024,25 +5065,9 @@ fn process_expr(
     }
 }
 
-fn generate_return(expr: Expr, inouts: &Vec<Expr>, types: &mut TypeSolver) -> Stmt {
-    let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>();
-
-    let mut return_types = vec![expr.get_type()];
-    return_types.extend(inout_types);
-
-    let mut return_vals = vec![expr];
-    return_vals.extend_from_slice(inouts);
-
-    let val = if return_vals.len() == 1 {
-        return_vals.pop().unwrap()
-    } else {
-        Expr::Tuple {
-            vals: return_vals,
-            typ: types.new_tuple(return_types),
-        }
-    };
-
-    Stmt::ReturnStmt { expr: val }
+fn generate_return(mut exprs: Vec<Expr>, inouts: &[Expr]) -> Stmt {
+    exprs.extend_from_slice(inouts);
+    Stmt::ReturnStmt { exprs: exprs }
 }
 
 fn convert_primitive(prim: parser::Primitive) -> types::Primitive {
@@ -5098,6 +5123,7 @@ fn process_irrefutable_pattern(
                     "Bound variables must be local names, without a package separator".to_string(),
                 )));
             }
+            assert!(types.get_return_types(typ).is_none());
 
             let nm = intern_package_name(&name, lexer, stringtab)[0];
             let variable = env.uniq();
diff --git a/juno_frontend/src/ssa.rs b/juno_frontend/src/ssa.rs
index 7076d622..9dbc0bfd 100644
--- a/juno_frontend/src/ssa.rs
+++ b/juno_frontend/src/ssa.rs
@@ -45,10 +45,10 @@ impl SSA {
         let right_proj = right_builder.id();
 
         // True branch
-        left_builder.build_projection(if_builder.id(), 1);
+        left_builder.build_control_projection(if_builder.id(), 1);
 
         // False branch
-        right_builder.build_projection(if_builder.id(), 0);
+        right_builder.build_control_projection(if_builder.id(), 0);
 
         builder.add_node(left_builder);
         builder.add_node(right_builder);
diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs
index edb51db5..6e59169d 100644
--- a/juno_frontend/src/types.rs
+++ b/juno_frontend/src/types.rs
@@ -204,6 +204,11 @@ enum TypeForm {
         kind: parser::Kind,
         loc: Location,
     },
+
+    // The types of call nodes are MultiReturns
+    MultiReturn {
+        types: Vec<Type>,
+    },
 }
 
 #[derive(Debug)]
@@ -279,6 +284,10 @@ impl TypeSolver {
         })
     }
 
+    pub fn new_multi_return(&mut self, types: Vec<Type>) -> Type {
+        self.create_type(TypeForm::MultiReturn { types })
+    }
+
     fn create_type(&mut self, typ: TypeForm) -> Type {
         let idx = self.types.len();
         self.types.push(typ);
@@ -543,26 +552,13 @@ impl TypeSolver {
                 }
             }
 
+            // Note that MultReturn types never unify with anything (even itself), this is
+            // intentional and makes it so that the only way MultiReturns can be used is to
+            // destruct them
+
             _ => false,
         }
     }
-    /*
-        pub fn is_tuple(&self, Type { val } : Type) -> bool {
-            match &self.types[val] {
-                TypeForm::Tuple(_) => true,
-                TypeForm::OtherType(t) => self.is_tuple(*t),
-                _ => false,
-            }
-        }
-
-        pub fn get_num_fields(&self, Type { val } : Type) -> Option<usize> {
-            match &self.types[val] {
-                TypeForm::Tuple(fields) => { Some(fields.len()) },
-                TypeForm::OtherType(t) => self.get_num_fields(*t),
-                _ => None,
-            }
-        }
-    */
 
     // Returns the types of the fields of a tuple
     pub fn get_fields(&self, Type { val }: Type) -> Option<&Vec<Type>> {
@@ -676,26 +672,13 @@ impl TypeSolver {
         }
     }
 
-    /*
-        pub fn get_constructor_list(&self, Type { val } : Type) -> Option<Vec<usize>> {
-            match &self.types[val] {
-                TypeForm::Union { name : _, id : _, constr : _, names } => {
-                    Some(names.keys().map(|i| *i).collect::<Vec<_>>())
-                },
-                TypeForm::OtherType(t) => self.get_constructor_list(*t),
-                _ => None,
-            }
-        }
-
-
-        fn is_type_var_num(&self, num : usize, Type { val } : Type) -> bool {
-            match &self.types[val] {
-                TypeForm::TypeVar { name : _, index, .. } => *index == num,
-                TypeForm::OtherType(t) => self.is_type_var_num(num, *t),
-                _ => false,
-            }
+    pub fn get_return_types(&self, Type { val }: Type) -> Option<&Vec<Type>> {
+        match &self.types[val] {
+            TypeForm::MultiReturn { types } => Some(types),
+            TypeForm::OtherType { other, .. } => self.get_return_types(*other),
+            _ => None,
         }
-    */
+    }
 
     pub fn to_string(&self, Type { val }: Type, stringtab: &dyn Fn(usize) -> String) -> String {
         match &self.types[val] {
@@ -724,6 +707,8 @@ impl TypeSolver {
             | TypeForm::Struct { name, .. }
             | TypeForm::Union { name, .. } => stringtab(*name),
             TypeForm::AnyOfKind { kind, .. } => kind.to_string(),
+            TypeForm::MultiReturn { types } =>
+                types.iter().map(|t| self.to_string(*t, stringtab)).collect::<Vec<_>>().join(", "),
         }
     }
 
@@ -825,6 +810,9 @@ impl TypeSolver {
                     Some(Type { val })
                 }
             }
+            TypeForm::MultiReturn { .. } => {
+                panic!("Multi-Return types should never be instantiated")
+            }
         }
     }
 
@@ -969,6 +957,9 @@ impl TypeSolverInst<'_> {
                 TypeForm::AnyOfKind { .. } => {
                     panic!("TypeSolverInst only works on solved types which do not have AnyOfKinds")
                 }
+                TypeForm::MultiReturn { .. } => {
+                    panic!("MultiReturn types should never be lowered")
+                }
             };
 
             match solution {
-- 
GitLab