use crate::ir; use crate::parser; use juno_utils::env::Env; use juno_utils::stringtab::StringTable; use hercules_ir::ir::{Device, Schedule}; use lrlex::DefaultLexerTypes; use lrpar::NonStreamingLexer; use std::fmt; use std::str::FromStr; type Location = ((usize, usize), (usize, usize)); pub enum ScheduleCompilerError { UndefinedMacro(String, Location), NoSuchPass(String, Location), IncorrectArguments { expected: String, actual: usize, loc: Location, }, } impl fmt::Display for ScheduleCompilerError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { ScheduleCompilerError::UndefinedMacro(name, loc) => write!( f, "({}, {}) -- ({}, {}): Undefined macro '{}'", loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name ), ScheduleCompilerError::NoSuchPass(name, loc) => write!( f, "({}, {}) -- ({}, {}): Undefined pass '{}'", loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name ), ScheduleCompilerError::IncorrectArguments { expected, actual, loc, } => write!( f, "({}, {}) -- ({}, {}): Expected {} arguments, found {}", loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, expected, actual ), } } } pub fn compile_schedule( sched: parser::OperationList, lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, ) -> Result<ir::ScheduleStmt, ScheduleCompilerError> { let mut macrostab = StringTable::new(); let mut macros = Env::new(); macros.open_scope(); Ok(ir::ScheduleStmt::Block { body: compile_ops_as_block(sched, lexer, &mut macrostab, &mut macros)?, }) } #[derive(Debug, Clone)] struct MacroInfo { params: Vec<String>, selection_name: String, def: ir::ScheduleExp, } enum Appliable { Pass(ir::Pass), // DeleteUncalled requires special handling because it changes FunctionIDs, so it is not // treated like a pass DeleteUncalled, Schedule(Schedule), Device(Device), } impl Appliable { // Tests whether a given number of arguments is a valid number of arguments for this fn is_valid_num_args(&self, num: usize) -> bool { match self { Appliable::Pass(pass) => pass.is_valid_num_args(num), // Delete uncalled, Schedules, and devices do not take arguments _ => num == 0, } } // Returns a description of the number of arguments this requires fn valid_arg_nums(&self) -> &'static str { match self { Appliable::Pass(pass) => pass.valid_arg_nums(), _ => "0", } } } impl FromStr for Appliable { type Err = String; fn from_str(s: &str) -> Result<Self, Self::Err> { match s { "array-slf" => Ok(Appliable::Pass(ir::Pass::ArraySLF)), "array-to-product" | "array-to-prod" | "a2p" => { Ok(Appliable::Pass(ir::Pass::ArrayToProduct)) } "auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)), "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)), "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)), "clean-monoid-reduces" => Ok(Appliable::Pass(ir::Pass::CleanMonoidReduces)), "const-inline" => Ok(Appliable::Pass(ir::Pass::ConstInline)), "dce" => Ok(Appliable::Pass(ir::Pass::DCE)), "delete-uncalled" => Ok(Appliable::DeleteUncalled), "float-collections" | "collections" => Ok(Appliable::Pass(ir::Pass::FloatCollections)), "fork-guard-elim" => Ok(Appliable::Pass(ir::Pass::ForkGuardElim)), "fork-split" => Ok(Appliable::Pass(ir::Pass::ForkSplit)), "forkify" => Ok(Appliable::Pass(ir::Pass::Forkify)), "gcm" | "bbs" => Ok(Appliable::Pass(ir::Pass::GCM)), "gvn" => Ok(Appliable::Pass(ir::Pass::GVN)), "infer-schedules" => Ok(Appliable::Pass(ir::Pass::InferSchedules)), "inline" => Ok(Appliable::Pass(ir::Pass::Inline)), "ip-sroa" | "interprocedural-sroa" => { Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) } "fork-fission-bufferize" | "fork-fission" => { Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize)) } "fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)), "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), "fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), "predication" => Ok(Appliable::Pass(ir::Pass::Predication)), "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)), "rename" => Ok(Appliable::Pass(ir::Pass::Rename)), "reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)), "rewrite" | "rewrite-math" | "rewrite-math-expressions" => { Ok(Appliable::Pass(ir::Pass::RewriteMathExpressions)) } "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)), "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), "unforkify" => Ok(Appliable::Pass(ir::Pass::Unforkify)), "unforkify-one" => Ok(Appliable::Pass(ir::Pass::UnforkifyOne)), "fork-coalesce" => Ok(Appliable::Pass(ir::Pass::ForkCoalesce)), "verify" => Ok(Appliable::Pass(ir::Pass::Verify)), "xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)), "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)), "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)), "print" => Ok(Appliable::Pass(ir::Pass::Print)), "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)), "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)), "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)), "monoid" | "associative" => Ok(Appliable::Schedule(Schedule::MonoidReduce)), "parallel-fork" => Ok(Appliable::Schedule(Schedule::ParallelFork)), "parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)), "no-memset" | "no-reset" => Ok(Appliable::Schedule(Schedule::NoResetConstant)), "task-parallel" | "async-call" => Ok(Appliable::Schedule(Schedule::AsyncCall)), _ => Err(s.to_string()), } } } fn compile_ops_as_block( sched: parser::OperationList, lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, macrostab: &mut StringTable, macros: &mut Env<usize, MacroInfo>, ) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> { match sched { parser::OperationList::NilStmt() => Ok(vec![]), parser::OperationList::FinalExpr(expr) => { Ok(vec![compile_exp_as_stmt(expr, lexer, macrostab, macros)?]) } parser::OperationList::ConsStmt(stmt, ops) => { let mut res = compile_stmt(stmt, lexer, macrostab, macros)?; res.extend(compile_ops_as_block(*ops, lexer, macrostab, macros)?); Ok(res) } } } fn compile_stmt( stmt: parser::Stmt, lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, macrostab: &mut StringTable, macros: &mut Env<usize, MacroInfo>, ) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> { match stmt { parser::Stmt::LetStmt { span: _, var, expr } => { let var = lexer.span_str(var).to_string(); Ok(vec![ir::ScheduleStmt::Let { var, exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, }]) } parser::Stmt::LetsStmt { span: _, vars, expr, } => { let tmp = format!("{}_tmp", macros.uniq()); Ok(std::iter::once(ir::ScheduleStmt::Let { var: tmp.clone(), exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, }) .chain(vars.into_iter().enumerate().map(|(idx, v)| { let var = lexer.span_str(v).to_string(); ir::ScheduleStmt::Let { var, exp: ir::ScheduleExp::TupleField { lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }), field: idx, }, } })) .collect()) } parser::Stmt::AssignStmt { span: _, var, rhs } => { let var = lexer.span_str(var).to_string(); Ok(vec![ir::ScheduleStmt::Assign { var, exp: compile_exp_as_expr(rhs, lexer, macrostab, macros)?, }]) } parser::Stmt::ExprStmt { span: _, exp } => { Ok(vec![compile_exp_as_stmt(exp, lexer, macrostab, macros)?]) } parser::Stmt::Fixpoint { span: _, limit, body, } => { let limit = match limit { parser::FixpointLimit::NoLimit { .. } => ir::FixpointLimit::NoLimit(), parser::FixpointLimit::StopAfter { span: _, limit } => { ir::FixpointLimit::StopAfter( lexer .span_str(limit) .parse() .expect("Parsing ensures integer"), ) } parser::FixpointLimit::PanicAfter { span: _, limit } => { ir::FixpointLimit::PanicAfter( lexer .span_str(limit) .parse() .expect("Parsing ensures integer"), ) } parser::FixpointLimit::PrintIter { .. } => ir::FixpointLimit::PrintIter(), }; macros.open_scope(); let body = compile_ops_as_block(*body, lexer, macrostab, macros); macros.close_scope(); Ok(vec![ir::ScheduleStmt::Fixpoint { body: Box::new(ir::ScheduleStmt::Block { body: body? }), limit, }]) } parser::Stmt::MacroDecl { span: _, def } => { let parser::MacroDecl { name, params, selection_name, def, } = def; let name = lexer.span_str(name).to_string(); let macro_id = macrostab.lookup_string(name); let selection_name = lexer.span_str(selection_name).to_string(); let params = params .into_iter() .map(|s| lexer.span_str(s).to_string()) .collect(); let def = compile_macro_def(*def, params, selection_name, lexer, macrostab, macros)?; macros.insert(macro_id, def); Ok(vec![]) } } } fn compile_exp_as_stmt( expr: parser::Expr, lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, macrostab: &mut StringTable, macros: &mut Env<usize, MacroInfo>, ) -> Result<ir::ScheduleStmt, ScheduleCompilerError> { match compile_expr(expr, lexer, macrostab, macros)? { ExprResult::Expr(exp) => Ok(ir::ScheduleStmt::Let { var: "_".to_string(), exp, }), ExprResult::Stmt(stm) => Ok(stm), } } fn compile_exp_as_expr( expr: parser::Expr, lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, macrostab: &mut StringTable, macros: &mut Env<usize, MacroInfo>, ) -> Result<ir::ScheduleExp, ScheduleCompilerError> { match compile_expr(expr, lexer, macrostab, macros)? { ExprResult::Expr(exp) => Ok(exp), ExprResult::Stmt(stm) => Ok(ir::ScheduleExp::Block { body: vec![stm], res: Box::new(ir::ScheduleExp::Record { fields: vec![] }), }), } } enum ExprResult { Expr(ir::ScheduleExp), Stmt(ir::ScheduleStmt), } fn compile_expr( expr: parser::Expr, lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, macrostab: &mut StringTable, macros: &mut Env<usize, MacroInfo>, ) -> Result<ExprResult, ScheduleCompilerError> { match expr { parser::Expr::Function { span, name, args, selection, } => { let func: Appliable = lexer .span_str(name) .to_lowercase() .parse() .map_err(|s| ScheduleCompilerError::NoSuchPass(s, lexer.line_col(name)))?; if !func.is_valid_num_args(args.len()) { return Err(ScheduleCompilerError::IncorrectArguments { expected: func.valid_arg_nums().to_string(), actual: args.len(), loc: lexer.line_col(span), }); } let mut arg_vals = vec![]; for arg in args { arg_vals.push(compile_exp_as_expr(arg, lexer, macrostab, macros)?); } let selection = compile_selector(selection, lexer, macrostab, macros)?; match func { Appliable::Pass(pass) => Ok(ExprResult::Expr(ir::ScheduleExp::RunPass { pass, args: arg_vals, on: selection, })), Appliable::DeleteUncalled => { Ok(ExprResult::Expr(ir::ScheduleExp::DeleteUncalled { on: selection, })) } Appliable::Schedule(sched) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddSchedule { sched, on: selection, })), Appliable::Device(device) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddDevice { device, on: selection, })), } } parser::Expr::Macro { span, name, args, selection, } => { let name_str = lexer.span_str(name).to_string(); let macro_id = macrostab.lookup_string(name_str.clone()); let Some(macro_def) = macros.lookup(¯o_id) else { return Err(ScheduleCompilerError::UndefinedMacro( name_str, lexer.line_col(name), )); }; let macro_def: MacroInfo = macro_def.clone(); let MacroInfo { params, selection_name, def, } = macro_def; if args.len() != params.len() { return Err(ScheduleCompilerError::IncorrectArguments { expected: params.len().to_string(), actual: args.len(), loc: lexer.line_col(span), }); } // To initialize the macro's arguments, we have to do this in two steps, we first // evaluate all of the arguments and store them into new variables, using names that // cannot conflict with other values in the program and then we assign those variables // to the macro's parameters; this avoids any shadowing issues, for instance: // macro![3, x] where macro!'s arguments are named x and y becomes // let #0 = 3; let #1 = x; let x = #0; let y = #1; // which has the desired semantics, as opposed to // let x = 3; let y = x; let mut arg_eval = vec![]; let mut arg_setters = vec![]; for (i, (exp, var)) in args.into_iter().zip(params.into_iter()).enumerate() { let tmp = format!("#{}", i); arg_eval.push(ir::ScheduleStmt::Let { var: tmp.clone(), exp: compile_exp_as_expr(exp, lexer, macrostab, macros)?, }); arg_setters.push(ir::ScheduleStmt::Let { var, exp: ir::ScheduleExp::Variable { var: tmp }, }); } // Set the selection arg_eval.push(ir::ScheduleStmt::Let { var: selection_name, exp: ir::ScheduleExp::Selection { selection: compile_selector(selection, lexer, macrostab, macros)?, }, }); // Combine the evaluation and initialization code arg_eval.extend(arg_setters); Ok(ExprResult::Expr(ir::ScheduleExp::Block { body: arg_eval, res: Box::new(def), })) } parser::Expr::Variable { span } => { let var = lexer.span_str(span).to_string(); Ok(ExprResult::Expr(ir::ScheduleExp::Variable { var })) } parser::Expr::Integer { span } => { let val: usize = lexer.span_str(span).parse().expect("Parsing"); Ok(ExprResult::Expr(ir::ScheduleExp::Integer { val })) } parser::Expr::Boolean { span: _, val } => { Ok(ExprResult::Expr(ir::ScheduleExp::Boolean { val })) } parser::Expr::String { span } => { let string = lexer.span_str(span); let val = string[1..string.len() - 1].to_string(); Ok(ExprResult::Expr(ir::ScheduleExp::String { val })) } parser::Expr::Field { span: _, lhs, field, } => { let field = lexer.span_str(field).to_string(); let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; Ok(ExprResult::Expr(ir::ScheduleExp::Field { collect: Box::new(lhs), field, })) } parser::Expr::BlockExpr { span: _, body } => { compile_ops_as_expr(*body, lexer, macrostab, macros) } parser::Expr::Record { span: _, fields } => { let mut result = vec![]; for (name, expr) in fields { let name = lexer.span_str(name).to_string(); let expr = compile_exp_as_expr(expr, lexer, macrostab, macros)?; result.push((name, expr)); } Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result })) } parser::Expr::SetOp { span: _, op, lhs, rhs, } => { let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; let rhs = compile_exp_as_expr(*rhs, lexer, macrostab, macros)?; Ok(ExprResult::Expr(ir::ScheduleExp::SetOp { op, lhs: Box::new(lhs), rhs: Box::new(rhs), })) } parser::Expr::Tuple { span: _, exps } => { let exprs = exps .into_iter() .map(|e| compile_exp_as_expr(e, lexer, macrostab, macros)) .fold(Ok(vec![]), |mut res, exp| { let mut res = res?; res.push(exp?); Ok(res) })?; Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs })) } parser::Expr::TupleField { span: _, lhs, field, } => { let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; let field = lexer.span_str(field).parse().expect("Parsing"); Ok(ExprResult::Expr(ir::ScheduleExp::TupleField { lhs: Box::new(lhs), field, })) } } } fn compile_ops_as_expr( mut sched: parser::OperationList, lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, macrostab: &mut StringTable, macros: &mut Env<usize, MacroInfo>, ) -> Result<ExprResult, ScheduleCompilerError> { let mut body = vec![]; loop { match sched { parser::OperationList::NilStmt() => { return Ok(ExprResult::Stmt(ir::ScheduleStmt::Block { body })); } parser::OperationList::FinalExpr(expr) => { return Ok(ExprResult::Expr(ir::ScheduleExp::Block { body, res: Box::new(compile_exp_as_expr(expr, lexer, macrostab, macros)?), })); } parser::OperationList::ConsStmt(stmt, ops) => { body.extend(compile_stmt(stmt, lexer, macrostab, macros)?); sched = *ops; } } } } fn compile_selector( sel: parser::Selector, lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, macrostab: &mut StringTable, macros: &mut Env<usize, MacroInfo>, ) -> Result<ir::Selector, ScheduleCompilerError> { match sel { parser::Selector::SelectAll { span: _ } => Ok(ir::Selector::Everything()), parser::Selector::SelectExprs { span: _, exprs } => { let mut res = vec![]; for exp in exprs { res.push(compile_exp_as_expr(exp, lexer, macrostab, macros)?); } Ok(ir::Selector::Selection(res)) } } } fn compile_macro_def( body: parser::OperationList, params: Vec<String>, selection_name: String, lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, macrostab: &mut StringTable, macros: &mut Env<usize, MacroInfo>, ) -> Result<MacroInfo, ScheduleCompilerError> { // FIXME: The body should be checked in an environment that prohibits running anything on // everything (*) and check that only local variables/parameters are used Ok(MacroInfo { params, selection_name, def: match compile_ops_as_expr(body, lexer, macrostab, macros)? { ExprResult::Expr(expr) => expr, ExprResult::Stmt(stmt) => ir::ScheduleExp::Block { body: vec![stmt], res: Box::new(ir::ScheduleExp::Record { fields: vec![] }), }, }, }) }