From bf0d479a4027ab8938179a41a8a13353da9b4488 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 18 Feb 2025 15:08:38 -0600 Subject: [PATCH] Front-end changes for multi-return --- juno_frontend/src/codegen.rs | 57 +++-- juno_frontend/src/labeled_builder.rs | 4 +- juno_frontend/src/lang.y | 49 ++-- juno_frontend/src/semant.rs | 370 ++++++++++++++------------- juno_frontend/src/ssa.rs | 4 +- juno_frontend/src/types.rs | 63 ++--- 6 files changed, 295 insertions(+), 252 deletions(-) diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs index a3197041..4e6f89e8 100644 --- a/juno_frontend/src/codegen.rs +++ b/juno_frontend/src/codegen.rs @@ -118,15 +118,18 @@ impl CodeGenerator<'_> { param_types.push(solver_inst.lower_type(&mut self.builder.builder, *ty)); } - let return_type = - solver_inst.lower_type(&mut self.builder.builder, func.return_type); + let return_types = + func.return_types + .iter() + .map(|t| solver_inst.lower_type(&mut self.builder.builder, *t)) + .collect::<Vec<_>>(); let (func_id, entry) = self .builder .create_function( &name, param_types, - return_type, + return_types, func.num_dyn_consts as u32, func.entry, ) @@ -264,10 +267,16 @@ impl CodeGenerator<'_> { // loop Some(block_exit) } - Stmt::ReturnStmt { expr } => { - let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, cur_block); + Stmt::ReturnStmt { exprs } => { + let mut vals = vec![]; + let mut block = cur_block; + for expr in exprs { + let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, block); + vals.push(val_ret); + block = block_ret; + } let mut return_node = self.builder.allocate_node(); - return_node.build_return(block_ret, val_ret); + return_node.build_return(block, vals); self.builder.add_node(return_node); None } @@ -482,6 +491,7 @@ impl CodeGenerator<'_> { ty_args, dyn_consts, args, + num_returns, // number of non-inout returns (which are first) .. } => { // We start by lowering the type arguments to TypeIDs @@ -541,30 +551,27 @@ impl CodeGenerator<'_> { // Read each of the "inout values" and perform the SSA update let has_inouts = !inouts.is_empty(); - // TODO: We should omit unit returns, if we do so the + 1 below is not needed for (idx, var) in inouts.into_iter().enumerate() { - let index = self.builder.builder.create_field_index(idx + 1); - let mut read = self.builder.allocate_node(); - let read_id = read.id(); - read.build_read(call_id, vec![index].into()); - self.builder.add_node(read); + let index = self.builder.builder.create_field_index(num_returns + idx); + let mut proj = self.builder.allocate_node(); + let proj_id = proj.id(); + proj.build_data_projection(call_id, index); + self.builder.add_node(proj); - ssa.write_variable(var, block, read_id); + ssa.write_variable(var, block, proj_id); } - // Read the "actual return" value and return it - let result = if !has_inouts { - call_id - } else { - let value_index = self.builder.builder.create_field_index(0); - let mut read = self.builder.allocate_node(); - let read_id = read.id(); - read.build_read(call_id, vec![value_index].into()); - self.builder.add_node(read); - read_id - }; + (call_id, block) + } + Expr::CallExtract { call, index, .. } => { + let (call, block) = self.codegen_expr(call, types, ssa, cur_block); + + let mut proj = self.builder.allocate_node(); + let proj_id = proj.id(); + proj.build_data_projection(call, index); + self.builder.add_node(proj); - (result, block) + (proj_id, block) } Expr::Intrinsic { id, diff --git a/juno_frontend/src/labeled_builder.rs b/juno_frontend/src/labeled_builder.rs index 15bed6c2..869485e9 100644 --- a/juno_frontend/src/labeled_builder.rs +++ b/juno_frontend/src/labeled_builder.rs @@ -33,14 +33,14 @@ impl<'a> LabeledBuilder<'a> { &mut self, name: &str, param_types: Vec<TypeID>, - return_type: TypeID, + return_types: Vec<TypeID>, num_dynamic_constants: u32, entry: bool, ) -> Result<(FunctionID, NodeID), String> { let (func, entry) = self.builder.create_function( name, param_types, - return_type, + return_types, num_dynamic_constants, entry, )?; diff --git a/juno_frontend/src/lang.y b/juno_frontend/src/lang.y index be9161aa..b9efe1fa 100644 --- a/juno_frontend/src/lang.y +++ b/juno_frontend/src/lang.y @@ -167,17 +167,17 @@ ConstDecl -> Result<Top, ()> FuncDecl -> Result<Top, ()> : PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?, - ty_vars : $4?, args : $6?, ty : None, body : $8? }) } + ty_vars : $4?, args : $6?, rets: vec![], body : $8? }) } | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?), - name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : None, + name : span_of_tok($4)?, ty_vars : $5?, args : $7?, rets: vec![], body : $9? }) } - | PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts + | PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Types Stmts { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?, - ty_vars : $4?, args : $6?, ty : Some($9?), body : $10? }) } - | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts + ty_vars : $4?, args : $6?, rets: $9?, body : $10? }) } + | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Types Stmts { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?), - name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : Some($10?), + name : span_of_tok($4)?, ty_vars : $5?, args : $7?, rets: $10?, body : $11? }) } ; Arguments -> Result<Vec<(Option<Span>, VarBind)>, ()> @@ -198,6 +198,18 @@ VarBind -> Result<VarBind, ()> | Pattern ':' Type { Ok(VarBind{ span : $span, pattern : $1?, typ : Some($3?) }) } ; +LetBind -> Result<LetBind, ()> + : VarBind { + let VarBind { span, pattern, typ } = $1?; + Ok(LetBind::Single { span, pattern, typ }) + } + | PatternsCommaS ',' Pattern { + let mut pats = $1?; + pats.push($3?); + Ok(LetBind::Multi { span: $span, patterns: pats }) + } + ; + Pattern -> Result<Pattern, ()> : '_' { Ok(Pattern::Wildcard { span : $span }) } | IntLit { let (span, base) = $1?; @@ -240,9 +252,9 @@ StructPatterns -> Result<(VecDeque<(Id, Pattern)>, bool), ()> ; Stmt -> Result<Stmt, ()> - : 'let' VarBind ';' + : 'let' LetBind ';' { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : None }) } - | 'let' VarBind '=' Expr ';' + | 'let' LetBind '=' Expr ';' { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : Some($4?) }) } | 'const' VarBind ';' { Ok(Stmt::ConstStmt{ span : $span, var : $2?, init : None }) } @@ -305,10 +317,8 @@ Stmt -> Result<Stmt, ()> inclusive: true, step: None, body: Box::new($8?) }) } | 'while' NonStructExpr Stmts { Ok(Stmt::WhileStmt{ span : $span, cond : $2?, body : Box::new($3?) }) } - | 'return' ';' - { Ok(Stmt::ReturnStmt{ span : $span, expr : None }) } - | 'return' Expr ';' - { Ok(Stmt::ReturnStmt{ span : $span, expr : Some($2?)}) } + | 'return' Exprs ';' + { Ok(Stmt::ReturnStmt{ span : $span, vals: $2?}) } | 'break' ';' { Ok(Stmt::BreakStmt{ span : $span }) } | 'continue' ';' @@ -659,6 +669,15 @@ pub struct VarBind { pub span : Span, pub pattern : Pattern, pub typ : Option<Ty #[derive(Debug)] pub struct Case { pub span : Span, pub pat : Vec<Pattern>, pub body : Stmt } +// Let bindings are different from other bindings because they can be used to +// destruct multi-return function values, and so can actually contain multiple +// patterns +#[derive(Debug)] +pub enum LetBind { + Single { span: Span, pattern: Pattern, typ: Option<Type> }, + Multi { span: Span, patterns: Vec<Pattern> }, +} + #[derive(Debug)] pub enum Top { Import { span : Span, name : ImportName }, @@ -666,7 +685,7 @@ pub enum Top { ConstDecl { span : Span, public : bool, name : Id, ty : Option<Type>, body : Expr }, FuncDecl { span : Span, public : bool, attr : Option<Span>, name : Id, ty_vars : Vec<TypeVar>, args : Vec<(Option<Span>, VarBind)>, // option is for inout - ty : Option<Type>, body : Stmt }, + rets : Vec<Type>, body : Stmt }, ModDecl { span : Span, public : bool, name : Id, body : Vec<Top> }, } @@ -688,7 +707,7 @@ pub enum Type { #[derive(Debug)] pub enum Stmt { - LetStmt { span : Span, var : VarBind, init : Option<Expr> }, + LetStmt { span : Span, var : LetBind, init : Option<Expr> }, ConstStmt { span : Span, var : VarBind, init : Option<Expr> }, AssignStmt { span : Span, lhs : LExpr, assign : AssignOp, assign_span : Span, rhs : Expr }, IfStmt { span : Span, cond : Expr, thn : Box<Stmt>, els : Option<Box<Stmt>> }, @@ -697,7 +716,7 @@ pub enum Stmt { ForStmt { span : Span, var : VarBind, init : Expr, bound : Expr, inclusive: bool, step : Option<(bool, Span, IntBase)>, body : Box<Stmt> }, WhileStmt { span : Span, cond : Expr, body : Box<Stmt> }, - ReturnStmt { span : Span, expr : Option<Expr> }, + ReturnStmt { span : Span, vals : Vec<Expr> }, BreakStmt { span : Span }, ContinueStmt { span : Span }, BlockStmt { span : Span, body : Vec<Stmt> }, diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index bfd4cf7f..ae696b4f 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -46,7 +46,7 @@ enum Entity { index: usize, type_args: Vec<parser::Kind>, args: Vec<(types::Type, bool)>, - return_type: types::Type, + return_types: Vec<types::Type>, }, } @@ -202,7 +202,7 @@ pub struct Function { pub num_dyn_consts: usize, pub num_type_args: usize, pub arguments: Vec<(usize, Type)>, - pub return_type: Type, + pub return_types: Vec<Type>, pub body: Stmt, pub entry: bool, } @@ -236,7 +236,7 @@ pub enum Stmt { body: Box<Stmt>, }, ReturnStmt { - expr: Expr, + exprs: Vec<Expr>, }, BreakStmt {}, ContinueStmt {}, @@ -328,6 +328,14 @@ pub enum Expr { ty_args: Vec<Type>, dyn_consts: Vec<DynConst>, args: Vec<Either<Expr, usize>>, + // Include the number of Juno returns (i.e. non-inouts) for codegen + num_returns: usize, + typ: Type, + }, + // A projection from a call + CallExtract { + call: Box<Expr>, + index: usize, typ: Type, }, Intrinsic { @@ -415,57 +423,20 @@ fn convert_binary_op(op: parser::BinaryOp) -> BinaryOp { impl Expr { pub fn get_type(&self) -> Type { match self { - Expr::Variable { var: _, typ } - | Expr::DynConst { val: _, typ } - | Expr::Read { - index: _, - val: _, - typ, - } - | Expr::Write { - index: _, - val: _, - rep: _, - typ, - } - | Expr::Tuple { vals: _, typ } - | Expr::Union { - tag: _, - val: _, - typ, - } - | Expr::Constant { val: _, typ } - | Expr::UnaryExp { - op: _, - expr: _, - typ, - } - | Expr::BinaryExp { - op: _, - lhs: _, - rhs: _, - typ, - } - | Expr::CastExpr { expr: _, typ } - | Expr::CondExpr { - cond: _, - thn: _, - els: _, - typ, - } - | Expr::CallExpr { - func: _, - ty_args: _, - dyn_consts: _, - args: _, - typ, - } - | Expr::Intrinsic { - id: _, - ty_args: _, - args: _, - typ, - } + Expr::Variable { typ, .. } + | Expr::DynConst { typ, .. } + | Expr::Read { typ, .. } + | Expr::Write { typ, .. } + | Expr::Tuple { typ, .. } + | Expr::Union { typ, .. } + | Expr::Constant { typ, .. } + | Expr::UnaryExp { typ, .. } + | Expr::BinaryExp { typ, .. } + | Expr::CastExpr { typ, .. } + | Expr::CondExpr { typ, .. } + | Expr::CallExpr { typ, .. } + | Expr::CallExtract { typ, .. } + | Expr::Intrinsic { typ, .. } | Expr::Zero { typ } => *typ, } } @@ -650,7 +621,7 @@ fn analyze_program( name, ty_vars, args, - ty, + rets, body, } => { // TODO: Handle public @@ -778,44 +749,36 @@ fn analyze_program( } } - let return_type = { - // A missing return type is implicitly void - let ty = ty.unwrap_or(parser::Type::PrimType { - span: span, - typ: parser::Primitive::Void, - }); - match process_type( - ty, - num_dyn_const, - lexer, - &mut stringtab, - &env, - &mut types, - true, - ) { - Ok(ty) => ty, - Err(mut errs) => { - errors.append(&mut errs); - types.new_primitive(types::Primitive::Unit) + let return_types = rets + .into_iter() + .map(|ty| + match process_type( + ty, + num_dyn_const, + lexer, + &mut stringtab, + &env, + &mut types, + true, + ) { + Ok(ty) => ty, + Err(mut errs) => { + errors.append(&mut errs); + // Type we return doesn't matter, error will be propagated upwards + // next, but need to return something + types.new_primitive(types::Primitive::Unit) + } } - } - }; + ) + .collect::<Vec<_>>(); if !errors.is_empty() { return Err(errors); } // Compute the proper type accounting for the inouts (which become returns) - let mut inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); - - let mut return_types = vec![return_type]; - return_types.extend(inout_types); - // TODO: Ideally we would omit unit returns - let pure_return_type = if return_types.len() == 1 { - return_types.pop().unwrap() - } else { - types.new_tuple(return_types) - }; + let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); + let pure_return_types = return_types.clone().into_iter().chain(inout_types.into_iter()).collect::<Vec<_>>(); // Finally, we have a properly built environment and we can // start processing the body @@ -827,7 +790,7 @@ fn analyze_program( &mut env, &mut types, false, - return_type, + &return_types, &inouts, &mut labels, )?; @@ -835,19 +798,15 @@ fn analyze_program( if end_reachable { // The end of a function being reachable (i.e. there is some possible path // where there is no return statement) is an error unless the return type is - // void - if types.unify_void(return_type) { + // empty + if return_types.is_empty() { // Insert return at the end body = Stmt::BlockStmt { body: vec![ body, generate_return( - Expr::Tuple { - vals: vec![], - typ: types.new_primitive(types::Primitive::Unit), - }, + vec![], &inouts, - &mut types, ), ], }; @@ -876,7 +835,7 @@ fn analyze_program( .iter() .map(|(ty, is, _)| (*ty, *is)) .collect::<Vec<_>>(), - return_type: return_type, + return_types, }, ); @@ -889,7 +848,7 @@ fn analyze_program( .iter() .map(|(t, _, v)| (*v, *t)) .collect::<Vec<_>>(), - return_type: pure_return_type, + return_types: pure_return_types, body: body, entry: entry, }); @@ -1610,21 +1569,22 @@ fn process_stmt( env: &mut Env<usize, Entity>, types: &mut TypeSolver, in_loop: bool, - return_type: Type, + return_types: &[Type], inouts: &Vec<Expr>, labels: &mut StringTable, ) -> Result<(Stmt, bool), ErrorMessages> { match stmt { parser::Stmt::LetStmt { span, - var: - VarBind { - span: v_span, - pattern, - typ, - }, + var, init, } => { + let (_, pattern, typ) = + match var { + LetBind::Single { span, pattern, typ } => (span, Either::Left(pattern), typ), + LetBind::Multi { span, patterns } => (span, Either::Right(patterns), None), + }; + if typ.is_none() && init.is_none() { return Err(singleton_error(ErrorMessage::SemanticError( span_to_loc(span, lexer), @@ -1676,12 +1636,60 @@ fn process_stmt( let mut res = vec![]; res.push(Stmt::AssignStmt { var, val }); - res.extend( - process_irrefutable_pattern( - pattern, false, var, typ, lexer, stringtab, env, types, false, - )? - .0, - ); + + match pattern { + Either::Left(pattern) => { + if let Some(return_types) = types.get_return_types(typ) { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} patterns, found 1 pattern", return_types.len()), + ))); + } + res.extend( + process_irrefutable_pattern( + pattern, false, var, typ, lexer, stringtab, env, types, false, + )? + .0, + ); + } + Either::Right(patterns) => { + let Some(return_types) = types.get_return_types(typ) else { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected 1 pattern, found {} patterns", patterns.len()), + ))); + }; + if return_types.len() != patterns.len() { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} pattern, found {} patterns", return_types.len(), patterns.len()), + ))); + } + + // Process each pattern after extracting the appropriate value from the call + for (index, (pat, ret_typ)) in + patterns.into_iter() + .zip(return_types.clone().into_iter()) + .enumerate() { + let extract_var = env.uniq(); + res.push(Stmt::AssignStmt { + var: extract_var, + val: Expr::CallExtract { + call: Box::new(Expr::Variable { var, typ }), + index, + typ: ret_typ, + } + }); + res.extend( + process_irrefutable_pattern( + pat, false, extract_var, ret_typ, lexer, stringtab, env, types, false + )? + .0, + ); + } + } + } + Ok((Stmt::BlockStmt { body: res }, true)) } @@ -1689,7 +1697,7 @@ fn process_stmt( span, var: VarBind { - span: v_span, + span: _v_span, pattern, typ, }, @@ -1935,7 +1943,7 @@ fn process_stmt( env, types, in_loop, - return_type, + return_types, inouts, labels, ); @@ -1952,7 +1960,7 @@ fn process_stmt( env, types, in_loop, - return_type, + return_types, inouts, labels, ) @@ -2110,7 +2118,7 @@ fn process_stmt( env, types, true, - return_type, + return_types, inouts, labels, )?; @@ -2214,7 +2222,7 @@ fn process_stmt( env, types, true, - return_type, + return_types, inouts, labels, ); @@ -2241,36 +2249,52 @@ fn process_stmt( true, )) } - parser::Stmt::ReturnStmt { span, expr } => { - let return_val = if expr.is_none() && types.unify_void(return_type) { - Expr::Constant { - val: (Literal::Unit, return_type), - typ: return_type, - } - } else if expr.is_none() { - Err(singleton_error(ErrorMessage::SemanticError( + parser::Stmt::ReturnStmt { span, vals } => { + if return_types.len() != vals.len() { + return Err(singleton_error(ErrorMessage::SemanticError( span_to_loc(span, lexer), format!( - "Expected return of type {} found no return value", - unparse_type(types, return_type, stringtab) - ), - )))? - } else { - let val = process_expr(expr.unwrap(), num_dyn_const, lexer, stringtab, env, types)?; - let typ = val.get_type(); - if !types.unify(return_type, typ) { - Err(singleton_error(ErrorMessage::TypeError( - span_to_loc(span, lexer), - unparse_type(types, return_type, stringtab), - unparse_type(types, typ, stringtab), - )))? - } - val - }; + "Expected {} return values found {}", + return_types.len(), + vals.len(), + )))); + } - // We return a tuple of the return value and of the inout variables + let return_vals = vals + .into_iter() + .zip(return_types.iter()) + .map(|(expr, typ)| { + let expr_span = expr.span(); + let val = process_expr(expr, num_dyn_const, lexer, stringtab, env, types)?; + if types.unify(*typ, val.get_type()) { + Ok(val) + } else { + Err(singleton_error(ErrorMessage::TypeError( + span_to_loc(expr_span, lexer), + unparse_type(types, *typ, stringtab), + unparse_type(types, val.get_type(), stringtab), + ))) + } + }) + .fold(Ok(vec![]), + |res, val| { + match (res, val) { + (Ok(mut res), Ok(val)) => { + res.push(val); + Ok(res) + } + (Ok(_), Err(msg)) => Err(msg), + (Err(msg), Ok(_)) => Err(msg), + (Err(mut msgs), Err(msg)) => { + msgs.extend(msg); + Err(msgs) + } + } + })?; + + // We return both the actual return values and the inout arguments // Statements after a return are never reachable - Ok((generate_return(return_val, inouts, types), false)) + Ok((generate_return(return_vals, inouts), false)) } parser::Stmt::BreakStmt { span } => { if !in_loop { @@ -2318,7 +2342,7 @@ fn process_stmt( env, types, in_loop, - return_type, + return_types, inouts, labels, ) { @@ -2384,7 +2408,7 @@ fn process_stmt( env, types, in_loop, - return_type, + return_types, inouts, labels, )?; @@ -4723,7 +4747,7 @@ fn process_expr( index: function, type_args: kinds, args: func_args, - return_type, + return_types, }) => { let func = *function; @@ -4813,11 +4837,11 @@ fn process_expr( } tys }; - let return_typ = if let Some(res) = - types.instantiate(*return_type, &type_vars, &dyn_consts) - { - res - } else { + let return_types = return_types + .iter() + .map(|t| types.instantiate(*t, &type_vars, &dyn_consts)) + .collect::<Option<Vec<_>>>(); + let Some(return_types) = return_types else { return Err(singleton_error(ErrorMessage::SemanticError( span_to_loc(span, lexer), "Failure in variable substitution".to_string(), @@ -4887,13 +4911,30 @@ fn process_expr( if !errors.is_empty() { Err(errors) } else { - Ok(Expr::CallExpr { - func: func, + let single_type = + if return_types.len() == 1 { + Some(return_types[0]) + } else { + None + }; + let num_returns = return_types.len(); + let call = Expr::CallExpr { + func, ty_args: type_vars, - dyn_consts: dyn_consts, + dyn_consts, args: arg_vals, - typ: return_typ, - }) + num_returns, + typ: types.new_multi_return(return_types), + }; + if let Some(return_type) = single_type { + Ok(Expr::CallExtract { + call: Box::new(call), + index: 0, + typ: return_type, + }) + } else { + Ok(call) + } } } } @@ -5024,25 +5065,9 @@ fn process_expr( } } -fn generate_return(expr: Expr, inouts: &Vec<Expr>, types: &mut TypeSolver) -> Stmt { - let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); - - let mut return_types = vec![expr.get_type()]; - return_types.extend(inout_types); - - let mut return_vals = vec![expr]; - return_vals.extend_from_slice(inouts); - - let val = if return_vals.len() == 1 { - return_vals.pop().unwrap() - } else { - Expr::Tuple { - vals: return_vals, - typ: types.new_tuple(return_types), - } - }; - - Stmt::ReturnStmt { expr: val } +fn generate_return(mut exprs: Vec<Expr>, inouts: &[Expr]) -> Stmt { + exprs.extend_from_slice(inouts); + Stmt::ReturnStmt { exprs: exprs } } fn convert_primitive(prim: parser::Primitive) -> types::Primitive { @@ -5098,6 +5123,7 @@ fn process_irrefutable_pattern( "Bound variables must be local names, without a package separator".to_string(), ))); } + assert!(types.get_return_types(typ).is_none()); let nm = intern_package_name(&name, lexer, stringtab)[0]; let variable = env.uniq(); diff --git a/juno_frontend/src/ssa.rs b/juno_frontend/src/ssa.rs index 7076d622..9dbc0bfd 100644 --- a/juno_frontend/src/ssa.rs +++ b/juno_frontend/src/ssa.rs @@ -45,10 +45,10 @@ impl SSA { let right_proj = right_builder.id(); // True branch - left_builder.build_projection(if_builder.id(), 1); + left_builder.build_control_projection(if_builder.id(), 1); // False branch - right_builder.build_projection(if_builder.id(), 0); + right_builder.build_control_projection(if_builder.id(), 0); builder.add_node(left_builder); builder.add_node(right_builder); diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs index edb51db5..6e59169d 100644 --- a/juno_frontend/src/types.rs +++ b/juno_frontend/src/types.rs @@ -204,6 +204,11 @@ enum TypeForm { kind: parser::Kind, loc: Location, }, + + // The types of call nodes are MultiReturns + MultiReturn { + types: Vec<Type>, + }, } #[derive(Debug)] @@ -279,6 +284,10 @@ impl TypeSolver { }) } + pub fn new_multi_return(&mut self, types: Vec<Type>) -> Type { + self.create_type(TypeForm::MultiReturn { types }) + } + fn create_type(&mut self, typ: TypeForm) -> Type { let idx = self.types.len(); self.types.push(typ); @@ -543,26 +552,13 @@ impl TypeSolver { } } + // Note that MultReturn types never unify with anything (even itself), this is + // intentional and makes it so that the only way MultiReturns can be used is to + // destruct them + _ => false, } } - /* - pub fn is_tuple(&self, Type { val } : Type) -> bool { - match &self.types[val] { - TypeForm::Tuple(_) => true, - TypeForm::OtherType(t) => self.is_tuple(*t), - _ => false, - } - } - - pub fn get_num_fields(&self, Type { val } : Type) -> Option<usize> { - match &self.types[val] { - TypeForm::Tuple(fields) => { Some(fields.len()) }, - TypeForm::OtherType(t) => self.get_num_fields(*t), - _ => None, - } - } - */ // Returns the types of the fields of a tuple pub fn get_fields(&self, Type { val }: Type) -> Option<&Vec<Type>> { @@ -676,26 +672,13 @@ impl TypeSolver { } } - /* - pub fn get_constructor_list(&self, Type { val } : Type) -> Option<Vec<usize>> { - match &self.types[val] { - TypeForm::Union { name : _, id : _, constr : _, names } => { - Some(names.keys().map(|i| *i).collect::<Vec<_>>()) - }, - TypeForm::OtherType(t) => self.get_constructor_list(*t), - _ => None, - } - } - - - fn is_type_var_num(&self, num : usize, Type { val } : Type) -> bool { - match &self.types[val] { - TypeForm::TypeVar { name : _, index, .. } => *index == num, - TypeForm::OtherType(t) => self.is_type_var_num(num, *t), - _ => false, - } + pub fn get_return_types(&self, Type { val }: Type) -> Option<&Vec<Type>> { + match &self.types[val] { + TypeForm::MultiReturn { types } => Some(types), + TypeForm::OtherType { other, .. } => self.get_return_types(*other), + _ => None, } - */ + } pub fn to_string(&self, Type { val }: Type, stringtab: &dyn Fn(usize) -> String) -> String { match &self.types[val] { @@ -724,6 +707,8 @@ impl TypeSolver { | TypeForm::Struct { name, .. } | TypeForm::Union { name, .. } => stringtab(*name), TypeForm::AnyOfKind { kind, .. } => kind.to_string(), + TypeForm::MultiReturn { types } => + types.iter().map(|t| self.to_string(*t, stringtab)).collect::<Vec<_>>().join(", "), } } @@ -825,6 +810,9 @@ impl TypeSolver { Some(Type { val }) } } + TypeForm::MultiReturn { .. } => { + panic!("Multi-Return types should never be instantiated") + } } } @@ -969,6 +957,9 @@ impl TypeSolverInst<'_> { TypeForm::AnyOfKind { .. } => { panic!("TypeSolverInst only works on solved types which do not have AnyOfKinds") } + TypeForm::MultiReturn { .. } => { + panic!("MultiReturn types should never be lowered") + } }; match solution { -- GitLab