From bf0d479a4027ab8938179a41a8a13353da9b4488 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 18 Feb 2025 15:08:38 -0600 Subject: [PATCH 01/19] Front-end changes for multi-return --- juno_frontend/src/codegen.rs | 57 +++-- juno_frontend/src/labeled_builder.rs | 4 +- juno_frontend/src/lang.y | 49 ++-- juno_frontend/src/semant.rs | 370 ++++++++++++++------------- juno_frontend/src/ssa.rs | 4 +- juno_frontend/src/types.rs | 63 ++--- 6 files changed, 295 insertions(+), 252 deletions(-) diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs index a3197041..4e6f89e8 100644 --- a/juno_frontend/src/codegen.rs +++ b/juno_frontend/src/codegen.rs @@ -118,15 +118,18 @@ impl CodeGenerator<'_> { param_types.push(solver_inst.lower_type(&mut self.builder.builder, *ty)); } - let return_type = - solver_inst.lower_type(&mut self.builder.builder, func.return_type); + let return_types = + func.return_types + .iter() + .map(|t| solver_inst.lower_type(&mut self.builder.builder, *t)) + .collect::<Vec<_>>(); let (func_id, entry) = self .builder .create_function( &name, param_types, - return_type, + return_types, func.num_dyn_consts as u32, func.entry, ) @@ -264,10 +267,16 @@ impl CodeGenerator<'_> { // loop Some(block_exit) } - Stmt::ReturnStmt { expr } => { - let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, cur_block); + Stmt::ReturnStmt { exprs } => { + let mut vals = vec![]; + let mut block = cur_block; + for expr in exprs { + let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, block); + vals.push(val_ret); + block = block_ret; + } let mut return_node = self.builder.allocate_node(); - return_node.build_return(block_ret, val_ret); + return_node.build_return(block, vals); self.builder.add_node(return_node); None } @@ -482,6 +491,7 @@ impl CodeGenerator<'_> { ty_args, dyn_consts, args, + num_returns, // number of non-inout returns (which are first) .. } => { // We start by lowering the type arguments to TypeIDs @@ -541,30 +551,27 @@ impl CodeGenerator<'_> { // Read each of the "inout values" and perform the SSA update let has_inouts = !inouts.is_empty(); - // TODO: We should omit unit returns, if we do so the + 1 below is not needed for (idx, var) in inouts.into_iter().enumerate() { - let index = self.builder.builder.create_field_index(idx + 1); - let mut read = self.builder.allocate_node(); - let read_id = read.id(); - read.build_read(call_id, vec![index].into()); - self.builder.add_node(read); + let index = self.builder.builder.create_field_index(num_returns + idx); + let mut proj = self.builder.allocate_node(); + let proj_id = proj.id(); + proj.build_data_projection(call_id, index); + self.builder.add_node(proj); - ssa.write_variable(var, block, read_id); + ssa.write_variable(var, block, proj_id); } - // Read the "actual return" value and return it - let result = if !has_inouts { - call_id - } else { - let value_index = self.builder.builder.create_field_index(0); - let mut read = self.builder.allocate_node(); - let read_id = read.id(); - read.build_read(call_id, vec![value_index].into()); - self.builder.add_node(read); - read_id - }; + (call_id, block) + } + Expr::CallExtract { call, index, .. } => { + let (call, block) = self.codegen_expr(call, types, ssa, cur_block); + + let mut proj = self.builder.allocate_node(); + let proj_id = proj.id(); + proj.build_data_projection(call, index); + self.builder.add_node(proj); - (result, block) + (proj_id, block) } Expr::Intrinsic { id, diff --git a/juno_frontend/src/labeled_builder.rs b/juno_frontend/src/labeled_builder.rs index 15bed6c2..869485e9 100644 --- a/juno_frontend/src/labeled_builder.rs +++ b/juno_frontend/src/labeled_builder.rs @@ -33,14 +33,14 @@ impl<'a> LabeledBuilder<'a> { &mut self, name: &str, param_types: Vec<TypeID>, - return_type: TypeID, + return_types: Vec<TypeID>, num_dynamic_constants: u32, entry: bool, ) -> Result<(FunctionID, NodeID), String> { let (func, entry) = self.builder.create_function( name, param_types, - return_type, + return_types, num_dynamic_constants, entry, )?; diff --git a/juno_frontend/src/lang.y b/juno_frontend/src/lang.y index be9161aa..b9efe1fa 100644 --- a/juno_frontend/src/lang.y +++ b/juno_frontend/src/lang.y @@ -167,17 +167,17 @@ ConstDecl -> Result<Top, ()> FuncDecl -> Result<Top, ()> : PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?, - ty_vars : $4?, args : $6?, ty : None, body : $8? }) } + ty_vars : $4?, args : $6?, rets: vec![], body : $8? }) } | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' Stmts { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?), - name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : None, + name : span_of_tok($4)?, ty_vars : $5?, args : $7?, rets: vec![], body : $9? }) } - | PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts + | PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Types Stmts { Ok(Top::FuncDecl{ span : $span, public : $1?, attr : None, name : span_of_tok($3)?, - ty_vars : $4?, args : $6?, ty : Some($9?), body : $10? }) } - | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Type Stmts + ty_vars : $4?, args : $6?, rets: $9?, body : $10? }) } + | 'FUNC_ATTR' PubOption 'fn' 'ID' TypeVars '(' Arguments ')' '->' Types Stmts { Ok(Top::FuncDecl{ span : $span, public : $2?, attr : Some(span_of_tok($1)?), - name : span_of_tok($4)?, ty_vars : $5?, args : $7?, ty : Some($10?), + name : span_of_tok($4)?, ty_vars : $5?, args : $7?, rets: $10?, body : $11? }) } ; Arguments -> Result<Vec<(Option<Span>, VarBind)>, ()> @@ -198,6 +198,18 @@ VarBind -> Result<VarBind, ()> | Pattern ':' Type { Ok(VarBind{ span : $span, pattern : $1?, typ : Some($3?) }) } ; +LetBind -> Result<LetBind, ()> + : VarBind { + let VarBind { span, pattern, typ } = $1?; + Ok(LetBind::Single { span, pattern, typ }) + } + | PatternsCommaS ',' Pattern { + let mut pats = $1?; + pats.push($3?); + Ok(LetBind::Multi { span: $span, patterns: pats }) + } + ; + Pattern -> Result<Pattern, ()> : '_' { Ok(Pattern::Wildcard { span : $span }) } | IntLit { let (span, base) = $1?; @@ -240,9 +252,9 @@ StructPatterns -> Result<(VecDeque<(Id, Pattern)>, bool), ()> ; Stmt -> Result<Stmt, ()> - : 'let' VarBind ';' + : 'let' LetBind ';' { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : None }) } - | 'let' VarBind '=' Expr ';' + | 'let' LetBind '=' Expr ';' { Ok(Stmt::LetStmt{ span : $span, var : $2?, init : Some($4?) }) } | 'const' VarBind ';' { Ok(Stmt::ConstStmt{ span : $span, var : $2?, init : None }) } @@ -305,10 +317,8 @@ Stmt -> Result<Stmt, ()> inclusive: true, step: None, body: Box::new($8?) }) } | 'while' NonStructExpr Stmts { Ok(Stmt::WhileStmt{ span : $span, cond : $2?, body : Box::new($3?) }) } - | 'return' ';' - { Ok(Stmt::ReturnStmt{ span : $span, expr : None }) } - | 'return' Expr ';' - { Ok(Stmt::ReturnStmt{ span : $span, expr : Some($2?)}) } + | 'return' Exprs ';' + { Ok(Stmt::ReturnStmt{ span : $span, vals: $2?}) } | 'break' ';' { Ok(Stmt::BreakStmt{ span : $span }) } | 'continue' ';' @@ -659,6 +669,15 @@ pub struct VarBind { pub span : Span, pub pattern : Pattern, pub typ : Option<Ty #[derive(Debug)] pub struct Case { pub span : Span, pub pat : Vec<Pattern>, pub body : Stmt } +// Let bindings are different from other bindings because they can be used to +// destruct multi-return function values, and so can actually contain multiple +// patterns +#[derive(Debug)] +pub enum LetBind { + Single { span: Span, pattern: Pattern, typ: Option<Type> }, + Multi { span: Span, patterns: Vec<Pattern> }, +} + #[derive(Debug)] pub enum Top { Import { span : Span, name : ImportName }, @@ -666,7 +685,7 @@ pub enum Top { ConstDecl { span : Span, public : bool, name : Id, ty : Option<Type>, body : Expr }, FuncDecl { span : Span, public : bool, attr : Option<Span>, name : Id, ty_vars : Vec<TypeVar>, args : Vec<(Option<Span>, VarBind)>, // option is for inout - ty : Option<Type>, body : Stmt }, + rets : Vec<Type>, body : Stmt }, ModDecl { span : Span, public : bool, name : Id, body : Vec<Top> }, } @@ -688,7 +707,7 @@ pub enum Type { #[derive(Debug)] pub enum Stmt { - LetStmt { span : Span, var : VarBind, init : Option<Expr> }, + LetStmt { span : Span, var : LetBind, init : Option<Expr> }, ConstStmt { span : Span, var : VarBind, init : Option<Expr> }, AssignStmt { span : Span, lhs : LExpr, assign : AssignOp, assign_span : Span, rhs : Expr }, IfStmt { span : Span, cond : Expr, thn : Box<Stmt>, els : Option<Box<Stmt>> }, @@ -697,7 +716,7 @@ pub enum Stmt { ForStmt { span : Span, var : VarBind, init : Expr, bound : Expr, inclusive: bool, step : Option<(bool, Span, IntBase)>, body : Box<Stmt> }, WhileStmt { span : Span, cond : Expr, body : Box<Stmt> }, - ReturnStmt { span : Span, expr : Option<Expr> }, + ReturnStmt { span : Span, vals : Vec<Expr> }, BreakStmt { span : Span }, ContinueStmt { span : Span }, BlockStmt { span : Span, body : Vec<Stmt> }, diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index bfd4cf7f..ae696b4f 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -46,7 +46,7 @@ enum Entity { index: usize, type_args: Vec<parser::Kind>, args: Vec<(types::Type, bool)>, - return_type: types::Type, + return_types: Vec<types::Type>, }, } @@ -202,7 +202,7 @@ pub struct Function { pub num_dyn_consts: usize, pub num_type_args: usize, pub arguments: Vec<(usize, Type)>, - pub return_type: Type, + pub return_types: Vec<Type>, pub body: Stmt, pub entry: bool, } @@ -236,7 +236,7 @@ pub enum Stmt { body: Box<Stmt>, }, ReturnStmt { - expr: Expr, + exprs: Vec<Expr>, }, BreakStmt {}, ContinueStmt {}, @@ -328,6 +328,14 @@ pub enum Expr { ty_args: Vec<Type>, dyn_consts: Vec<DynConst>, args: Vec<Either<Expr, usize>>, + // Include the number of Juno returns (i.e. non-inouts) for codegen + num_returns: usize, + typ: Type, + }, + // A projection from a call + CallExtract { + call: Box<Expr>, + index: usize, typ: Type, }, Intrinsic { @@ -415,57 +423,20 @@ fn convert_binary_op(op: parser::BinaryOp) -> BinaryOp { impl Expr { pub fn get_type(&self) -> Type { match self { - Expr::Variable { var: _, typ } - | Expr::DynConst { val: _, typ } - | Expr::Read { - index: _, - val: _, - typ, - } - | Expr::Write { - index: _, - val: _, - rep: _, - typ, - } - | Expr::Tuple { vals: _, typ } - | Expr::Union { - tag: _, - val: _, - typ, - } - | Expr::Constant { val: _, typ } - | Expr::UnaryExp { - op: _, - expr: _, - typ, - } - | Expr::BinaryExp { - op: _, - lhs: _, - rhs: _, - typ, - } - | Expr::CastExpr { expr: _, typ } - | Expr::CondExpr { - cond: _, - thn: _, - els: _, - typ, - } - | Expr::CallExpr { - func: _, - ty_args: _, - dyn_consts: _, - args: _, - typ, - } - | Expr::Intrinsic { - id: _, - ty_args: _, - args: _, - typ, - } + Expr::Variable { typ, .. } + | Expr::DynConst { typ, .. } + | Expr::Read { typ, .. } + | Expr::Write { typ, .. } + | Expr::Tuple { typ, .. } + | Expr::Union { typ, .. } + | Expr::Constant { typ, .. } + | Expr::UnaryExp { typ, .. } + | Expr::BinaryExp { typ, .. } + | Expr::CastExpr { typ, .. } + | Expr::CondExpr { typ, .. } + | Expr::CallExpr { typ, .. } + | Expr::CallExtract { typ, .. } + | Expr::Intrinsic { typ, .. } | Expr::Zero { typ } => *typ, } } @@ -650,7 +621,7 @@ fn analyze_program( name, ty_vars, args, - ty, + rets, body, } => { // TODO: Handle public @@ -778,44 +749,36 @@ fn analyze_program( } } - let return_type = { - // A missing return type is implicitly void - let ty = ty.unwrap_or(parser::Type::PrimType { - span: span, - typ: parser::Primitive::Void, - }); - match process_type( - ty, - num_dyn_const, - lexer, - &mut stringtab, - &env, - &mut types, - true, - ) { - Ok(ty) => ty, - Err(mut errs) => { - errors.append(&mut errs); - types.new_primitive(types::Primitive::Unit) + let return_types = rets + .into_iter() + .map(|ty| + match process_type( + ty, + num_dyn_const, + lexer, + &mut stringtab, + &env, + &mut types, + true, + ) { + Ok(ty) => ty, + Err(mut errs) => { + errors.append(&mut errs); + // Type we return doesn't matter, error will be propagated upwards + // next, but need to return something + types.new_primitive(types::Primitive::Unit) + } } - } - }; + ) + .collect::<Vec<_>>(); if !errors.is_empty() { return Err(errors); } // Compute the proper type accounting for the inouts (which become returns) - let mut inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); - - let mut return_types = vec![return_type]; - return_types.extend(inout_types); - // TODO: Ideally we would omit unit returns - let pure_return_type = if return_types.len() == 1 { - return_types.pop().unwrap() - } else { - types.new_tuple(return_types) - }; + let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); + let pure_return_types = return_types.clone().into_iter().chain(inout_types.into_iter()).collect::<Vec<_>>(); // Finally, we have a properly built environment and we can // start processing the body @@ -827,7 +790,7 @@ fn analyze_program( &mut env, &mut types, false, - return_type, + &return_types, &inouts, &mut labels, )?; @@ -835,19 +798,15 @@ fn analyze_program( if end_reachable { // The end of a function being reachable (i.e. there is some possible path // where there is no return statement) is an error unless the return type is - // void - if types.unify_void(return_type) { + // empty + if return_types.is_empty() { // Insert return at the end body = Stmt::BlockStmt { body: vec![ body, generate_return( - Expr::Tuple { - vals: vec![], - typ: types.new_primitive(types::Primitive::Unit), - }, + vec![], &inouts, - &mut types, ), ], }; @@ -876,7 +835,7 @@ fn analyze_program( .iter() .map(|(ty, is, _)| (*ty, *is)) .collect::<Vec<_>>(), - return_type: return_type, + return_types, }, ); @@ -889,7 +848,7 @@ fn analyze_program( .iter() .map(|(t, _, v)| (*v, *t)) .collect::<Vec<_>>(), - return_type: pure_return_type, + return_types: pure_return_types, body: body, entry: entry, }); @@ -1610,21 +1569,22 @@ fn process_stmt( env: &mut Env<usize, Entity>, types: &mut TypeSolver, in_loop: bool, - return_type: Type, + return_types: &[Type], inouts: &Vec<Expr>, labels: &mut StringTable, ) -> Result<(Stmt, bool), ErrorMessages> { match stmt { parser::Stmt::LetStmt { span, - var: - VarBind { - span: v_span, - pattern, - typ, - }, + var, init, } => { + let (_, pattern, typ) = + match var { + LetBind::Single { span, pattern, typ } => (span, Either::Left(pattern), typ), + LetBind::Multi { span, patterns } => (span, Either::Right(patterns), None), + }; + if typ.is_none() && init.is_none() { return Err(singleton_error(ErrorMessage::SemanticError( span_to_loc(span, lexer), @@ -1676,12 +1636,60 @@ fn process_stmt( let mut res = vec![]; res.push(Stmt::AssignStmt { var, val }); - res.extend( - process_irrefutable_pattern( - pattern, false, var, typ, lexer, stringtab, env, types, false, - )? - .0, - ); + + match pattern { + Either::Left(pattern) => { + if let Some(return_types) = types.get_return_types(typ) { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} patterns, found 1 pattern", return_types.len()), + ))); + } + res.extend( + process_irrefutable_pattern( + pattern, false, var, typ, lexer, stringtab, env, types, false, + )? + .0, + ); + } + Either::Right(patterns) => { + let Some(return_types) = types.get_return_types(typ) else { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected 1 pattern, found {} patterns", patterns.len()), + ))); + }; + if return_types.len() != patterns.len() { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!("Expected {} pattern, found {} patterns", return_types.len(), patterns.len()), + ))); + } + + // Process each pattern after extracting the appropriate value from the call + for (index, (pat, ret_typ)) in + patterns.into_iter() + .zip(return_types.clone().into_iter()) + .enumerate() { + let extract_var = env.uniq(); + res.push(Stmt::AssignStmt { + var: extract_var, + val: Expr::CallExtract { + call: Box::new(Expr::Variable { var, typ }), + index, + typ: ret_typ, + } + }); + res.extend( + process_irrefutable_pattern( + pat, false, extract_var, ret_typ, lexer, stringtab, env, types, false + )? + .0, + ); + } + } + } + Ok((Stmt::BlockStmt { body: res }, true)) } @@ -1689,7 +1697,7 @@ fn process_stmt( span, var: VarBind { - span: v_span, + span: _v_span, pattern, typ, }, @@ -1935,7 +1943,7 @@ fn process_stmt( env, types, in_loop, - return_type, + return_types, inouts, labels, ); @@ -1952,7 +1960,7 @@ fn process_stmt( env, types, in_loop, - return_type, + return_types, inouts, labels, ) @@ -2110,7 +2118,7 @@ fn process_stmt( env, types, true, - return_type, + return_types, inouts, labels, )?; @@ -2214,7 +2222,7 @@ fn process_stmt( env, types, true, - return_type, + return_types, inouts, labels, ); @@ -2241,36 +2249,52 @@ fn process_stmt( true, )) } - parser::Stmt::ReturnStmt { span, expr } => { - let return_val = if expr.is_none() && types.unify_void(return_type) { - Expr::Constant { - val: (Literal::Unit, return_type), - typ: return_type, - } - } else if expr.is_none() { - Err(singleton_error(ErrorMessage::SemanticError( + parser::Stmt::ReturnStmt { span, vals } => { + if return_types.len() != vals.len() { + return Err(singleton_error(ErrorMessage::SemanticError( span_to_loc(span, lexer), format!( - "Expected return of type {} found no return value", - unparse_type(types, return_type, stringtab) - ), - )))? - } else { - let val = process_expr(expr.unwrap(), num_dyn_const, lexer, stringtab, env, types)?; - let typ = val.get_type(); - if !types.unify(return_type, typ) { - Err(singleton_error(ErrorMessage::TypeError( - span_to_loc(span, lexer), - unparse_type(types, return_type, stringtab), - unparse_type(types, typ, stringtab), - )))? - } - val - }; + "Expected {} return values found {}", + return_types.len(), + vals.len(), + )))); + } - // We return a tuple of the return value and of the inout variables + let return_vals = vals + .into_iter() + .zip(return_types.iter()) + .map(|(expr, typ)| { + let expr_span = expr.span(); + let val = process_expr(expr, num_dyn_const, lexer, stringtab, env, types)?; + if types.unify(*typ, val.get_type()) { + Ok(val) + } else { + Err(singleton_error(ErrorMessage::TypeError( + span_to_loc(expr_span, lexer), + unparse_type(types, *typ, stringtab), + unparse_type(types, val.get_type(), stringtab), + ))) + } + }) + .fold(Ok(vec![]), + |res, val| { + match (res, val) { + (Ok(mut res), Ok(val)) => { + res.push(val); + Ok(res) + } + (Ok(_), Err(msg)) => Err(msg), + (Err(msg), Ok(_)) => Err(msg), + (Err(mut msgs), Err(msg)) => { + msgs.extend(msg); + Err(msgs) + } + } + })?; + + // We return both the actual return values and the inout arguments // Statements after a return are never reachable - Ok((generate_return(return_val, inouts, types), false)) + Ok((generate_return(return_vals, inouts), false)) } parser::Stmt::BreakStmt { span } => { if !in_loop { @@ -2318,7 +2342,7 @@ fn process_stmt( env, types, in_loop, - return_type, + return_types, inouts, labels, ) { @@ -2384,7 +2408,7 @@ fn process_stmt( env, types, in_loop, - return_type, + return_types, inouts, labels, )?; @@ -4723,7 +4747,7 @@ fn process_expr( index: function, type_args: kinds, args: func_args, - return_type, + return_types, }) => { let func = *function; @@ -4813,11 +4837,11 @@ fn process_expr( } tys }; - let return_typ = if let Some(res) = - types.instantiate(*return_type, &type_vars, &dyn_consts) - { - res - } else { + let return_types = return_types + .iter() + .map(|t| types.instantiate(*t, &type_vars, &dyn_consts)) + .collect::<Option<Vec<_>>>(); + let Some(return_types) = return_types else { return Err(singleton_error(ErrorMessage::SemanticError( span_to_loc(span, lexer), "Failure in variable substitution".to_string(), @@ -4887,13 +4911,30 @@ fn process_expr( if !errors.is_empty() { Err(errors) } else { - Ok(Expr::CallExpr { - func: func, + let single_type = + if return_types.len() == 1 { + Some(return_types[0]) + } else { + None + }; + let num_returns = return_types.len(); + let call = Expr::CallExpr { + func, ty_args: type_vars, - dyn_consts: dyn_consts, + dyn_consts, args: arg_vals, - typ: return_typ, - }) + num_returns, + typ: types.new_multi_return(return_types), + }; + if let Some(return_type) = single_type { + Ok(Expr::CallExtract { + call: Box::new(call), + index: 0, + typ: return_type, + }) + } else { + Ok(call) + } } } } @@ -5024,25 +5065,9 @@ fn process_expr( } } -fn generate_return(expr: Expr, inouts: &Vec<Expr>, types: &mut TypeSolver) -> Stmt { - let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); - - let mut return_types = vec![expr.get_type()]; - return_types.extend(inout_types); - - let mut return_vals = vec![expr]; - return_vals.extend_from_slice(inouts); - - let val = if return_vals.len() == 1 { - return_vals.pop().unwrap() - } else { - Expr::Tuple { - vals: return_vals, - typ: types.new_tuple(return_types), - } - }; - - Stmt::ReturnStmt { expr: val } +fn generate_return(mut exprs: Vec<Expr>, inouts: &[Expr]) -> Stmt { + exprs.extend_from_slice(inouts); + Stmt::ReturnStmt { exprs: exprs } } fn convert_primitive(prim: parser::Primitive) -> types::Primitive { @@ -5098,6 +5123,7 @@ fn process_irrefutable_pattern( "Bound variables must be local names, without a package separator".to_string(), ))); } + assert!(types.get_return_types(typ).is_none()); let nm = intern_package_name(&name, lexer, stringtab)[0]; let variable = env.uniq(); diff --git a/juno_frontend/src/ssa.rs b/juno_frontend/src/ssa.rs index 7076d622..9dbc0bfd 100644 --- a/juno_frontend/src/ssa.rs +++ b/juno_frontend/src/ssa.rs @@ -45,10 +45,10 @@ impl SSA { let right_proj = right_builder.id(); // True branch - left_builder.build_projection(if_builder.id(), 1); + left_builder.build_control_projection(if_builder.id(), 1); // False branch - right_builder.build_projection(if_builder.id(), 0); + right_builder.build_control_projection(if_builder.id(), 0); builder.add_node(left_builder); builder.add_node(right_builder); diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs index edb51db5..6e59169d 100644 --- a/juno_frontend/src/types.rs +++ b/juno_frontend/src/types.rs @@ -204,6 +204,11 @@ enum TypeForm { kind: parser::Kind, loc: Location, }, + + // The types of call nodes are MultiReturns + MultiReturn { + types: Vec<Type>, + }, } #[derive(Debug)] @@ -279,6 +284,10 @@ impl TypeSolver { }) } + pub fn new_multi_return(&mut self, types: Vec<Type>) -> Type { + self.create_type(TypeForm::MultiReturn { types }) + } + fn create_type(&mut self, typ: TypeForm) -> Type { let idx = self.types.len(); self.types.push(typ); @@ -543,26 +552,13 @@ impl TypeSolver { } } + // Note that MultReturn types never unify with anything (even itself), this is + // intentional and makes it so that the only way MultiReturns can be used is to + // destruct them + _ => false, } } - /* - pub fn is_tuple(&self, Type { val } : Type) -> bool { - match &self.types[val] { - TypeForm::Tuple(_) => true, - TypeForm::OtherType(t) => self.is_tuple(*t), - _ => false, - } - } - - pub fn get_num_fields(&self, Type { val } : Type) -> Option<usize> { - match &self.types[val] { - TypeForm::Tuple(fields) => { Some(fields.len()) }, - TypeForm::OtherType(t) => self.get_num_fields(*t), - _ => None, - } - } - */ // Returns the types of the fields of a tuple pub fn get_fields(&self, Type { val }: Type) -> Option<&Vec<Type>> { @@ -676,26 +672,13 @@ impl TypeSolver { } } - /* - pub fn get_constructor_list(&self, Type { val } : Type) -> Option<Vec<usize>> { - match &self.types[val] { - TypeForm::Union { name : _, id : _, constr : _, names } => { - Some(names.keys().map(|i| *i).collect::<Vec<_>>()) - }, - TypeForm::OtherType(t) => self.get_constructor_list(*t), - _ => None, - } - } - - - fn is_type_var_num(&self, num : usize, Type { val } : Type) -> bool { - match &self.types[val] { - TypeForm::TypeVar { name : _, index, .. } => *index == num, - TypeForm::OtherType(t) => self.is_type_var_num(num, *t), - _ => false, - } + pub fn get_return_types(&self, Type { val }: Type) -> Option<&Vec<Type>> { + match &self.types[val] { + TypeForm::MultiReturn { types } => Some(types), + TypeForm::OtherType { other, .. } => self.get_return_types(*other), + _ => None, } - */ + } pub fn to_string(&self, Type { val }: Type, stringtab: &dyn Fn(usize) -> String) -> String { match &self.types[val] { @@ -724,6 +707,8 @@ impl TypeSolver { | TypeForm::Struct { name, .. } | TypeForm::Union { name, .. } => stringtab(*name), TypeForm::AnyOfKind { kind, .. } => kind.to_string(), + TypeForm::MultiReturn { types } => + types.iter().map(|t| self.to_string(*t, stringtab)).collect::<Vec<_>>().join(", "), } } @@ -825,6 +810,9 @@ impl TypeSolver { Some(Type { val }) } } + TypeForm::MultiReturn { .. } => { + panic!("Multi-Return types should never be instantiated") + } } } @@ -969,6 +957,9 @@ impl TypeSolverInst<'_> { TypeForm::AnyOfKind { .. } => { panic!("TypeSolverInst only works on solved types which do not have AnyOfKinds") } + TypeForm::MultiReturn { .. } => { + panic!("MultiReturn types should never be lowered") + } }; match solution { -- GitLab From 2b326bd3dc54df360ddd41fb4d6da183944eacf7 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 18 Feb 2025 16:48:04 -0600 Subject: [PATCH 02/19] IR changes for multi-return --- hercules_ir/src/build.rs | 15 +++++--- hercules_ir/src/collections.rs | 14 ++++---- hercules_ir/src/def_use.rs | 22 +++++++++--- hercules_ir/src/ir.rs | 63 +++++++++++++++++++++++++--------- hercules_ir/src/parse.rs | 55 ++++++++++++++++++++++++----- hercules_ir/src/typecheck.rs | 44 ++++++++++++++++++++---- hercules_ir/src/verify.rs | 14 +++++--- 7 files changed, 176 insertions(+), 51 deletions(-) diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 40538cef..3e966d53 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -392,6 +392,7 @@ impl<'a> Builder<'a> { pub fn create_constant_zero(&mut self, typ: TypeID) -> ConstantID { match &self.module.types[typ.idx()] { Type::Control => panic!("Cannot create constant for control types"), + Type::MultiReturn(..) => panic!("Cannot create constant for multi-return types"), Type::Boolean => self.create_constant_bool(false), Type::Integer8 => self.create_constant_i8(0), Type::Integer16 => self.create_constant_i16(0), @@ -503,7 +504,7 @@ impl<'a> Builder<'a> { &mut self, name: &str, param_types: Vec<TypeID>, - return_type: TypeID, + return_types: Vec<TypeID>, num_dynamic_constants: u32, entry: bool, ) -> BuilderResult<(FunctionID, NodeID)> { @@ -515,7 +516,7 @@ impl<'a> Builder<'a> { self.module.functions.push(Function { name: name.to_owned(), param_types, - return_type, + return_types, num_dynamic_constants, entry, nodes: vec![Node::Start], @@ -594,11 +595,15 @@ impl NodeBuilder { }; } - pub fn build_projection(&mut self, control: NodeID, selection: usize) { - self.node = Node::Projection { control, selection }; + pub fn build_control_projection(&mut self, control: NodeID, selection: usize) { + self.node = Node::ControlProjection { control, selection }; } - pub fn build_return(&mut self, control: NodeID, data: NodeID) { + pub fn build_data_projection(&mut self, data: NodeID, selection: usize) { + self.node = Node::DataProjection { data, selection }; + } + + pub fn build_return(&mut self, control: NodeID, data: Box<[NodeID]>) { self.node = Node::Return { control, data }; } diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 6b631519..c4e71f8b 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -328,7 +328,9 @@ pub fn collection_objects( let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new(); for node in func.nodes.iter() { if let Node::Return { control: _, data } = node { - returned.extend(&objects_per_node[data.idx()]); + for node in data { + returned.extend(&objects_per_node[node.idx()]); + } } } let returned = returned.into_iter().collect(); @@ -500,16 +502,16 @@ pub fn no_reset_constant_collections( collect: _, data, indices: _, - } - | Node::Return { control: _, data } => { + } => { Either::Left(zip(once(&full_indices), once(data))) } - Node::Call { + Node::Return { control: _, ref data } + | Node::Call { control: _, function: _, dynamic_constants: _, - ref args, - } => Either::Right(zip(repeat(&full_indices), args.into_iter().map(|id| *id))), + args: ref data, + } => Either::Right(zip(repeat(&full_indices), data.into_iter().map(|id| *id))), _ => return None, }; diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index ff0e08ed..a99c8a23 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -156,7 +156,11 @@ pub fn get_uses(node: &Node) -> NodeUses { init, reduct, } => NodeUses::Three([*control, *init, *reduct]), - Node::Return { control, data } => NodeUses::Two([*control, *data]), + Node::Return { control, data } => { + let mut uses: Vec<NodeID> = Vec::from(&data[..]); + uses.push(*control); + NodeUses::Variable(uses.into_boxed_slice()) + } Node::Parameter { index: _ } => NodeUses::One([NodeID::new(0)]), Node::Constant { id: _ } => NodeUses::One([NodeID::new(0)]), Node::DynamicConstant { id: _ } => NodeUses::One([NodeID::new(0)]), @@ -222,10 +226,14 @@ pub fn get_uses(node: &Node) -> NodeUses { NodeUses::Two([*collect, *data]) } } - Node::Projection { + Node::ControlProjection { control, selection: _, } => NodeUses::One([*control]), + Node::DataProjection { + data, + selection: _, + } => NodeUses::One([*data]), Node::Undef { ty: _ } => NodeUses::One([NodeID::new(0)]), } } @@ -260,7 +268,9 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { init, reduct, } => NodeUsesMut::Three([control, init, reduct]), - Node::Return { control, data } => NodeUsesMut::Two([control, data]), + Node::Return { control, data } => { + NodeUsesMut::Variable(std::iter::once(control).chain(data.iter_mut()).collect()) + } Node::Parameter { index: _ } => NodeUsesMut::Zero, Node::Constant { id: _ } => NodeUsesMut::Zero, Node::DynamicConstant { id: _ } => NodeUsesMut::Zero, @@ -326,10 +336,14 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { NodeUsesMut::Two([collect, data]) } } - Node::Projection { + Node::ControlProjection { control, selection: _, } => NodeUsesMut::One([control]), + Node::DataProjection { + data, + selection: _, + } => NodeUsesMut::One([data]), Node::Undef { ty: _ } => NodeUsesMut::Zero, } } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index bf9698b3..68fdc26c 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -39,7 +39,7 @@ pub struct Module { pub struct Function { pub name: String, pub param_types: Vec<TypeID>, - pub return_type: TypeID, + pub return_types: Vec<TypeID>, pub num_dynamic_constants: u32, pub entry: bool, @@ -77,6 +77,7 @@ pub enum Type { Product(Box<[TypeID]>), Summation(Box<[TypeID]>), Array(TypeID, Box<[DynamicConstantID]>), + MultiReturn(Box<[TypeID]>), } /* @@ -186,7 +187,7 @@ pub enum Node { }, Return { control: NodeID, - data: NodeID, + data: Box<[NodeID]>, }, Parameter { index: usize, @@ -237,10 +238,14 @@ pub enum Node { data: NodeID, indices: Box<[Index]>, }, - Projection { + ControlProjection { control: NodeID, selection: usize, }, + DataProjection { + data: NodeID, + selection: usize, + }, Undef { ty: TypeID, }, @@ -434,6 +439,17 @@ impl Module { } write!(w, ")") } + Type::MultiReturn(fields) => { + write!(w, "MultiReturn(")?; + for idx in 0..fields.len() { + let field_ty_id = fields[idx]; + self.write_type(field_ty_id, w)?; + if idx + 1 < fields.len() { + write!(w, ", ")?; + } + } + write!(w, ")") + } }?; Ok(()) @@ -1262,12 +1278,19 @@ impl Node { ); define_pattern_predicate!(is_match, Node::Match { control: _, sum: _ }); define_pattern_predicate!( - is_projection, - Node::Projection { + is_control_projection, + Node::ControlProjection { control: _, selection: _ } ); + define_pattern_predicate!( + is_data_projection, + Node::DataProjection { + data: _, + selection: _ + } + ); define_pattern_predicate!(is_undef, Node::Undef { ty: _ }); @@ -1287,8 +1310,8 @@ impl Node { } } - pub fn try_proj(&self) -> Option<(NodeID, usize)> { - if let Node::Projection { control, selection } = self { + pub fn try_control_proj(&self) -> Option<(NodeID, usize)> { + if let Node::ControlProjection { control, selection } = self { Some((*control, *selection)) } else { None @@ -1303,9 +1326,9 @@ impl Node { } } - pub fn try_return(&self) -> Option<(NodeID, NodeID)> { + pub fn try_return(&self) -> Option<(NodeID, &[NodeID])> { if let Node::Return { control, data } = self { - Some((*control, *data)) + Some((*control, data)) } else { None } @@ -1479,8 +1502,8 @@ impl Node { } } - pub fn try_projection(&self, branch: usize) -> Option<NodeID> { - if let Node::Projection { control, selection } = self + pub fn try_control_projection(&self, branch: usize) -> Option<NodeID> { + if let Node::ControlProjection { control, selection } = self && branch == *selection { Some(*control) @@ -1560,10 +1583,14 @@ impl Node { data: _, indices: _, } => "Write", - Node::Projection { + Node::ControlProjection { control: _, selection: _, - } => "Projection", + } => "ControlProjection", + Node::DataProjection { + data: _, + selection: _, + } => "DataProjection", Node::Undef { ty: _ } => "Undef", } } @@ -1639,10 +1666,14 @@ impl Node { data: _, indices: _, } => "write", - Node::Projection { + Node::ControlProjection { control: _, selection: _, - } => "projection", + } => "control_projection", + Node::DataProjection { + data: _, + selection: _, + } => "data_projection", Node::Undef { ty: _ } => "undef", } } @@ -1655,7 +1686,7 @@ impl Node { || self.is_fork() || self.is_join() || self.is_return() - || self.is_projection() + || self.is_control_projection() } } diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index f1f4153a..42730f77 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -139,7 +139,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a Function { name: String::from(""), param_types: vec![], - return_type: TypeID::new(0), + return_types: vec![], num_dynamic_constants: 0, entry: true, nodes: vec![], @@ -245,7 +245,14 @@ fn parse_function<'a>( let ir_text = nom::character::complete::char(')')(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0; - let (ir_text, return_type) = parse_type_id(ir_text, context)?; + let (ir_text, return_types) = nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + |text| parse_type_id(text, context), + )(ir_text)?; let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context))(ir_text)?; // `nodes`, as returned by parsing, is in parse order, which may differ from @@ -286,7 +293,7 @@ fn parse_function<'a>( Function { name: String::from(function_name), param_types: params.into_iter().map(|x| x.5).collect(), - return_type, + return_types, num_dynamic_constants, entry: true, nodes: fixed_nodes, @@ -334,7 +341,8 @@ fn parse_node<'a>( "return" => parse_return(ir_text, context)?, "constant" => parse_constant_node(ir_text, context)?, "dynamic_constant" => parse_dynamic_constant_node(ir_text, context)?, - "projection" => parse_projection(ir_text, context)?, + "control_projection" => parse_control_projection(ir_text, context)?, + "data_projection" => parse_data_projection(ir_text, context)?, // Unary and binary ops are spelled out in the textual format, but we // parse them into Unary or Binary node kinds. "not" => parse_unary(ir_text, context, UnaryOperator::Not)?, @@ -489,9 +497,21 @@ fn parse_return<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Node> { - let (ir_text, (control, data)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char('(')(ir_text)?.0; + let (ir_text, control) = parse_identifier(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(',')(ir_text)?.0; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let (ir_text, data) = nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + parse_identifier)(ir_text)?; let control = context.borrow_mut().get_node_id(control); - let data = context.borrow_mut().get_node_id(data); + let data = data.into_iter().map(|d| context.borrow_mut().get_node_id(d)).collect(); Ok((ir_text, Node::Return { control, data })) } @@ -719,7 +739,7 @@ fn parse_index<'a>( Ok((ir_text, idx)) } -fn parse_projection<'a>( +fn parse_control_projection<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Node> { @@ -728,13 +748,29 @@ fn parse_projection<'a>( let control = context.borrow_mut().get_node_id(control); Ok(( ir_text, - Node::Projection { + Node::ControlProjection { control, selection: index, }, )) } +fn parse_data_projection<'a>( + ir_text: &'a str, + context: &RefCell<Context<'a>>, +) -> nom::IResult<&'a str, Node> { + let parse_usize = |x| parse_prim::<usize>(x, "1234567890"); + let (ir_text, (data, index)) = parse_tuple2(parse_identifier, parse_usize)(ir_text)?; + let data = context.borrow_mut().get_node_id(data); + Ok(( + ir_text, + Node::DataProjection { + data, + selection: index, + }, + )) +} + fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::char('(')(ir_text)?.0; @@ -991,7 +1027,8 @@ fn parse_constant<'a>( ) -> nom::IResult<&'a str, Constant> { let (ir_text, constant) = match ty { // There are not control constants. - Type::Control => Err(nom::Err::Error(nom::error::Error { + Type::Control + | Type::MultiReturn(_) => Err(nom::Err::Error(nom::error::Error { input: ir_text, code: nom::error::ErrorKind::IsNot, }))?, diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 1ff890db..dca11fe7 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -441,12 +441,14 @@ fn typeflow( return inputs[0].clone(); } - if let Concrete(id) = inputs[1] { - if *id != function.return_type { - return Error(String::from("Return node's data input type must be the same as the function's return type.")); + for (idx, (input, return_type)) in inputs[1..].iter().zip(function.return_types.iter()).enumerate() { + if let Concrete(id) = input { + if *id != *return_type { + return Error(format!("Return node's data input at index {} does not match function's return type.", idx)); + } + } else if input.is_error() { + return (*input).clone(); } - } else if inputs[1].is_error() { - return inputs[1].clone(); } Concrete(get_type_id( @@ -759,7 +761,7 @@ fn typeflow( } } - Concrete(subst.type_subst(callee.return_type)) + Concrete(subst.build_return_type(&callee.return_types)) } Node::IntrinsicCall { intrinsic, args: _ } => { let num_params = match intrinsic { @@ -1061,13 +1063,35 @@ fn typeflow( TypeSemilattice::Error(msg) => TypeSemilattice::Error(msg), } } - Node::Projection { + Node::ControlProjection { control: _, selection: _, } => { // Type is the type of the _if node inputs[0].clone() } + Node::DataProjection { + data: _, + selection, + } => { + if let Concrete(type_id) = inputs[0] { + match &types[type_id.idx()] { + Type::MultiReturn(types) => { + if *selection >= types.len() { + return Error(String::from("Data projection's selection must be in range of the multi-return being indexed")); + } + return Concrete(*type_id); + } + _ => { + return Error(String::from( + "Data projection node must read from multi-return value.", + )); + } + } + } + + inputs[0].clone() + } Node::Undef { ty } => TypeSemilattice::Concrete(*ty), } } @@ -1138,6 +1162,11 @@ impl<'a> DCSubst<'a> { } } + fn build_return_type(&mut self, tys: &[TypeID]) -> TypeID { + let tys = tys.iter().map(|t| self.type_subst(*t)).collect(); + self.intern_type(Type::MultiReturn(tys)) + } + fn type_subst(&mut self, typ: TypeID) -> TypeID { match &self.types[typ.idx()] { Type::Control @@ -1172,6 +1201,7 @@ impl<'a> DCSubst<'a> { let new_elem = self.type_subst(elem); self.intern_type(Type::Array(new_elem, new_dims)) } + Type::MultiReturn(..) => panic!("A multi-return type should never be substituted"), } } diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs index f188932e..b50ab0d2 100644 --- a/hercules_ir/src/verify.rs +++ b/hercules_ir/src/verify.rs @@ -251,11 +251,11 @@ fn verify_structure( Err(format!("If node must have 2 users, not {}.", users.len()))?; } if let ( - Node::Projection { + Node::ControlProjection { control: _, selection: result1, }, - Node::Projection { + Node::ControlProjection { control: _, selection: result2, }, @@ -290,7 +290,8 @@ fn verify_structure( Err("ThreadID node's control input must be a fork node.")?; } } - // Call nodes must depend on a region node. + // Call nodes must depend on a region node and its only users must + // be DataProjections. Node::Call { control, function: _, @@ -300,6 +301,11 @@ fn verify_structure( if !function.nodes[control.idx()].is_region() { Err("Call node's control input must be a region node.")?; } + for user in users { + if !function.nodes[user.idx()].is_data_projection() { + Err("Call node users must be DataProjection nodes.")?; + } + } } // Reduce nodes must depend on a join node. Node::Reduce { @@ -339,7 +345,7 @@ fn verify_structure( } let mut users_covered = bitvec![u8, Lsb0; 0; users.len()]; for user in users { - if let Node::Projection { + if let Node::ControlProjection { control: _, ref selection, } = function.nodes[user.idx()] -- GitLab From 53108467bf037b9aa1bcf036ee05d5521753d8fe Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 19 Feb 2025 08:16:53 -0600 Subject: [PATCH 03/19] Fixes to front-end and ir --- hercules_ir/src/def_use.rs | 4 ++-- hercules_ir/src/dot.rs | 8 ++++++++ hercules_ir/src/typecheck.rs | 6 +++--- juno_frontend/src/codegen.rs | 7 +++---- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index a99c8a23..99531345 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -157,8 +157,8 @@ pub fn get_uses(node: &Node) -> NodeUses { reduct, } => NodeUses::Three([*control, *init, *reduct]), Node::Return { control, data } => { - let mut uses: Vec<NodeID> = Vec::from(&data[..]); - uses.push(*control); + let mut uses: Vec<NodeID> = vec![*control]; + uses.extend(data); NodeUses::Variable(uses.into_boxed_slice()) } Node::Parameter { index: _ } => NodeUses::One([NodeID::new(0)]), diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 921a813d..a7f890f8 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -349,6 +349,14 @@ fn write_node<W: Write>( } } } + Node::ControlProjection { + control: _, + selection, + } + | Node::DataProjection { + data: _, + selection, + } => write!(&mut suffix, "{}", selection)?, _ => {} }; diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index dca11fe7..919da640 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -423,8 +423,8 @@ fn typeflow( control: _, data: _, } => { - if inputs.len() != 2 { - return Error(String::from("Return node must have exactly two inputs.")); + if inputs.len() < 1 { + return Error(String::from("Return node must have at least one input.")); } // Check type of control input first, since this may produce an @@ -1080,7 +1080,7 @@ fn typeflow( if *selection >= types.len() { return Error(String::from("Data projection's selection must be in range of the multi-return being indexed")); } - return Concrete(*type_id); + return Concrete(types[*selection]); } _ => { return Error(String::from( diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs index 4e6f89e8..533fd268 100644 --- a/juno_frontend/src/codegen.rs +++ b/juno_frontend/src/codegen.rs @@ -276,7 +276,7 @@ impl CodeGenerator<'_> { block = block_ret; } let mut return_node = self.builder.allocate_node(); - return_node.build_return(block, vals); + return_node.build_return(block, vals.into()); self.builder.add_node(return_node); None } @@ -552,10 +552,9 @@ impl CodeGenerator<'_> { // Read each of the "inout values" and perform the SSA update let has_inouts = !inouts.is_empty(); for (idx, var) in inouts.into_iter().enumerate() { - let index = self.builder.builder.create_field_index(num_returns + idx); let mut proj = self.builder.allocate_node(); let proj_id = proj.id(); - proj.build_data_projection(call_id, index); + proj.build_data_projection(call_id, num_returns + idx); self.builder.add_node(proj); ssa.write_variable(var, block, proj_id); @@ -568,7 +567,7 @@ impl CodeGenerator<'_> { let mut proj = self.builder.allocate_node(); let proj_id = proj.id(); - proj.build_data_projection(call, index); + proj.build_data_projection(call, *index); self.builder.add_node(proj); (proj_id, block) -- GitLab From 89cf63ebcec1055bf17a48ae6862891121515261 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 19 Feb 2025 08:19:51 -0600 Subject: [PATCH 04/19] Formatting --- hercules_ir/src/collections.rs | 2 + hercules_ir/src/def_use.rs | 10 +--- hercules_ir/src/dot.rs | 5 +- hercules_ir/src/parse.rs | 11 ++-- hercules_ir/src/typecheck.rs | 11 ++-- juno_frontend/src/codegen.rs | 10 ++-- juno_frontend/src/semant.rs | 100 +++++++++++++++++---------------- juno_frontend/src/types.rs | 8 ++- 8 files changed, 79 insertions(+), 78 deletions(-) diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index c4e71f8b..06a53fdb 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -155,6 +155,8 @@ pub fn collection_objects( typing: &ModuleTyping, callgraph: &CallGraph, ) -> CollectionObjects { + panic!("Collections analysis needs to be updated to handle multi-return"); + // Analyze functions in reverse topological order, since the analysis of a // function depends on all functions it calls. let mut collection_objects: CollectionObjects = BTreeMap::new(); diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index 99531345..e9ba4576 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -230,10 +230,7 @@ pub fn get_uses(node: &Node) -> NodeUses { control, selection: _, } => NodeUses::One([*control]), - Node::DataProjection { - data, - selection: _, - } => NodeUses::One([*data]), + Node::DataProjection { data, selection: _ } => NodeUses::One([*data]), Node::Undef { ty: _ } => NodeUses::One([NodeID::new(0)]), } } @@ -340,10 +337,7 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { control, selection: _, } => NodeUsesMut::One([control]), - Node::DataProjection { - data, - selection: _, - } => NodeUsesMut::One([data]), + Node::DataProjection { data, selection: _ } => NodeUsesMut::One([data]), Node::Undef { ty: _ } => NodeUsesMut::Zero, } } diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index a7f890f8..aff1f9c5 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -353,10 +353,7 @@ fn write_node<W: Write>( control: _, selection, } - | Node::DataProjection { - data: _, - selection, - } => write!(&mut suffix, "{}", selection)?, + | Node::DataProjection { data: _, selection } => write!(&mut suffix, "{}", selection)?, _ => {} }; diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 42730f77..b41b1f6f 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -509,9 +509,13 @@ fn parse_return<'a>( nom::character::complete::char(','), nom::character::complete::multispace0, )), - parse_identifier)(ir_text)?; + parse_identifier, + )(ir_text)?; let control = context.borrow_mut().get_node_id(control); - let data = data.into_iter().map(|d| context.borrow_mut().get_node_id(d)).collect(); + let data = data + .into_iter() + .map(|d| context.borrow_mut().get_node_id(d)) + .collect(); Ok((ir_text, Node::Return { control, data })) } @@ -1027,8 +1031,7 @@ fn parse_constant<'a>( ) -> nom::IResult<&'a str, Constant> { let (ir_text, constant) = match ty { // There are not control constants. - Type::Control - | Type::MultiReturn(_) => Err(nom::Err::Error(nom::error::Error { + Type::Control | Type::MultiReturn(_) => Err(nom::Err::Error(nom::error::Error { input: ir_text, code: nom::error::ErrorKind::IsNot, }))?, diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 919da640..2a3f9fb1 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -441,7 +441,11 @@ fn typeflow( return inputs[0].clone(); } - for (idx, (input, return_type)) in inputs[1..].iter().zip(function.return_types.iter()).enumerate() { + for (idx, (input, return_type)) in inputs[1..] + .iter() + .zip(function.return_types.iter()) + .enumerate() + { if let Concrete(id) = input { if *id != *return_type { return Error(format!("Return node's data input at index {} does not match function's return type.", idx)); @@ -1070,10 +1074,7 @@ fn typeflow( // Type is the type of the _if node inputs[0].clone() } - Node::DataProjection { - data: _, - selection, - } => { + Node::DataProjection { data: _, selection } => { if let Concrete(type_id) = inputs[0] { match &types[type_id.idx()] { Type::MultiReturn(types) => { diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs index 533fd268..0902bc61 100644 --- a/juno_frontend/src/codegen.rs +++ b/juno_frontend/src/codegen.rs @@ -118,11 +118,11 @@ impl CodeGenerator<'_> { param_types.push(solver_inst.lower_type(&mut self.builder.builder, *ty)); } - let return_types = - func.return_types - .iter() - .map(|t| solver_inst.lower_type(&mut self.builder.builder, *t)) - .collect::<Vec<_>>(); + let return_types = func + .return_types + .iter() + .map(|t| solver_inst.lower_type(&mut self.builder.builder, *t)) + .collect::<Vec<_>>(); let (func_id, entry) = self .builder diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index ae696b4f..f0736d2b 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -751,7 +751,7 @@ fn analyze_program( let return_types = rets .into_iter() - .map(|ty| + .map(|ty| { match process_type( ty, num_dyn_const, @@ -769,7 +769,7 @@ fn analyze_program( types.new_primitive(types::Primitive::Unit) } } - ) + }) .collect::<Vec<_>>(); if !errors.is_empty() { @@ -778,7 +778,11 @@ fn analyze_program( // Compute the proper type accounting for the inouts (which become returns) let inout_types = inouts.iter().map(|e| e.get_type()).collect::<Vec<_>>(); - let pure_return_types = return_types.clone().into_iter().chain(inout_types.into_iter()).collect::<Vec<_>>(); + let pure_return_types = return_types + .clone() + .into_iter() + .chain(inout_types.into_iter()) + .collect::<Vec<_>>(); // Finally, we have a properly built environment and we can // start processing the body @@ -802,13 +806,7 @@ fn analyze_program( if return_types.is_empty() { // Insert return at the end body = Stmt::BlockStmt { - body: vec![ - body, - generate_return( - vec![], - &inouts, - ), - ], + body: vec![body, generate_return(vec![], &inouts)], }; } else { Err(singleton_error(ErrorMessage::SemanticError( @@ -1574,16 +1572,11 @@ fn process_stmt( labels: &mut StringTable, ) -> Result<(Stmt, bool), ErrorMessages> { match stmt { - parser::Stmt::LetStmt { - span, - var, - init, - } => { - let (_, pattern, typ) = - match var { - LetBind::Single { span, pattern, typ } => (span, Either::Left(pattern), typ), - LetBind::Multi { span, patterns } => (span, Either::Right(patterns), None), - }; + parser::Stmt::LetStmt { span, var, init } => { + let (_, pattern, typ) = match var { + LetBind::Single { span, pattern, typ } => (span, Either::Left(pattern), typ), + LetBind::Multi { span, patterns } => (span, Either::Right(patterns), None), + }; if typ.is_none() && init.is_none() { return Err(singleton_error(ErrorMessage::SemanticError( @@ -1662,15 +1655,20 @@ fn process_stmt( if return_types.len() != patterns.len() { return Err(singleton_error(ErrorMessage::SemanticError( span_to_loc(span, lexer), - format!("Expected {} pattern, found {} patterns", return_types.len(), patterns.len()), + format!( + "Expected {} pattern, found {} patterns", + return_types.len(), + patterns.len() + ), ))); } // Process each pattern after extracting the appropriate value from the call - for (index, (pat, ret_typ)) in - patterns.into_iter() - .zip(return_types.clone().into_iter()) - .enumerate() { + for (index, (pat, ret_typ)) in patterns + .into_iter() + .zip(return_types.clone().into_iter()) + .enumerate() + { let extract_var = env.uniq(); res.push(Stmt::AssignStmt { var: extract_var, @@ -1678,11 +1676,19 @@ fn process_stmt( call: Box::new(Expr::Variable { var, typ }), index, typ: ret_typ, - } + }, }); res.extend( process_irrefutable_pattern( - pat, false, extract_var, ret_typ, lexer, stringtab, env, types, false + pat, + false, + extract_var, + ret_typ, + lexer, + stringtab, + env, + types, + false, )? .0, ); @@ -1690,7 +1696,6 @@ fn process_stmt( } } - Ok((Stmt::BlockStmt { body: res }, true)) } parser::Stmt::ConstStmt { @@ -2257,7 +2262,8 @@ fn process_stmt( "Expected {} return values found {}", return_types.len(), vals.len(), - )))); + ), + ))); } let return_vals = vals @@ -2276,20 +2282,17 @@ fn process_stmt( ))) } }) - .fold(Ok(vec![]), - |res, val| { - match (res, val) { - (Ok(mut res), Ok(val)) => { - res.push(val); - Ok(res) - } - (Ok(_), Err(msg)) => Err(msg), - (Err(msg), Ok(_)) => Err(msg), - (Err(mut msgs), Err(msg)) => { - msgs.extend(msg); - Err(msgs) - } - } + .fold(Ok(vec![]), |res, val| match (res, val) { + (Ok(mut res), Ok(val)) => { + res.push(val); + Ok(res) + } + (Ok(_), Err(msg)) => Err(msg), + (Err(msg), Ok(_)) => Err(msg), + (Err(mut msgs), Err(msg)) => { + msgs.extend(msg); + Err(msgs) + } })?; // We return both the actual return values and the inout arguments @@ -4911,12 +4914,11 @@ fn process_expr( if !errors.is_empty() { Err(errors) } else { - let single_type = - if return_types.len() == 1 { - Some(return_types[0]) - } else { - None - }; + let single_type = if return_types.len() == 1 { + Some(return_types[0]) + } else { + None + }; let num_returns = return_types.len(); let call = Expr::CallExpr { func, diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs index 6e59169d..4099c567 100644 --- a/juno_frontend/src/types.rs +++ b/juno_frontend/src/types.rs @@ -555,7 +555,6 @@ impl TypeSolver { // Note that MultReturn types never unify with anything (even itself), this is // intentional and makes it so that the only way MultiReturns can be used is to // destruct them - _ => false, } } @@ -707,8 +706,11 @@ impl TypeSolver { | TypeForm::Struct { name, .. } | TypeForm::Union { name, .. } => stringtab(*name), TypeForm::AnyOfKind { kind, .. } => kind.to_string(), - TypeForm::MultiReturn { types } => - types.iter().map(|t| self.to_string(*t, stringtab)).collect::<Vec<_>>().join(", "), + TypeForm::MultiReturn { types } => types + .iter() + .map(|t| self.to_string(*t, stringtab)) + .collect::<Vec<_>>() + .join(", "), } } -- GitLab From 24e0a9b9999038de504bed4e87ca7e5b72835b4f Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 19 Feb 2025 09:50:48 -0600 Subject: [PATCH 05/19] Starting on opts --- hercules_opt/src/ccp.rs | 10 +- hercules_opt/src/editor.rs | 20 ++-- hercules_opt/src/fork_guard_elim.rs | 4 +- hercules_opt/src/inline.rs | 23 +++- hercules_opt/src/loop_bound_canon.rs | 2 +- hercules_opt/src/outline.rs | 75 ++++--------- hercules_opt/src/pred.rs | 2 +- hercules_opt/src/sroa.rs | 152 +++++++++++++++------------ hercules_opt/src/unforkify.rs | 8 +- hercules_opt/src/utils.rs | 34 ++++-- 10 files changed, 176 insertions(+), 154 deletions(-) diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index b626148c..87d23c11 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -697,6 +697,11 @@ fn ccp_flow_function( }), constant: ConstantLattice::bottom(), }, + // Data projections are uninterpretable. + Node::DataProjection { data, selection: _ } => CCPLattice { + reachability: inputs[data.idx()].reachability.clone(), + constant: ConstantLattice::bottom(), + }, Node::IntrinsicCall { intrinsic, args } => { let mut new_reachability = ReachabilityLattice::bottom(); let mut new_constant = ConstantLattice::top(); @@ -961,8 +966,9 @@ fn ccp_flow_function( constant: ConstantLattice::bottom(), } } - // Projection handles reachability when following an if or match. - Node::Projection { control, selection } => match &editor.func().nodes[control.idx()] { + // Control projection handles reachability when following an if or match. + Node::ControlProjection { control, selection } => match &editor.func().nodes[control.idx()] + { Node::If { control: _, cond } => { let cond_constant = &inputs[cond.idx()].constant; let if_reachability = &inputs[control.idx()].reachability; diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 16e5c326..51c27275 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -69,7 +69,7 @@ pub struct FunctionEdit<'a: 'b, 'b> { // Compute a def-use map entries iteratively. updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>, updated_param_types: Option<Vec<TypeID>>, - updated_return_type: Option<TypeID>, + updated_return_types: Option<Vec<TypeID>>, // Keep track of which deleted and added node IDs directly correspond. sub_edits: Vec<(NodeID, NodeID)>, } @@ -208,7 +208,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { added_labels: Vec::new().into(), updated_def_use: BTreeMap::new(), updated_param_types: None, - updated_return_type: None, + updated_return_types: None, sub_edits: vec![], }; @@ -228,7 +228,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { added_labels, updated_def_use, updated_param_types, - updated_return_type, + updated_return_types, sub_edits, } = populated_edit; @@ -358,8 +358,8 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { } // Step 9: update return type if necessary. - if let Some(return_type) = updated_return_type { - editor.function.return_type = return_type; + if let Some(return_types) = updated_return_types { + editor.function.return_types = return_types; } true @@ -768,6 +768,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } Type::Summation(tys) => Constant::Summation(id, 0, self.add_zero_constant(tys[0])), Type::Array(_, _) => Constant::Array(id), + Type::MultiReturn(_) => { + panic!("PANIC: Can't create zero constant for multi-return types.") + } }; self.add_constant(constant_to_construct) } @@ -791,6 +794,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { panic!("PANIC: Can't create one constant of a collection type.") } + Type::MultiReturn(_) => { + panic!("PANIC: Can't create one constant for multi-return types.") + } }; self.add_constant(constant_to_construct) } @@ -835,8 +841,8 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { self.updated_param_types = Some(tys); } - pub fn set_return_type(&mut self, ty: TypeID) { - self.updated_return_type = Some(ty); + pub fn set_return_types(&mut self, tys: Vec<TypeID>) { + self.updated_return_types = Some(tys); } } diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index df40e60f..c480f266 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -95,7 +95,7 @@ fn guarded_fork( }); // Whose predecessor is a read from an if - let Node::Projection { + let Node::ControlProjection { control: if_node, ref selection, } = function.nodes[control.idx()] @@ -226,7 +226,7 @@ fn guarded_fork( return None; }; // Other predecessor needs to be the other projection from the guard's if - let Node::Projection { + let Node::ControlProjection { control: if_node2, ref selection, } = function.nodes[other_pred.idx()] diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index f01b2366..2a5ad9af 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::iter::zip; use hercules_ir::callgraph::*; use hercules_ir::def_use::*; @@ -125,13 +124,25 @@ fn inline_func( assert_eq!(call_pred.as_ref().len(), 1); let call_pred = call_pred.as_ref()[0]; let called_func = called[&function].func(); + let call_users = editor.get_users(id); + let call_projs = call_users + .map(|node_id| { + ( + node_id, + editor.func().nodes[node_id.idx()] + .try_data_proj() + .expect("PANIC: Call user is not a data projection") + .1, + ) + }) + .collect::<Vec<_>>(); // We can't inline calls to functions with multiple returns. let Some(called_return) = single_return_nodes[function.idx()] else { continue; }; let called_return_uses = get_uses(&called_func.nodes[called_return.idx()]); let called_return_pred = called_return_uses.as_ref()[0]; - let called_return_data = called_return_uses.as_ref()[1]; + let called_return_data = &called_return_uses.as_ref()[1..]; // Perform the actual edit. editor.edit(|mut edit| { @@ -209,8 +220,12 @@ fn inline_func( } } - // Finally, delete the call node. - edit = edit.replace_all_uses(id, old_id_to_new_id(called_return_data))?; + // Replace and delete the call's (data projection) users and the call node + for (proj_id, proj_idx) in call_proj { + edit = + edit.replace_all_uses(proj_id, old_id_to_new_id(called_return_data[proj_idx]))?; + edit = edit.delete_node(proj_id)?; + } edit = edit.delete_node(control)?; edit = edit.delete_node(id)?; diff --git a/hercules_opt/src/loop_bound_canon.rs b/hercules_opt/src/loop_bound_canon.rs index 680236f1..a1ad6257 100644 --- a/hercules_opt/src/loop_bound_canon.rs +++ b/hercules_opt/src/loop_bound_canon.rs @@ -113,7 +113,7 @@ pub fn canonicalize_single_loop_bounds( // FIXME: This is quite fragile. let guard_info: Option<(NodeID, NodeID, NodeID, NodeID)> = (|| { - let Node::Projection { + let Node::ControlProjection { control, selection: _, } = editor.node(loop_pred) diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs index 874e75e7..6a9b6084 100644 --- a/hercules_opt/src/outline.rs +++ b/hercules_opt/src/outline.rs @@ -180,12 +180,11 @@ pub fn outline( editor.edit(|mut edit| { // Step 2: assemble the outlined function. let u32_ty = edit.add_type(Type::UnsignedInteger32); - let return_types: Box<[_]> = return_idx_to_inside_id + let return_types: Vec<_> = return_idx_to_inside_id .iter() .map(|id| typing[id.idx()]) .chain(callee_succ_return_idx.map(|_| u32_ty)) .collect(); - let single_return = return_types.len() == 1; let mut outlined = Function { name: format!( @@ -198,11 +197,7 @@ pub fn outline( .map(|id| typing[id.idx()]) .chain(callee_pred_param_idx.map(|_| u32_ty)) .collect(), - return_type: if single_return { - return_types[0] - } else { - edit.add_type(Type::Product(return_types)) - }, + return_types, num_dynamic_constants: edit.get_num_dynamic_constant_params(), entry: false, nodes: vec![], @@ -363,7 +358,6 @@ pub fn outline( outlined.nodes.extend(select_top_phi_inputs); // Add the return nodes. - let cons_id = edit.add_zero_constant(outlined.return_type); for ((exit, _), dom_return_values) in zip(exit_points.iter(), exit_point_dom_return_values.iter()) { @@ -398,29 +392,10 @@ pub fn outline( data_ids.push(cons_node_id); } - // Build the return value - let construct_id = if single_return { - assert!(data_ids.len() == 1); - data_ids.pop().unwrap() - } else { - let mut construct_id = NodeID::new(outlined.nodes.len()); - outlined.nodes.push(Node::Constant { id: cons_id }); - for (idx, data) in data_ids.into_iter().enumerate() { - let write = Node::Write { - collect: construct_id, - data: data, - indices: Box::new([Index::Field(idx)]), - }; - construct_id = NodeID::new(outlined.nodes.len()); - outlined.nodes.push(write); - } - construct_id - }; - // Return the return product. outlined.nodes.push(Node::Return { control: convert_id(*exit), - data: construct_id, + data: data_ids.into(), }); } @@ -515,29 +490,25 @@ pub fn outline( (new_region_id, call_id) }; - // Create the read nodes from the call node to get the outputs of the - // outlined function (if there are multiple returned values) - let output_reads: Vec<_> = if single_return { - vec![call_id] - } else { - (0..return_idx_to_inside_id.len()) - .map(|idx| { - let read = Node::Read { - collect: call_id, - indices: Box::new([Index::Field(idx)]), - }; - edit.add_node(read) - }) - .collect() - }; - let indicator_read = callee_succ_return_idx.map(|idx| { - let read = Node::Read { - collect: call_id, - indices: Box::new([Index::Field(idx)]), + // Create the data projection nodes from the call node to get the outputs of the outlined + // function + let output_projs: Vec<_> = (0..return_idx_to_inside_id.len()) + .map(|idx| { + let proj = Node::DataProjection { + data: call_id, + selection: idx, + }; + edit.add_node(proj) + }) + .collect(); + let indicator_proj = callee_succ_return_idx.map(|idx| { + let proj = Node::DataProjection { + data: call_id, + selection: idx, }; - edit.add_node(read) + edit.add_node(proj) }); - for (old_id, new_id) in zip(return_idx_to_inside_id.iter(), output_reads.iter()) { + for (old_id, new_id) in zip(return_idx_to_inside_id.iter(), output_projs.iter()) { edit = edit.replace_all_uses(*old_id, *new_id)?; } @@ -554,18 +525,18 @@ pub fn outline( }); let cmp_id = edit.add_node(Node::Binary { op: BinaryOperator::EQ, - left: indicator_read.unwrap(), + left: indicator_proj.unwrap(), right: indicator_cons_node_id, }); let if_id = edit.add_node(Node::If { control: if_tree_acc, cond: cmp_id, }); - let false_id = edit.add_node(Node::Projection { + let false_id = edit.add_node(Node::ControlProjection { control: if_id, selection: 0, }); - let true_id = edit.add_node(Node::Projection { + let true_id = edit.add_node(Node::ControlProjection { control: if_id, selection: 1, }); diff --git a/hercules_opt/src/pred.rs b/hercules_opt/src/pred.rs index 644c69d0..ed7c3a85 100644 --- a/hercules_opt/src/pred.rs +++ b/hercules_opt/src/pred.rs @@ -26,7 +26,7 @@ pub fn predication(editor: &mut FunctionEditor, typing: &Vec<TypeID>) { // Look for two projections with the same branch. let preds = preds.into_iter().filter_map(|id| { nodes[id.idx()] - .try_proj() + .try_control_proj() .map(|(branch, selection)| (*id, branch, selection)) }); // Index projections by if branch. diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 8865f863..eff0a729 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -27,9 +27,10 @@ use crate::*; * are broken up into ternary nodes for the individual fields * * - Call: the call node can use a product value as an argument to another - * function, and can produce a product value as a result. Argument values - * will be constructed at the call site and the return value will be broken - * into individual fields + * function, argument values will be constructed at the call site + * + * - DataProjection: data projection nodes can produce a product value that was + * returned by a function, we will break the value into individual fields * * - Read: the read node reads primitive fields from product values - these get * replaced by a direct use of the field value @@ -71,8 +72,9 @@ pub fn sroa( // First: determine all nodes which interact with products (as described above) let mut product_nodes: Vec<NodeID> = vec![]; - // We track call and return nodes separately since they (may) require constructing new products - // for the call's arguments or the return's value + // We track call, data projection, and return nodes separately since they (may) require + // constructing new products for the call's arguments, data projection's value, or a + // returned value let mut call_return_nodes: Vec<NodeID> = vec![]; for node in reverse_postorder { @@ -303,37 +305,57 @@ pub fn sroa( } } - // We add all calls to the call/return list and check their arguments later - Node::Call { .. } => call_return_nodes.push(*node), - Node::Return { control: _, data } if can_sroa_type(editor, types[&data]) => { - call_return_nodes.push(*node) + // We add all calls and returns to the call/return list and check their + // arguments/return values later + Node::Call { .. } | Node::Return { .. } => call_return_nodes.push(*node), + // We add DataProjetion nodes that produce SROAable values + Node::DataProjection { .. } if can_sroa_type(editor, types[&node]) => { + call_return_nodes.push(*node); } _ => (), } } - // Next, we handle calls and returns. For returns, we will insert nodes that read each field of - // the returned product and then write them into a new product. These writes are not put into - // the list of product nodes since they must remain but the reads are so that they will be - // replaced later on. - // For calls, we do a similar process for each (product) argument. Additionally, if the call - // returns a product, we create reads for each field in that product and store it into our - // field map + // Next, we handle calls and returns. For returns, for each returned value that is a product, + // we will insert nodes that read each field of it and then write them into a new product. + // The writes we create are not put into the list of product nodes since they must remain but + // the reads are put in the list so that they will be replaced later on. + // For calls, we do a similar process for each (product) argument. + // For data projection that produce product values, we create reads for each field of that + // product and store it into our field map for node in call_return_nodes { match &editor.func().nodes[node.idx()] { Node::Return { control, data } => { - assert!(can_sroa_type(editor, types[&data])); let control = *control; - let new_data = reconstruct_product(editor, types[&data], *data, &mut product_nodes); - editor.edit(|mut edit| { - let new_return = edit.add_node(Node::Return { - control, - data: new_data, + let data = data.clone(); + + let (new_data, changed) = + data.into_iter() + .fold((vec![], false), |(mut vals, changed), val_id| { + if !can_sroa_type(editor, types[val_id]) { + vals.push(*val_id); + (vals, changed) + } else { + vals.push(reconstruct_product( + editor, + types[val_id], + *val_id, + &mut product_nodes, + )); + (vals, true) + } + }); + if changed { + editor.edit(|mut edit| { + let new_return = edit.add_node(Node::Return { + control, + data: new_data.into(), + }); + edit.sub_edit(node, new_return); + edit.delete_node(node) }); - edit.sub_edit(node, new_return); - edit.delete_node(node) - }); + } } Node::Call { control, @@ -346,53 +368,42 @@ pub fn sroa( let dynamic_constants = dynamic_constants.clone(); let args = args.clone(); - // If the call returns a product that we can sroa, we generate reads for each field - let fields = if can_sroa_type(editor, types[&node]) { - Some(generate_reads(editor, types[&node], node)) - } else { - None - }; + let (new_args, changed) = + args.into_iter() + .fold((vec![], false), |(mut vals, changed), arg| { + if !can_sroa_type(editor, types[arg]) { + vals.push(*arg); + (vals, changed) + } else { + vals.push(reconstruct_product( + editor, + types[arg], + *arg, + &mut product_nodes, + )); + (vals, true) + } + }); - let mut new_args = vec![]; - for arg in args { - if can_sroa_type(editor, types[&arg]) { - new_args.push(reconstruct_product( - editor, - types[&arg], - arg, - &mut product_nodes, - )); - } else { - new_args.push(arg); - } - } - editor.edit(|mut edit| { - let new_call = edit.add_node(Node::Call { - control, - function, - dynamic_constants, - args: new_args.into(), - }); - edit.sub_edit(node, new_call); - let edit = edit.replace_all_uses(node, new_call)?; - let edit = edit.delete_node(node)?; - - // Since we've replaced uses of calls with the new node, we update the type - // information so that we can retrieve the type of the new call if needed - // Because the other nodes we've created so far are only used in very - // particular ways (i.e. are not used by arbitrary nodes) we don't need their - // type information but do for the new calls - types.insert(new_call, types[&node]); - - match fields { - None => {} - Some(fields) => { - field_map.insert(new_call, fields); - } - } + if changed { + editor.edit(|mut edit| { + let new_call = edit.add_node(Node::Call { + control, + function, + dynamic_constants, + args: new_args.into(), + }); + edit.sub_edit(node, new_call); + let edit = edit.replace_all_uses(node, new_call)?; + let edit = edit.delete_node(node)?; - Ok(edit) - }); + Ok(edit) + }); + } + } + Node::DataProjection { .. } => { + assert!(can_sroa_type(editor, types[&node])); + field_map.insert(node, generate_reads(editor, types[&node], node)); } _ => panic!("Processing non-call or return node"), } @@ -1055,6 +1066,7 @@ fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { add_const!(editor, Constant::Array(typ)) } Type::Control => panic!("Cannot create constant of control type"), + Type::MultiReturn(_) => panic!("Cannot create constant of multi-return type"), } } diff --git a/hercules_opt/src/unforkify.rs b/hercules_opt/src/unforkify.rs index b44ed8df..2d6cf7b3 100644 --- a/hercules_opt/src/unforkify.rs +++ b/hercules_opt/src/unforkify.rs @@ -205,11 +205,11 @@ pub fn unforkify( control: fork_control, cond: guard_cond_id, }; - let guard_taken_proj = Node::Projection { + let guard_taken_proj = Node::ControlProjection { control: guard_if_id, selection: 1, }; - let guard_skipped_proj = Node::Projection { + let guard_skipped_proj = Node::ControlProjection { control: guard_if_id, selection: 0, }; @@ -224,11 +224,11 @@ pub fn unforkify( control: join_control, cond: neq_id, }; - let proj_back = Node::Projection { + let proj_back = Node::ControlProjection { control: if_id, selection: 1, }; - let proj_exit = Node::Projection { + let proj_exit = Node::ControlProjection { control: if_id, selection: 0, }; diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 1806d5c7..c165c0a0 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -244,31 +244,42 @@ pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { if returns.len() == 1 { return Some(returns[0]); } - let preds_before_returns: Vec<NodeID> = returns + let preds_before_returns: Box<[NodeID]> = returns .iter() .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[0]) .collect(); - let data_to_return: Vec<NodeID> = returns - .iter() - .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[1]) + + let num_return_data = editor.func().return_types.len(); + let data_to_return: Vec<Box<[NodeID]>> = (0..num_return_data) + .map(|idx| { + returns + .iter() + .map(|ret_id| get_uses(&editor.func().nodes[ret_id.idx()]).as_ref()[idx + 1]) + .collect() + }) .collect(); // All of the old returns get replaced in a single edit. let mut new_return = None; editor.edit(|mut edit| { let region = edit.add_node(Node::Region { - preds: preds_before_returns.into_boxed_slice(), - }); - let phi = edit.add_node(Node::Phi { - control: region, - data: data_to_return.into_boxed_slice(), + preds: preds_before_returns, }); + let return_vals = data_to_return + .into_iter() + .map(|data| { + edit.add_node(Node::Phi { + control: region, + data, + }) + }) + .collect(); for ret in returns { edit = edit.delete_node(ret)?; } new_return = Some(edit.add_node(Node::Return { control: region, - data: phi, + data: return_vals, })); Ok(edit) }); @@ -293,10 +304,11 @@ pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID .filter(|id| editor.func().nodes[id.idx()].is_control()) .next() .unwrap(); - let Node::Return { control, data } = editor.func().nodes[ret.idx()] else { + let Node::Return { control, ref data } = editor.func().nodes[ret.idx()] else { panic!("PANIC: A Hercules function with only two control nodes must have a return node be the other control node, other than the start node.") }; assert_eq!(control, NodeID::new(0), "PANIC: The only other control node in a Hercules function, the return node, is not using the start node."); + let data = data.clone(); let mut region_id = None; editor.edit(|mut edit| { edit = edit.delete_node(ret)?; -- GitLab From 3e5f7614f050d433c73bab6aaa2acf83c61dfa3c Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 19 Feb 2025 14:40:57 -0600 Subject: [PATCH 06/19] Fixes to optimizations --- hercules_ir/src/ir.rs | 8 + hercules_opt/src/inline.rs | 14 +- hercules_opt/src/interprocedural_sroa.rs | 594 +++++++---------------- hercules_opt/src/sroa.rs | 61 +-- juno_scheduler/src/ir.rs | 2 + 5 files changed, 214 insertions(+), 465 deletions(-) diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 68fdc26c..7a0158fb 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1318,6 +1318,14 @@ impl Node { } } + pub fn try_data_proj(&self) -> Option<(NodeID, usize)> { + if let Node::DataProjection { data, selection } = self { + Some((*data, *selection)) + } else { + None + } + } + pub fn try_phi(&self) -> Option<(NodeID, &[NodeID])> { if let Node::Phi { control, data } = self { Some((*control, data)) diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index 2a5ad9af..99187dd2 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -211,6 +211,13 @@ fn inline_func( }, )?; + // Replace and delete the call's (data projection) users + for (proj_id, proj_idx) in call_projs { + let proj_val = called_return_data[proj_idx]; + edit = edit.replace_all_uses(proj_id, old_id_to_new_id(proj_val))?; + edit = edit.delete_node(proj_id)?; + } + // Stitch uses of parameter nodes in the inlined function to the IDs // of arguments provided to the call node. for (node_idx, node) in called_func.nodes.iter().enumerate() { @@ -220,12 +227,7 @@ fn inline_func( } } - // Replace and delete the call's (data projection) users and the call node - for (proj_id, proj_idx) in call_proj { - edit = - edit.replace_all_uses(proj_id, old_id_to_new_id(called_return_data[proj_idx]))?; - edit = edit.delete_node(proj_id)?; - } + // Finally delete the call node edit = edit.delete_node(control)?; edit = edit.delete_node(id)?; diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index 944ef8fd..32fa9cc8 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -5,466 +5,196 @@ use hercules_ir::ir::*; use crate::*; -/** - * Given an editor for each function in a module, return V s.t. - * V[i] = true iff every call node to the function with index i - * is editable. If there are no calls to this function, V[i] = true. - */ -fn get_editable_callsites(editors: &mut Vec<FunctionEditor>) -> Vec<bool> { - let mut callsites_editable = vec![true; editors.len()]; - for editor in editors { - for (idx, (_, function, _, _)) in editor - .func() - .nodes - .iter() - .enumerate() - .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c))) - { - if !editor.is_mutable(NodeID::new(idx)) { - callsites_editable[function.idx()] = false; - } - } - } - callsites_editable -} - -/** - * Given a type tree, return a Vec containing all leaves which are not units. - */ -fn get_nonempty_leaves(edit: &FunctionEdit, type_id: &TypeID) -> Vec<TypeID> { - let ty = edit.get_type(*type_id).clone(); - match ty { - Type::Product(type_ids) => { - let mut leaves = vec![]; - for type_id in type_ids { - leaves.extend(get_nonempty_leaves(&edit, &type_id)) - } - leaves - } - _ => vec![*type_id], - } -} - -/** - * Given a `source` NodeID which produces a product containing - * all nonempty leaves of the type tree for `type_id` in order, build - * a node producing the `type_id`. +/* + * Top-level function for running interprocedural analysis. * - * `offset` represents the index at which to begin reading - * elements of the `source` product. + * IP SROA expects that all nodes in all functions provided to it can be edited, + * since it needs to be able to modify both the functions whose types are being + * changed and call sites of those functions. What functions to run IP SROA on + * is therefore specified by a separate argument. * - * Returns a 3-tuple of - * 1. Node producing the `type` - * 2. "Next" offset, i.e. `offset` + number of reads performed to build (1) - * 3. List of node IDs which read `source` (tracked so that these will not - * be replaced by replace_all_uses_where) + * This optimization also takes an allow_sroa_arrays arguments (like non-IP + * SROA) which controls whether it will break up products of arrays. */ -fn build_uncompressed_product( - edit: &mut FunctionEdit, - source: &NodeID, - type_id: &TypeID, - offset: usize, -) -> (NodeID, usize, Vec<NodeID>) { - let ty = edit.get_type(*type_id).clone(); - match ty { - Type::Product(child_type_ids) => { - // Step 1. Create an empty constant for the type. We'll write - // child values into this constant. - let empty_constant_id = edit.add_zero_constant(*type_id); - let empty_constant_node = edit.add_node(Node::Constant { - id: empty_constant_id, - }); - // Step 2. Build a node that generates each inner type. - // Since `source` contains nonempty leaves *in order*, - // we must process inner types in order; as part of this, - // inner type i+1 must read from where inner type i left off, - // hence we track the `current_offset` at which we are reading. - // Similarly, to combine results of all recursive calls, - // we keep the invariant that, at iteration i+1, currently_writing_to - // is an instance of `type_id` for which the first i elements - // have been populated based on inorder nonempty leaves - // (and, at iteration 0, it is empty). - let mut current_offset = offset; - let mut currently_writing_to = empty_constant_node; - let mut readers = vec![]; - for (idx, child_type_id) in child_type_ids.iter().enumerate() { - let (child_data, next_offset, child_readers) = - build_uncompressed_product(edit, source, child_type_id, current_offset); - current_offset = next_offset; - currently_writing_to = edit.add_node(Node::Write { - collect: currently_writing_to, - data: child_data, - indices: Box::new([Index::Field(idx)]), - }); - readers.extend(child_readers) - } - (currently_writing_to, current_offset, readers) - } - _ => { - // If the type is not a product, then we've reached a nonempty - // leaf, which we must read from source. Since this is a single - // read, the new offset increases by only 1. - let reader = edit.add_node(Node::Read { - collect: *source, - indices: Box::new([Index::Field(offset)]), - }); - (reader, offset + 1, vec![reader]) +pub fn interprocedural_sroa( + editors: &mut Vec<FunctionEditor>, + types: &Vec<Vec<TypeID>>, + func_selection: &Vec<bool>, + allow_sroa_arrays: bool, +) { + let can_sroa_type = |editor: &FunctionEditor, typ: TypeID| { + editor.get_type(typ).is_product() + && (allow_sroa_arrays || !type_contains_array(editor, typ)) + }; + + let callsites = get_callsites(editors); + + for ((func_id, apply), callsites) in (0..func_selection.len()).map(FunctionID::new).zip(func_selection.iter()).zip(callsites.into_iter()) { + if !apply { + continue; } - } -} -/** - * Given a node with a product value, read the product's values - * *in order* into the nonempty leaves of a product type represented - * by type_id. Returns the ID of the resulting node, as well as the IDs - * of all nodes which read from `node_id`. - */ -fn uncompress_product( - edit: &mut FunctionEdit, - node_id: &NodeID, - type_id: &TypeID, -) -> (NodeID, Vec<NodeID>) { - let (uncompressed_value, _, readers) = build_uncompressed_product(edit, node_id, type_id, 0); - (uncompressed_value, readers) -} - -/** -* Let `read_from` be a node with a value of type `type_id`. -* Let `source` be a product value. -* Returns a node representing the value obtained by writing -* nonempty leaves of `read_from` *in order* into `source`, -* starting at `offset`. -* -* `source` should be a product type with at least enough indices -* to support this operation. Typically, `build_compressed_product` -* should be called initially with a `source` created by adding a -* zero constant for the flattened `type_id`. -* -* Returns: -* 1. The ID of the node to which all nonempty leaves have been written -* 2. The first offset after `offset` which was not written to. -*/ -fn build_compressed_product( - mut edit: &mut FunctionEdit, - source: &NodeID, - type_id: &TypeID, - offset: usize, - read_from: &NodeID, -) -> (NodeID, usize) { - let ty = edit.get_type(*type_id).clone(); - match ty { - Type::Product(child_type_ids) => { - // Iterate through child types in order. For each type, construct - // a node that reads the corresponding value from `read_from`, - // and pass it as the node to read from in the recursive call. - let mut next_offset = offset; - let mut next_destination = *source; - for (idx, child_type_id) in child_type_ids.iter().enumerate() { - let child_value = edit.add_node(Node::Read { - collect: *read_from, - indices: Box::new([Index::Field(idx)]), - }); - (next_destination, next_offset) = build_compressed_product( - &mut edit, - &next_destination, - &child_type_id, - next_offset, - &child_value, - ); + let editor: &mut FunctionEditor = &mut editors[func_id.idx()]; + let return_types = &editor.func().return_types.to_vec(); + + // We determine the new return types of the function and track a map + // that tells us how the old return values are constructed from the + // new ones + let mut new_return_types = vec![]; + let mut old_return_type_map = vec![]; + let mut changed = false; + + for ret_typ in return_types.iter() { + if !can_sroa_type(editor, *ret_typ) { + old_return_type_map.push(IndexTree::Leaf(new_return_types.len())); + new_return_types.push(*ret_typ); + } else { + let (types, index) = sroa_type(editor, *ret_typ, new_return_types.len()); + old_return_type_map.push(index); + new_return_types.extend(types); + changed = true; } - (next_destination, next_offset) } - _ => { - let writer = edit.add_node(Node::Write { - collect: *source, - data: *read_from, - indices: Box::new([Index::Field(offset)]), - }); - (writer, offset + 1) - } - } -} - -/** - * Given a node which has a value of the given type (which must be a product) - * generate a new product node created by inserting nonempty leaves of the - * source node *in order*. Returns the ID of this node, as well as the ID of - * its type. - */ -fn compress_product( - edit: &mut FunctionEdit, - node_id: &NodeID, - type_id: &TypeID, -) -> (NodeID, TypeID) { - let nonempty_leaves = get_nonempty_leaves(&edit, &type_id); - let compressed_type = Type::Product(nonempty_leaves.into_boxed_slice()); - let compressed_type_id = edit.add_type(compressed_type); - - let empty_compressed_constant_id = edit.add_zero_constant(compressed_type_id); - let empty_compressed_node_id = edit.add_node(Node::Constant { - id: empty_compressed_constant_id, - }); - - let (compressed_value, _) = - build_compressed_product(edit, &empty_compressed_node_id, type_id, 0, node_id); - - (compressed_value, compressed_type_id) -} - -fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_editable: &Vec<bool>) { - // Track whether we successfully applied edits to return statements, - // so that callsites are only modified when returns were. This is - // initialized to false, so that `is_compressed` is false when - // the corresponding entry in `callsites_editable` is false. - let mut is_compressed = vec![false; editors.len()]; - let old_return_type_ids: Vec<_> = editors - .iter() - .map(|editor| editor.func().return_type) - .collect(); - - // Step 1. Track mapping of dynamic constant indexes to ids, so that - // we can substitute when generating empty constants later. The reason - // this works is that the following property is satisfied: - // Let f and g be two functions such that f has d_f dynamic constants - // and g has d_g dynamic constants. Wlog assume d_f < d_g. Then, the - // first d_f dynamic constants of g are the dynamic constants of f. - // For any call node, the ith dynamic constant in the node is provided - // for the ith dynamic constant of the function called. So, when we need - // to take a type and replace d function dynamic constants with their - // values from a call, it suffices to look at the first d entries of - // dc_param_idx_to_dc_id to get the id of the dynamic constants in the function, - // and then replace dc_param_idx_to_dc_id[i] with call.dynamic_constants[i], - // for all i. - let max_num_dc_params = editors - .iter() - .map(|editor| editor.func().num_dynamic_constants) - .max() - .unwrap(); - let mut dc_args = vec![]; - editors[0].edit(|mut edit| { - dc_args = (0..max_num_dc_params as usize) - .map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i))) - .collect(); - Ok(edit) - }); - // Step 2. Modify the return type of all editors corresponding to a function - // for which we can edit every callsite, and the return type is a product. - for (idx, editor) in editors.iter_mut().enumerate() { - if !all_callsites_editable[idx] { + // If the return type is not changed by IP SROA, skip to the next function + if !changed { continue; } - let old_return_id = NodeID::new( - (0..editor.func().nodes.len()) - .filter(|idx| editor.func().nodes[*idx].is_return()) - .next() - .unwrap(), - ); - let old_return_type_id = old_return_type_ids[idx]; - - is_compressed[idx] = editor.get_type(editor.func().return_type).is_product() - && editor.edit(|mut edit| { - let return_node = edit.get_node(old_return_id); - let (return_control, return_data) = return_node.try_return().unwrap(); - - let (compressed_data_id, compressed_type_id) = - compress_product(&mut edit, &return_data, &old_return_type_id); + // Now, modify each return in the current function and the return type + let return_nodes = editor.func().nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| if node.try_return().is_some() { + Some(NodeID::new(idx)) + } else { + None + }) + .collect::<Vec<_>>(); + let success = editor.edit(|mut edit| { + for node in return_nodes { + let Node::Return { control, data } = edit.get_node(node) else { + panic!() + }; + let control = *control; + let data = data.to_vec(); + + let mut new_data = vec![]; + for (idx, (data_id, update_info)) in data.into_iter().zip(old_return_type_map.iter()).enumerate() { + if let IndexTree::Leaf(new_idx) = update_info { + // Unchanged return value + assert!(new_data.len() == *new_idx); + new_data.push(data_id); + } else { + // SROA'd return value + let reads = generate_reads_edit(&mut edit, return_types[idx], data_id); + reads.zip(update_info).for_each(|_, (read_id, ret_idx)| { + assert!(new_data.len() == **ret_idx); + new_data.push(*read_id); + }); + } + } - edit.set_return_type(compressed_type_id); - let new_return_id = edit.add_node(Node::Return { - control: return_control, - data: compressed_data_id, + let new_ret = edit.add_node(Node::Return { + control, + data: new_data.into(), }); - edit.sub_edit(old_return_id, new_return_id); - let edit = edit.replace_all_uses(old_return_id, new_return_id)?; - edit.delete_node(old_return_id) - }); - } - - // Step 3: For every editor, update all mutable callsites corresponding to - // calls to functions which have been compressed. Since we only compress returns - // for functions for which every callsite is mutable, this should never fail, - // so we panic if it does. - for (_, editor) in editors.iter_mut().enumerate() { - let call_node_ids: Vec<_> = (0..editor.func().nodes.len()) - .map(NodeID::new) - .filter(|id| editor.func().nodes[id.idx()].is_call()) - .filter(|id| editor.is_mutable(*id)) - .collect(); - - for call_node_id in call_node_ids { - let (_, function_id, ref dynamic_constants, _) = - editor.func().nodes[call_node_id.idx()].try_call().unwrap(); - if !is_compressed[function_id.idx()] { - continue; + edit.sub_edit(node, new_ret); + edit = edit.delete_node(node)?; } - // Before creating the uncompressed product, we must update - // the type of the uncompressed product to reflect the dynamic - // constants provided when calling the function. Since we can - // only replace one constant at a time, we need to map - // constants to dummy values, and then map these to the - // replacement values (this prevents the case of replacements - // (0->1), (1->2) causing conflicts when we have [0, 1], we should - // get [1, 2], not [2, 2], which a naive loop would generate). - - // A similar loop exists in the inline pass but at the node level. - // If this becomes a common pattern, it would be worth creating - // a better abstraction around bulk replacement. - - let new_dcs = (*dynamic_constants).to_vec(); - let old_dcs = dc_args[..new_dcs.len()].to_vec(); - assert_eq!(old_dcs.len(), new_dcs.len()); - let substs = old_dcs - .into_iter() - .zip(new_dcs.into_iter()) - .collect::<HashMap<_, _>>(); - - let edit_successful = editor.edit(|mut edit| { - let substituted = substitute_dynamic_constants_in_type( - &substs, - old_return_type_ids[function_id.idx()], - &mut edit, - ); - - let (expanded_product, readers) = - uncompress_product(&mut edit, &call_node_id, &substituted); - edit.replace_all_uses_where(call_node_id, expanded_product, |id| { - !readers.contains(id) - }) - }); - - if !edit_successful { - panic!("Tried and failed to edit mutable callsite!"); + edit.set_return_types(new_return_types); + + Ok(edit) + }); + assert!(success, "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument"); + + // Finally, update calls of this function + // In particular, we actually don't have to update the call node at all but have to update + // its DataProjection users + for (caller, callsite) in callsites { + let editor = &mut editors[caller.idx()]; + assert!(editor.func_id() == caller); + let projs = editor.get_users(callsite).collect::<Vec<_>>(); + for proj_id in projs { + let Node::DataProjection { data: _, selection } = editor.node(proj_id) else { + panic!("Call has a non data-projection user"); + }; + let new_return_info = &old_return_type_map[*selection]; + let typ = types[caller.idx()][proj_id.idx()]; + replace_returned_value(editor, proj_id, typ, new_return_info, callsite); } } } } -fn remove_return_singletons(editors: &mut Vec<FunctionEditor>, all_callsites_editable: &Vec<bool>) { - // Track whether we removed a singleton product from the return of each - // editor's function. Defaults to false so that if the function was not - // edited (i.e. because not all callsites are editable), then no callsites - // will be edited. - let mut singleton_removed = vec![false; editors.len()]; - let old_return_type_ids: Vec<_> = editors - .iter() - .map(|editor| editor.func().return_type) - .collect(); - - // Step 1. For all editors which correspond to a function for whic hall - // callsites are editable, modify their return type by extracting the - // value from the singleton and returning it directly. - for (idx, editor) in editors.iter_mut().enumerate() { - if !all_callsites_editable[idx] { - continue; - } - - let return_type = editor.get_type(old_return_type_ids[idx]).clone(); - singleton_removed[idx] = match return_type { - Type::Product(tys) if tys.len() == 1 && all_callsites_editable[idx] => { - let old_return_id = NodeID::new( - (0..editor.func().nodes.len()) - .filter(|idx| editor.func().nodes[*idx].is_return()) - .next() - .unwrap(), - ); - - editor.edit(|mut edit| { - let (old_control, old_data) = - edit.get_node(old_return_id).try_return().unwrap(); - - let extracted_singleton_id = edit.add_node(Node::Read { - collect: old_data, - indices: Box::new([Index::Field(0)]), - }); - let new_return_id = edit.add_node(Node::Return { - control: old_control, - data: extracted_singleton_id, - }); - edit.sub_edit(old_return_id, new_return_id); - edit.set_return_type(tys[0]); - - edit.delete_node(old_return_id) - }) +fn sroa_type(editor: &FunctionEditor, typ: TypeID, type_index: usize) -> (Vec<TypeID>, IndexTree<usize>) { + match &*editor.get_type(typ) { + Type::Product(ts) => { + let mut res_types = vec![]; + let mut index = type_index; + let mut children = vec![]; + for t in ts { + let (types, child) = sroa_type(editor, *t, index); + index += types.len(); + res_types.extend(types); + children.push(child); } - _ => false, + (res_types, IndexTree::Node(children)) } + _ => (vec![typ], IndexTree::Leaf(type_index)), } +} - // Step 2. For each editor, find all callsites and reconstruct - // the singleton product at each if the return of the corresponding - // function was modified. This should always succeed since we only - // edited functions for which all callsites were mutable, so panic - // if an edit does not succeed. - for editor in editors.iter_mut() { - let call_node_ids: Vec<_> = (0..editor.func().nodes.len()) - .map(NodeID::new) - .filter(|id| editor.func().nodes[id.idx()].is_call()) - .filter(|id| editor.is_mutable(*id)) - .collect(); - - for call_node_id in call_node_ids { - let (_, function, dc_args, _) = - editor.func().nodes[call_node_id.idx()].try_call().unwrap(); - - let dc_args = dc_args.to_vec(); - - if singleton_removed[function.idx()] { - let edit_successful = editor.edit(|mut edit| { - let dc_params = (0..dc_args.len()) - .map(|param_idx| { - edit.add_dynamic_constant(DynamicConstant::Parameter(param_idx)) - }) - .collect::<Vec<_>>(); - let substs = dc_params - .into_iter() - .zip(dc_args.into_iter()) - .collect::<HashMap<_, _>>(); - - let substituted = substitute_dynamic_constants_in_type( - &substs, - old_return_type_ids[function.idx()], - &mut edit, - ); - let empty_constant_id = edit.add_zero_constant(substituted); - let empty_node_id = edit.add_node(Node::Constant { - id: empty_constant_id, - }); - - let restored_singleton_id = edit.add_node(Node::Write { - collect: empty_node_id, - data: call_node_id, - indices: Box::new([Index::Field(0)]), - }); - edit.replace_all_uses_where(call_node_id, restored_singleton_id, |id| { - *id != restored_singleton_id - }) - }); +// Returns a list for each function of the call sites of that function +fn get_callsites(editors: &Vec<FunctionEditor>) -> Vec<Vec<(FunctionID, NodeID)>> { + let mut callsites = vec![vec![]; editors.len()]; - if !edit_successful { - panic!("Tried and failed to edit mutable callsite!"); - } - } + for editor in editors { + let caller = editor.func_id(); + for (callsite, (_, callee, _, _)) in editor + .func() + .nodes + .iter() + .enumerate() + .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c))) { + assert!(editor.is_mutable(NodeID::new(callsite)), "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument"); + callsites[callee.idx()].push((caller, NodeID::new(callsite))); } } + + callsites } -pub fn interprocedural_sroa(editors: &mut Vec<FunctionEditor>) { - // SROA is implemented in two phases. First, we flatten (or "compress") - // all product return types, so that they are only depth 1 products, - // and do not contain any empty products. - // Next, if any return type is now a singleton product, we - // remove the singleton and just retun the type directly. - // We only apply these changes to functions for which - // all their callsites are editable. - let all_callsites_editable = get_editable_callsites(editors); - compress_return_products(editors, &all_callsites_editable); - remove_return_singletons(editors, &all_callsites_editable); +// Replaces a projection node (from before the function signature change) based on the of_new_call +// description (which tells how to construct the value from the new returned values). +fn replace_returned_value( + editor: &mut FunctionEditor, + proj_id: NodeID, + proj_typ: TypeID, + of_new_call: &IndexTree<usize>, + call_node: NodeID, +) { + let constant = generate_constant(editor, proj_typ); + + let success = editor.edit(|mut edit| { + let mut new_val = edit.add_node(Node::Constant { + id: constant, + }); + of_new_call.for_each(|idx, selection| { + let new_proj = edit.add_node(Node::DataProjection { + data: call_node, + selection: *selection, + }); + new_val = edit.add_node(Node::Write { + collect: new_val, + data: new_proj, + indices: idx.clone().into(), + }); + }); - // Run DCE to prevent issues with schedule repair. - for editor in editors.iter_mut() { - dce(editor); - } + edit = edit.replace_all_uses(proj_id, new_val)?; + edit.delete_node(proj_id) + }); + assert!(success); } diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index eff0a729..68a1b25e 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -328,19 +328,19 @@ pub fn sroa( match &editor.func().nodes[node.idx()] { Node::Return { control, data } => { let control = *control; - let data = data.clone(); + let data = data.to_vec(); let (new_data, changed) = data.into_iter() .fold((vec![], false), |(mut vals, changed), val_id| { - if !can_sroa_type(editor, types[val_id]) { - vals.push(*val_id); + if !can_sroa_type(editor, types[&val_id]) { + vals.push(val_id); (vals, changed) } else { vals.push(reconstruct_product( editor, - types[val_id], - *val_id, + types[&val_id], + val_id, &mut product_nodes, )); (vals, true) @@ -366,19 +366,19 @@ pub fn sroa( let control = *control; let function = *function; let dynamic_constants = dynamic_constants.clone(); - let args = args.clone(); + let args = args.to_vec(); let (new_args, changed) = args.into_iter() .fold((vec![], false), |(mut vals, changed), arg| { - if !can_sroa_type(editor, types[arg]) { - vals.push(*arg); + if !can_sroa_type(editor, types[&arg]) { + vals.push(arg); (vals, changed) } else { vals.push(reconstruct_product( editor, - types[arg], - *arg, + types[&arg], + arg, &mut product_nodes, )); (vals, true) @@ -736,7 +736,7 @@ pub fn sroa( }); } -fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool { +pub fn type_contains_array(editor: &FunctionEditor, typ: TypeID) -> bool { match &*editor.get_type(typ) { Type::Array(_, _) => true, Type::Product(ts) | Type::Summation(ts) => { @@ -978,20 +978,31 @@ fn reconstruct_product( // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and // returns an IndexTree that tracks the nodes reading each leaf field -fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { - let res = generate_reads_at_index(editor, typ, val, vec![]); - res +pub fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { + let mut result = None; + + editor.edit(|mut edit| { + result = Some(generate_reads_edit(&mut edit, typ, val)); + Ok(edit) + }); + + result.unwrap() +} + +// The same as generate_reads but for if we have a FunctionEdit rather than a FunctionEditor +pub fn generate_reads_edit(edit: &mut FunctionEdit, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { + generate_reads_at_index_edit(edit, typ, val, vec![]) } // Given a node val of type which at the indices idx has type typ, construct reads of all (leaf) // fields within this sub-value of val and return the correspondence list -fn generate_reads_at_index( - editor: &mut FunctionEditor, +fn generate_reads_at_index_edit( + edit: &mut FunctionEdit, typ: TypeID, val: NodeID, idx: Vec<Index>, ) -> IndexTree<NodeID> { - let ts: Option<Vec<TypeID>> = if let Some(ts) = editor.get_type(typ).try_product() { + let ts: Option<Vec<TypeID>> = if let Some(ts) = edit.get_type(typ).try_product() { Some(ts.into()) } else { None @@ -1004,22 +1015,18 @@ fn generate_reads_at_index( for (i, t) in ts.into_iter().enumerate() { let mut new_idx = idx.clone(); new_idx.push(Index::Field(i)); - fields.push(generate_reads_at_index(editor, t, val, new_idx)); + fields.push(generate_reads_at_index_edit(edit, t, val, new_idx)); } IndexTree::Node(fields) } else { // For non-product types, we've reached a leaf so we generate the read and return it's // information - let mut read_id = None; - editor.edit(|mut edit| { - read_id = Some(edit.add_node(Node::Read { - collect: val, - indices: idx.clone().into(), - })); - Ok(edit) + let read_id = edit.add_node(Node::Read { + collect: val, + indices: idx.into(), }); - IndexTree::Leaf(read_id.expect("Add node canont fail")) + IndexTree::Leaf(read_id) } } @@ -1035,7 +1042,7 @@ macro_rules! add_const { } // Given a type, builds a default constant of that type -fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { +pub fn generate_constant(editor: &mut FunctionEditor, typ: TypeID) -> ConstantID { let t = editor.get_type(typ).clone(); match t { diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index 4ac5a732..a888cf74 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -56,6 +56,7 @@ impl Pass { Pass::Print => num == 1, Pass::Rename => num == 1, Pass::SROA => num == 0 || num == 1, + Pass::InterproceduralSROA => num == 0 || num == 1, Pass::Xdot => num == 0 || num == 1, _ => num == 0, } @@ -70,6 +71,7 @@ impl Pass { Pass::Print => "1", Pass::Rename => "1", Pass::SROA => "0 or 1", + Pass::InterproceduralSROA => "0 or 1", Pass::Xdot => "0 or 1", _ => "0", } -- GitLab From 46d62811a72ae590bea2dc7dc382f42e21e88a97 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 19 Feb 2025 15:13:59 -0600 Subject: [PATCH 07/19] Pass manager fixes and example --- Cargo.lock | 10 ++++ Cargo.toml | 1 + juno_samples/multi_return/Cargo.toml | 21 ++++++++ juno_samples/multi_return/build.rs | 15 ++++++ juno_samples/multi_return/src/cpu.sch | 31 +++++++++++ juno_samples/multi_return/src/gpu.sch | 26 ++++++++++ juno_samples/multi_return/src/main.rs | 22 ++++++++ juno_samples/multi_return/src/multi_return.jn | 32 ++++++++++++ juno_scheduler/src/pm.rs | 52 +++++++++++-------- 9 files changed, 189 insertions(+), 21 deletions(-) create mode 100644 juno_samples/multi_return/Cargo.toml create mode 100644 juno_samples/multi_return/build.rs create mode 100644 juno_samples/multi_return/src/cpu.sch create mode 100644 juno_samples/multi_return/src/gpu.sch create mode 100644 juno_samples/multi_return/src/main.rs create mode 100644 juno_samples/multi_return/src/multi_return.jn diff --git a/Cargo.lock b/Cargo.lock index c438e846..32dc6a0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1347,6 +1347,16 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_multi_return" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_patterns" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 42d28135..01f8cc13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ members = [ "juno_samples/matmul", "juno_samples/median_window", "juno_samples/multi_device", + "juno_samples/multi_return", "juno_samples/patterns", "juno_samples/products", "juno_samples/rodinia/backprop", diff --git a/juno_samples/multi_return/Cargo.toml b/juno_samples/multi_return/Cargo.toml new file mode 100644 index 00000000..0fb3de94 --- /dev/null +++ b/juno_samples/multi_return/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "juno_multi_return" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_multi_return" +path = "src/main.rs" + +[features] +cuda = ["juno_build/cuda", "hercules_rt/cuda"] + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" diff --git a/juno_samples/multi_return/build.rs b/juno_samples/multi_return/build.rs new file mode 100644 index 00000000..3a8f9b1c --- /dev/null +++ b/juno_samples/multi_return/build.rs @@ -0,0 +1,15 @@ +use juno_build::JunoCompiler; + +fn main() { + JunoCompiler::new() + .file_in_src("multi_return.jn") + .unwrap() + .schedule_in_src(if cfg!(feature = "cuda") { + "gpu.sch" + } else { + "cpu.sch" + }) + .unwrap() + .build() + .unwrap(); +} diff --git a/juno_samples/multi_return/src/cpu.sch b/juno_samples/multi_return/src/cpu.sch new file mode 100644 index 00000000..03fb2585 --- /dev/null +++ b/juno_samples/multi_return/src/cpu.sch @@ -0,0 +1,31 @@ +gvn(*); +phi-elim(*); +dce(*); + +ip-sroa(*); +sroa(*); +dce(*); + +forkify(*); +fork-guard-elim(*); +gvn(*); +dce(*); + +inline(*); +delete-uncalled(*); + +let out = auto-outline(*); +cpu(out.rolling_sum_prod); + +fork-fusion(out.rolling_sum_prod); +gvn(*); +dce(*); + +float-collections(*); + +unforkify(*); +gvn(*); +ccp(*); +dce(*); + +gcm(*); diff --git a/juno_samples/multi_return/src/gpu.sch b/juno_samples/multi_return/src/gpu.sch new file mode 100644 index 00000000..e733551d --- /dev/null +++ b/juno_samples/multi_return/src/gpu.sch @@ -0,0 +1,26 @@ +gvn(*); +phi-elim(*); +dce(*); + +ip-sroa(*); +sroa(*); +dce(*); + +forkify(*); +fork-guard-elim(*); +gvn(*); +dce(*); + +inline(*); +delete-uncalled(*); + +let out = auto-outline(*); +gpu(out.rolling_sum_prod); + +fork-fusion(out.rolling_sum_prod); +gvn(*); +dce(*); + +float-collections(*); + +gcm(*); diff --git a/juno_samples/multi_return/src/main.rs b/juno_samples/multi_return/src/main.rs new file mode 100644 index 00000000..63479dba --- /dev/null +++ b/juno_samples/multi_return/src/main.rs @@ -0,0 +1,22 @@ +#![feature(concat_idents)] + +juno_build::juno!("median"); + +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; + +fn main() { + let m = vec![ + 86, 72, 14, 5, 55, 25, 98, 89, 3, 66, 44, 81, 27, 3, 40, 18, 4, 57, 93, 34, 70, 50, 50, 18, + 34, + ]; + let m = HerculesImmBox::from(m.as_slice()); + + let mut r = runner!(median_window); + let res = async_std::task::block_on(async { r.run(m.to()).await }); + assert_eq!(res, 57); +} + +#[test] +fn test_median_window() { + main() +} diff --git a/juno_samples/multi_return/src/multi_return.jn b/juno_samples/multi_return/src/multi_return.jn new file mode 100644 index 00000000..a49df91c --- /dev/null +++ b/juno_samples/multi_return/src/multi_return.jn @@ -0,0 +1,32 @@ +fn rolling_sum<t: number, n: usize>(x: t[n]) -> t, t[n + 1] { + let rolling_sum: t[n + 1]; + let sum = 0; + + for i in 0..n { + rolling_sum[i] = sum; + sum += x[i]; + } + rolling_sum[n] = sum; + + return (sum, rolling_sum); +} + +fn rolling_prod<t: number, n: usize>(x: t[n]) -> t, t[n + 1] { + let rolling_prod: t[n + 1]; + let prod = 1; + + for i in 0..n { + rolling_prod[i] = prod; + prod *= x[i]; + } + rolling_prod[n] = prod; + + return prod, rolling_prod; +} + +#[entry] +fn rolling_sum_prod<n: usize>(x: f32[n]) -> f32[n + 1], f32[n + 1] { + let rsum = rolling_sum::<_, n>(x).1; + let _, rprod = rolling_prod::<_, n>(x); + return rsum, rprod; +} diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index d5c0af27..94f90048 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2077,16 +2077,32 @@ fn run_pass( pm.clear_analyses(); } Pass::InterproceduralSROA => { - assert!(args.is_empty()); - if let Some(_) = selection { - return Err(SchedulerError::PassError { - pass: "interproceduralSROA".to_string(), - error: "must be applied to the entire module".to_string(), - }); - } + let sroa_with_arrays = match args.get(0) { + Some(Value::Boolean { val }) => *val, + Some(_) => { + return Err(SchedulerError::PassError { + pass: "sroa".to_string(), + error: "expected boolean argument".to_string(), + }); + } + None => false, + }; + + let selection = selection_of_functions(pm, selection) + .ok_or_else(|| { + SchedulerError::PassError { + pass: "xdot".to_string(), + error: "expected coarse-grained selection (can't partially xdot a function)".to_string(), + } + })?; + let mut bool_selection = vec![false; pm.functions.len()]; + selection.into_iter().for_each(|func| bool_selection[func.idx()] = true); + + pm.make_typing(); + let typing = pm.typing.take().unwrap(); let mut editors = build_editors(pm); - interprocedural_sroa(&mut editors); + interprocedural_sroa(&mut editors, &typing, &bool_selection, sroa_with_arrays); for func in editors { changed |= func.modified(); @@ -2720,21 +2736,15 @@ fn run_pass( None => true, }; - let mut bool_selection = vec![]; - if let Some(selection) = selection { - bool_selection = vec![false; pm.functions.len()]; - for loc in selection { - let CodeLocation::Function(id) = loc else { - return Err(SchedulerError::PassError { + let selection = selection_of_functions(pm, selection) + .ok_or_else(|| { + SchedulerError::PassError { pass: "xdot".to_string(), error: "expected coarse-grained selection (can't partially xdot a function)".to_string(), - }); - }; - bool_selection[id.idx()] = true; - } - } else { - bool_selection = vec![true; pm.functions.len()]; - } + } + })?; + let mut bool_selection = vec![false; pm.functions.len()]; + selection.into_iter().for_each(|func| bool_selection[func.idx()] = true); pm.make_reverse_postorders(); if force_analyses { -- GitLab From 369b8beee80194974a63352e454b354d90c7e739 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 19 Feb 2025 15:14:15 -0600 Subject: [PATCH 08/19] Collections analysis --- hercules_ir/src/collections.rs | 99 +++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 06a53fdb..fb3e6bbd 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -40,7 +40,7 @@ use crate::*; pub enum CollectionObjectOrigin { Parameter(usize), Constant(NodeID), - Call(NodeID), + DataProjection(NodeID), Undef(NodeID), } @@ -50,7 +50,7 @@ define_id_type!(CollectionObjectID); pub struct FunctionCollectionObjects { objects_per_node: Vec<Vec<CollectionObjectID>>, mutated: Vec<Vec<NodeID>>, - returned: Vec<CollectionObjectID>, + returned: Vec<Vec<CollectionObjectID>>, origins: Vec<CollectionObjectOrigin>, } @@ -92,8 +92,8 @@ impl FunctionCollectionObjects { .map(CollectionObjectID::new) } - pub fn returned_objects(&self) -> &Vec<CollectionObjectID> { - &self.returned + pub fn returned_objects(&self, selection: usize) -> &Vec<CollectionObjectID> { + &self.returned[selection] } pub fn is_mutated(&self, object: CollectionObjectID) -> bool { @@ -155,8 +155,6 @@ pub fn collection_objects( typing: &ModuleTyping, callgraph: &CallGraph, ) -> CollectionObjects { - panic!("Collections analysis needs to be updated to handle multi-return"); - // Analyze functions in reverse topological order, since the analysis of a // function depends on all functions it calls. let mut collection_objects: CollectionObjects = BTreeMap::new(); @@ -167,8 +165,9 @@ pub fn collection_objects( let typing = &typing[func_id.idx()]; let reverse_postorder = &reverse_postorders[func_id.idx()]; - // Find collection objects originating at parameters, constants, calls, - // or undefs. Each node may *originate* one collection object. + // Find collection objects originating at parameters, constants, + // data projections (of calls), or undefs. + // Each of these nodes may *originate* one collection object. let param_origins = func .param_types .iter() @@ -183,24 +182,29 @@ pub fn collection_objects( Node::Constant { id: _ } if !types[typing[idx].idx()].is_primitive() => { Some(CollectionObjectOrigin::Constant(NodeID::new(idx))) } - Node::Call { - control: _, - function: callee, - dynamic_constants: _, - args: _, - } if { + Node::DataProjection { data, selection } => { + let Node::Call { + control: _, + function: callee, + dynamic_constants: _, + args: _, + } = func.nodes[data.idx()] else { + panic!("Data-projection's data is not a call node"); + }; + let fco = &collection_objects[&callee]; - fco.returned - .iter() - .any(|returned| fco.origins[returned.idx()].try_parameter().is_none()) - } => - { - // If the callee may return a new collection object, then - // this call node originates a single collection object. The - // node may output multiple collection objects, say if the - // callee may return an object passed in as a parameter - - // this is determined later. - Some(CollectionObjectOrigin::Call(NodeID::new(idx))) + if fco.returned[*selection] + .iter() + .any(|returned| fco.origins[returned.idx()].try_parameter().is_some()) { + // If the callee may return a new collection object, then + // this data projection node originates a single collection object. The + // node may output multiple collection objects, say if the + // callee may return an object passed in as a parameter - + // this is determined later. + Some(CollectionObjectOrigin::DataProjection(NodeID::new(idx))) + } else { + None + } } Node::Undef { ty: _ } if !types[typing[idx].idx()].is_primitive() => { Some(CollectionObjectOrigin::Undef(NodeID::new(idx))) @@ -218,8 +222,8 @@ pub fn collection_objects( // - Reduce: reduces over an object, similar to phis. // - Parameter: may originate an object. // - Constant: may originate an object. - // - Call: may originate an object and may return an object passed in as - // a parameter. + // - DataProjection: may originate an object and may return an object + // passed in to its associated call as a parameter. // - LibraryCall: may return an object passed in as a parameter, but may // not originate an object. // - Read: may extract a smaller object from the input - this is @@ -230,7 +234,13 @@ pub fn collection_objects( // mutation. // - Undef: may originate a dummy object. // - Ternary (select): selects between two objects, may output either. - let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| { + let lattice = dataflow_global(func, reverse_postorder, |global_input, id| { + let inputs = get_uses(&func.nodes[id.idx()]) + .as_ref() + .iter() + .map(|id| &global_input[id.idx()]) + .collect::<Vec<_>>(); + match func.nodes[id.idx()] { Node::Phi { control: _, @@ -269,22 +279,27 @@ pub fn collection_objects( objs: obj.into_iter().collect(), } } - Node::Call { - control: _, - function: callee, - dynamic_constants: _, - args: _, - } if !types[typing[id.idx()].idx()].is_primitive() => { + Node::DataProjection { data, selection } + if !types[typing[id.idx()].idx()].is_primitive() => { + let Node::Call { + control: _, + function: callee, + dynamic_constants: _, + ref args, + } = func.nodes[data.idx()] else { + panic!(); + }; + let new_obj = origins .iter() - .position(|origin| *origin == CollectionObjectOrigin::Call(id)) + .position(|origin| *origin == CollectionObjectOrigin::DataProjection(id)) .map(CollectionObjectID::new); let fco = &collection_objects[&callee]; let param_objs = fco - .returned + .returned[selection] .iter() .filter_map(|returned| fco.origins[returned.idx()].try_parameter()) - .map(|param_index| inputs[param_index + 1]); + .map(|param_index| &global_input[args[param_index].idx()]); let mut objs: BTreeSet<_> = new_obj.into_iter().collect(); for param_objs in param_objs { @@ -326,16 +341,16 @@ pub fn collection_objects( .map(|l| l.objs.into_iter().collect()) .collect(); - // Look at the collection objects that each return may take as input. - let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new(); + // Look at the collection objects that each return value may take as input. + let mut returned: Vec<BTreeSet<CollectionObjectID>> = vec![BTreeSet::new(); func.return_types.len()]; for node in func.nodes.iter() { if let Node::Return { control: _, data } = node { - for node in data { - returned.extend(&objects_per_node[node.idx()]); + for (idx, node) in data.iter().enumerate() { + returned[idx].extend(&objects_per_node[node.idx()]); } } } - let returned = returned.into_iter().collect(); + let returned = returned.into_iter().map(|set| set.into_iter().collect()).collect(); // Determine which objects are potentially mutated. let mut mutated = vec![vec![]; origins.len()]; -- GitLab From fac3f220759f59826bed3c7d5828c0f4fb468a10 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Wed, 19 Feb 2025 16:10:25 -0600 Subject: [PATCH 09/19] Fix collections and gcm --- hercules_cg/src/lib.rs | 4 +- hercules_ir/src/collections.rs | 4 ++ hercules_ir/src/ir.rs | 8 +++ hercules_ir/src/parse.rs | 10 +-- hercules_opt/src/gcm.rs | 64 +++++++++++++------ juno_samples/multi_return/src/cpu.sch | 5 ++ juno_samples/multi_return/src/multi_return.jn | 2 +- 7 files changed, 70 insertions(+), 27 deletions(-) diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index af2420d8..446231de 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -23,7 +23,7 @@ pub const LARGEST_ALIGNMENT: usize = 32; */ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize { match types[ty.idx()] { - Type::Control => panic!(), + Type::Control | Type::MultiReturn(_) => panic!(), Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => 1, Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => 2, Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4, @@ -46,7 +46,7 @@ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize { pub type FunctionNodeColors = ( BTreeMap<NodeID, Device>, Vec<Option<Device>>, - Option<Device>, + Vec<Option<Device>>, ); pub type NodeColors = BTreeMap<FunctionID, FunctionNodeColors>; diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index fb3e6bbd..cc0703ab 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -96,6 +96,10 @@ impl FunctionCollectionObjects { &self.returned[selection] } + pub fn all_returned_objects(&self) -> impl Iterator<Item = CollectionObjectID> + '_ { + self.returned.iter().flat_map(|colls| colls.iter().map(|c| *c)) + } + pub fn is_mutated(&self, object: CollectionObjectID) -> bool { !self.mutators(object).is_empty() } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 7a0158fb..3d625a39 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -863,6 +863,14 @@ impl Type { } } + pub fn is_multireturn(&self) -> bool { + if let Type::MultiReturn(_) = self { + true + } else { + false + } + } + pub fn is_bool(&self) -> bool { self == &Type::Boolean } diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index a019f4d3..d61ff6e7 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -248,11 +248,11 @@ fn parse_function<'a>( let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0; let (ir_text, return_types) = nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), |text| parse_type_id(text, context), ).parse(ir_text)?; let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context)).parse(ir_text)?; @@ -506,13 +506,13 @@ fn parse_return<'a>( let ir_text = nom::character::complete::char(',')(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, data) = nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), parse_identifier, - )(ir_text)?; + ).parse(ir_text)?; let control = context.borrow_mut().get_node_id(control); let data = data .into_iter() diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index f2405893..c2ec4e94 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -127,7 +127,7 @@ pub fn gcm( let mut alignments = vec![]; Ref::map(editor.get_types(), |types| { for idx in 0..types.len() { - if types[idx].is_control() { + if types[idx].is_control() || types[idx].is_multireturn() { alignments.push(0); } else { alignments.push(get_type_alignment(types, TypeID::new(idx))); @@ -255,6 +255,15 @@ fn basic_blocks( dynamic_constants: _, args: _, } => bbs[idx] = Some(control), + Node::DataProjection { + data, + selection: _, + } => { + let Node::Call { control, .. } = function.nodes[data.idx()] else { + panic!(); + }; + bbs[idx] = Some(control); + } Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)), _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)), _ => {} @@ -508,7 +517,7 @@ fn basic_blocks( && objects[&func_id] .objects(id) .into_iter() - .any(|obj| objects[&func_id].returned_objects().contains(obj)); + .any(|obj| objects[&func_id].all_returned_objects().any(|ret| ret == *obj)); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); @@ -646,9 +655,9 @@ fn terminating_reads<'a>( ref args, } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| { let objects = &objects[&callee]; - let returns = objects.returned_objects(); + let mut returns = objects.all_returned_objects(); let param_obj = objects.param_to_object(idx)?; - if !objects.is_mutated(param_obj) && !returns.contains(¶m_obj) { + if !objects.is_mutated(param_obj) && !returns.any(|ret| ret == param_obj) { Some(*arg) } else { None @@ -692,9 +701,9 @@ fn forwarding_reads<'a>( ref args, } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| { let objects = &objects[&callee]; - let returns = objects.returned_objects(); + let mut returns = objects.all_returned_objects(); let param_obj = objects.param_to_object(idx)?; - if !objects.is_mutated(param_obj) && returns.contains(¶m_obj) { + if !objects.is_mutated(param_obj) && returns.any(|ret| ret == param_obj) { Some(*arg) } else { None @@ -1218,7 +1227,7 @@ fn color_nodes( let mut func_colors = ( BTreeMap::new(), vec![None; editor.func().param_types.len()], - None, + vec![None; editor.func().return_types.len()], ); // Assigning nodes to devices is tricky due to function calls. Technically, @@ -1320,17 +1329,31 @@ fn color_nodes( equations.push((UTerm::Node(*arg), UTerm::Device(device))); } } + } + Node::DataProjection { + data, + selection, + } => { + let Node::Call { + control: _, + function: callee, + dynamic_constants: _, + ref args, + } = &nodes[data.idx()] else { + panic!() + }; - // If the callee has a definite device for the returned value, - // add an equation for the call node itself. - if let Some(device) = node_colors[&callee].2 { + // If the callee has a definite device for this returned value, + // add an equation for the data projection node itself. + if let Some(device) = node_colors[&callee].2[selection] { equations.push((UTerm::Node(id), UTerm::Device(device))); } - // For any object that may be returned by the callee that - // originates as a parameter in the callee, the device of the - // corresponding argument and call node itself must be equal. - for ret in objects[&callee].returned_objects() { + // For any object that may be returned in this position by the + // callee that originates as a parameter in the callee, the + // device of the corresponding argument and the data projection + // must be equal. + for ret in objects[&callee].returned_objects(selection) { if let Some(idx) = objects[&callee].origin(*ret).try_parameter() { equations.push((UTerm::Node(args[idx]), UTerm::Node(id))); } @@ -1365,11 +1388,13 @@ fn color_nodes( { assert!(func_colors.1[index].is_none(), "PANIC: Found multiple parameter nodes for the same index in GCM. Please just run GVN first."); func_colors.1[index] = Some(*device); - } else if let Node::Return { control: _, data } = nodes[id.idx()] - && let Some(device) = func_colors.0.get(&data) - { - assert!(func_colors.2.is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix."); - func_colors.2 = Some(*device); + } else if let Node::Return { control: _, ref data } = nodes[id.idx()] { + for (idx, val) in data.iter().enumerate() { + if let Some(device) = func_colors.0.get(val) { + assert!(func_colors.2[idx].is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix."); + func_colors.2[idx] = Some(*device); + } + } } } Some(func_colors) @@ -1420,6 +1445,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> let ty = edit.get_type(ty_id).clone(); let size = match ty { Type::Control => panic!(), + Type::MultiReturn(_) => panic!(), Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { edit.add_dynamic_constant(DynamicConstant::Constant(1)) } diff --git a/juno_samples/multi_return/src/cpu.sch b/juno_samples/multi_return/src/cpu.sch index 03fb2585..972405f5 100644 --- a/juno_samples/multi_return/src/cpu.sch +++ b/juno_samples/multi_return/src/cpu.sch @@ -4,6 +4,10 @@ dce(*); ip-sroa(*); sroa(*); + +ip-sroa[true](rolling_sum); +sroa[true](rolling_sum, rolling_sum_prod); + dce(*); forkify(*); @@ -29,3 +33,4 @@ ccp(*); dce(*); gcm(*); +xdot[true](*); diff --git a/juno_samples/multi_return/src/multi_return.jn b/juno_samples/multi_return/src/multi_return.jn index a49df91c..84bab015 100644 --- a/juno_samples/multi_return/src/multi_return.jn +++ b/juno_samples/multi_return/src/multi_return.jn @@ -1,4 +1,4 @@ -fn rolling_sum<t: number, n: usize>(x: t[n]) -> t, t[n + 1] { +fn rolling_sum<t: number, n: usize>(x: t[n]) -> (t, t[n + 1]) { let rolling_sum: t[n + 1]; let sum = 0; -- GitLab From 0c59f141585a275d5ff826351c1f233c91af1798 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 20 Feb 2025 11:31:42 -0600 Subject: [PATCH 10/19] Multi return cpu functions --- hercules_cg/src/cpu.rs | 82 +++++++++++++++++++++++++++++++++--------- 1 file changed, 66 insertions(+), 16 deletions(-) diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 27daf2a1..45c0f467 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -60,19 +60,33 @@ struct LLVMBlock { impl<'a> CPUContext<'a> { fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> { // Dump the function signature. - if self.types[self.function.return_type.idx()].is_primitive() { - write!( - w, - "define dso_local {} @{}(", - self.get_type(self.function.return_type), - self.function.name - )?; + if self.function.return_types.len() == 1 { + let return_type = self.function.return_types[0]; + if self.types[return_type.idx()].is_primitive() { + write!( + w, + "define dso_local {} @{}(", + self.get_type(return_type), + self.function.name + )?; + } else { + write!( + w, + "define dso_local nonnull noundef {} @{}(", + self.get_type(return_type), + self.function.name + )?; + } } else { write!( w, - "define dso_local nonnull noundef {} @{}(", - self.get_type(self.function.return_type), - self.function.name + "%return.{} = type {{ {} }}\n", + self.function.name, + self.function.return_types + .iter() + .map(|t| self.get_type(*t)) + .collect::<Vec<_>>() + .join(", "), )?; } let mut first_param = true; @@ -110,6 +124,19 @@ impl<'a> CPUContext<'a> { )?; } } + // Lastly, if the function has multiple returns, is a pointer to the return struct + if self.function.return_types.len() != 1 { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!( + w, + "ptr noalias nofree nonnull noundef sret(%return.{}) %ret.ptr", + self.function.name, + )?; + } write!(w, ") {{\n")?; let mut blocks: BTreeMap<_, _> = (0..self.function.nodes.len()) @@ -171,7 +198,7 @@ impl<'a> CPUContext<'a> { // successor and are otherwise simple. Node::Start | Node::Region { preds: _ } - | Node::Projection { + | Node::ControlProjection { control: _, selection: _, } => { @@ -186,7 +213,7 @@ impl<'a> CPUContext<'a> { let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); - let succ1_is_true = self.function.nodes[succ1.idx()].try_projection(1).is_some(); + let succ1_is_true = self.function.nodes[succ1.idx()].try_control_projection(1).is_some(); write!( term, " br {}, label %{}, label %{}\n", @@ -195,9 +222,32 @@ impl<'a> CPUContext<'a> { self.get_block_name(if succ1_is_true { succ2 } else { succ1 }), )? } - Node::Return { control: _, data } => { - let term = &mut blocks.get_mut(&id).unwrap().term; - write!(term, " ret {}\n", self.get_value(data, true))? + Node::Return { control: _, ref data } => { + if data.len() == 1 { + let ret_data = data[0]; + let term = &mut blocks.get_mut(&id).unwrap().term; + write!(term, " ret {}\n", self.get_value(ret_data, true))? + } else { + let term = &mut blocks.get_mut(&id).unwrap().term; + // Generate gep and stores into the output pointer + for (idx, val) in data.iter().enumerate() { + write!( + term, + " %ret_ptr.{} = getelementptr inbounds %return.{}, ptr %ret.ptr, i32 0, i32 {}\n", + idx, + self.function.name, + idx, + )?; + write!( + term, + " store {}, ptr %ret_ptr.{}\n", + self.get_value(*val, true), + idx, + )?; + } + // Finally return void + write!(term, " ret void\n")? + } } _ => panic!( "PANIC: Can't lower {:?} in {}.", @@ -808,7 +858,7 @@ impl<'a> CPUContext<'a> { */ fn codegen_type_size(&self, ty: TypeID, body: &mut String) -> Result<String, Error> { match self.types[ty.idx()] { - Type::Control => panic!(), + Type::Control | Type::MultiReturn(_) => panic!(), Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { Ok("1".to_string()) } -- GitLab From 44f38882fd4cc3719dc05f12ea986b895bbde7da Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 20 Feb 2025 14:18:53 -0600 Subject: [PATCH 11/19] Handle multi-return in async rust functions --- hercules_cg/src/rt.rs | 238 +++++++++++++++++++++++++++++++++--------- hercules_ir/src/ir.rs | 8 ++ 2 files changed, 194 insertions(+), 52 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index d3013239..26ca9d41 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -192,40 +192,70 @@ impl<'a> RTContext<'a> { } write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?; } - write!(w, ") -> {} {{", self.get_type(func.return_type))?; + write!(w, ") -> ")?; + self.write_rust_return_type(w, &func.return_types)?; + write!(w, " {{")?; // Dump signatures for called device functions. - write!(w, "extern \"C\" {{")?; + // For single-return functions we directly expose the device function + // while for multi-return functions we generate a wrapper which handles + // allocation of the return struct and extracting values from it. This + // ensures that device function signatures match what they would be in + // AsyncRust for callee_id in self.callgraph.get_callees(self.func_id) { if self.devices[callee_id.idx()] == Device::AsyncRust { continue; } let callee = &self.module.functions[callee_id.idx()]; - write!(w, "fn {}(", callee.name)?; - let mut first_param = true; - if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { - first_param = false; - write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?; + let is_single_return = callee.return_types.len() == 1; + if is_single_return { + write!(w, "extern \"C\" {{")?; } - for idx in 0..callee.num_dynamic_constants { - if first_param { + self.write_device_signature_async(w, *callee_id)?; + if is_single_return { + write!(w, ";}}")?; + } else { + // Generate the wrapper function for multi-return device functions + write!(w, " {{ ")?; + // Define the return struct + write!(w, "#[repr(C)] struct ReturnStruct {{ {} }} ", + callee.return_types + .iter() + .enumerate() + .map(|(idx, t)| format!("f{}: {}", idx, self.get_type(*t))) + .collect::<Vec<_>>() + .join(", "), + )?; + // Declare the extern function's signature + write!(w, "extern \"C\" {{ ")?; + self.write_device_signature(w, *callee_id)?; + write!(w, "; }}")?; + // Create the return struct + write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?; + // Call the device function + write!(w, "{}(", callee.name)?; + if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { first_param = false; - } else { - write!(w, ", ")?; + write!(w, "backing, ")?; } - write!(w, "dc{}: u64", idx)?; - } - for (idx, ty) in callee.param_types.iter().enumerate() { - if first_param { - first_param = false; - } else { - write!(w, ", ")?; + for idx in 0..callee.num_dynamic_constants { + write!(w, "dc{}, ", idx)?; } - write!(w, "p{}: {}", idx, self.get_type(*ty))?; + for idx in 0..callee.param_types.len() { + write!(w, "p{}, ", idx)?; + } + write!(w, "ret_struct.as_mut_ptr());")?; + // Extract the result into a Rust product + write!(w, "let ret_struct = ret_struct.assume_init();")?; + write!(w, "({})", + (0..callee.return_types.len()) + .map(|idx| format!("ret_struct.f{}", idx)) + .collect::<Vec<_>>() + .join(", "), + )?; + write!(w, "}}")?; } - write!(w, ") -> {};", self.get_type(callee.return_type))?; } - write!(w, "}}")?; // Set up the root environment for the function. An environment is set // up for every created task in async closures, and there needs to be a @@ -301,7 +331,7 @@ impl<'a> RTContext<'a> { // successor and are otherwise simple. Node::Start | Node::Region { preds: _ } - | Node::Projection { + | Node::ControlProjection { control: _, selection: _, } => { @@ -320,7 +350,7 @@ impl<'a> RTContext<'a> { let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); - let succ1_is_true = func.nodes[succ1.idx()].try_projection(1).is_some(); + let succ1_is_true = func.nodes[succ1.idx()].try_control_projection(1).is_some(); write!( epilogue, "control_token = if {} {{{}}} else {{{}}};}}", @@ -329,11 +359,20 @@ impl<'a> RTContext<'a> { if succ1_is_true { succ2 } else { succ1 }.idx(), )?; } - Node::Return { control: _, data } => { + Node::Return { control: _, ref data } => { let prologue = &mut blocks.get_mut(&id).unwrap().prologue; write!(prologue, "{} => {{", id.idx())?; let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; - write!(epilogue, "return {};}}", self.get_value(data, id, false))?; + if data.len() == 1 { + write!(epilogue, "return {};}}", self.get_value(data[0], id, false))?; + } else { + write!(epilogue, "return ({});}}", + data.iter() + .map(|v| self.get_value(*v, id, false)) + .collect::<Vec<_>>() + .join(", "), + )?; + } } // Fork nodes open a new environment for defining an async closure. Node::Fork { @@ -574,8 +613,8 @@ impl<'a> RTContext<'a> { ref args, } => { assert_eq!(control, bb); - // The device backends ensure that device functions have the - // same interface as AsyncRust functions. + // The device backends and the wrappers we generated earlier ensure that device + // functions have the same interface as AsyncRust functions. let block = &mut blocks.get_mut(&bb).unwrap(); let block = &mut block.data; let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall); @@ -628,6 +667,13 @@ impl<'a> RTContext<'a> { } write!(block, "){};", postfix)?; } + Node::DataProjection { data, selection } => { + let block = &mut blocks.get_mut(&bb).unwrap().data; + write!(block, "{} = {}.{};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), + selection)?; + } Node::LibraryCall { library_function, ref args, @@ -1008,7 +1054,7 @@ impl<'a> RTContext<'a> { */ fn codegen_type_size(&self, ty: TypeID) -> String { match self.module.types[ty.idx()] { - Type::Control => panic!(), + Type::Control | Type::MultiReturn(_) => panic!(), Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { "1".to_string() } @@ -1116,15 +1162,7 @@ impl<'a> RTContext<'a> { if is_reduce_on_child { "reduce" } else { "node" }, idx, self.get_type(self.typing[idx]), - if self.module.types[self.typing[idx].idx()].is_bool() { - "false" - } else if self.module.types[self.typing[idx].idx()].is_integer() { - "0" - } else if self.module.types[self.typing[idx].idx()].is_float() { - "0.0" - } else { - "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" - } + self.get_default_value(self.typing[idx]), )?; } } @@ -1402,8 +1440,101 @@ impl<'a> RTContext<'a> { } } - fn get_type(&self, id: TypeID) -> &'static str { - convert_type(&self.module.types[id.idx()]) + fn get_type(&self, id: TypeID) -> String { + convert_type(&self.module.types[id.idx()], &self.module.types) + } + + fn get_default_value(&self, idx: TypeID) -> String { + let typ = &self.module.types[idx.idx()]; + if typ.is_bool() { + "false".to_string() + } else if typ.is_integer() { + "0".to_string() + } else if typ.is_float() { + "0.0".to_string() + } else if let Some(ts) = typ.try_multi_return() { + format!("({})", ts.iter().map(|t| self.get_default_value(*t)).collect::<Vec<_>>().join(", ")) + } else { + "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())".to_string() + } + } + + fn write_rust_return_type<W: Write>(&self, w: &mut W, tys: &[TypeID]) -> Result<(), Error> { + if tys.len() == 1 { + write!(w, "{}", self.get_type(tys[0])) + } else { + write!(w, "({})", + tys.iter() + .map(|t| self.get_type(*t)) + .collect::<Vec<_>>() + .join(", "), + ) + } + } + + // Writes the signature of a device function as if it were an async function, in particular + // this means that if the function is multi-return it will return a product in the produced + // Rust code + // Writes from the "fn" keyword up to the end of the return type + fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID) -> Result<(), Error> { + let func = &self.module.functions[func_id.idx()]; + write!(w, "fn {}(", func.name)?; + let mut first_param = true; + if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) { + first_param = false; + write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?; + } + for idx in 0..func.num_dynamic_constants { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!(w, "dc{}: u64", idx)?; + } + for (idx, ty) in func.param_types.iter().enumerate() { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!(w, "p{}: {}", idx, self.get_type(*ty))?; + } + write!(w, ") -> ")?; + self.write_rust_return_type(w, &func.return_types) + } + + // Writes the true signature of a device function + // Compared to the _async version this converts multi-return into a return struct + fn write_device_signature<W: Write>(&self, w: &mut W, func_id: FunctionID) -> Result<(), Error> { + let func = &self.module.functions[func_id.idx()]; + write!(w, "fn {}(", func.name)?; + let mut first_param = true; + if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) { + first_param = false; + write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?; + } + for idx in 0..func.num_dynamic_constants { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!(w, "dc{}: u64", idx)?; + } + for (idx, ty) in func.param_types.iter().enumerate() { + if first_param { + first_param = false; + } else { + write!(w, ", ")?; + } + write!(w, "p{}: {}", idx, self.get_type(*ty))?; + } + if func.return_types.len() == 1 { + write!(w, ") -> {}", self.get_type(func.return_types[0])) + } else { + write!(w, ", ret_ptr: *mut ReturnStruct)") + } } fn library_prim_ty(&self, id: TypeID) -> &'static str { @@ -1426,21 +1557,24 @@ impl<'a> RTContext<'a> { } } -fn convert_type(ty: &Type) -> &'static str { +fn convert_type(ty: &Type, types: &[Type]) -> String { match ty { - Type::Boolean => "bool", - Type::Integer8 => "i8", - Type::Integer16 => "i16", - Type::Integer32 => "i32", - Type::Integer64 => "i64", - Type::UnsignedInteger8 => "u8", - Type::UnsignedInteger16 => "u16", - Type::UnsignedInteger32 => "u32", - Type::UnsignedInteger64 => "u64", - Type::Float32 => "f32", - Type::Float64 => "f64", + Type::Boolean => "bool".to_string(), + Type::Integer8 => "i8".to_string(), + Type::Integer16 => "i16".to_string(), + Type::Integer32 => "i32".to_string(), + Type::Integer64 => "i64".to_string(), + Type::UnsignedInteger8 => "u8".to_string(), + Type::UnsignedInteger16 => "u16".to_string(), + Type::UnsignedInteger32 => "u32".to_string(), + Type::UnsignedInteger64 => "u64".to_string(), + Type::Float32 => "f32".to_string(), + Type::Float64 => "f64".to_string(), Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { - "::hercules_rt::__RawPtrSendSync" + "::hercules_rt::__RawPtrSendSync".to_string() + } + Type::MultiReturn(ts) => { + format!("({})", ts.iter().map(|t| convert_type(&types[t.idx()], types)).collect::<Vec<_>>().join(", ")) } _ => panic!(), } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 3d625a39..d69c3cd7 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -979,6 +979,14 @@ impl Type { } } + pub fn try_multi_return(&self) -> Option<&[TypeID]> { + if let Type::MultiReturn(ts) = self { + Some(ts) + } else { + None + } + } + pub fn num_bits(&self) -> u8 { match self { Type::Boolean => 1, -- GitLab From 260f6f4737f049f049d733cd33c8bdf1b761715f Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 20 Feb 2025 15:08:13 -0600 Subject: [PATCH 12/19] Fixes for cpu and rt --- hercules_cg/src/cpu.rs | 5 + hercules_cg/src/rt.rs | 176 ++++++++++++++++---------- juno_samples/multi_return/src/cpu.sch | 1 - juno_samples/multi_return/src/main.rs | 19 ++- 4 files changed, 121 insertions(+), 80 deletions(-) diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 45c0f467..509b98c5 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -88,6 +88,11 @@ impl<'a> CPUContext<'a> { .collect::<Vec<_>>() .join(", "), )?; + write!( + w, + "define dso_local void @{}(", + self.function.name, + )?; } let mut first_param = true; // The first parameter is a pointer to CPU backing memory, if it's diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 26ca9d41..8a41b08d 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -211,7 +211,7 @@ impl<'a> RTContext<'a> { if is_single_return { write!(w, "extern \"C\" {{")?; } - self.write_device_signature_async(w, *callee_id)?; + self.write_device_signature_async(w, *callee_id, !is_single_return)?; if is_single_return { write!(w, ";}}")?; } else { @@ -1200,9 +1200,9 @@ impl<'a> RTContext<'a> { // references. let func = self.get_func(); let param_devices = &self.node_colors.1; - let return_device = self.node_colors.2; + let return_devices = &self.node_colors.2; let mut param_muts = vec![false; func.param_types.len()]; - let mut return_mut = true; + let mut return_muts = vec![true; func.return_types.len()]; let objects = &self.collection_objects[&self.func_id]; for idx in 0..func.param_types.len() { if let Some(object) = objects.param_to_object(idx) @@ -1211,11 +1211,14 @@ impl<'a> RTContext<'a> { param_muts[idx] = true; } } - for object in objects.returned_objects() { - if let Some(idx) = objects.origin(*object).try_parameter() - && !param_muts[idx] - { - return_mut = false; + let num_returns = func.return_types.len(); + for idx in 0..num_returns { + for object in objects.returned_objects(idx) { + if let Some(param_idx) = objects.origin(*object).try_parameter() + && !param_muts[param_idx] + { + return_muts[idx] = false; + } } } @@ -1245,27 +1248,38 @@ impl<'a> RTContext<'a> { } write!(w, "}}}}")?; - // Every reference that may be returned has the same lifetime. Every - // other reference gets its own unique lifetime. - let returned_origins: HashSet<_> = self.collection_objects[&self.func_id] - .returned_objects() - .into_iter() - .map(|obj| self.collection_objects[&self.func_id].origin(*obj)) - .collect(); - - write!(w, "async fn run<'runner, 'returned")?; - for idx in 0..func.param_types.len() { - write!(w, ", 'p{}", idx)?; + // Each returned reference, input reference, and the runner will have + // its own lifetime. We use lifetime bounds to ensure that the runner + // and parameters are borrowed for the lifetimes needed by the outputs + let returned_origins: Vec<HashSet<_>> = (0..num_returns) + .map(|idx| objects.returned_objects(idx) + .iter() + .map(|obj| objects.origin(*obj)) + .collect() + ).collect(); + + write!(w, "async fn run<'runner:")?; + for (ret_idx, origins) in returned_origins.iter().enumerate() { + if origins.iter().any(|origin| !origin.is_parameter()) { + write!(w, " 'r{} +", ret_idx)?; + } } - write!( - w, - ">(&'{} mut self", - if returned_origins.iter().any(|origin| !origin.is_parameter()) { - "returned" - } else { - "runner" + for idx in 0..num_returns { + write!(w, ", 'r{}", idx)?; + } + for idx in 0..func.param_types.len() { + write!(w, ", 'p{}:", idx)?; + for (ret_idx, origins) in returned_origins.iter().enumerate() { + if origins.iter().any(|origin| origin + .try_parameter() + .map(|oidx| idx == oidx) + .unwrap_or(false)) + { + write!(w, " 'r{} +", ret_idx)?; + } } - )?; + } + write!(w, ">(&'runner mut self")?; for idx in 0..func.num_dynamic_constants { write!(w, ", dc_p{}: u64", idx)?; } @@ -1281,37 +1295,35 @@ impl<'a> RTContext<'a> { let mutability = if param_muts[idx] { "Mut" } else { "" }; write!( w, - ", p{}: ::hercules_rt::Hercules{}Ref{}<'{}>", + ", p{}: ::hercules_rt::Hercules{}Ref{}<'p{}>", idx, device, mutability, - if returned_origins.iter().any(|origin| origin - .try_parameter() - .map(|oidx| idx == oidx) - .unwrap_or(false)) - { - "returned".to_string() - } else { - format!("p{}", idx) - } + idx, )?; } } - if self.module.types[func.return_type.idx()].is_primitive() { - write!(w, ") -> {} {{", self.get_type(func.return_type))?; - } else { - let device = match return_device { - Some(Device::LLVM) | None => "CPU", - Some(Device::CUDA) => "CUDA", - _ => panic!(), - }; - let mutability = if return_mut { "Mut" } else { "" }; - write!( - w, - ") -> ::hercules_rt::Hercules{}Ref{}<'returned> {{", - device, mutability - )?; - } + write!(w, ") -> {}{}{} {{", + if num_returns != 1 { "(" } else { "" }, + func.return_types.iter().enumerate() + .map(|(ret_idx, typ)| + if self.module.types[typ.idx()].is_primitive() { + self.get_type(*typ) + } else { + let device = match return_devices[ret_idx] { + Some(Device::LLVM) | None => "CPU", + Some(Device::CUDA) => "CUDA", + _ => panic!(), + }; + let mutability = if return_muts[ret_idx] { "Mut" } else { "" }; + format!("::hercules_rt::Hercules{}Ref{}<'r{}>", + device, mutability, ret_idx) + } + ) + .collect::<Vec<_>>() + .join(", "), + if num_returns != 1 { ")" } else { "" }, + )?; // Start with possibly re-allocating the backing memory if it's not // large enough. @@ -1367,22 +1379,48 @@ impl<'a> RTContext<'a> { write!(w, "p{}, ", idx)?; } write!(w, ").await;")?; - if self.module.types[func.return_type.idx()].is_primitive() { - write!(w, " ret")?; + // Return the result, appropriately wrapping pointers + if num_returns == 1 { + if self.module.types[func.return_types[0].idx()].is_primitive() { + write!(w, "ret")?; + } else { + let device = match return_devices[0] { + Some(Device::LLVM) | None => "CPU", + Some(Device::CUDA) => "CUDA", + _ => panic!(), + }; + let mutability = if return_muts[0] { "Mut" } else { "" }; + write!( + w, + "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)", + device, + mutability, + self.codegen_type_size(func.return_types[0]) + )?; + } } else { - let device = match return_device { - Some(Device::LLVM) | None => "CPU", - Some(Device::CUDA) => "CUDA", - _ => panic!(), - }; - let mutability = if return_mut { "Mut" } else { "" }; - write!( - w, - "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)", - device, - mutability, - self.codegen_type_size(func.return_type) - )?; + write!(w, "(")?; + for (idx, typ) in func.return_types.iter().enumerate() { + if self.module.types[typ.idx()].is_primitive() { + write!(w, "ret.{},", idx)?; + } else { + let device = match return_devices[idx] { + Some(Device::LLVM) | None => "CPU", + Some(Device::CUDA) => "CUDA", + _ => panic!(), + }; + let mutability = if return_muts[idx] { "Mut" } else { "" }; + write!( + w, + "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.{}.0, {} as usize),", + device, + mutability, + idx, + self.codegen_type_size(func.return_types[idx]), + )?; + } + } + write!(w, ")")?; } write!(w, "}}}}")?; @@ -1476,9 +1514,9 @@ impl<'a> RTContext<'a> { // this means that if the function is multi-return it will return a product in the produced // Rust code // Writes from the "fn" keyword up to the end of the return type - fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID) -> Result<(), Error> { + fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID, is_unsafe: bool) -> Result<(), Error> { let func = &self.module.functions[func_id.idx()]; - write!(w, "fn {}(", func.name)?; + write!(w, "{}fn {}(", if is_unsafe { "unsafe " } else { "" }, func.name)?; let mut first_param = true; if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) { first_param = false; diff --git a/juno_samples/multi_return/src/cpu.sch b/juno_samples/multi_return/src/cpu.sch index 972405f5..d89b3b8b 100644 --- a/juno_samples/multi_return/src/cpu.sch +++ b/juno_samples/multi_return/src/cpu.sch @@ -33,4 +33,3 @@ ccp(*); dce(*); gcm(*); -xdot[true](*); diff --git a/juno_samples/multi_return/src/main.rs b/juno_samples/multi_return/src/main.rs index 63479dba..6966e0df 100644 --- a/juno_samples/multi_return/src/main.rs +++ b/juno_samples/multi_return/src/main.rs @@ -1,22 +1,21 @@ #![feature(concat_idents)] -juno_build::juno!("median"); +juno_build::juno!("multi_return"); use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; fn main() { - let m = vec![ - 86, 72, 14, 5, 55, 25, 98, 89, 3, 66, 44, 81, 27, 3, 40, 18, 4, 57, 93, 34, 70, 50, 50, 18, - 34, - ]; - let m = HerculesImmBox::from(m.as_slice()); + const N: usize = 32; + let a: Box<[f32]> = (1..=N).map(|i| i as f32).collect(); + let a = HerculesImmBox::from(a.as_ref()); + let mut r = runner!(rolling_sum_prod); + let (sums, prods) = async_std::task::block_on(async { r.run(N as u64, a.to()).await }); - let mut r = runner!(median_window); - let res = async_std::task::block_on(async { r.run(m.to()).await }); - assert_eq!(res, 57); + println!("Partial Sums: {:?}", sums.as_slice::<f32>()); + println!("Partial Prods: {:?}", prods.as_slice::<f32>()); } #[test] -fn test_median_window() { +fn test_multi_return() { main() } -- GitLab From 9973248a3fd4a9dd66fbe9b2d0234a530ac18c80 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 20 Feb 2025 15:35:09 -0600 Subject: [PATCH 13/19] Fix samples, CPU + RT working --- hercules_cg/src/rt.rs | 20 ++++++++++--- hercules_ir/src/parse.rs | 2 ++ hercules_samples/call/src/call.hir | 6 ++-- hercules_samples/ccp/src/ccp.hir | 8 ++--- hercules_samples/fac/src/fac.hir | 4 +-- juno_samples/rodinia/backprop/src/backprop.jn | 30 +++++++++---------- juno_samples/rodinia/backprop/src/main.rs | 19 ++++++++---- 7 files changed, 55 insertions(+), 34 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 8a41b08d..cffed48a 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -669,10 +669,22 @@ impl<'a> RTContext<'a> { } Node::DataProjection { data, selection } => { let block = &mut blocks.get_mut(&bb).unwrap().data; - write!(block, "{} = {}.{};", - self.get_value(id, bb, true), - self.get_value(data, bb, false), - selection)?; + let Node::Call { function: callee_id, .. } = func.nodes[data.idx()] else { + panic!() + }; + if self.module.functions[callee_id.idx()].return_types.len() == 1 { + assert!(selection == 0); + write!(block, "{} = {};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), + )?; + } else { + write!(block, "{} = {}.{};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), + selection, + )?; + } } Node::LibraryCall { library_function, diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index d61ff6e7..a5a05d0f 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -513,6 +513,8 @@ fn parse_return<'a>( ), parse_identifier, ).parse(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(')')(ir_text)?.0; let control = context.borrow_mut().get_node_id(control); let data = data .into_iter() diff --git a/hercules_samples/call/src/call.hir b/hercules_samples/call/src/call.hir index cecee343..77f5db2d 100644 --- a/hercules_samples/call/src/call.hir +++ b/hercules_samples/call/src/call.hir @@ -2,8 +2,10 @@ fn myfunc(x: u64) -> u64 cr1 = region(start) cr2 = region(cr1) c = constant(u64, 24) - y = call<16>(add, cr1, x, x) - z = call<10>(add, cr2, x, c) + cy = call<16>(add, cr1, x, x) + y = data_projection(cy, 0) + cz = call<10>(add, cr2, x, c) + z = data_projection(cz, 0) w = add(y, z) r = return(cr2, w) diff --git a/hercules_samples/ccp/src/ccp.hir b/hercules_samples/ccp/src/ccp.hir index b8e93994..e07df1d3 100644 --- a/hercules_samples/ccp/src/ccp.hir +++ b/hercules_samples/ccp/src/ccp.hir @@ -7,14 +7,14 @@ fn tricky(x: i32) -> i32 val = phi(loop, one, later_val) b = ne(one, val) if1 = if(loop, b) - if1_false = projection(if1, 0) - if1_true = projection(if1, 1) + if1_false = control_projection(if1, 0) + if1_true = control_projection(if1, 1) middle = region(if1_false, if1_true) inter_val = sub(two, val) later_val = phi(middle, inter_val, two) idx_dec = sub(idx, one) cond = gte(idx_dec, one) if2 = if(middle, cond) - if2_false = projection(if2, 0) - if2_true = projection(if2, 1) + if2_false = control_projection(if2, 0) + if2_true = control_projection(if2, 1) r = return(if2_false, later_val) diff --git a/hercules_samples/fac/src/fac.hir b/hercules_samples/fac/src/fac.hir index e43dd8ca..aaf55c1d 100644 --- a/hercules_samples/fac/src/fac.hir +++ b/hercules_samples/fac/src/fac.hir @@ -8,6 +8,6 @@ fn fac(x: i32) -> i32 fac_acc = mul(fac, idx_inc) in_bounds = lt(idx_inc, x) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) r = return(if_false, fac_acc) diff --git a/juno_samples/rodinia/backprop/src/backprop.jn b/juno_samples/rodinia/backprop/src/backprop.jn index 2927dbb5..c7f4345b 100644 --- a/juno_samples/rodinia/backprop/src/backprop.jn +++ b/juno_samples/rodinia/backprop/src/backprop.jn @@ -18,7 +18,7 @@ fn layer_forward<n, m: usize>(vals: f32[n + 1], weights: f32[n + 1, m + 1]) -> f return result; } -fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> (f32, f32[n + 1]) { +fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> f32, f32[n + 1] { let errsum = 0.0; let delta : f32[n + 1]; @@ -29,14 +29,14 @@ fn output_error<n: usize>(target: f32[n + 1], actual: f32[n + 1]) -> (f32, f32[n errsum += abs!(delta[j]); } - return (errsum, delta); + return errsum, delta; } fn hidden_error<hidden_n, output_n: usize>( out_delta: f32[output_n + 1], hidden_weights: f32[hidden_n + 1, output_n + 1], hidden_vals: f32[hidden_n + 1], -) -> (f32, f32[hidden_n + 1]) { +) -> f32, f32[hidden_n + 1] { let errsum = 0.0; let delta : f32[hidden_n + 1]; @@ -52,7 +52,7 @@ fn hidden_error<hidden_n, output_n: usize>( errsum += abs!(delta[j]); } - return (errsum, delta); + return errsum, delta; } const ETA : f32 = 0.3; @@ -63,7 +63,7 @@ fn adjust_weights<n, m: usize>( vals: f32[n + 1], weights: f32[n + 1, m + 1], prev_weights: f32[n + 1, m + 1] -) -> (f32[n + 1, m + 1], f32[n + 1, m + 1]) { +) -> f32[n + 1, m + 1], f32[n + 1, m + 1] { for j in 1..=m { for k in 0..=n { let new_dw = ETA * delta[j] * vals[k] + MOMENTUM * prev_weights[k, j]; @@ -72,7 +72,7 @@ fn adjust_weights<n, m: usize>( } } - return (weights, prev_weights); + return weights, prev_weights; } #[entry] @@ -83,21 +83,19 @@ fn backprop<input_n, hidden_n, output_n: usize>( target: f32[output_n + 1], input_prev_weights: f32[input_n + 1, hidden_n + 1], hidden_prev_weights: f32[hidden_n + 1, output_n + 1], -//) -> (f32, f32, -// f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1], -// f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1]) { -) -> (f32, f32, f32) { +) -> f32, f32, + f32[input_n + 1, hidden_n + 1], f32[input_n + 1, hidden_n + 1], + f32[hidden_n + 1, output_n + 1], f32[hidden_n + 1, output_n + 1] { let hidden_vals = layer_forward::<input_n, hidden_n>(input_vals, input_weights); let output_vals = layer_forward::<hidden_n, output_n>(hidden_vals, hidden_weights); - let (out_err, out_delta) = output_error::<output_n>(target, output_vals); - let (hid_err, hid_delta) = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals); + let out_err, out_delta = output_error::<output_n>(target, output_vals); + let hid_err, hid_delta = hidden_error::<hidden_n, output_n>(out_delta, hidden_weights, hidden_vals); - let (hidden_weights, hidden_prev_weights) + let hidden_weights, hidden_prev_weights = adjust_weights::<hidden_n, output_n>(out_delta, hidden_vals, hidden_weights, hidden_prev_weights); - let (input_weights, input_prev_weights) + let input_weights, input_prev_weights = adjust_weights::<input_n, hidden_n>(hid_delta, input_vals, input_weights, input_prev_weights); - return (out_err, hid_err, input_weights[0, 0] + input_prev_weights[0, 0] + hidden_weights[0, 0] + hidden_prev_weights[0, 0]); - //return (input_weights, input_prev_weights, hidden_weights, hidden_prev_weights); + return out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights; } diff --git a/juno_samples/rodinia/backprop/src/main.rs b/juno_samples/rodinia/backprop/src/main.rs index 848b0abb..23f78fe4 100644 --- a/juno_samples/rodinia/backprop/src/main.rs +++ b/juno_samples/rodinia/backprop/src/main.rs @@ -37,7 +37,14 @@ fn run_backprop( let mut hidden_prev_weights = HerculesMutBox::from(hidden_prev_weights.to_vec()); let mut runner = runner!(backprop); - let res = HerculesMutBox::from(async_std::task::block_on(async { + let ( + out_err, + hid_err, + input_weights, + input_prev_weights, + hidden_weights, + hidden_prev_weights + ) = async_std::task::block_on(async { runner .run( input_n, @@ -51,11 +58,11 @@ fn run_backprop( hidden_prev_weights.to(), ) .await - })) - .as_slice() - .to_vec(); - let out_err = res[0]; - let hid_err = res[1]; + }); + let mut input_weights = HerculesMutBox::from(input_weights); + let mut hidden_weights = HerculesMutBox::from(hidden_weights); + let mut input_prev_weights = HerculesMutBox::from(input_prev_weights); + let mut hidden_prev_weights = HerculesMutBox::from(hidden_prev_weights); ( out_err, -- GitLab From 7c997bd9519fbba0bd28dff29f3302b29d08d2b5 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 20 Feb 2025 15:51:57 -0600 Subject: [PATCH 14/19] Fixes --- hercules_cg/src/cpu.rs | 4 +--- hercules_cg/src/gpu.rs | 5 ++--- hercules_cg/src/rt.rs | 1 - 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 509b98c5..1f3ab0a4 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -131,9 +131,7 @@ impl<'a> CPUContext<'a> { } // Lastly, if the function has multiple returns, is a pointer to the return struct if self.function.return_types.len() != 1 { - if first_param { - first_param = false; - } else { + if !first_param { write!(w, ", ")?; } write!( diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 17f0f893..4390e25c 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -4,7 +4,6 @@ extern crate hercules_ir; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Error, Write}; use std::fs::{File, OpenOptions}; -use std::io::Write as _; use self::hercules_ir::*; @@ -1516,7 +1515,7 @@ extern \"C\" {} {}(", let tabs = match &self.function.nodes[id.idx()] { Node::Start | Node::Region { preds: _ } - | Node::Projection { + | Node::ControlProjection { control: _, selection: _, } => { @@ -1528,7 +1527,7 @@ extern \"C\" {} {}(", let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); - let succ1_is_true = self.function.nodes[succ1.idx()].try_projection(1).is_some(); + let succ1_is_true = self.function.nodes[succ1.idx()].try_control_projection(1).is_some(); let succ1_block_name = self.get_block_name(succ1, false); let succ2_block_name = self.get_block_name(succ2, false); write!( diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index cffed48a..a943cb4d 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -235,7 +235,6 @@ impl<'a> RTContext<'a> { // Call the device function write!(w, "{}(", callee.name)?; if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { - first_param = false; write!(w, "backing, ")?; } for idx in 0..callee.num_dynamic_constants { -- GitLab From 21047255c6c0772eff83c29be643ef0572254a7f Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 20 Feb 2025 22:33:40 -0600 Subject: [PATCH 15/19] Fix merge --- hercules_opt/src/editor.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index bd449ce3..9cf5af72 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -820,6 +820,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { panic!("PANIC: Can't create largest constant of a collection type.") } + Type::MultiReturn(_) => { + panic!("PANIC: Can't create largest constant for multi-return types.") + } }; self.add_constant(constant_to_construct) } @@ -843,6 +846,9 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { panic!("PANIC: Can't create smallest constant of a collection type.") } + Type::MultiReturn(_) => { + panic!("PANIC: Can't create smallest constant for multi-return types.") + } }; self.add_constant(constant_to_construct) } -- GitLab From bc473a6fde0f42f1f514237ce2dac6a0c023323a Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 21 Feb 2025 08:53:45 -0600 Subject: [PATCH 16/19] Multi-return for gpu --- hercules_cg/src/gpu.rs | 181 ++++++++++++------ juno_samples/multi_return/src/gpu.sch | 4 + juno_samples/multi_return/src/main.rs | 28 ++- juno_samples/multi_return/src/multi_return.jn | 8 +- juno_samples/products/src/gpu.sch | 1 - 5 files changed, 155 insertions(+), 67 deletions(-) diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index d1a31d47..453d33d5 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -150,13 +150,17 @@ pub fn gpu_codegen<W: Write>( } } - let return_parameter = if collection_objects.returned_objects().len() == 1 { - collection_objects - .origin(*collection_objects.returned_objects().first().unwrap()) - .try_parameter() - } else { - None - }; + // Tracks for each return value whether it is always the same parameter + // collection + let return_parameters = (0..function.return_types.len()) + .map(|idx| if collection_objects.returned_objects(idx).len() == 1 { + collection_objects + .origin(*collection_objects.returned_objects(idx).first().unwrap()) + .try_parameter() + } else { + None + }) + .collect::<Vec<_>>(); let kernel_params = &GPUKernelParams { max_num_threads: 1024, @@ -181,7 +185,7 @@ pub fn gpu_codegen<W: Write>( fork_reduce_map, reduct_reduce_map, control_data_phi_map, - return_parameter, + return_parameters, kernel_params, }; ctx.codegen_function(w) @@ -210,7 +214,7 @@ struct GPUContext<'a> { fork_reduce_map: HashMap<NodeID, Vec<NodeID>>, reduct_reduce_map: HashMap<NodeID, Vec<NodeID>>, control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>, - return_parameter: Option<usize>, + return_parameters: Vec<Option<usize>>, kernel_params: &'a GPUKernelParams, } @@ -262,7 +266,9 @@ impl GPUContext<'_> { fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> { // Emit all code up to the "goto" to Start's block let mut top = String::new(); - self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?; + self.codegen_kernel_preamble(&mut top)?; + self.codegen_return_struct(&mut top)?; + self.codegen_kernel_begin(&mut top)?; let mut dynamic_shared_offset = "0".to_string(); self.codegen_dynamic_constants(&mut top)?; self.codegen_declare_data(&mut top)?; @@ -339,10 +345,7 @@ impl GPUContext<'_> { Ok(()) } - /* - * Emit kernel headers, signature, arguments, and dynamic shared memory declaration - */ - fn codegen_kernel_begin(&self, has_ret_var: bool, w: &mut String) -> Result<(), Error> { + fn codegen_kernel_preamble<W: Write>(&self, w: &mut W) -> Result<(), Error> { write!( w, " @@ -366,8 +369,23 @@ namespace cg = cooperative_groups; #define isqrt(a) ((int)sqrtf((float)(a))) ", - )?; + ) + } + + fn codegen_return_struct<W: Write>(&self, w: &mut W) -> Result<(), Error> { + write!(w, "struct return_{} {{ {} }};\n", + self.function.name, + self.function.return_types.iter().enumerate() + .map(|(idx, typ)| format!("{} f{};", self.get_type(*typ, false), idx)) + .collect::<Vec<_>>() + .join(" "), + ) + } + /* + * Emit kernel signature, arguments, and dynamic shared memory declaration + */ + fn codegen_kernel_begin<W: Write>(&self, w: &mut W) -> Result<(), Error> { write!( w, "__global__ void __launch_bounds__({}) {}_gpu(", @@ -403,11 +421,21 @@ namespace cg = cooperative_groups; }; write!(w, "{} p{}", param_type, idx)?; } - if has_ret_var { + let ret_fields = self + .return_parameters + .iter() + .enumerate() + .filter_map(|(idx, param)| if param.is_some() { + None + } else { + Some((idx, self.function.return_types[idx])) + }) + .collect::<Vec<(usize, TypeID)>>(); + if !ret_fields.is_empty() { if !first_param { write!(w, ", ")?; } - write!(w, "void* __restrict__ ret",)?; + write!(w, "return_{}* __restrict__ ret", self.function.name)?; } // Type is char since it's simplest to use single bytes for indexing @@ -599,17 +627,17 @@ namespace cg = cooperative_groups; dynamic_shared_offset: &str, w: &mut String, ) -> Result<(), Error> { - // The following steps are for host-side C function arguments, but we also - // need to pass arguments to kernel, so we keep track of the arguments here. - let ret_type = self.get_type(self.function.return_type, false); let mut pass_args = String::new(); - write!( - w, - " -extern \"C\" {} {}(", - ret_type.clone(), - self.function.name - )?; + + let is_multi_return = self.function.return_types.len() != 1; + write!(w, "extern \"C\" ")?; + if is_multi_return { + write!(w, "void")?; + } else { + write!(w, "{}", self.get_type(self.function.return_types[0], false))?; + } + write!(w, " {}(", self.function.name)?; + let mut first_param = true; // The first parameter is a pointer to GPU backing memory, if it's // needed. @@ -641,20 +669,42 @@ extern \"C\" {} {}(", write!(w, "{} p{}", param_type, idx)?; write!(pass_args, "p{}", idx)?; } + // If the function is multi-return, the last argument is the return pointer + // This is a CPU pointer, we will allocate a separate pointer used for the kernel's return + // arguments (if any) + if is_multi_return { + if !first_param { + write!(w, ", ")?; + } + write!(w, "return_{}* ret_ptr", self.function.name)?; + } write!(w, ") {{\n")?; // For case of dynamic block count self.codegen_dynamic_constants(w)?; - let has_ret_var = self.return_parameter.is_none(); - if has_ret_var { - // Allocate return parameter and lift to kernel argument - let ret_type_pnt = self.get_type(self.function.return_type, true); - write!(w, "\t{} ret;\n", ret_type_pnt)?; + + let (kernel_returns, param_returns) = + self.return_parameters.iter().enumerate() + .fold((vec![], vec![]), + |(mut kernel_returns, mut param_returns), (idx, param)| { + if let Some(param_idx) = param { + param_returns.push((idx, param_idx)); + } else { + kernel_returns.push((idx, self.function.return_types[idx])); + } + (kernel_returns, param_returns) + }); + + if !kernel_returns.is_empty() { + // Allocate kernel return struct + write!(w, "\treturn_{}* ret_cuda;\n", self.function.name)?; + write!(w, "\tcudaMalloc((void**)&ret_cuda, sizeof(return_{}));\n", self.function.name)?; + // Add the return pointer to the kernel arguments if !first_param { write!(pass_args, ", ")?; } - write!(pass_args, "ret")?; - write!(w, "\tcudaMalloc((void**)&ret, sizeof({}));\n", ret_type)?; + write!(pass_args, "ret_cuda")?; } + write!(w, "\tcudaError_t err;\n")?; write!( w, @@ -666,18 +716,38 @@ extern \"C\" {} {}(", w, "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n" )?; - if has_ret_var { - // Copy return from device to host, whether it's primitive value or collection pointer - write!(w, "\t{} host_ret;\n", ret_type)?; - write!( - w, - "\tcudaMemcpy(&host_ret, ret, sizeof({}), cudaMemcpyDeviceToHost);\n", - ret_type - )?; - write!(w, "\treturn host_ret;\n")?; + + if !is_multi_return { + if kernel_returns.is_empty() { + // A single return of a parameter, we can just return it directly + write!(w, "\treturn p{};\n", param_returns[0].1)?; + } else { + // A single return of a value computed on the device, we create a stack allocation + // and retrieve the value from the device and then return it + write!(w, "\t return_{} ret_host;\n", self.function.name)?; + write!(w, + "\tcudaMemcpy(&ret_host, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n", + self.function.name, + )?; + write!(w, "\treturn ret_host.f0;\n")?; + } } else { - write!(w, "\treturn p{};\n", self.return_parameter.unwrap())?; + // Multi return is handle via an output pointer provided to this function + // If there are kernel returns then we copy those back from the device and then fill in + // the parameter returns + if !kernel_returns.is_empty() { + // Copy from the device directly into the output struct + write!(w, + "\tcudaMemcpy(ret_ptr, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n", + self.function.name, + )?; + } + for (field_idx, param_idx) in param_returns { + write!(w, "\tret_ptr->f{} = p{};\n", field_idx, param_idx)?; + } + write!(w, "\treturn;\n")?; } + write!(w, "}}\n")?; Ok(()) } @@ -1710,20 +1780,17 @@ extern \"C\" {} {}(", } tabs } - Node::Return { control: _, data } => { - if self.return_parameter.is_none() { - // Since we lift return into a kernel argument, we write to that - // argument upon return. - let return_val = self.get_value(*data, false, false); - let return_type_ptr = self.get_type(self.function.return_type, true); - write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?; - write!( - w_term, - "\t\t*(reinterpret_cast<{}>(ret)) = {};\n", - return_type_ptr, return_val - )?; - write!(w_term, "\t}}\n")?; + Node::Return { control: _, ref data } => { + write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?; + for (idx, (data, param)) in data.iter().zip(self.return_parameters.iter()).enumerate() { + // For return values that are not identical to some parameter, we write it into + // the output struct + if !param.is_some() { + write!(w_term, "\t\tret->f{} = {};\n", idx, + self.get_value(*data, false, false))?; + } } + write!(w_term, "\t}}\n")?; write!(w_term, "\treturn;\n")?; 1 } diff --git a/juno_samples/multi_return/src/gpu.sch b/juno_samples/multi_return/src/gpu.sch index e733551d..0c0b569a 100644 --- a/juno_samples/multi_return/src/gpu.sch +++ b/juno_samples/multi_return/src/gpu.sch @@ -4,6 +4,10 @@ dce(*); ip-sroa(*); sroa(*); + +ip-sroa[true](rolling_sum); +sroa[true](rolling_sum, rolling_sum_prod); + dce(*); forkify(*); diff --git a/juno_samples/multi_return/src/main.rs b/juno_samples/multi_return/src/main.rs index 6966e0df..0b3508a7 100644 --- a/juno_samples/multi_return/src/main.rs +++ b/juno_samples/multi_return/src/main.rs @@ -2,17 +2,35 @@ juno_build::juno!("multi_return"); -use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo}; +use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo, HerculesMutBox}; fn main() { const N: usize = 32; let a: Box<[f32]> = (1..=N).map(|i| i as f32).collect(); - let a = HerculesImmBox::from(a.as_ref()); + let arg = HerculesImmBox::from(a.as_ref()); let mut r = runner!(rolling_sum_prod); - let (sums, prods) = async_std::task::block_on(async { r.run(N as u64, a.to()).await }); + let (sums, sum, prods, prod) = async_std::task::block_on(async { r.run(N as u64, arg.to()).await }); - println!("Partial Sums: {:?}", sums.as_slice::<f32>()); - println!("Partial Prods: {:?}", prods.as_slice::<f32>()); + let mut sums = HerculesMutBox::<f32>::from(sums); + let mut prods = HerculesMutBox::<f32>::from(prods); + + let (expected_sums, expected_sum) = a.iter() + .fold((vec![0.0], 0.0), |(mut sums, sum), v| { + let new_sum = sum + v; + sums.push(new_sum); + (sums, new_sum) + }); + let (expected_prods, expected_prod) = a.iter() + .fold((vec![1.0], 1.0), |(mut prods, prod), v| { + let new_prod = prod * v; + prods.push(new_prod); + (prods, new_prod) + }); + + assert_eq!(sum, expected_sum); + assert_eq!(sums.as_slice(), expected_sums.as_slice()); + assert_eq!(prod, expected_prod); + assert_eq!(prods.as_slice(), expected_prods.as_slice()); } #[test] diff --git a/juno_samples/multi_return/src/multi_return.jn b/juno_samples/multi_return/src/multi_return.jn index 84bab015..30b5576e 100644 --- a/juno_samples/multi_return/src/multi_return.jn +++ b/juno_samples/multi_return/src/multi_return.jn @@ -25,8 +25,8 @@ fn rolling_prod<t: number, n: usize>(x: t[n]) -> t, t[n + 1] { } #[entry] -fn rolling_sum_prod<n: usize>(x: f32[n]) -> f32[n + 1], f32[n + 1] { - let rsum = rolling_sum::<_, n>(x).1; - let _, rprod = rolling_prod::<_, n>(x); - return rsum, rprod; +fn rolling_sum_prod<n: usize>(x: f32[n]) -> f32[n + 1], f32, f32[n + 1], f32 { + let (sum, rsum) = rolling_sum::<_, n>(x); + let prod, rprod = rolling_prod::<_, n>(x); + return rsum, sum, rprod, prod; } diff --git a/juno_samples/products/src/gpu.sch b/juno_samples/products/src/gpu.sch index 5ef4c479..0a734bb2 100644 --- a/juno_samples/products/src/gpu.sch +++ b/juno_samples/products/src/gpu.sch @@ -5,7 +5,6 @@ dce(*); let out = auto-outline(*); gpu(out.product_read); -ip-sroa(*); sroa(*); reuse-products(*); crc(*); -- GitLab From 0f8b2d6bc01bec6844f260ba14b7eaa43a6dfbb5 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 21 Feb 2025 09:21:06 -0600 Subject: [PATCH 17/19] Fix interpreter --- .../hercules_interpreter/src/interpreter.rs | 17 +++++++++++++---- .../hercules_interpreter/src/value.rs | 3 +++ hercules_test/test_inputs/call.hir | 5 +++-- hercules_test/test_inputs/call_dc_params.hir | 5 +++-- hercules_test/test_inputs/ccp_example.hir | 8 ++++---- .../fork_fission/inner_loop.hir | 6 +++--- .../test_inputs/forkify/alternate_bounds.hir | 6 +++--- .../test_inputs/forkify/broken_sum.hir | 6 +++--- .../forkify/control_after_condition.hir | 10 +++++----- .../forkify/control_before_condition.hir | 10 +++++----- .../expected_fails.hir/bad_3nest_return.hir | 12 ++++++------ .../expected_fails.hir/bad_loop_tid_sum.hir | 6 +++--- .../test_inputs/forkify/inner_fork.hir | 6 +++--- .../test_inputs/forkify/inner_fork_complex.hir | 10 +++++----- .../test_inputs/forkify/loop_array_sum.hir | 6 +++--- .../test_inputs/forkify/loop_simple_iv.hir | 6 +++--- hercules_test/test_inputs/forkify/loop_sum.hir | 6 +++--- .../test_inputs/forkify/loop_tid_sum.hir | 6 +++--- .../test_inputs/forkify/merged_phi_cycle.hir | 6 +++--- .../test_inputs/forkify/nested_loop2.hir | 10 +++++----- .../test_inputs/forkify/nested_tid_sum.hir | 10 +++++----- .../test_inputs/forkify/nested_tid_sum_2.hir | 10 +++++----- .../test_inputs/forkify/phi_loop4.hir | 6 +++--- .../test_inputs/forkify/split_phi_cycle.hir | 6 +++--- .../test_inputs/forkify/super_nested_loop.hir | 12 ++++++------ .../loop_analysis/alternate_bounds.hir | 6 +++--- .../alternate_bounds_internal_control.hir | 6 +++--- .../alternate_bounds_internal_control2.hir | 6 +++--- .../alternate_bounds_nested_do_loop.hir | 10 +++++----- .../alternate_bounds_nested_do_loop2.hir | 10 +++++----- .../alternate_bounds_nested_do_loop_array.hir | 10 +++++----- ...alternate_bounds_nested_do_loop_guarded.hir | 18 +++++++++--------- .../alternate_bounds_use_after_loop.hir | 6 +++--- .../alternate_bounds_use_after_loop2.hir | 6 +++--- .../alternate_bounds_use_after_loop_no_tid.hir | 6 +++--- ...alternate_bounds_use_after_loop_no_tid2.hir | 6 +++--- .../test_inputs/loop_analysis/broken_sum.hir | 6 +++--- .../loop_analysis/do_loop_far_guard.hir | 6 +++--- .../loop_analysis/do_loop_immediate_guard.hir | 10 +++++----- .../loop_analysis/do_loop_no_guard.hir | 6 +++--- .../loop_analysis/do_while_separate_body.hir | 6 +++--- .../loop_analysis/do_while_separate_body2.hir | 6 +++--- .../loop_analysis/loop_array_sum.hir | 6 +++--- .../loop_analysis/loop_body_count.hir | 6 +++--- .../test_inputs/loop_analysis/loop_sum.hir | 6 +++--- .../loop_analysis/loop_trip_count.hir | 6 +++--- .../loop_analysis/loop_trip_count_tuple.hir | 6 +++--- hercules_test/test_inputs/simple2.hir | 6 +++--- hercules_test/test_inputs/strset.hir | 4 ++-- hercules_test/test_inputs/sum_int1.hir | 6 +++--- hercules_test/test_inputs/sum_int2.hir | 8 ++++---- 51 files changed, 196 insertions(+), 182 deletions(-) diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index 2e352644..f9d666a5 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -529,6 +529,13 @@ impl<'a> FunctionExecutionState<'a> { state.run() } + Node::DataProjection { data, selection } => { + let data = self.handle_data(token, *data); + let InterpreterVal::MultiReturn(vs) = data else { + panic!(); + }; + vs[*selection].clone() + } Node::Read { collect, indices } => { let collection = self.handle_data(token, *collect); if let InterpreterVal::Undef(_) = collection { @@ -745,7 +752,7 @@ impl<'a> FunctionExecutionState<'a> { .succs(ctrl_token.curr) .find(|n| { self.get_function().nodes[n.idx()] - .try_projection(cond) + .try_control_projection(cond) .is_some() }) .expect("PANIC: No outgoing valid outgoing edge."); @@ -753,7 +760,7 @@ impl<'a> FunctionExecutionState<'a> { let ctrl_token = ctrl_token.moved_to(next); vec![ctrl_token] } - Node::Projection { .. } => { + Node::ControlProjection { .. } => { let next: NodeID = self .get_control_subgraph() .succs(ctrl_token.curr) @@ -861,8 +868,10 @@ impl<'a> FunctionExecutionState<'a> { } } Node::Return { control: _, data } => { - let result = self.handle_data(&ctrl_token, *data); - break 'outer result; + let results = data.iter() + .map(|data| self.handle_data(&ctrl_token, *data)) + .collect(); + break 'outer InterpreterVal::MultiReturn(results); } _ => { panic!("PANIC: Unexpected node in control subgraph {:?}", node); diff --git a/hercules_test/hercules_interpreter/src/value.rs b/hercules_test/hercules_interpreter/src/value.rs index dfc290b2..0f5716e7 100644 --- a/hercules_test/hercules_interpreter/src/value.rs +++ b/hercules_test/hercules_interpreter/src/value.rs @@ -36,6 +36,8 @@ pub enum InterpreterVal { // These can be freely? casted DynamicConstant(usize), ThreadID(usize), + + MultiReturn(Box<[InterpreterVal]>), } #[derive(Clone)] @@ -848,6 +850,7 @@ impl<'a> InterpreterVal { Type::Product(_) => todo!(), Type::Summation(_) => todo!(), Type::Array(type_id, _) => todo!(), + Type::MultiReturn(_) => todo!(), } } (_, Self::Undef(v)) => InterpreterVal::Undef(v), diff --git a/hercules_test/test_inputs/call.hir b/hercules_test/test_inputs/call.hir index 44748934..0ebf7c7d 100644 --- a/hercules_test/test_inputs/call.hir +++ b/hercules_test/test_inputs/call.hir @@ -1,7 +1,8 @@ fn myfunc(x: i32) -> i32 - y = call(add, x, x) + cy = call(add, x, x) + y = data_projection(cy, 0) r = return(start, y) fn add(x: i32, y: i32) -> i32 w = add(x, y) - r = return(start, w) \ No newline at end of file + r = return(start, w) diff --git a/hercules_test/test_inputs/call_dc_params.hir b/hercules_test/test_inputs/call_dc_params.hir index 5ccf2686..b8da9791 100644 --- a/hercules_test/test_inputs/call_dc_params.hir +++ b/hercules_test/test_inputs/call_dc_params.hir @@ -1,9 +1,10 @@ fn myfunc(x: u64) -> u64 - y = call<10, 4>(add, x, x) + cy = call<10, 4>(add, x, x) + y = data_projection(cy, 0) r = return(start, y) fn add<2>(x: u64, y: u64) -> u64 b = dynamic_constant(#1) r = return(start, z) w = add(x, y) - z = add(b, w) \ No newline at end of file + z = add(b, w) diff --git a/hercules_test/test_inputs/ccp_example.hir b/hercules_test/test_inputs/ccp_example.hir index 25b7379e..f8004b63 100644 --- a/hercules_test/test_inputs/ccp_example.hir +++ b/hercules_test/test_inputs/ccp_example.hir @@ -6,14 +6,14 @@ fn tricky(x: i32) -> i32 val = phi(loop, one, later_val) b = ne(one, val) if1 = if(loop, b) - if1_false = projection(if1, 0) - if1_true = projection(if1, 1) + if1_false = control_projection(if1, 0) + if1_true = control_projection(if1, 1) middle = region(if1_false, if1_true) inter_val = sub(two, val) later_val = phi(middle, inter_val, two) idx_dec = sub(idx, one) cond = gte(idx_dec, one) if2 = if(middle, cond) - if2_false = projection(if2, 0) - if2_true = projection(if2, 1) + if2_false = control_projection(if2, 0) + if2_true = control_projection(if2, 1) r = return(if2_false, later_val) diff --git a/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir b/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir index 0cc13b2f..b7458a43 100644 --- a/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir +++ b/hercules_test/test_inputs/fork_transforms/fork_fission/inner_loop.hir @@ -11,8 +11,8 @@ fn fun<2>(x: u64) -> u64 idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) j = join(if_false) tid = thread_id(f, 0) add1 = add(reduce1, idx) @@ -20,4 +20,4 @@ fn fun<2>(x: u64) -> u64 add2 = add(reduce2, idx_inc) reduce2 = reduce(j, zero, add2) out1 = add(reduce1, reduce2) - z = return(j, out1) \ No newline at end of file + z = return(j, out1) diff --git a/hercules_test/test_inputs/forkify/alternate_bounds.hir b/hercules_test/test_inputs/forkify/alternate_bounds.hir index 4a9ba015..7de8cf1e 100644 --- a/hercules_test/test_inputs/forkify/alternate_bounds.hir +++ b/hercules_test/test_inputs/forkify/alternate_bounds.hir @@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red_add) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red_add) diff --git a/hercules_test/test_inputs/forkify/broken_sum.hir b/hercules_test/test_inputs/forkify/broken_sum.hir index d15ef561..75b12350 100644 --- a/hercules_test/test_inputs/forkify/broken_sum.hir +++ b/hercules_test/test_inputs/forkify/broken_sum.hir @@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red_add) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red_add) diff --git a/hercules_test/test_inputs/forkify/control_after_condition.hir b/hercules_test/test_inputs/forkify/control_after_condition.hir index db40225b..a1a97fba 100644 --- a/hercules_test/test_inputs/forkify/control_after_condition.hir +++ b/hercules_test/test_inputs/forkify/control_after_condition.hir @@ -11,8 +11,8 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32 rem = rem(idx, two_idx) odd = eq(rem, one_idx) negate_if = if(loop_continue, odd) - negate_if_false = projection(negate_if, 0) - negate_if_true = projection(negate_if, 1) + negate_if_false = control_projection(negate_if, 0) + negate_if_true = control_projection(negate_if, 1) negate_bottom = region(negate_if_false, negate_if_true) read = read(a, position(idx)) read_neg = neg(read) @@ -20,6 +20,6 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read_phi) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - loop_exit = projection(if, 0) - loop_continue = projection(if, 1) - r = return(loop_exit, red) \ No newline at end of file + loop_exit = control_projection(if, 0) + loop_continue = control_projection(if, 1) + r = return(loop_exit, red) diff --git a/hercules_test/test_inputs/forkify/control_before_condition.hir b/hercules_test/test_inputs/forkify/control_before_condition.hir index f24b565a..e351d714 100644 --- a/hercules_test/test_inputs/forkify/control_before_condition.hir +++ b/hercules_test/test_inputs/forkify/control_before_condition.hir @@ -11,8 +11,8 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32 rem = rem(idx, two_idx) odd = eq(rem, one_idx) negate_if = if(loop, odd) - negate_if_false = projection(negate_if, 0) - negate_if_true = projection(negate_if, 1) + negate_if_false = control_projection(negate_if, 0) + negate_if_true = control_projection(negate_if, 1) negate_bottom = region(negate_if_false, negate_if_true) read = read(a, position(idx)) read_neg = neg(read) @@ -20,6 +20,6 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read_phi) in_bounds = lt(idx, bound) if = if(negate_bottom, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red) diff --git a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir index f5ec4370..7599e6ec 100644 --- a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir +++ b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_3nest_return.hir @@ -16,18 +16,18 @@ fn loop<3>(a: u32) -> i32 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx, outer_bound) inner_if = if(inner_loop, inner_in_bounds) - inner_if_false = projection(inner_if, 0) - inner_if_true = projection(inner_if, 1) + inner_if_false = control_projection(inner_if, 0) + inner_if_true = control_projection(inner_if, 1) outer_if = if(outer_loop, outer_in_bounds) - outer_if_false = projection(outer_if, 0) - outer_if_true = projection(outer_if, 1) + outer_if_false = control_projection(outer_if, 0) + outer_if_true = control_projection(outer_if, 1) outer_bound = dynamic_constant(#1) outer_outer_bound = dynamic_constant(#2) outer_outer_loop = region(start, outer_if_false) outer_outer_var = phi(outer_outer_loop, zero_var, outer_var) outer_outer_if = if(outer_outer_loop, outer_outer_in_bounds) - outer_outer_if_false = projection(outer_outer_if, 0) - outer_outer_if_true = projection(outer_outer_if, 1) + outer_outer_if_false = control_projection(outer_outer_if, 0) + outer_outer_if_true = control_projection(outer_outer_if, 1) outer_outer_idx = phi(outer_outer_loop, zero_idx, outer_outer_idx_inc, outer_outer_idx) outer_outer_idx_inc = add(outer_outer_idx, one_idx) outer_outer_in_bounds = lt(outer_outer_idx, outer_outer_bound) diff --git a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir index 8dda179b..8f7d5e48 100644 --- a/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir +++ b/hercules_test/test_inputs/forkify/expected_fails.hir/bad_loop_tid_sum.hir @@ -11,6 +11,6 @@ fn loop<1>(a: u64) -> u64 idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, var_inc) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, var_inc) diff --git a/hercules_test/test_inputs/forkify/inner_fork.hir b/hercules_test/test_inputs/forkify/inner_fork.hir index e2c96a68..c603dc42 100644 --- a/hercules_test/test_inputs/forkify/inner_fork.hir +++ b/hercules_test/test_inputs/forkify/inner_fork.hir @@ -6,7 +6,7 @@ fn loop<2>(a: u32) -> i32 inner_bound = dynamic_constant(#0) outer_bound = dynamic_constant(#1) outer_loop = region(start, inner_join) - outer_if_true = projection(outer_if, 1) + outer_if_true = control_projection(outer_if, 1) inner_fork = fork(outer_if_true, #0) inner_join = join(inner_fork) outer_var = phi(outer_loop, zero_var, inner_var) @@ -17,6 +17,6 @@ fn loop<2>(a: u32) -> i32 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx, outer_bound) outer_if = if(outer_loop, outer_in_bounds) - outer_if_false = projection(outer_if, 0) + outer_if_false = control_projection(outer_if, 0) r = return(outer_if_false, outer_var) - \ No newline at end of file + diff --git a/hercules_test/test_inputs/forkify/inner_fork_complex.hir b/hercules_test/test_inputs/forkify/inner_fork_complex.hir index 91eb00fa..c9488f7f 100644 --- a/hercules_test/test_inputs/forkify/inner_fork_complex.hir +++ b/hercules_test/test_inputs/forkify/inner_fork_complex.hir @@ -8,14 +8,14 @@ fn loop<2>(a: u32) -> u64 inner_bound = dynamic_constant(#0) outer_bound = dynamic_constant(#1) outer_loop = region(start, inner_condition_true_projection, inner_condition_false_projection ) - outer_if_true = projection(outer_if, 1) + outer_if_true = control_projection(outer_if, 1) other_phi_weird = phi(outer_loop, zero_var, inner_var, other_phi_weird) inner_fork = fork(outer_if_true, #0) inner_join = join(inner_fork) inner_condition_eq = eq(outer_idx, two) inner_condition_if = if(inner_join, inner_condition_eq) - inner_condition_true_projection = projection(inner_condition_if, 1) - inner_condition_false_projection = projection(inner_condition_if, 0) + inner_condition_true_projection = control_projection(inner_condition_if, 1) + inner_condition_false_projection = control_projection(inner_condition_if, 0) outer_var = phi(outer_loop, zero_var, inner_var, inner_var) inner_var = reduce(inner_join, outer_var, inner_var_inc) inner_var_inc = add(inner_var, inner_var_inc_3) @@ -26,7 +26,7 @@ fn loop<2>(a: u32) -> u64 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx, outer_bound) outer_if = if(outer_loop, outer_in_bounds) - outer_if_false = projection(outer_if, 0) + outer_if_false = control_projection(outer_if, 0) ret_val = add(outer_var, other_phi_weird) r = return(outer_if_false, ret_val) - \ No newline at end of file + diff --git a/hercules_test/test_inputs/forkify/loop_array_sum.hir b/hercules_test/test_inputs/forkify/loop_array_sum.hir index f9972b59..884d22d4 100644 --- a/hercules_test/test_inputs/forkify/loop_array_sum.hir +++ b/hercules_test/test_inputs/forkify/loop_array_sum.hir @@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red) diff --git a/hercules_test/test_inputs/forkify/loop_simple_iv.hir b/hercules_test/test_inputs/forkify/loop_simple_iv.hir index c25b9a2c..c671b94c 100644 --- a/hercules_test/test_inputs/forkify/loop_simple_iv.hir +++ b/hercules_test/test_inputs/forkify/loop_simple_iv.hir @@ -7,6 +7,6 @@ fn loop<1>(a: u32) -> u64 idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, idx) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, idx) diff --git a/hercules_test/test_inputs/forkify/loop_sum.hir b/hercules_test/test_inputs/forkify/loop_sum.hir index fd9c4deb..a236ddf7 100644 --- a/hercules_test/test_inputs/forkify/loop_sum.hir +++ b/hercules_test/test_inputs/forkify/loop_sum.hir @@ -11,6 +11,6 @@ fn loop<1>(a: u32) -> i32 idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, var) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, var) diff --git a/hercules_test/test_inputs/forkify/loop_tid_sum.hir b/hercules_test/test_inputs/forkify/loop_tid_sum.hir index 2d3ca34d..6a1e2c56 100644 --- a/hercules_test/test_inputs/forkify/loop_tid_sum.hir +++ b/hercules_test/test_inputs/forkify/loop_tid_sum.hir @@ -11,6 +11,6 @@ fn loop<1>(a: u64) -> u64 idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, var) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, var) diff --git a/hercules_test/test_inputs/forkify/merged_phi_cycle.hir b/hercules_test/test_inputs/forkify/merged_phi_cycle.hir index cee473a0..2b276d3e 100644 --- a/hercules_test/test_inputs/forkify/merged_phi_cycle.hir +++ b/hercules_test/test_inputs/forkify/merged_phi_cycle.hir @@ -13,6 +13,6 @@ fn sum<1>(a: i32) -> u64 second_red_add_2 = add(first_red_add, two) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, first_red_add_2) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, first_red_add_2) diff --git a/hercules_test/test_inputs/forkify/nested_loop2.hir b/hercules_test/test_inputs/forkify/nested_loop2.hir index 0f29ec74..c3c7d8e5 100644 --- a/hercules_test/test_inputs/forkify/nested_loop2.hir +++ b/hercules_test/test_inputs/forkify/nested_loop2.hir @@ -17,9 +17,9 @@ fn loop<2>(a: u32) -> i32 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx, outer_bound) inner_if = if(inner_loop, inner_in_bounds) - inner_if_false = projection(inner_if, 0) - inner_if_true = projection(inner_if, 1) + inner_if_false = control_projection(inner_if, 0) + inner_if_true = control_projection(inner_if, 1) outer_if = if(outer_loop, outer_in_bounds) - outer_if_false = projection(outer_if, 0) - outer_if_true = projection(outer_if, 1) - r = return(outer_if_false, outer_var) \ No newline at end of file + outer_if_false = control_projection(outer_if, 0) + outer_if_true = control_projection(outer_if, 1) + r = return(outer_if_false, outer_var) diff --git a/hercules_test/test_inputs/forkify/nested_tid_sum.hir b/hercules_test/test_inputs/forkify/nested_tid_sum.hir index 5539202d..f7e4bda4 100644 --- a/hercules_test/test_inputs/forkify/nested_tid_sum.hir +++ b/hercules_test/test_inputs/forkify/nested_tid_sum.hir @@ -17,9 +17,9 @@ fn loop<2>(a: u32) -> u64 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx, outer_bound) inner_if = if(inner_loop, inner_in_bounds) - inner_if_false = projection(inner_if, 0) - inner_if_true = projection(inner_if, 1) + inner_if_false = control_projection(inner_if, 0) + inner_if_true = control_projection(inner_if, 1) outer_if = if(outer_loop, outer_in_bounds) - outer_if_false = projection(outer_if, 0) - outer_if_true = projection(outer_if, 1) - r = return(outer_if_false, outer_var) \ No newline at end of file + outer_if_false = control_projection(outer_if, 0) + outer_if_true = control_projection(outer_if, 1) + r = return(outer_if_false, outer_var) diff --git a/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir b/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir index 9221fd47..50634a2c 100644 --- a/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir +++ b/hercules_test/test_inputs/forkify/nested_tid_sum_2.hir @@ -18,9 +18,9 @@ fn loop<2>(a: u32) -> u64 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx, outer_bound) inner_if = if(inner_loop, inner_in_bounds) - inner_if_false = projection(inner_if, 0) - inner_if_true = projection(inner_if, 1) + inner_if_false = control_projection(inner_if, 0) + inner_if_true = control_projection(inner_if, 1) outer_if = if(outer_loop, outer_in_bounds) - outer_if_false = projection(outer_if, 0) - outer_if_true = projection(outer_if, 1) - r = return(outer_if_false, outer_var) \ No newline at end of file + outer_if_false = control_projection(outer_if, 0) + outer_if_true = control_projection(outer_if, 1) + r = return(outer_if_false, outer_var) diff --git a/hercules_test/test_inputs/forkify/phi_loop4.hir b/hercules_test/test_inputs/forkify/phi_loop4.hir index e69ecc3d..9ce594da 100644 --- a/hercules_test/test_inputs/forkify/phi_loop4.hir +++ b/hercules_test/test_inputs/forkify/phi_loop4.hir @@ -11,6 +11,6 @@ fn loop<1>(a: u32) -> i32 idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, var_inc) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, var_inc) diff --git a/hercules_test/test_inputs/forkify/split_phi_cycle.hir b/hercules_test/test_inputs/forkify/split_phi_cycle.hir index 96de73c8..a233230b 100644 --- a/hercules_test/test_inputs/forkify/split_phi_cycle.hir +++ b/hercules_test/test_inputs/forkify/split_phi_cycle.hir @@ -11,6 +11,6 @@ fn sum<1>(a: i32) -> u64 first_red_add_2 = add(first_red_add, two) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, first_red_add_2) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, first_red_add_2) diff --git a/hercules_test/test_inputs/forkify/super_nested_loop.hir b/hercules_test/test_inputs/forkify/super_nested_loop.hir index 6853efbf..b568b85a 100644 --- a/hercules_test/test_inputs/forkify/super_nested_loop.hir +++ b/hercules_test/test_inputs/forkify/super_nested_loop.hir @@ -16,18 +16,18 @@ fn loop<3>(a: u32) -> i32 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx, outer_bound) inner_if = if(inner_loop, inner_in_bounds) - inner_if_false = projection(inner_if, 0) - inner_if_true = projection(inner_if, 1) + inner_if_false = control_projection(inner_if, 0) + inner_if_true = control_projection(inner_if, 1) outer_if = if(outer_loop, outer_in_bounds) - outer_if_false = projection(outer_if, 0) - outer_if_true = projection(outer_if, 1) + outer_if_false = control_projection(outer_if, 0) + outer_if_true = control_projection(outer_if, 1) outer_bound = dynamic_constant(#1) outer_outer_bound = dynamic_constant(#2) outer_outer_loop = region(start, outer_if_false) outer_outer_var = phi(outer_outer_loop, zero_var, outer_var) outer_outer_if = if(outer_outer_loop, outer_outer_in_bounds) - outer_outer_if_false = projection(outer_outer_if, 0) - outer_outer_if_true = projection(outer_outer_if, 1) + outer_outer_if_false = control_projection(outer_outer_if, 0) + outer_outer_if_true = control_projection(outer_outer_if, 1) outer_outer_idx = phi(outer_outer_loop, zero_idx, outer_outer_idx_inc, outer_outer_idx) outer_outer_idx_inc = add(outer_outer_idx, one_idx) outer_outer_in_bounds = lt(outer_outer_idx, outer_outer_bound) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir index 4df92a18..a6ae209b 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds.hir @@ -9,6 +9,6 @@ fn sum<1>(a: u32) -> u64 red_add = add(red, one_idx) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir index 8b4431bf..9bd6b626 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control.hir @@ -15,9 +15,9 @@ fn sum<1>(a: u64) -> u64 red_add2 = add(red, inner_phi) in_bounds = lt(idx_inc, bound) if = if(inner_ctrl, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) plus_ten = add(red_add, ten) red_add_2_plus_blah = add(red2, plus_ten) final_add = add(inner_phi, red_add_2_plus_blah) - r = return(if_false, final_add) \ No newline at end of file + r = return(if_false, final_add) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir index f4adf643..2801a165 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_internal_control2.hir @@ -13,9 +13,9 @@ fn sum<1>(a: u64) -> u64 red_add = add(red, two) in_bounds = lt(idx_inc, bound) if = if(inner_ctrl, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) plus_ten = add(red_add, ten) red_add_2_plus_blah = add(inner_phi, plus_ten) final_add = add(inner_phi, red_add_2_plus_blah) - r = return(if_false, final_add) \ No newline at end of file + r = return(if_false, final_add) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir index 52f70172..edbec0c5 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop.hir @@ -20,9 +20,9 @@ fn loop<2>(a: u64) -> u64 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx_inc, outer_bound) inner_if = if(inner_loop, inner_in_bounds) - inner_if_false = projection(inner_if, 0) - inner_if_true = projection(inner_if, 1) + inner_if_false = control_projection(inner_if, 0) + inner_if_true = control_projection(inner_if, 1) outer_if = if(inner_if_false, outer_in_bounds) - outer_if_false = projection(outer_if, 0) - outer_if_true = projection(outer_if, 1) - r = return(outer_if_false, inner_var_inc) \ No newline at end of file + outer_if_false = control_projection(outer_if, 0) + outer_if_true = control_projection(outer_if, 1) + r = return(outer_if_false, inner_var_inc) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir index f295b391..4e81871c 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop2.hir @@ -17,9 +17,9 @@ fn loop<2>(a: u32) -> i32 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx_inc, outer_bound) inner_if = if(inner_loop, inner_in_bounds) - inner_if_false = projection(inner_if, 0) - inner_if_true = projection(inner_if, 1) + inner_if_false = control_projection(inner_if, 0) + inner_if_true = control_projection(inner_if, 1) outer_if = if(inner_if_false, outer_in_bounds) - outer_if_false = projection(outer_if, 0) - outer_if_true = projection(outer_if, 1) - r = return(outer_if_false, inner_var_inc) \ No newline at end of file + outer_if_false = control_projection(outer_if, 0) + outer_if_true = control_projection(outer_if, 1) + r = return(outer_if_false, inner_var_inc) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir index e5401779..98477c91 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_array.hir @@ -20,9 +20,9 @@ fn loop<2>(a: array(u64, #1)) -> u64 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx_inc, outer_bound) inner_if = if(inner_loop, inner_in_bounds) - inner_if_false = projection(inner_if, 0) - inner_if_true = projection(inner_if, 1) + inner_if_false = control_projection(inner_if, 0) + inner_if_true = control_projection(inner_if, 1) outer_if = if(inner_if_false, outer_in_bounds) - outer_if_false = projection(outer_if, 0) - outer_if_true = projection(outer_if, 1) - r = return(outer_if_false, inner_var_inc) \ No newline at end of file + outer_if_false = control_projection(outer_if, 0) + outer_if_true = control_projection(outer_if, 1) + r = return(outer_if_false, inner_var_inc) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir index b979ad42..eee77b6c 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_nested_do_loop_guarded.hir @@ -5,8 +5,8 @@ fn loop<2>(a: u64) -> u64 one_var = constant(u64, 1) ten = constant(u64, 10) outer_guard_if = if(start, outer_guard_lt) - outer_guard_if_false = projection(outer_guard_if, 0) - outer_guard_if_true = projection(outer_guard_if, 1) + outer_guard_if_false = control_projection(outer_guard_if, 0) + outer_guard_if_true = control_projection(outer_guard_if, 1) outer_guard_lt = lt(zero_idx, outer_bound) outer_join = region(outer_guard_if_false, outer_if_false) outer_join_var = phi(outer_join, zero_idx, join_var) @@ -16,8 +16,8 @@ fn loop<2>(a: u64) -> u64 inner_loop = region(guard_if_true, inner_if_true) guard_lt = lt(zero_idx, inner_bound) guard_if = if(outer_loop, guard_lt) - guard_if_true = projection(guard_if, 1) - guard_if_false = projection(guard_if, 0) + guard_if_true = control_projection(guard_if, 1) + guard_if_false = control_projection(guard_if, 0) guard_join = region(guard_if_false, inner_if_false) inner_idx = phi(inner_loop, zero_idx, inner_idx_inc) inner_idx_inc = add(inner_idx, one_idx) @@ -26,15 +26,15 @@ fn loop<2>(a: u64) -> u64 outer_idx_inc = add(outer_idx, one_idx) outer_in_bounds = lt(outer_idx_inc, outer_bound) inner_if = if(inner_loop, inner_in_bounds) - inner_if_false = projection(inner_if, 0) - inner_if_true = projection(inner_if, 1) + inner_if_false = control_projection(inner_if, 0) + inner_if_true = control_projection(inner_if, 1) outer_if = if(guard_join, outer_in_bounds) - outer_if_false = projection(outer_if, 0) - outer_if_true = projection(outer_if, 1) + outer_if_false = control_projection(outer_if, 0) + outer_if_true = control_projection(outer_if, 1) outer_var = phi(outer_loop, zero_var, join_var) inner_var = phi(inner_loop, outer_var, inner_var_inc) blah = mul(outer_idx, ten) blah2 = add(blah, inner_idx) inner_var_inc = add(inner_var, blah2) join_var = phi(guard_join, outer_var, inner_var_inc) - r = return(outer_join, outer_join_var) \ No newline at end of file + r = return(outer_join, outer_join_var) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir index 2fe4ca57..4a6e8cd6 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop.hir @@ -13,9 +13,9 @@ fn sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) plus_ten = add(red_add, ten) mult = mul(read, three) final = add(plus_ten, mult) - r = return(if_false, final) \ No newline at end of file + r = return(if_false, final) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir index 760ae5ad..f735c8c6 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop2.hir @@ -13,9 +13,9 @@ fn sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) plus_ten = add(red, ten) mult = mul(read, three) final = add(plus_ten, mult) - r = return(if_false, final) \ No newline at end of file + r = return(if_false, final) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir index 4b937509..c2f5e30a 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid.hir @@ -11,7 +11,7 @@ fn sum<1>(a: u64) -> u64 red_add = add(red, two) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) plus_ten = add(red_add, ten) - r = return(if_false, plus_ten) \ No newline at end of file + r = return(if_false, plus_ten) diff --git a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir index fd06eb7d..f7a4af06 100644 --- a/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir +++ b/hercules_test/test_inputs/loop_analysis/alternate_bounds_use_after_loop_no_tid2.hir @@ -12,8 +12,8 @@ fn sum<1>(a: u64) -> u64 blah = phi(loop, zero_idx, red_add) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) plus_ten = add(red_add, ten) plus_blah = add(blah, red_add) - r = return(if_false, plus_blah) \ No newline at end of file + r = return(if_false, plus_blah) diff --git a/hercules_test/test_inputs/loop_analysis/broken_sum.hir b/hercules_test/test_inputs/loop_analysis/broken_sum.hir index d15ef561..75b12350 100644 --- a/hercules_test/test_inputs/loop_analysis/broken_sum.hir +++ b/hercules_test/test_inputs/loop_analysis/broken_sum.hir @@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red_add) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red_add) diff --git a/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir b/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir index 4df92a18..a6ae209b 100644 --- a/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir +++ b/hercules_test/test_inputs/loop_analysis/do_loop_far_guard.hir @@ -9,6 +9,6 @@ fn sum<1>(a: u32) -> u64 red_add = add(red, one_idx) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red) diff --git a/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir b/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir index a4732cde..bfdb673f 100644 --- a/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir +++ b/hercules_test/test_inputs/loop_analysis/do_loop_immediate_guard.hir @@ -4,8 +4,8 @@ fn sum<1>(a: u64) -> u64 bound = dynamic_constant(#0) guard_lt = lt(zero_idx, bound) guard = if(start, guard_lt) - guard_true = projection(guard, 1) - guard_false = projection(guard, 0) + guard_true = control_projection(guard, 1) + guard_false = control_projection(guard, 0) loop = region(guard_true, if_true) inner_side_effect = region(loop) idx = phi(loop, zero_idx, idx_inc) @@ -15,7 +15,7 @@ fn sum<1>(a: u64) -> u64 join_phi = phi(final, zero_idx, red_add) in_bounds = lt(idx_inc, bound) if = if(inner_side_effect, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) final = region(guard_false, if_false) - r = return(final, join_phi) \ No newline at end of file + r = return(final, join_phi) diff --git a/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir b/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir index 9e22e14b..d48fe062 100644 --- a/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir +++ b/hercules_test/test_inputs/loop_analysis/do_loop_no_guard.hir @@ -10,6 +10,6 @@ fn sum<1>(a: u64) -> u64 red_add = add(red, one_idx) in_bounds = lt(idx_inc, bound) if = if(inner_side_effect, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red_add) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red_add) diff --git a/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir b/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir index 42269040..435b6268 100644 --- a/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir +++ b/hercules_test/test_inputs/loop_analysis/do_while_separate_body.hir @@ -11,6 +11,6 @@ fn sum<1>(a: i32) -> u64 red_add = add(outer_red, idx) in_bounds = lt(idx_inc, bound) if = if(inner_region, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, inner_red) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, inner_red) diff --git a/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir b/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir index a751952d..d1e2d4d6 100644 --- a/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir +++ b/hercules_test/test_inputs/loop_analysis/do_while_separate_body2.hir @@ -13,6 +13,6 @@ fn sum<1>(a: i32) -> u64 red_mul = mul(red_add, idx) in_bounds = lt(idx_inc, bound) if = if(inner_region, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, inner_red) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, inner_red) diff --git a/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir b/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir index f9972b59..884d22d4 100644 --- a/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir +++ b/hercules_test/test_inputs/loop_analysis/loop_array_sum.hir @@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red) diff --git a/hercules_test/test_inputs/loop_analysis/loop_body_count.hir b/hercules_test/test_inputs/loop_analysis/loop_body_count.hir index c6f3cbf6..5ec745ba 100644 --- a/hercules_test/test_inputs/loop_analysis/loop_body_count.hir +++ b/hercules_test/test_inputs/loop_analysis/loop_body_count.hir @@ -11,6 +11,6 @@ fn loop<1>(a: u64) -> u64 idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, var) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, var) diff --git a/hercules_test/test_inputs/loop_analysis/loop_sum.hir b/hercules_test/test_inputs/loop_analysis/loop_sum.hir index fd9c4deb..a236ddf7 100644 --- a/hercules_test/test_inputs/loop_analysis/loop_sum.hir +++ b/hercules_test/test_inputs/loop_analysis/loop_sum.hir @@ -11,6 +11,6 @@ fn loop<1>(a: u32) -> i32 idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, var) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, var) diff --git a/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir b/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir index b756f090..799cc6d9 100644 --- a/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir +++ b/hercules_test/test_inputs/loop_analysis/loop_trip_count.hir @@ -12,8 +12,8 @@ fn loop<1>(b: prod(u64, u64)) -> prod(u64, u64) idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) tuple1 = write(c, var, field(0)) tuple2 = write(tuple1, idx, field(1)) - r = return(if_false, tuple2) \ No newline at end of file + r = return(if_false, tuple2) diff --git a/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir b/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir index b756f090..799cc6d9 100644 --- a/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir +++ b/hercules_test/test_inputs/loop_analysis/loop_trip_count_tuple.hir @@ -12,8 +12,8 @@ fn loop<1>(b: prod(u64, u64)) -> prod(u64, u64) idx_inc = add(idx, one_idx) in_bounds = lt(idx, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) tuple1 = write(c, var, field(0)) tuple2 = write(tuple1, idx, field(1)) - r = return(if_false, tuple2) \ No newline at end of file + r = return(if_false, tuple2) diff --git a/hercules_test/test_inputs/simple2.hir b/hercules_test/test_inputs/simple2.hir index af5ac284..d4f1bebe 100644 --- a/hercules_test/test_inputs/simple2.hir +++ b/hercules_test/test_inputs/simple2.hir @@ -8,6 +8,6 @@ fn simple2(x: i32) -> i32 fac_acc = mul(fac, idx_inc) in_bounds = lt(idx_inc, x) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, fac_acc) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, fac_acc) diff --git a/hercules_test/test_inputs/strset.hir b/hercules_test/test_inputs/strset.hir index 4c8b32ee..e8615f21 100644 --- a/hercules_test/test_inputs/strset.hir +++ b/hercules_test/test_inputs/strset.hir @@ -12,6 +12,6 @@ fn strset<1>(str: array(u8, #0), byte: u8) -> array(u8, #0) continue = ne(read, byte) if_cond = and(continue, in_bounds) if = if(loop, if_cond) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) r = return(if_false, str_inc) diff --git a/hercules_test/test_inputs/sum_int1.hir b/hercules_test/test_inputs/sum_int1.hir index 4a9ba015..7de8cf1e 100644 --- a/hercules_test/test_inputs/sum_int1.hir +++ b/hercules_test/test_inputs/sum_int1.hir @@ -11,6 +11,6 @@ fn sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read) in_bounds = lt(idx_inc, bound) if = if(loop, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) - r = return(if_false, red_add) \ No newline at end of file + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) + r = return(if_false, red_add) diff --git a/hercules_test/test_inputs/sum_int2.hir b/hercules_test/test_inputs/sum_int2.hir index b5e9a5c0..bc614d4e 100644 --- a/hercules_test/test_inputs/sum_int2.hir +++ b/hercules_test/test_inputs/sum_int2.hir @@ -11,8 +11,8 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32 rem = rem(idx, two_idx) odd = eq(rem, one_idx) negate_if = if(loop, odd) - negate_if_false = projection(negate_if, 0) - negate_if_true = projection(negate_if, 1) + negate_if_false = control_projection(negate_if, 0) + negate_if_true = control_projection(negate_if, 1) negate_bottom = region(negate_if_false, negate_if_true) read = read(a, position(idx)) read_neg = neg(read) @@ -20,6 +20,6 @@ fn alt_sum<1>(a: array(i32, #0)) -> i32 red_add = add(red, read_phi) in_bounds = lt(idx_inc, bound) if = if(negate_bottom, in_bounds) - if_false = projection(if, 0) - if_true = projection(if, 1) + if_false = control_projection(if, 0) + if_true = control_projection(if, 1) r = return(if_false, red_add) -- GitLab From 8540083f49c966bf9a5e9fd97f619a76ba9f96b5 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 21 Feb 2025 09:43:27 -0600 Subject: [PATCH 18/19] Unforkify mutli_return, gpu bug --- juno_samples/multi_return/src/gpu.sch | 1 + 1 file changed, 1 insertion(+) diff --git a/juno_samples/multi_return/src/gpu.sch b/juno_samples/multi_return/src/gpu.sch index 0c0b569a..f690086b 100644 --- a/juno_samples/multi_return/src/gpu.sch +++ b/juno_samples/multi_return/src/gpu.sch @@ -26,5 +26,6 @@ gvn(*); dce(*); float-collections(*); +unforkify(*); gcm(*); -- GitLab From 78c4ad8319cce45036c4ecf205adbd9f895f4e86 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 21 Feb 2025 09:45:01 -0600 Subject: [PATCH 19/19] Formatting --- hercules_cg/src/cpu.rs | 20 +-- hercules_cg/src/gpu.rs | 93 +++++++---- hercules_cg/src/rt.rs | 150 ++++++++++++------ hercules_ir/src/collections.rs | 40 +++-- hercules_ir/src/parse.rs | 6 +- hercules_opt/src/gcm.rs | 28 ++-- hercules_opt/src/interprocedural_sroa.rs | 37 +++-- hercules_opt/src/sroa.rs | 2 +- .../hercules_interpreter/src/interpreter.rs | 3 +- juno_samples/multi_return/src/main.rs | 18 +-- juno_samples/rodinia/backprop/src/main.rs | 38 ++--- juno_scheduler/src/pm.rs | 30 ++-- 12 files changed, 284 insertions(+), 181 deletions(-) diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 1f3ab0a4..6ad38fc0 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -82,17 +82,14 @@ impl<'a> CPUContext<'a> { w, "%return.{} = type {{ {} }}\n", self.function.name, - self.function.return_types + self.function + .return_types .iter() .map(|t| self.get_type(*t)) .collect::<Vec<_>>() .join(", "), )?; - write!( - w, - "define dso_local void @{}(", - self.function.name, - )?; + write!(w, "define dso_local void @{}(", self.function.name,)?; } let mut first_param = true; // The first parameter is a pointer to CPU backing memory, if it's @@ -216,7 +213,9 @@ impl<'a> CPUContext<'a> { let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); - let succ1_is_true = self.function.nodes[succ1.idx()].try_control_projection(1).is_some(); + let succ1_is_true = self.function.nodes[succ1.idx()] + .try_control_projection(1) + .is_some(); write!( term, " br {}, label %{}, label %{}\n", @@ -225,7 +224,10 @@ impl<'a> CPUContext<'a> { self.get_block_name(if succ1_is_true { succ2 } else { succ1 }), )? } - Node::Return { control: _, ref data } => { + Node::Return { + control: _, + ref data, + } => { if data.len() == 1 { let ret_data = data[0]; let term = &mut blocks.get_mut(&id).unwrap().term; @@ -1027,7 +1029,7 @@ fn convert_intrinsic(intrinsic: &Intrinsic, ty: &Type) -> String { } else { panic!() } - }, + } Intrinsic::ACos => "acos", Intrinsic::ASin => "asin", Intrinsic::ATan => "atan", diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 453d33d5..76aba7e0 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -153,12 +153,14 @@ pub fn gpu_codegen<W: Write>( // Tracks for each return value whether it is always the same parameter // collection let return_parameters = (0..function.return_types.len()) - .map(|idx| if collection_objects.returned_objects(idx).len() == 1 { - collection_objects - .origin(*collection_objects.returned_objects(idx).first().unwrap()) - .try_parameter() - } else { - None + .map(|idx| { + if collection_objects.returned_objects(idx).len() == 1 { + collection_objects + .origin(*collection_objects.returned_objects(idx).first().unwrap()) + .try_parameter() + } else { + None + } }) .collect::<Vec<_>>(); @@ -373,12 +375,17 @@ namespace cg = cooperative_groups; } fn codegen_return_struct<W: Write>(&self, w: &mut W) -> Result<(), Error> { - write!(w, "struct return_{} {{ {} }};\n", - self.function.name, - self.function.return_types.iter().enumerate() - .map(|(idx, typ)| format!("{} f{};", self.get_type(*typ, false), idx)) - .collect::<Vec<_>>() - .join(" "), + write!( + w, + "struct return_{} {{ {} }};\n", + self.function.name, + self.function + .return_types + .iter() + .enumerate() + .map(|(idx, typ)| format!("{} f{};", self.get_type(*typ, false), idx)) + .collect::<Vec<_>>() + .join(" "), ) } @@ -425,10 +432,12 @@ namespace cg = cooperative_groups; .return_parameters .iter() .enumerate() - .filter_map(|(idx, param)| if param.is_some() { - None - } else { - Some((idx, self.function.return_types[idx])) + .filter_map(|(idx, param)| { + if param.is_some() { + None + } else { + Some((idx, self.function.return_types[idx])) + } }) .collect::<Vec<(usize, TypeID)>>(); if !ret_fields.is_empty() { @@ -682,22 +691,26 @@ namespace cg = cooperative_groups; // For case of dynamic block count self.codegen_dynamic_constants(w)?; - let (kernel_returns, param_returns) = - self.return_parameters.iter().enumerate() - .fold((vec![], vec![]), - |(mut kernel_returns, mut param_returns), (idx, param)| { - if let Some(param_idx) = param { - param_returns.push((idx, param_idx)); - } else { - kernel_returns.push((idx, self.function.return_types[idx])); - } - (kernel_returns, param_returns) - }); + let (kernel_returns, param_returns) = self.return_parameters.iter().enumerate().fold( + (vec![], vec![]), + |(mut kernel_returns, mut param_returns), (idx, param)| { + if let Some(param_idx) = param { + param_returns.push((idx, param_idx)); + } else { + kernel_returns.push((idx, self.function.return_types[idx])); + } + (kernel_returns, param_returns) + }, + ); if !kernel_returns.is_empty() { // Allocate kernel return struct write!(w, "\treturn_{}* ret_cuda;\n", self.function.name)?; - write!(w, "\tcudaMalloc((void**)&ret_cuda, sizeof(return_{}));\n", self.function.name)?; + write!( + w, + "\tcudaMalloc((void**)&ret_cuda, sizeof(return_{}));\n", + self.function.name + )?; // Add the return pointer to the kernel arguments if !first_param { write!(pass_args, ", ")?; @@ -737,7 +750,8 @@ namespace cg = cooperative_groups; // the parameter returns if !kernel_returns.is_empty() { // Copy from the device directly into the output struct - write!(w, + write!( + w, "\tcudaMemcpy(ret_ptr, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n", self.function.name, )?; @@ -1627,7 +1641,9 @@ namespace cg = cooperative_groups; let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); - let succ1_is_true = self.function.nodes[succ1.idx()].try_control_projection(1).is_some(); + let succ1_is_true = self.function.nodes[succ1.idx()] + .try_control_projection(1) + .is_some(); let succ1_block_name = self.get_block_name(succ1, false); let succ2_block_name = self.get_block_name(succ2, false); write!( @@ -1780,14 +1796,23 @@ namespace cg = cooperative_groups; } tabs } - Node::Return { control: _, ref data } => { + Node::Return { + control: _, + ref data, + } => { write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?; - for (idx, (data, param)) in data.iter().zip(self.return_parameters.iter()).enumerate() { + for (idx, (data, param)) in + data.iter().zip(self.return_parameters.iter()).enumerate() + { // For return values that are not identical to some parameter, we write it into // the output struct if !param.is_some() { - write!(w_term, "\t\tret->f{} = {};\n", idx, - self.get_value(*data, false, false))?; + write!( + w_term, + "\t\tret->f{} = {};\n", + idx, + self.get_value(*data, false, false) + )?; } } write!(w_term, "\t}}\n")?; diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 19cf29f0..884129c7 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -218,13 +218,16 @@ impl<'a> RTContext<'a> { // Generate the wrapper function for multi-return device functions write!(w, " {{ ")?; // Define the return struct - write!(w, "#[repr(C)] struct ReturnStruct {{ {} }} ", - callee.return_types - .iter() - .enumerate() - .map(|(idx, t)| format!("f{}: {}", idx, self.get_type(*t))) - .collect::<Vec<_>>() - .join(", "), + write!( + w, + "#[repr(C)] struct ReturnStruct {{ {} }} ", + callee + .return_types + .iter() + .enumerate() + .map(|(idx, t)| format!("f{}: {}", idx, self.get_type(*t))) + .collect::<Vec<_>>() + .join(", "), )?; // Declare the extern function's signature write!(w, "extern \"C\" {{ ")?; @@ -234,7 +237,8 @@ impl<'a> RTContext<'a> { write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?; // Call the device function write!(w, "{}(", callee.name)?; - if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { + if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) + { write!(w, "backing, ")?; } for idx in 0..callee.num_dynamic_constants { @@ -246,11 +250,13 @@ impl<'a> RTContext<'a> { write!(w, "ret_struct.as_mut_ptr());")?; // Extract the result into a Rust product write!(w, "let ret_struct = ret_struct.assume_init();")?; - write!(w, "({})", + write!( + w, + "({})", (0..callee.return_types.len()) - .map(|idx| format!("ret_struct.f{}", idx)) - .collect::<Vec<_>>() - .join(", "), + .map(|idx| format!("ret_struct.f{}", idx)) + .collect::<Vec<_>>() + .join(", "), )?; write!(w, "}}")?; } @@ -358,14 +364,19 @@ impl<'a> RTContext<'a> { if succ1_is_true { succ2 } else { succ1 }.idx(), )?; } - Node::Return { control: _, ref data } => { + Node::Return { + control: _, + ref data, + } => { let prologue = &mut blocks.get_mut(&id).unwrap().prologue; write!(prologue, "{} => {{", id.idx())?; let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; if data.len() == 1 { write!(epilogue, "return {};}}", self.get_value(data[0], id, false))?; } else { - write!(epilogue, "return ({});}}", + write!( + epilogue, + "return ({});}}", data.iter() .map(|v| self.get_value(*v, id, false)) .collect::<Vec<_>>() @@ -684,20 +695,28 @@ impl<'a> RTContext<'a> { } Node::DataProjection { data, selection } => { let block = &mut blocks.get_mut(&bb).unwrap().data; - let Node::Call { function: callee_id, .. } = func.nodes[data.idx()] else { + let Node::Call { + function: callee_id, + .. + } = func.nodes[data.idx()] + else { panic!() }; if self.module.functions[callee_id.idx()].return_types.len() == 1 { assert!(selection == 0); - write!(block, "{} = {};", - self.get_value(id, bb, true), - self.get_value(data, bb, false), + write!( + block, + "{} = {};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), )?; } else { - write!(block, "{} = {}.{};", - self.get_value(id, bb, true), - self.get_value(data, bb, false), - selection, + write!( + block, + "{} = {}.{};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), + selection, )?; } } @@ -1296,11 +1315,14 @@ impl<'a> RTContext<'a> { // its own lifetime. We use lifetime bounds to ensure that the runner // and parameters are borrowed for the lifetimes needed by the outputs let returned_origins: Vec<HashSet<_>> = (0..num_returns) - .map(|idx| objects.returned_objects(idx) - .iter() - .map(|obj| objects.origin(*obj)) - .collect() - ).collect(); + .map(|idx| { + objects + .returned_objects(idx) + .iter() + .map(|obj| objects.origin(*obj)) + .collect() + }) + .collect(); write!(w, "async fn run<'runner:")?; for (ret_idx, origins) in returned_origins.iter().enumerate() { @@ -1314,11 +1336,12 @@ impl<'a> RTContext<'a> { for idx in 0..func.param_types.len() { write!(w, ", 'p{}:", idx)?; for (ret_idx, origins) in returned_origins.iter().enumerate() { - if origins.iter().any(|origin| origin - .try_parameter() - .map(|oidx| idx == oidx) - .unwrap_or(false)) - { + if origins.iter().any(|origin| { + origin + .try_parameter() + .map(|oidx| idx == oidx) + .unwrap_or(false) + }) { write!(w, " 'r{} +", ret_idx)?; } } @@ -1340,18 +1363,19 @@ impl<'a> RTContext<'a> { write!( w, ", p{}: ::hercules_rt::Hercules{}Ref{}<'p{}>", - idx, - device, - mutability, - idx, + idx, device, mutability, idx, )?; } } - write!(w, ") -> {}{}{} {{", + write!( + w, + ") -> {}{}{} {{", if num_returns != 1 { "(" } else { "" }, - func.return_types.iter().enumerate() - .map(|(ret_idx, typ)| - if self.module.types[typ.idx()].is_primitive() { + func.return_types + .iter() + .enumerate() + .map( + |(ret_idx, typ)| if self.module.types[typ.idx()].is_primitive() { self.get_type(*typ) } else { let device = match return_devices[ret_idx] { @@ -1360,8 +1384,10 @@ impl<'a> RTContext<'a> { _ => panic!(), }; let mutability = if return_muts[ret_idx] { "Mut" } else { "" }; - format!("::hercules_rt::Hercules{}Ref{}<'r{}>", - device, mutability, ret_idx) + format!( + "::hercules_rt::Hercules{}Ref{}<'r{}>", + device, mutability, ret_idx + ) } ) .collect::<Vec<_>>() @@ -1535,7 +1561,13 @@ impl<'a> RTContext<'a> { } else if typ.is_float() { "0.0".to_string() } else if let Some(ts) = typ.try_multi_return() { - format!("({})", ts.iter().map(|t| self.get_default_value(*t)).collect::<Vec<_>>().join(", ")) + format!( + "({})", + ts.iter() + .map(|t| self.get_default_value(*t)) + .collect::<Vec<_>>() + .join(", ") + ) } else { "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())".to_string() } @@ -1545,7 +1577,9 @@ impl<'a> RTContext<'a> { if tys.len() == 1 { write!(w, "{}", self.get_type(tys[0])) } else { - write!(w, "({})", + write!( + w, + "({})", tys.iter() .map(|t| self.get_type(*t)) .collect::<Vec<_>>() @@ -1558,9 +1592,19 @@ impl<'a> RTContext<'a> { // this means that if the function is multi-return it will return a product in the produced // Rust code // Writes from the "fn" keyword up to the end of the return type - fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID, is_unsafe: bool) -> Result<(), Error> { + fn write_device_signature_async<W: Write>( + &self, + w: &mut W, + func_id: FunctionID, + is_unsafe: bool, + ) -> Result<(), Error> { let func = &self.module.functions[func_id.idx()]; - write!(w, "{}fn {}(", if is_unsafe { "unsafe " } else { "" }, func.name)?; + write!( + w, + "{}fn {}(", + if is_unsafe { "unsafe " } else { "" }, + func.name + )?; let mut first_param = true; if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) { first_param = false; @@ -1588,7 +1632,11 @@ impl<'a> RTContext<'a> { // Writes the true signature of a device function // Compared to the _async version this converts multi-return into a return struct - fn write_device_signature<W: Write>(&self, w: &mut W, func_id: FunctionID) -> Result<(), Error> { + fn write_device_signature<W: Write>( + &self, + w: &mut W, + func_id: FunctionID, + ) -> Result<(), Error> { let func = &self.module.functions[func_id.idx()]; write!(w, "fn {}(", func.name)?; let mut first_param = true; @@ -1656,7 +1704,13 @@ fn convert_type(ty: &Type, types: &[Type]) -> String { "::hercules_rt::__RawPtrSendSync".to_string() } Type::MultiReturn(ts) => { - format!("({})", ts.iter().map(|t| convert_type(&types[t.idx()], types)).collect::<Vec<_>>().join(", ")) + format!( + "({})", + ts.iter() + .map(|t| convert_type(&types[t.idx()], types)) + .collect::<Vec<_>>() + .join(", ") + ) } _ => panic!(), } diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index cc0703ab..60f4fb1c 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -97,7 +97,9 @@ impl FunctionCollectionObjects { } pub fn all_returned_objects(&self) -> impl Iterator<Item = CollectionObjectID> + '_ { - self.returned.iter().flat_map(|colls| colls.iter().map(|c| *c)) + self.returned + .iter() + .flat_map(|colls| colls.iter().map(|c| *c)) } pub fn is_mutated(&self, object: CollectionObjectID) -> bool { @@ -187,19 +189,21 @@ pub fn collection_objects( Some(CollectionObjectOrigin::Constant(NodeID::new(idx))) } Node::DataProjection { data, selection } => { - let Node::Call { + let Node::Call { control: _, function: callee, dynamic_constants: _, args: _, - } = func.nodes[data.idx()] else { + } = func.nodes[data.idx()] + else { panic!("Data-projection's data is not a call node"); }; let fco = &collection_objects[&callee]; if fco.returned[*selection] - .iter() - .any(|returned| fco.origins[returned.idx()].try_parameter().is_some()) { + .iter() + .any(|returned| fco.origins[returned.idx()].try_parameter().is_some()) + { // If the callee may return a new collection object, then // this data projection node originates a single collection object. The // node may output multiple collection objects, say if the @@ -283,14 +287,16 @@ pub fn collection_objects( objs: obj.into_iter().collect(), } } - Node::DataProjection { data, selection } - if !types[typing[id.idx()].idx()].is_primitive() => { + Node::DataProjection { data, selection } + if !types[typing[id.idx()].idx()].is_primitive() => + { let Node::Call { control: _, function: callee, dynamic_constants: _, ref args, - } = func.nodes[data.idx()] else { + } = func.nodes[data.idx()] + else { panic!(); }; @@ -299,8 +305,7 @@ pub fn collection_objects( .position(|origin| *origin == CollectionObjectOrigin::DataProjection(id)) .map(CollectionObjectID::new); let fco = &collection_objects[&callee]; - let param_objs = fco - .returned[selection] + let param_objs = fco.returned[selection] .iter() .filter_map(|returned| fco.origins[returned.idx()].try_parameter()) .map(|param_index| &global_input[args[param_index].idx()]); @@ -346,7 +351,8 @@ pub fn collection_objects( .collect(); // Look at the collection objects that each return value may take as input. - let mut returned: Vec<BTreeSet<CollectionObjectID>> = vec![BTreeSet::new(); func.return_types.len()]; + let mut returned: Vec<BTreeSet<CollectionObjectID>> = + vec![BTreeSet::new(); func.return_types.len()]; for node in func.nodes.iter() { if let Node::Return { control: _, data } = node { for (idx, node) in data.iter().enumerate() { @@ -354,7 +360,10 @@ pub fn collection_objects( } } } - let returned = returned.into_iter().map(|set| set.into_iter().collect()).collect(); + let returned = returned + .into_iter() + .map(|set| set.into_iter().collect()) + .collect(); // Determine which objects are potentially mutated. let mut mutated = vec![vec![]; origins.len()]; @@ -523,10 +532,11 @@ pub fn no_reset_constant_collections( collect: _, data, indices: _, - } => { - Either::Left(zip(once(&full_indices), once(data))) + } => Either::Left(zip(once(&full_indices), once(data))), + Node::Return { + control: _, + ref data, } - Node::Return { control: _, ref data } | Node::Call { control: _, function: _, diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index a5a05d0f..9462df4d 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -254,7 +254,8 @@ fn parse_function<'a>( nom::character::complete::multispace0, ), |text| parse_type_id(text, context), - ).parse(ir_text)?; + ) + .parse(ir_text)?; let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context)).parse(ir_text)?; // `nodes`, as returned by parsing, is in parse order, which may differ from @@ -512,7 +513,8 @@ fn parse_return<'a>( nom::character::complete::multispace0, ), parse_identifier, - ).parse(ir_text)?; + ) + .parse(ir_text)?; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::char(')')(ir_text)?.0; let control = context.borrow_mut().get_node_id(control); diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index c2ec4e94..d3119705 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -255,10 +255,7 @@ fn basic_blocks( dynamic_constants: _, args: _, } => bbs[idx] = Some(control), - Node::DataProjection { - data, - selection: _, - } => { + Node::DataProjection { data, selection: _ } => { let Node::Call { control, .. } = function.nodes[data.idx()] else { panic!(); }; @@ -514,10 +511,11 @@ fn basic_blocks( || function.nodes[id.idx()].is_undef()) && !types[typing[id.idx()].idx()].is_primitive(); let is_gpu_returned = devices[func_id.idx()] == Device::CUDA - && objects[&func_id] - .objects(id) - .into_iter() - .any(|obj| objects[&func_id].all_returned_objects().any(|ret| ret == *obj)); + && objects[&func_id].objects(id).into_iter().any(|obj| { + objects[&func_id] + .all_returned_objects() + .any(|ret| ret == *obj) + }); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); @@ -1330,16 +1328,14 @@ fn color_nodes( } } } - Node::DataProjection { - data, - selection, - } => { + Node::DataProjection { data, selection } => { let Node::Call { control: _, function: callee, dynamic_constants: _, ref args, - } = &nodes[data.idx()] else { + } = &nodes[data.idx()] + else { panic!() }; @@ -1388,7 +1384,11 @@ fn color_nodes( { assert!(func_colors.1[index].is_none(), "PANIC: Found multiple parameter nodes for the same index in GCM. Please just run GVN first."); func_colors.1[index] = Some(*device); - } else if let Node::Return { control: _, ref data } = nodes[id.idx()] { + } else if let Node::Return { + control: _, + ref data, + } = nodes[id.idx()] + { for (idx, val) in data.iter().enumerate() { if let Some(device) = func_colors.0.get(val) { assert!(func_colors.2[idx].is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix."); diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index 32fa9cc8..ad4ce19e 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -29,7 +29,11 @@ pub fn interprocedural_sroa( let callsites = get_callsites(editors); - for ((func_id, apply), callsites) in (0..func_selection.len()).map(FunctionID::new).zip(func_selection.iter()).zip(callsites.into_iter()) { + for ((func_id, apply), callsites) in (0..func_selection.len()) + .map(FunctionID::new) + .zip(func_selection.iter()) + .zip(callsites.into_iter()) + { if !apply { continue; } @@ -62,13 +66,17 @@ pub fn interprocedural_sroa( } // Now, modify each return in the current function and the return type - let return_nodes = editor.func().nodes + let return_nodes = editor + .func() + .nodes .iter() .enumerate() - .filter_map(|(idx, node)| if node.try_return().is_some() { - Some(NodeID::new(idx)) - } else { - None + .filter_map(|(idx, node)| { + if node.try_return().is_some() { + Some(NodeID::new(idx)) + } else { + None + } }) .collect::<Vec<_>>(); let success = editor.edit(|mut edit| { @@ -80,7 +88,9 @@ pub fn interprocedural_sroa( let data = data.to_vec(); let mut new_data = vec![]; - for (idx, (data_id, update_info)) in data.into_iter().zip(old_return_type_map.iter()).enumerate() { + for (idx, (data_id, update_info)) in + data.into_iter().zip(old_return_type_map.iter()).enumerate() + { if let IndexTree::Leaf(new_idx) = update_info { // Unchanged return value assert!(new_data.len() == *new_idx); @@ -128,7 +138,11 @@ pub fn interprocedural_sroa( } } -fn sroa_type(editor: &FunctionEditor, typ: TypeID, type_index: usize) -> (Vec<TypeID>, IndexTree<usize>) { +fn sroa_type( + editor: &FunctionEditor, + typ: TypeID, + type_index: usize, +) -> (Vec<TypeID>, IndexTree<usize>) { match &*editor.get_type(typ) { Type::Product(ts) => { let mut res_types = vec![]; @@ -157,7 +171,8 @@ fn get_callsites(editors: &Vec<FunctionEditor>) -> Vec<Vec<(FunctionID, NodeID)> .nodes .iter() .enumerate() - .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c))) { + .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c))) + { assert!(editor.is_mutable(NodeID::new(callsite)), "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument"); callsites[callee.idx()].push((caller, NodeID::new(callsite))); } @@ -178,9 +193,7 @@ fn replace_returned_value( let constant = generate_constant(editor, proj_typ); let success = editor.edit(|mut edit| { - let mut new_val = edit.add_node(Node::Constant { - id: constant, - }); + let mut new_val = edit.add_node(Node::Constant { id: constant }); of_new_call.for_each(|idx, selection| { let new_proj = edit.add_node(Node::DataProjection { data: call_node, diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 68a1b25e..e658ff88 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -985,7 +985,7 @@ pub fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> result = Some(generate_reads_edit(&mut edit, typ, val)); Ok(edit) }); - + result.unwrap() } diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index f9d666a5..8a577839 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -868,7 +868,8 @@ impl<'a> FunctionExecutionState<'a> { } } Node::Return { control: _, data } => { - let results = data.iter() + let results = data + .iter() .map(|data| self.handle_data(&ctrl_token, *data)) .collect(); break 'outer InterpreterVal::MultiReturn(results); diff --git a/juno_samples/multi_return/src/main.rs b/juno_samples/multi_return/src/main.rs index 0b3508a7..b0fd169f 100644 --- a/juno_samples/multi_return/src/main.rs +++ b/juno_samples/multi_return/src/main.rs @@ -9,19 +9,19 @@ fn main() { let a: Box<[f32]> = (1..=N).map(|i| i as f32).collect(); let arg = HerculesImmBox::from(a.as_ref()); let mut r = runner!(rolling_sum_prod); - let (sums, sum, prods, prod) = async_std::task::block_on(async { r.run(N as u64, arg.to()).await }); + let (sums, sum, prods, prod) = + async_std::task::block_on(async { r.run(N as u64, arg.to()).await }); let mut sums = HerculesMutBox::<f32>::from(sums); let mut prods = HerculesMutBox::<f32>::from(prods); - let (expected_sums, expected_sum) = a.iter() - .fold((vec![0.0], 0.0), |(mut sums, sum), v| { - let new_sum = sum + v; - sums.push(new_sum); - (sums, new_sum) - }); - let (expected_prods, expected_prod) = a.iter() - .fold((vec![1.0], 1.0), |(mut prods, prod), v| { + let (expected_sums, expected_sum) = a.iter().fold((vec![0.0], 0.0), |(mut sums, sum), v| { + let new_sum = sum + v; + sums.push(new_sum); + (sums, new_sum) + }); + let (expected_prods, expected_prod) = + a.iter().fold((vec![1.0], 1.0), |(mut prods, prod), v| { let new_prod = prod * v; prods.push(new_prod); (prods, new_prod) diff --git a/juno_samples/rodinia/backprop/src/main.rs b/juno_samples/rodinia/backprop/src/main.rs index 23f78fe4..fa80a7a5 100644 --- a/juno_samples/rodinia/backprop/src/main.rs +++ b/juno_samples/rodinia/backprop/src/main.rs @@ -37,28 +37,22 @@ fn run_backprop( let mut hidden_prev_weights = HerculesMutBox::from(hidden_prev_weights.to_vec()); let mut runner = runner!(backprop); - let ( - out_err, - hid_err, - input_weights, - input_prev_weights, - hidden_weights, - hidden_prev_weights - ) = async_std::task::block_on(async { - runner - .run( - input_n, - hidden_n, - output_n, - input_vals.to(), - input_weights.to(), - hidden_weights.to(), - target.to(), - input_prev_weights.to(), - hidden_prev_weights.to(), - ) - .await - }); + let (out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights) = + async_std::task::block_on(async { + runner + .run( + input_n, + hidden_n, + output_n, + input_vals.to(), + input_weights.to(), + hidden_weights.to(), + target.to(), + input_prev_weights.to(), + hidden_prev_weights.to(), + ) + .await + }); let mut input_weights = HerculesMutBox::from(input_weights); let mut hidden_weights = HerculesMutBox::from(hidden_weights); let mut input_prev_weights = HerculesMutBox::from(input_prev_weights); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 2c288097..84b25811 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2088,15 +2088,16 @@ fn run_pass( None => false, }; - let selection = selection_of_functions(pm, selection) - .ok_or_else(|| { - SchedulerError::PassError { - pass: "xdot".to_string(), - error: "expected coarse-grained selection (can't partially xdot a function)".to_string(), - } + let selection = + selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError { + pass: "xdot".to_string(), + error: "expected coarse-grained selection (can't partially xdot a function)" + .to_string(), })?; let mut bool_selection = vec![false; pm.functions.len()]; - selection.into_iter().for_each(|func| bool_selection[func.idx()] = true); + selection + .into_iter() + .for_each(|func| bool_selection[func.idx()] = true); pm.make_typing(); let typing = pm.typing.take().unwrap(); @@ -2733,15 +2734,16 @@ fn run_pass( None => true, }; - let selection = selection_of_functions(pm, selection) - .ok_or_else(|| { - SchedulerError::PassError { - pass: "xdot".to_string(), - error: "expected coarse-grained selection (can't partially xdot a function)".to_string(), - } + let selection = + selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError { + pass: "xdot".to_string(), + error: "expected coarse-grained selection (can't partially xdot a function)" + .to_string(), })?; let mut bool_selection = vec![false; pm.functions.len()]; - selection.into_iter().for_each(|func| bool_selection[func.idx()] = true); + selection + .into_iter() + .for_each(|func| bool_selection[func.idx()] = true); pm.make_reverse_postorders(); if force_analyses { -- GitLab