diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index c4e71f8b95fc24d07eb9cf9cbdf7377d8ea5bc39..06a53fdb373dcedf69673b443b6c1c99d3e6cab7 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -155,6 +155,8 @@ pub fn collection_objects( typing: &ModuleTyping, callgraph: &CallGraph, ) -> CollectionObjects { + panic!("Collections analysis needs to be updated to handle multi-return"); + // Analyze functions in reverse topological order, since the analysis of a // function depends on all functions it calls. let mut collection_objects: CollectionObjects = BTreeMap::new(); diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index 9953134577ffbc3867c6f60224d7b20b62aa1e9e..e9ba4576d06d8ba4c140f0b27328b57c9b933896 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -230,10 +230,7 @@ pub fn get_uses(node: &Node) -> NodeUses { control, selection: _, } => NodeUses::One([*control]), - Node::DataProjection { - data, - selection: _, - } => NodeUses::One([*data]), + Node::DataProjection { data, selection: _ } => NodeUses::One([*data]), Node::Undef { ty: _ } => NodeUses::One([NodeID::new(0)]), } } @@ -340,10 +337,7 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { control, selection: _, } => NodeUsesMut::One([control]), - Node::DataProjection { - data, - selection: _, - } => NodeUsesMut::One([data]), + Node::DataProjection { data, selection: _ } => NodeUsesMut::One([data]), Node::Undef { ty: _ } => NodeUsesMut::Zero, } } diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index a7f890f87bb04b526f2c023b110c2d815d2d747f..aff1f9c52eb8735db828620a15e27a6395b20f9f 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -353,10 +353,7 @@ fn write_node<W: Write>( control: _, selection, } - | Node::DataProjection { - data: _, - selection, - } => write!(&mut suffix, "{}", selection)?, + | Node::DataProjection { data: _, selection } => write!(&mut suffix, "{}", selection)?, _ => {} }; diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 42730f772823557a5b722a6048449679f22eefc6..b41b1f6fca55e05262135e9b159029f103f41078 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -509,9 +509,13 @@ fn parse_return<'a>( nom::character::complete::char(','), nom::character::complete::multispace0, )), - parse_identifier)(ir_text)?; + parse_identifier, + )(ir_text)?; let control = context.borrow_mut().get_node_id(control); - let data = data.into_iter().map(|d| context.borrow_mut().get_node_id(d)).collect(); + let data = data + .into_iter() + .map(|d| context.borrow_mut().get_node_id(d)) + .collect(); Ok((ir_text, Node::Return { control, data })) } @@ -1027,8 +1031,7 @@ fn parse_constant<'a>( ) -> nom::IResult<&'a str, Constant> { let (ir_text, constant) = match ty { // There are not control constants. - Type::Control - | Type::MultiReturn(_) => Err(nom::Err::Error(nom::error::Error { + Type::Control | Type::MultiReturn(_) => Err(nom::Err::Error(nom::error::Error { input: ir_text, code: nom::error::ErrorKind::IsNot, }))?, diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 919da6408f8ac72741353e7a36a57319ccc1866e..2a3f9fb1aa86dd092d048bf79b51aa118d2489b8 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -441,7 +441,11 @@ fn typeflow( return inputs[0].clone(); } - for (idx, (input, return_type)) in inputs[1..].iter().zip(function.return_types.iter()).enumerate() { + for (idx, (input, return_type)) in inputs[1..] + .iter() + .zip(function.return_types.iter()) + .enumerate() + { if let Concrete(id) = input { if *id != *return_type { return Error(format!("Return node's data input at index {} does not match function's return type.", idx)); @@ -1070,10 +1074,7 @@ fn typeflow( // Type is the type of the _if node inputs[0].clone() } - Node::DataProjection { - data: _, - selection, - } => { + Node::DataProjection { data: _, selection } => { if let Concrete(type_id) = inputs[0] { match &types[type_id.idx()] { Type::MultiReturn(types) => { diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs index 533fd268399722b50216c6a3c21c6ecf0603d528..0902bc61d429794dd67ea737f20553ae3185ba21 100644 --- a/juno_frontend/src/codegen.rs +++ b/juno_frontend/src/codegen.rs @@ -118,11 +118,11 @@ impl CodeGenerator<'_> { param_types.push(solver_inst.lower_type(&mut self.builder.builder, *ty)); } - let return_types = - func.return_types - .iter() - .map(|t| solver_inst.lower_type(&mut self.builder.builder, *t)) - .collect::<Vec<_>>(); + let return_types = func + .return_types + .iter() + .map(|t| solver_inst.lower_type(&mut self.builder.builder, *t)) + .collect::<Vec<_>>(); let (func_id, entry) = self .builder diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index ae696b4f9be5e9a6bec7d380534a2de11682154a..f0736d2b63296b8006e874ce629aa08ac6ec3de6 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -751,7 +751,7 @@ fn analyze_program( let return_types = rets .into_iter() - .map(|ty| + .map(|ty| { match process_type( ty, num_dyn_const, @@ -769,7 +769,7 @@ fn analyze_program( types.new_primitive(types::Primitive::Unit) } } - ) + }) .collect::<Vec<_>>(); if !errors.is_empty() { @@ -778,7 +778,11 @@ fn analyze_program( // Compute the proper type accounting for the inouts (which become returns) let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); - let pure_return_types = return_types.clone().into_iter().chain(inout_types.into_iter()).collect::<Vec<_>>(); + let pure_return_types = return_types + .clone() + .into_iter() + .chain(inout_types.into_iter()) + .collect::<Vec<_>>(); // Finally, we have a properly built environment and we can // start processing the body @@ -802,13 +806,7 @@ fn analyze_program( if return_types.is_empty() { // Insert return at the end body = Stmt::BlockStmt { - body: vec![ - body, - generate_return( - vec![], - &inouts, - ), - ], + body: vec![body, generate_return(vec![], &inouts)], }; } else { Err(singleton_error(ErrorMessage::SemanticError( @@ -1574,16 +1572,11 @@ fn process_stmt( labels: &mut StringTable, ) -> Result<(Stmt, bool), ErrorMessages> { match stmt { - parser::Stmt::LetStmt { - span, - var, - init, - } => { - let (_, pattern, typ) = - match var { - LetBind::Single { span, pattern, typ } => (span, Either::Left(pattern), typ), - LetBind::Multi { span, patterns } => (span, Either::Right(patterns), None), - }; + parser::Stmt::LetStmt { span, var, init } => { + let (_, pattern, typ) = match var { + LetBind::Single { span, pattern, typ } => (span, Either::Left(pattern), typ), + LetBind::Multi { span, patterns } => (span, Either::Right(patterns), None), + }; if typ.is_none() && init.is_none() { return Err(singleton_error(ErrorMessage::SemanticError( @@ -1662,15 +1655,20 @@ fn process_stmt( if return_types.len() != patterns.len() { return Err(singleton_error(ErrorMessage::SemanticError( span_to_loc(span, lexer), - format!("Expected {} pattern, found {} patterns", return_types.len(), patterns.len()), + format!( + "Expected {} pattern, found {} patterns", + return_types.len(), + patterns.len() + ), ))); } // Process each pattern after extracting the appropriate value from the call - for (index, (pat, ret_typ)) in - patterns.into_iter() - .zip(return_types.clone().into_iter()) - .enumerate() { + for (index, (pat, ret_typ)) in patterns + .into_iter() + .zip(return_types.clone().into_iter()) + .enumerate() + { let extract_var = env.uniq(); res.push(Stmt::AssignStmt { var: extract_var, @@ -1678,11 +1676,19 @@ fn process_stmt( call: Box::new(Expr::Variable { var, typ }), index, typ: ret_typ, - } + }, }); res.extend( process_irrefutable_pattern( - pat, false, extract_var, ret_typ, lexer, stringtab, env, types, false + pat, + false, + extract_var, + ret_typ, + lexer, + stringtab, + env, + types, + false, )? .0, ); @@ -1690,7 +1696,6 @@ fn process_stmt( } } - Ok((Stmt::BlockStmt { body: res }, true)) } parser::Stmt::ConstStmt { @@ -2257,7 +2262,8 @@ fn process_stmt( "Expected {} return values found {}", return_types.len(), vals.len(), - )))); + ), + ))); } let return_vals = vals @@ -2276,20 +2282,17 @@ fn process_stmt( ))) } }) - .fold(Ok(vec![]), - |res, val| { - match (res, val) { - (Ok(mut res), Ok(val)) => { - res.push(val); - Ok(res) - } - (Ok(_), Err(msg)) => Err(msg), - (Err(msg), Ok(_)) => Err(msg), - (Err(mut msgs), Err(msg)) => { - msgs.extend(msg); - Err(msgs) - } - } + .fold(Ok(vec![]), |res, val| match (res, val) { + (Ok(mut res), Ok(val)) => { + res.push(val); + Ok(res) + } + (Ok(_), Err(msg)) => Err(msg), + (Err(msg), Ok(_)) => Err(msg), + (Err(mut msgs), Err(msg)) => { + msgs.extend(msg); + Err(msgs) + } })?; // We return both the actual return values and the inout arguments @@ -4911,12 +4914,11 @@ fn process_expr( if !errors.is_empty() { Err(errors) } else { - let single_type = - if return_types.len() == 1 { - Some(return_types[0]) - } else { - None - }; + let single_type = if return_types.len() == 1 { + Some(return_types[0]) + } else { + None + }; let num_returns = return_types.len(); let call = Expr::CallExpr { func, diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs index 6e59169d57dcb44aa2fcead5236dd4b0de3d6453..4099c56704846454fa0fcea2a9cdc6dbc9ac1531 100644 --- a/juno_frontend/src/types.rs +++ b/juno_frontend/src/types.rs @@ -555,7 +555,6 @@ impl TypeSolver { // Note that MultReturn types never unify with anything (even itself), this is // intentional and makes it so that the only way MultiReturns can be used is to // destruct them - _ => false, } } @@ -707,8 +706,11 @@ impl TypeSolver { | TypeForm::Struct { name, .. } | TypeForm::Union { name, .. } => stringtab(*name), TypeForm::AnyOfKind { kind, .. } => kind.to_string(), - TypeForm::MultiReturn { types } => - types.iter().map(|t| self.to_string(*t, stringtab)).collect::<Vec<_>>().join(", "), + TypeForm::MultiReturn { types } => types + .iter() + .map(|t| self.to_string(*t, stringtab)) + .collect::<Vec<_>>() + .join(", "), } }