From a42eb07651ac9062fc9cfa288e519a70e6afa82e Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 13 Feb 2025 12:08:16 -0600 Subject: [PATCH] Fix inout codegen --- juno_frontend/src/semant.rs | 195 +++++++++++++++++++++++------------- 1 file changed, 128 insertions(+), 67 deletions(-) diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index 86979128..b8d04035 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -723,8 +723,9 @@ fn analyze_program( // We collect the list of the argument types, whether they are inout, and their // unique variable number let mut arg_info: Vec<(Type, bool, usize)> = vec![]; - // We collect the list of the types and variable numbers of the inout arguments - let mut inouts: Vec<(Type, usize)> = vec![]; + // We collect the list of expressions that should be returned for the inout + // arguments + let mut inouts: Vec<Expr> = vec![]; // A collection of errors we encounter processing the arguments let mut errors = LinkedList::new(); @@ -747,9 +748,6 @@ fn analyze_program( Ok(ty) => { let var = env.uniq(); - if inout.is_some() { - inouts.push((ty, var)); - } arg_info.push((ty, inout.is_some(), var)); match process_irrefutable_pattern( @@ -761,9 +759,13 @@ fn analyze_program( &mut stringtab, &mut env, &mut types, + inout.is_some(), ) { - Ok(prep) => { + Ok((prep, expr)) => { stmts.extend(prep); + if inout.is_some() { + inouts.push(expr.unwrap()); + } } Err(mut errs) => { errors.append(&mut errs); @@ -804,13 +806,11 @@ fn analyze_program( } // Compute the proper type accounting for the inouts (which become returns) - let mut inout_types = inouts.iter().map(|(t, _)| *t).collect::<Vec<_>>(); + let mut inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); - let inout_tuple = types.new_tuple(inout_types.clone()); + let inout_tuple = types.new_tuple(inout_types); let pure_return_type = types.new_tuple(vec![return_type, inout_tuple]); - let inout_variables = inouts.iter().map(|(_, v)| *v).collect::<Vec<_>>(); - // Finally, we have a properly built environment and we can // start processing the body let (mut body, end_reachable) = process_stmt( @@ -822,8 +822,7 @@ fn analyze_program( &mut types, false, return_type, - &inout_variables, - &inout_types, + &inouts, &mut labels, )?; @@ -841,8 +840,7 @@ fn analyze_program( vals: vec![], typ: types.new_primitive(types::Primitive::Unit), }, - &inout_variables, - &inout_types, + &inouts, &mut types, ), ], @@ -1607,8 +1605,7 @@ fn process_stmt( types: &mut TypeSolver, in_loop: bool, return_type: Type, - inout_vars: &Vec<usize>, - inout_types: &Vec<Type>, + inouts: &Vec<Expr>, labels: &mut StringTable, ) -> Result<(Stmt, bool), ErrorMessages> { match stmt { @@ -1673,9 +1670,12 @@ fn process_stmt( let mut res = vec![]; res.push(Stmt::AssignStmt { var, val }); - res.extend(process_irrefutable_pattern( - pattern, false, var, typ, lexer, stringtab, env, types, - )?); + res.extend( + process_irrefutable_pattern( + pattern, false, var, typ, lexer, stringtab, env, types, false, + )? + .0, + ); Ok((Stmt::BlockStmt { body: res }, true)) } @@ -1740,9 +1740,12 @@ fn process_stmt( let mut res = vec![]; res.push(Stmt::AssignStmt { var, val }); - res.extend(process_irrefutable_pattern( - pattern, false, var, typ, lexer, stringtab, env, types, - )?); + res.extend( + process_irrefutable_pattern( + pattern, false, var, typ, lexer, stringtab, env, types, false, + )? + .0, + ); Ok((Stmt::BlockStmt { body: res }, true)) } @@ -1927,8 +1930,7 @@ fn process_stmt( types, in_loop, return_type, - inout_vars, - inout_types, + inouts, labels, ); env.close_scope(); @@ -1945,8 +1947,7 @@ fn process_stmt( types, in_loop, return_type, - inout_vars, - inout_types, + inouts, labels, ) .map(|(s, b)| (Some(s), b)), @@ -2103,8 +2104,7 @@ fn process_stmt( types, true, return_type, - inout_vars, - inout_types, + inouts, labels, )?; @@ -2199,8 +2199,7 @@ fn process_stmt( types, true, return_type, - inout_vars, - inout_types, + inouts, labels, ); env.close_scope(); @@ -2255,10 +2254,7 @@ fn process_stmt( // We return a tuple of the return value and of the inout variables // Statements after a return are never reachable - Ok(( - generate_return(return_val, inout_vars, inout_types, types), - false, - )) + Ok((generate_return(return_val, inouts, types), false)) } parser::Stmt::BreakStmt { span } => { if !in_loop { @@ -2307,8 +2303,7 @@ fn process_stmt( types, in_loop, return_type, - inout_vars, - inout_types, + inouts, labels, ) { Err(mut errs) => { @@ -2374,8 +2369,7 @@ fn process_stmt( types, in_loop, return_type, - inout_vars, - inout_types, + inouts, labels, )?; Ok(( @@ -5014,24 +5008,12 @@ fn process_expr( } } -fn generate_return( - expr: Expr, - vars: &Vec<usize>, - var_types: &Vec<Type>, - types: &mut TypeSolver, -) -> Stmt { - let var_exprs = vars - .iter() - .zip(var_types.iter()) - .map(|(var, typ)| Expr::Variable { - var: *var, - typ: *typ, - }) - .collect::<Vec<_>>(); - - let inout_type = types.new_tuple(var_types.clone()); +fn generate_return(expr: Expr, inouts: &Vec<Expr>, types: &mut TypeSolver) -> Stmt { + let inout_types = inouts.iter().map(|e| e.get_type()).collect(); + let inout_type = types.new_tuple(inout_types); + let inout_vals = Expr::Tuple { - vals: var_exprs, + vals: inouts.clone(), typ: inout_type, }; @@ -5068,6 +5050,9 @@ fn convert_primitive(prim: parser::Primitive) -> types::Primitive { // Processes an irrefutable pattern by extracting the pieces from the given variable which has the // given type. Adds any variables in that pattern to the environment and returns a list of // statements that handle the pattern +// If the return_expr variable is set to true, it also returns an expression that reconstructs the +// value of the variable from the fields created by the pattern, this is used for inout variables +// where a pattern in its binding creates variables whose final values need to be returned fn process_irrefutable_pattern( pat: parser::Pattern, is_const: bool, @@ -5077,9 +5062,17 @@ fn process_irrefutable_pattern( stringtab: &mut StringTable, env: &mut Env<usize, Entity>, types: &mut TypeSolver, -) -> Result<Vec<Stmt>, ErrorMessages> { + return_expr: bool, +) -> Result<(Vec<Stmt>, Option<Expr>), ErrorMessages> { match pat { - Pattern::Wildcard { .. } => Ok(vec![]), + Pattern::Wildcard { .. } => Ok(( + vec![], + if return_expr { + Some(Expr::Variable { var, typ }) + } else { + None + }, + )), Pattern::Variable { span, name } => { if name.len() != 1 { return Err(singleton_error(ErrorMessage::SemanticError( @@ -5099,10 +5092,17 @@ fn process_irrefutable_pattern( }, ); - Ok(vec![Stmt::AssignStmt { - var: variable, - val: Expr::Variable { var, typ }, - }]) + Ok(( + vec![Stmt::AssignStmt { + var: variable, + val: Expr::Variable { var, typ }, + }], + if return_expr { + Some(Expr::Variable { var: variable, typ }) + } else { + None + }, + )) } Pattern::TuplePattern { span, pats } => { let Some(fields) = types.get_fields(typ) else { @@ -5128,6 +5128,7 @@ fn process_irrefutable_pattern( } let mut res = vec![]; + let mut exprs = vec![]; let mut errors = LinkedList::new(); for (idx, (pat, field)) in pats.into_iter().zip(fields.into_iter()).enumerate() { @@ -5143,15 +5144,35 @@ fn process_irrefutable_pattern( }); match process_irrefutable_pattern( - pat, is_const, variable, field, lexer, stringtab, env, types, + pat, + is_const, + variable, + field, + lexer, + stringtab, + env, + types, + return_expr, ) { - Ok(stmts) => res.extend(stmts), + Ok((stmts, expr)) => { + res.extend(stmts); + if return_expr { + exprs.push(expr.unwrap()); + } + } Err(errs) => errors.extend(errs), } } if errors.is_empty() { - Ok(res) + Ok(( + res, + if return_expr { + Some(Expr::Tuple { vals: exprs, typ }) + } else { + None + }, + )) } else { Err(errors) } @@ -5227,6 +5248,7 @@ fn process_irrefutable_pattern( .into_iter() .collect::<HashSet<_>>(); let mut res = vec![]; + let mut exprs = vec![None; unused_fields.len()]; let mut errors = LinkedList::new(); for (field_name, pat) in pats { @@ -5263,10 +5285,22 @@ fn process_irrefutable_pattern( }, }); match process_irrefutable_pattern( - pat, is_const, variable, field_typ, lexer, stringtab, env, + pat, + is_const, + variable, + field_typ, + lexer, + stringtab, + env, types, + return_expr, ) { - Ok(stmts) => res.extend(stmts), + Ok((stmts, expr)) => { + res.extend(stmts); + if return_expr { + exprs[idx] = expr; + } + } Err(errs) => errors.extend(errs), } } @@ -5286,12 +5320,39 @@ fn process_irrefutable_pattern( .join(", ") ), )); + } else if return_expr { + for field in unused_fields { + let (idx, field_typ) = types.get_field(struct_typ, field).unwrap(); + let variable = env.uniq(); + res.push(Stmt::AssignStmt { + var: variable, + val: Expr::Read { + index: vec![Index::Field(idx)], + val: Box::new(Expr::Variable { var, typ }), + typ: field_typ, + }, + }); + exprs[idx] = Some(Expr::Variable { + var: variable, + typ: field_typ, + }); + } } if !errors.is_empty() { Err(errors) } else { - Ok(res) + Ok(( + res, + if return_expr { + Some(Expr::Tuple { + vals: exprs.into_iter().map(|v| v.unwrap()).collect(), + typ, + }) + } else { + None + }, + )) } } } -- GitLab