diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index d2813388e0e7a1d7bd1696ffbb641e629096e2c2..7bc2083cf07c063eccda5855fa4fed3bfca91f87 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -1,24 +1,11 @@ use juno_build::JunoCompiler; fn main() { - #[cfg(not(feature = "cuda"))] - { - JunoCompiler::new() - .file_in_src("matmul.jn") - .unwrap() - .schedule_in_src("cpu.sch") - .unwrap() - .build() - .unwrap(); - } - #[cfg(feature = "cuda")] - { - JunoCompiler::new() - .file_in_src("matmul.jn") - .unwrap() - .schedule_in_src("gpu.sch") - .unwrap() - .build() - .unwrap(); - } + JunoCompiler::new() + .file_in_src("matmul.jn") + .unwrap() + .schedule_in_src("matmul.sch") + .unwrap() + .build() + .unwrap(); } diff --git a/juno_samples/matmul/src/cpu.sch b/juno_samples/matmul/src/cpu.sch deleted file mode 100644 index 69f1811d08a5691de6dc10fa86f0720e416fb85a..0000000000000000000000000000000000000000 --- a/juno_samples/matmul/src/cpu.sch +++ /dev/null @@ -1,61 +0,0 @@ -macro optimize!(X) { - gvn(X); - phi-elim(X); - dce(X); - ip-sroa(X); - sroa(X); - dce(X); - gvn(X); - phi-elim(X); - dce(X); -} - -macro codegen-prep!(X) { - optimize!(X); - gcm(X); - float-collections(X); - dce(X); - gcm(X); -} - -macro forkify!(X) { - fixpoint { - forkify(X); - fork-guard-elim(X); - } -} - -macro fork-tile { - fork-tile[n, 0, false, true](X); -} - -macro parallelize!(X) { - parallel-fork(X); - parallel-reduce(X); -} - -macro unforkify!(X) { - fork-split(X); - unforkify(X); -} - -optimize!(*); -forkify!(*); -associative(matmul@outer); - -// Parallelize by computing output array as 16 chunks -let par = matmul@outer \ matmul@inner; -fork-tile; -let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par); -parallelize!(outer \ inner); - -let body = outline(inner); -cpu(body); - -// Tile for cache, assuming 64B cache lines -fork-tile; -let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body); - -reduce-slf(inner); -unforkify!(body); -codegen-prep!(*); diff --git a/juno_samples/matmul/src/gpu.sch b/juno_samples/matmul/src/gpu.sch deleted file mode 100644 index 76808149e7b99dce97e43f0e936536d3d13b7417..0000000000000000000000000000000000000000 --- a/juno_samples/matmul/src/gpu.sch +++ /dev/null @@ -1,26 +0,0 @@ -phi-elim(*); - -forkify(*); -fork-guard-elim(*); -dce(*); - -fixpoint { - reduce-slf(*); - slf(*); - infer-schedules(*); -} -fork-coalesce(*); -infer-schedules(*); -dce(*); -rewrite(*); -fixpoint { - simplify-cfg(*); - dce(*); -} - -ip-sroa(*); -sroa(*); -dce(*); - -float-collections(*); -gcm(*); diff --git a/juno_samples/matmul/src/matmul.sch b/juno_samples/matmul/src/matmul.sch new file mode 100644 index 0000000000000000000000000000000000000000..306997f58eb217f9ce301dc18e418c412e6df621 --- /dev/null +++ b/juno_samples/matmul/src/matmul.sch @@ -0,0 +1,81 @@ +macro optimize!(X) { + gvn(X); + phi-elim(X); + dce(X); + ip-sroa(X); + sroa(X); + dce(X); + gvn(X); + phi-elim(X); + dce(X); +} + +macro codegen-prep!(X) { + optimize!(X); + gcm(X); + float-collections(X); + dce(X); + gcm(X); +} + +macro forkify!(X) { + fixpoint { + forkify(X); + fork-guard-elim(X); + } +} + +macro fork-tile { + fork-tile[n, 0, false, true](X); +} + +macro parallelize!(X) { + parallel-fork(X); + parallel-reduce(X); +} + +macro unforkify!(X) { + fork-split(X); + unforkify(X); +} + +optimize!(*); +forkify!(*); + +if feature("cuda") { + fixpoint { + reduce-slf(*); + slf(*); + infer-schedules(*); + } + fork-coalesce(*); + infer-schedules(*); + dce(*); + rewrite(*); + fixpoint { + simplify-cfg(*); + dce(*); + } + + optimize!(*); + codegen-prep!(*); +} else { + associative(matmul@outer); + + // Parallelize by computing output array as 16 chunks + let par = matmul@outer \ matmul@inner; + fork-tile; + let (outer, inner, _) = fork-reshape[[1, 3], [0], [2]](par); + parallelize!(outer \ inner); + + let body = outline(inner); + cpu(body); + + // Tile for cache, assuming 64B cache lines + fork-tile; + let (outer, inner) = fork-reshape[[0, 2, 4, 1, 3], [5]](body); + + reduce-slf(inner); + unforkify!(body); + codegen-prep!(*); +} diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 9d020c64ccef3b9c0a79694876c5b0ace606f938..8b68ed7111de9009458dbdc26b511e66ee13db58 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -22,6 +22,7 @@ pub enum ScheduleCompilerError { actual: usize, loc: Location, }, + SemanticError(String, Location), } impl fmt::Display for ScheduleCompilerError { @@ -46,6 +47,11 @@ impl fmt::Display for ScheduleCompilerError { "({}, {}) -- ({}, {}): Expected {} arguments, found {}", loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, expected, actual ), + ScheduleCompilerError::SemanticError(msg, loc) => write!( + f, + "({}, {}) -- ({}, {}): {}", + loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, msg, + ), } } } @@ -76,6 +82,8 @@ enum Appliable { // DeleteUncalled requires special handling because it changes FunctionIDs, so it is not // treated like a pass DeleteUncalled, + // Test whether a feature is enabled + Feature, Schedule(Schedule), Device(Device), } @@ -85,6 +93,8 @@ impl Appliable { fn is_valid_num_args(&self, num: usize) -> bool { match self { Appliable::Pass(pass) => pass.is_valid_num_args(num), + // Testing whether a feature is enabled takes the feature instead of a selection, so it + // has 0 arguments // Delete uncalled, Schedules, and devices do not take arguments _ => num == 0, } @@ -158,6 +168,8 @@ impl FromStr for Appliable { "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)), "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)), + "feature" => Ok(Appliable::Feature), + "print" => Ok(Appliable::Pass(ir::Pass::Print)), "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)), @@ -275,6 +287,35 @@ fn compile_stmt( limit, }]) } + parser::Stmt::IfThenElse { + span: _, + cond, + thn, + els, + } => { + let cond = compile_exp_as_expr(cond, lexer, macrostab, macros)?; + + macros.open_scope(); + let thn = ir::ScheduleStmt::Block { + body: compile_ops_as_block(*thn, lexer, macrostab, macros)?, + }; + macros.close_scope(); + + macros.open_scope(); + let els = match els { + Some(els) => ir::ScheduleStmt::Block { + body: compile_ops_as_block(*els, lexer, macrostab, macros)?, + }, + None => ir::ScheduleStmt::Block { body: vec![] }, + }; + macros.close_scope(); + + Ok(vec![ir::ScheduleStmt::IfThenElse { + cond, + thn: Box::new(thn), + els: Box::new(els), + }]) + } parser::Stmt::MacroDecl { span: _, def } => { let parser::MacroDecl { name, @@ -380,6 +421,17 @@ fn compile_expr( on: selection, })) } + Appliable::Feature => match selection { + ir::Selector::Selection(mut args) if args.len() == 1 => { + Ok(ExprResult::Expr(ir::ScheduleExp::Feature { + feature: Box::new(args.pop().unwrap()), + })) + } + _ => Err(ScheduleCompilerError::SemanticError( + "feature requires exactly one argument as its selection".to_string(), + lexer.line_col(span), + )), + }, Appliable::Schedule(sched) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddSchedule { sched, on: selection, diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index ab1495b816c99452560d03c0addf77a5aec18974..bacb4142c62df5b0bf974bcb1703400dd3caa02c 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -121,6 +121,9 @@ pub enum ScheduleExp { DeleteUncalled { on: Selector, }, + Feature { + feature: Box<ScheduleExp>, + }, Record { fields: Vec<(String, ScheduleExp)>, }, @@ -180,4 +183,9 @@ pub enum ScheduleStmt { device: Device, on: Selector, }, + IfThenElse { + cond: ScheduleExp, + thn: Box<ScheduleStmt>, + els: Box<ScheduleStmt>, + }, } diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l index af154fce3d489b7b078607e52888d76d08f8e5fe..1f4f8723de4d950920d503a11c9145ad1d15d24a 100644 --- a/juno_scheduler/src/lang.l +++ b/juno_scheduler/src/lang.l @@ -20,12 +20,15 @@ \. "." apply "apply" +else "else" fixpoint "fixpoint" +if "if" let "let" macro "macro_keyword" on "on" set "set" target "target" +then "then" true "true" false "false" diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 451f035b8122ac1a694567327a44776c9326fff2..55c82b9dc368e7cade9500fb852c9d48320be7d2 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -27,6 +27,10 @@ Stmt -> Stmt { Stmt::ExprStmt { span: $span, exp: $1 } } | 'fixpoint' FixpointLimit '{' Schedule '}' { Stmt::Fixpoint { span: $span, limit: $2, body: Box::new($4) } } + | 'if' Expr '{' Schedule '}' + { Stmt::IfThenElse { span: $span, cond: $2, thn: Box::new($4), els: None } } + | 'if' Expr '{' Schedule '}' 'else' '{' Schedule '}' + { Stmt::IfThenElse { span: $span, cond: $2, thn: Box::new($4), els: Some(Box::new($8)) } } | MacroDecl { Stmt::MacroDecl { span: $span, def: $1 } } ; @@ -163,6 +167,7 @@ pub enum Stmt { AssignStmt { span: Span, var: Span, rhs: Expr }, ExprStmt { span: Span, exp: Expr }, Fixpoint { span: Span, limit: FixpointLimit, body: Box<OperationList> }, + IfThenElse { span: Span, cond: Expr, thn: Box<OperationList>, els: Option<Box<OperationList>> }, MacroDecl { span: Span, def: MacroDecl }, } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 62bdaf739b69b52a0b540c6a602e1e5beb632786..4c981bd98ebd8eb605352cd3a78128c72e0e4ae9 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -16,6 +16,7 @@ use juno_utils::stringtab::StringTable; use std::cell::RefCell; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::env; use std::fmt; use std::fs::File; use std::io::Write; @@ -1199,6 +1200,23 @@ fn schedule_interpret( // were made Ok(i > 1) } + ScheduleStmt::IfThenElse { cond, thn, els } => { + let (cond, modified) = interp_expr(pm, cond, stringtab, env, functions)?; + let Value::Boolean { val: cond } = cond else { + return Err(SchedulerError::SemanticError( + "Condition must be a boolean value".to_string(), + )); + }; + let changed = schedule_interpret( + pm, + if cond { &*thn } else { &*els }, + stringtab, + env, + functions, + )?; + + Ok(modified || changed) + } ScheduleStmt::Block { body } => { let mut modified = false; env.open_scope(); @@ -1443,6 +1461,23 @@ fn interp_expr( changed, )) } + ScheduleExp::Feature { feature } => { + let (feature, modified) = interp_expr(pm, &*feature, stringtab, env, functions)?; + let Value::String { val } = feature else { + return Err(SchedulerError::SemanticError( + "Feature expects a single string argument (instead of a selection)".to_string(), + )); + }; + // To test for features, the scheduler needs to be invoked from a build script so that + // Cargo provides the enabled features via environment variables + let key = "CARGO_FEATURE_".to_string() + &val.to_uppercase().replace("-", "_"); + Ok(( + Value::Boolean { + val: env::var(key).is_ok(), + }, + modified, + )) + } ScheduleExp::Record { fields } => { let mut result = HashMap::new(); let mut changed = false;