use crate::ir;
use crate::parser;
use juno_utils::env::Env;
use juno_utils::stringtab::StringTable;
use hercules_ir::ir::{Device, Schedule};
use lrlex::DefaultLexerTypes;
use lrpar::NonStreamingLexer;
use std::fmt;
use std::str::FromStr;
type Location = ((usize, usize), (usize, usize));
pub enum ScheduleCompilerError {
UndefinedMacro(String, Location),
NoSuchPass(String, Location),
IncorrectArguments {
expected: String,
actual: usize,
loc: Location,
},
}
impl fmt::Display for ScheduleCompilerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ScheduleCompilerError::UndefinedMacro(name, loc) => write!(
f,
"({}, {}) -- ({}, {}): Undefined macro '{}'",
loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name
),
ScheduleCompilerError::NoSuchPass(name, loc) => write!(
f,
"({}, {}) -- ({}, {}): Undefined pass '{}'",
loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name
),
ScheduleCompilerError::IncorrectArguments {
expected,
actual,
loc,
} => write!(
f,
"({}, {}) -- ({}, {}): Expected {} arguments, found {}",
loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, expected, actual
),
}
}
}
pub fn compile_schedule(
sched: parser::OperationList,
lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
) -> Result<ir::ScheduleStmt, ScheduleCompilerError> {
let mut macrostab = StringTable::new();
let mut macros = Env::new();
macros.open_scope();
Ok(ir::ScheduleStmt::Block {
body: compile_ops_as_block(sched, lexer, &mut macrostab, &mut macros)?,
})
}
#[derive(Debug, Clone)]
struct MacroInfo {
params: Vec<String>,
selection_name: String,
def: ir::ScheduleExp,
}
enum Appliable {
Pass(ir::Pass),
// DeleteUncalled requires special handling because it changes FunctionIDs, so it is not
// treated like a pass
DeleteUncalled,
Schedule(Schedule),
Device(Device),
}
impl Appliable {
// Tests whether a given number of arguments is a valid number of arguments for this
fn is_valid_num_args(&self, num: usize) -> bool {
match self {
Appliable::Pass(pass) => pass.is_valid_num_args(num),
// Delete uncalled, Schedules, and devices do not take arguments
_ => num == 0,
}
}
// Returns a description of the number of arguments this requires
fn valid_arg_nums(&self) -> &'static str {
match self {
Appliable::Pass(pass) => pass.valid_arg_nums(),
_ => "0",
}
}
}
impl FromStr for Appliable {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"array-slf" => Ok(Appliable::Pass(ir::Pass::ArraySLF)),
"array-to-product" | "array-to-prod" | "a2p" => {
Ok(Appliable::Pass(ir::Pass::ArrayToProduct))
}
"auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)),
"ccp" => Ok(Appliable::Pass(ir::Pass::CCP)),
"crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)),
"clean-monoid-reduces" => Ok(Appliable::Pass(ir::Pass::CleanMonoidReduces)),
"const-inline" => Ok(Appliable::Pass(ir::Pass::ConstInline)),
"dce" => Ok(Appliable::Pass(ir::Pass::DCE)),
"delete-uncalled" => Ok(Appliable::DeleteUncalled),
"float-collections" | "collections" => Ok(Appliable::Pass(ir::Pass::FloatCollections)),
"fork-guard-elim" => Ok(Appliable::Pass(ir::Pass::ForkGuardElim)),
"fork-split" => Ok(Appliable::Pass(ir::Pass::ForkSplit)),
"forkify" => Ok(Appliable::Pass(ir::Pass::Forkify)),
"gcm" | "bbs" => Ok(Appliable::Pass(ir::Pass::GCM)),
"gvn" => Ok(Appliable::Pass(ir::Pass::GVN)),
"infer-schedules" => Ok(Appliable::Pass(ir::Pass::InferSchedules)),
"inline" => Ok(Appliable::Pass(ir::Pass::Inline)),
"ip-sroa" | "interprocedural-sroa" => {
Ok(Appliable::Pass(ir::Pass::InterproceduralSROA))
}
"fork-fission-bufferize" | "fork-fission" => {
Ok(Appliable::Pass(ir::Pass::ForkFissionBufferize))
}
"fork-dim-merge" => Ok(Appliable::Pass(ir::Pass::ForkDimMerge)),
"fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)),
"fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)),
"fork-extend" => Ok(Appliable::Pass(ir::Pass::ForkExtend)),
"fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)),
"fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)),
"fork-reshape" => Ok(Appliable::Pass(ir::Pass::ForkReshape)),
"lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)),
"loop-bound-canon" => Ok(Appliable::Pass(ir::Pass::LoopBoundCanon)),
"outline" => Ok(Appliable::Pass(ir::Pass::Outline)),
"phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)),
"predication" => Ok(Appliable::Pass(ir::Pass::Predication)),
"reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)),
"rename" => Ok(Appliable::Pass(ir::Pass::Rename)),
"reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)),
"rewrite" | "rewrite-math" | "rewrite-math-expressions" => {
Ok(Appliable::Pass(ir::Pass::RewriteMathExpressions))
}
"simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)),
"slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)),
"sroa" => Ok(Appliable::Pass(ir::Pass::SROA)),
"unforkify" => Ok(Appliable::Pass(ir::Pass::Unforkify)),
"unforkify-one" => Ok(Appliable::Pass(ir::Pass::UnforkifyOne)),
"fork-coalesce" => Ok(Appliable::Pass(ir::Pass::ForkCoalesce)),
"verify" => Ok(Appliable::Pass(ir::Pass::Verify)),
"xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)),
"serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)),
"write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)),
"print" => Ok(Appliable::Pass(ir::Pass::Print)),
"cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)),
"gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)),
"host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)),
"monoid" | "associative" => Ok(Appliable::Schedule(Schedule::MonoidReduce)),
"parallel-fork" => Ok(Appliable::Schedule(Schedule::ParallelFork)),
"parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)),
"no-memset" | "no-reset" => Ok(Appliable::Schedule(Schedule::NoResetConstant)),
"task-parallel" | "async-call" => Ok(Appliable::Schedule(Schedule::AsyncCall)),
_ => Err(s.to_string()),
}
}
}
fn compile_ops_as_block(
sched: parser::OperationList,
lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
macrostab: &mut StringTable,
macros: &mut Env<usize, MacroInfo>,
) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> {
match sched {
parser::OperationList::NilStmt() => Ok(vec![]),
parser::OperationList::FinalExpr(expr) => {
Ok(vec![compile_exp_as_stmt(expr, lexer, macrostab, macros)?])
}
parser::OperationList::ConsStmt(stmt, ops) => {
let mut res = compile_stmt(stmt, lexer, macrostab, macros)?;
res.extend(compile_ops_as_block(*ops, lexer, macrostab, macros)?);
Ok(res)
}
}
}
fn compile_stmt(
stmt: parser::Stmt,
lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
macrostab: &mut StringTable,
macros: &mut Env<usize, MacroInfo>,
) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> {
match stmt {
parser::Stmt::LetStmt { span: _, var, expr } => {
let var = lexer.span_str(var).to_string();
Ok(vec![ir::ScheduleStmt::Let {
var,
exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?,
}])
}
parser::Stmt::LetsStmt {
span: _,
vars,
expr,
} => {
let tmp = format!("{}_tmp", macros.uniq());
Ok(std::iter::once(ir::ScheduleStmt::Let {
var: tmp.clone(),
exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?,
})
.chain(vars.into_iter().enumerate().map(|(idx, v)| {
let var = lexer.span_str(v).to_string();
ir::ScheduleStmt::Let {
var,
exp: ir::ScheduleExp::TupleField {
lhs: Box::new(ir::ScheduleExp::Variable { var: tmp.clone() }),
field: idx,
},
}
}))
.collect())
}
parser::Stmt::AssignStmt { span: _, var, rhs } => {
let var = lexer.span_str(var).to_string();
Ok(vec![ir::ScheduleStmt::Assign {
var,
exp: compile_exp_as_expr(rhs, lexer, macrostab, macros)?,
}])
}
parser::Stmt::ExprStmt { span: _, exp } => {
Ok(vec![compile_exp_as_stmt(exp, lexer, macrostab, macros)?])
}
parser::Stmt::Fixpoint {
span: _,
limit,
body,
} => {
let limit = match limit {
parser::FixpointLimit::NoLimit { .. } => ir::FixpointLimit::NoLimit(),
parser::FixpointLimit::StopAfter { span: _, limit } => {
ir::FixpointLimit::StopAfter(
lexer
.span_str(limit)
.parse()
.expect("Parsing ensures integer"),
)
}
parser::FixpointLimit::PanicAfter { span: _, limit } => {
ir::FixpointLimit::PanicAfter(
lexer
.span_str(limit)
.parse()
.expect("Parsing ensures integer"),
)
}
parser::FixpointLimit::PrintIter { .. } => ir::FixpointLimit::PrintIter(),
};
macros.open_scope();
let body = compile_ops_as_block(*body, lexer, macrostab, macros);
macros.close_scope();
Ok(vec![ir::ScheduleStmt::Fixpoint {
body: Box::new(ir::ScheduleStmt::Block { body: body? }),
limit,
}])
}
parser::Stmt::MacroDecl { span: _, def } => {
let parser::MacroDecl {
name,
params,
selection_name,
def,
} = def;
let name = lexer.span_str(name).to_string();
let macro_id = macrostab.lookup_string(name);
let selection_name = lexer.span_str(selection_name).to_string();
let params = params
.into_iter()
.map(|s| lexer.span_str(s).to_string())
.collect();
let def = compile_macro_def(*def, params, selection_name, lexer, macrostab, macros)?;
macros.insert(macro_id, def);
Ok(vec![])
}
}
}
fn compile_exp_as_stmt(
expr: parser::Expr,
lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
macrostab: &mut StringTable,
macros: &mut Env<usize, MacroInfo>,
) -> Result<ir::ScheduleStmt, ScheduleCompilerError> {
match compile_expr(expr, lexer, macrostab, macros)? {
ExprResult::Expr(exp) => Ok(ir::ScheduleStmt::Let {
var: "_".to_string(),
exp,
}),
ExprResult::Stmt(stm) => Ok(stm),
}
}
fn compile_exp_as_expr(
expr: parser::Expr,
lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
macrostab: &mut StringTable,
macros: &mut Env<usize, MacroInfo>,
) -> Result<ir::ScheduleExp, ScheduleCompilerError> {
match compile_expr(expr, lexer, macrostab, macros)? {
ExprResult::Expr(exp) => Ok(exp),
ExprResult::Stmt(stm) => Ok(ir::ScheduleExp::Block {
body: vec![stm],
res: Box::new(ir::ScheduleExp::Record { fields: vec![] }),
}),
}
}
enum ExprResult {
Expr(ir::ScheduleExp),
Stmt(ir::ScheduleStmt),
}
fn compile_expr(
expr: parser::Expr,
lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
macrostab: &mut StringTable,
macros: &mut Env<usize, MacroInfo>,
) -> Result<ExprResult, ScheduleCompilerError> {
match expr {
parser::Expr::Function {
span,
name,
args,
selection,
} => {
let func: Appliable = lexer
.span_str(name)
.to_lowercase()
.parse()
.map_err(|s| ScheduleCompilerError::NoSuchPass(s, lexer.line_col(name)))?;
if !func.is_valid_num_args(args.len()) {
return Err(ScheduleCompilerError::IncorrectArguments {
expected: func.valid_arg_nums().to_string(),
actual: args.len(),
loc: lexer.line_col(span),
});
}
let mut arg_vals = vec![];
for arg in args {
arg_vals.push(compile_exp_as_expr(arg, lexer, macrostab, macros)?);
}
let selection = compile_selector(selection, lexer, macrostab, macros)?;
match func {
Appliable::Pass(pass) => Ok(ExprResult::Expr(ir::ScheduleExp::RunPass {
pass,
args: arg_vals,
on: selection,
})),
Appliable::DeleteUncalled => {
Ok(ExprResult::Expr(ir::ScheduleExp::DeleteUncalled {
on: selection,
}))
}
Appliable::Schedule(sched) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddSchedule {
sched,
on: selection,
})),
Appliable::Device(device) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddDevice {
device,
on: selection,
})),
}
}
parser::Expr::Macro {
span,
name,
args,
selection,
} => {
let name_str = lexer.span_str(name).to_string();
let macro_id = macrostab.lookup_string(name_str.clone());
let Some(macro_def) = macros.lookup(¯o_id) else {
return Err(ScheduleCompilerError::UndefinedMacro(
name_str,
lexer.line_col(name),
));
};
let macro_def: MacroInfo = macro_def.clone();
let MacroInfo {
params,
selection_name,
def,
} = macro_def;
if args.len() != params.len() {
return Err(ScheduleCompilerError::IncorrectArguments {
expected: params.len().to_string(),
actual: args.len(),
loc: lexer.line_col(span),
});
}
// To initialize the macro's arguments, we have to do this in two steps, we first
// evaluate all of the arguments and store them into new variables, using names that
// cannot conflict with other values in the program and then we assign those variables
// to the macro's parameters; this avoids any shadowing issues, for instance:
// macro![3, x] where macro!'s arguments are named x and y becomes
// let #0 = 3; let #1 = x; let x = #0; let y = #1;
// which has the desired semantics, as opposed to
// let x = 3; let y = x;
let mut arg_eval = vec![];
let mut arg_setters = vec![];
for (i, (exp, var)) in args.into_iter().zip(params.into_iter()).enumerate() {
let tmp = format!("#{}", i);
arg_eval.push(ir::ScheduleStmt::Let {
var: tmp.clone(),
exp: compile_exp_as_expr(exp, lexer, macrostab, macros)?,
});
arg_setters.push(ir::ScheduleStmt::Let {
var,
exp: ir::ScheduleExp::Variable { var: tmp },
});
}
// Set the selection
arg_eval.push(ir::ScheduleStmt::Let {
var: selection_name,
exp: ir::ScheduleExp::Selection {
selection: compile_selector(selection, lexer, macrostab, macros)?,
},
});
// Combine the evaluation and initialization code
arg_eval.extend(arg_setters);
Ok(ExprResult::Expr(ir::ScheduleExp::Block {
body: arg_eval,
res: Box::new(def),
}))
}
parser::Expr::Variable { span } => {
let var = lexer.span_str(span).to_string();
Ok(ExprResult::Expr(ir::ScheduleExp::Variable { var }))
}
parser::Expr::Integer { span } => {
let val: usize = lexer.span_str(span).parse().expect("Parsing");
Ok(ExprResult::Expr(ir::ScheduleExp::Integer { val }))
}
parser::Expr::Boolean { span: _, val } => {
Ok(ExprResult::Expr(ir::ScheduleExp::Boolean { val }))
}
parser::Expr::String { span } => {
let string = lexer.span_str(span);
let val = string[1..string.len() - 1].to_string();
Ok(ExprResult::Expr(ir::ScheduleExp::String { val }))
}
parser::Expr::Field {
span: _,
lhs,
field,
} => {
let field = lexer.span_str(field).to_string();
let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
Ok(ExprResult::Expr(ir::ScheduleExp::Field {
collect: Box::new(lhs),
field,
}))
}
parser::Expr::BlockExpr { span: _, body } => {
compile_ops_as_expr(*body, lexer, macrostab, macros)
}
parser::Expr::Record { span: _, fields } => {
let mut result = vec![];
for (name, expr) in fields {
let name = lexer.span_str(name).to_string();
let expr = compile_exp_as_expr(expr, lexer, macrostab, macros)?;
result.push((name, expr));
}
Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result }))
}
parser::Expr::SetOp {
span: _,
op,
lhs,
rhs,
} => {
let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
let rhs = compile_exp_as_expr(*rhs, lexer, macrostab, macros)?;
Ok(ExprResult::Expr(ir::ScheduleExp::SetOp {
op,
lhs: Box::new(lhs),
rhs: Box::new(rhs),
}))
}
parser::Expr::Tuple { span: _, exps } => {
let exprs = exps
.into_iter()
.map(|e| compile_exp_as_expr(e, lexer, macrostab, macros))
.fold(Ok(vec![]), |mut res, exp| {
let mut res = res?;
res.push(exp?);
Ok(res)
})?;
Ok(ExprResult::Expr(ir::ScheduleExp::Tuple { exprs }))
}
parser::Expr::TupleField {
span: _,
lhs,
field,
} => {
let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?;
let field = lexer.span_str(field).parse().expect("Parsing");
Ok(ExprResult::Expr(ir::ScheduleExp::TupleField {
lhs: Box::new(lhs),
field,
}))
}
}
}
fn compile_ops_as_expr(
mut sched: parser::OperationList,
lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
macrostab: &mut StringTable,
macros: &mut Env<usize, MacroInfo>,
) -> Result<ExprResult, ScheduleCompilerError> {
let mut body = vec![];
loop {
match sched {
parser::OperationList::NilStmt() => {
return Ok(ExprResult::Stmt(ir::ScheduleStmt::Block { body }));
}
parser::OperationList::FinalExpr(expr) => {
return Ok(ExprResult::Expr(ir::ScheduleExp::Block {
body,
res: Box::new(compile_exp_as_expr(expr, lexer, macrostab, macros)?),
}));
}
parser::OperationList::ConsStmt(stmt, ops) => {
body.extend(compile_stmt(stmt, lexer, macrostab, macros)?);
sched = *ops;
}
}
}
}
fn compile_selector(
sel: parser::Selector,
lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
macrostab: &mut StringTable,
macros: &mut Env<usize, MacroInfo>,
) -> Result<ir::Selector, ScheduleCompilerError> {
match sel {
parser::Selector::SelectAll { span: _ } => Ok(ir::Selector::Everything()),
parser::Selector::SelectExprs { span: _, exprs } => {
let mut res = vec![];
for exp in exprs {
res.push(compile_exp_as_expr(exp, lexer, macrostab, macros)?);
}
Ok(ir::Selector::Selection(res))
}
}
}
fn compile_macro_def(
body: parser::OperationList,
params: Vec<String>,
selection_name: String,
lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>,
macrostab: &mut StringTable,
macros: &mut Env<usize, MacroInfo>,
) -> Result<MacroInfo, ScheduleCompilerError> {
// FIXME: The body should be checked in an environment that prohibits running anything on
// everything (*) and check that only local variables/parameters are used
Ok(MacroInfo {
params,
selection_name,
def: match compile_ops_as_expr(body, lexer, macrostab, macros)? {
ExprResult::Expr(expr) => expr,
ExprResult::Stmt(stmt) => ir::ScheduleExp::Block {
body: vec![stmt],
res: Box::new(ir::ScheduleExp::Record { fields: vec![] }),
},
},
})
}