diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch index 7577aebba60a03f3b21de8c54b83be083c557e7c..69f1811d08a5691de6dc10fa86f0720e416fb85a 100644 --- a/juno_samples/matmul/src/cpu.sch +++ b/juno_samples/matmul/src/cpu.sch @@ -46,16 +46,15 @@ associative(matmul@outer); // Parallelize by computing output array as 16 chunks let par = matmul@outer \ matmul@inner; fork-tile; -let res = fork-reshape[[1, 3], [0], [2]](par); let outer = res.0; let inner = res.1; +let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par); parallelize!(outer \ inner); let body = outline(inner); cpu(body); // Tile for cache, assuming 64B cache lines -fork-split(body); fork-tile; -let res = fork-reshape[[0, 2, 4, 1, 3], [5]](body); let outer = res.0; let inner = res.1; +let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body); reduce-slf(inner); unforkify!(body); diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 0f473f0d3b96e56b796e6436109335adf5d5f038..51ba3699f53b29985866e8c4fda25c7ce7e5bd6e 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -207,6 +207,24 @@ fn compile_stmt( exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, }]) } + parser::Stmt::LetsStmt { span: _, vars, expr } => { + let tmp = format!("{}_tmp", macros.uniq()); + Ok(std::iter::once(ir::ScheduleStmt::Let { + var: tmp.clone(), + exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, + }).chain(vars.into_iter().enumerate() + .map(|(idx, v)| { + let var = lexer.span_str(v).to_string(); + ir::ScheduleStmt::Let { + var, + exp: ir::ScheduleExp::TupleField { + lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }), + field: idx, + } + } + }) + ).collect()) + } parser::Stmt::AssignStmt { span: _, var, rhs } => { let var = lexer.span_str(var).to_string(); Ok(vec![ir::ScheduleStmt::Assign { diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 69d50014a57865807bd9f82311a8852c21e648da..451f035b8122ac1a694567327a44776c9326fff2 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -19,6 +19,8 @@ Schedule -> OperationList Stmt -> Stmt : 'let' 'ID' '=' Expr ';' { Stmt::LetStmt { span: $span, var: span_of_tok($2), expr: $4 } } + | 'let' '(' Ids ')' '=' Expr ';' + { Stmt::LetsStmt { span: $span, vars: rev($3), expr: $6 } } | 'ID' '=' Expr ';' { Stmt::AssignStmt { span: $span, var: span_of_tok($1), rhs: $3 } } | Expr ';' @@ -157,6 +159,7 @@ pub enum OperationList { pub enum Stmt { LetStmt { span: Span, var: Span, expr: Expr }, + LetsStmt { span: Span, vars: Vec<Span>, expr: Expr }, AssignStmt { span: Span, var: Span, rhs: Expr }, ExprStmt { span: Span, exp: Expr }, Fixpoint { span: Span, limit: FixpointLimit, body: Box<OperationList> },