diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index 38ed1b22d2d81be971aa867857fd01e3752ecc0a..9b0a9200b6301a8928f5fda2f70b515fc1d45dde 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -307,7 +307,11 @@ impl ParameterLattice { * These functions can have that constant "inlined" - the parameter is removed * and all uses of the parameter becomes uses of the constant directly. */ -pub fn const_inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) { +pub fn const_inline( + editors: &mut [FunctionEditor], + callgraph: &CallGraph, + inline_collections: bool, +) { // Run const inlining on each function, starting at the most shallow // function first, since we want to propagate constants down the call graph. for func_id in callgraph.topo().into_iter().rev() { @@ -361,22 +365,29 @@ pub fn const_inline(editors: &mut [FunctionEditor], callgraph: &CallGraph) { let mut param_tys = edit.get_param_types().clone(); let mut decrement_index_by = 0; for idx in 0..param_tys.len() { - if let Some(node) = match param_lattice[idx] { - ParameterLattice::Top => Some(Node::Undef { ty: param_tys[idx] }), - ParameterLattice::Constant(id) => Some(Node::Constant { id }), - ParameterLattice::DynamicConstant(id, _) => { - // Rust moment. - let maybe_cons = edit.get_dynamic_constant(id).try_constant(); - if let Some(val) = maybe_cons { - Some(Node::DynamicConstant { - id: edit.add_dynamic_constant(DynamicConstant::Constant(val)), - }) - } else { - None + if (inline_collections + || edit + .get_type(param_tys[idx - decrement_index_by]) + .is_primitive()) + && let Some(node) = match param_lattice[idx] { + ParameterLattice::Top => Some(Node::Undef { + ty: param_tys[idx - decrement_index_by], + }), + ParameterLattice::Constant(id) => Some(Node::Constant { id }), + ParameterLattice::DynamicConstant(id, _) => { + // Rust moment. + let maybe_cons = edit.get_dynamic_constant(id).try_constant(); + if let Some(val) = maybe_cons { + Some(Node::DynamicConstant { + id: edit.add_dynamic_constant(DynamicConstant::Constant(val)), + }) + } else { + None + } } + _ => None, } - _ => None, - } && let Some(ids) = param_idx_to_ids.get(&idx) + && let Some(ids) = param_idx_to_ids.get(&idx) { let node = edit.add_node(node); for id in ids { diff --git a/juno_samples/rodinia/srad/benches/srad_bench.rs b/juno_samples/rodinia/srad/benches/srad_bench.rs index d327454002a6f9cabe4c40f74098570ea0d22d66..728702d9bcc18405ef291945f81413f49f5715af 100644 --- a/juno_samples/rodinia/srad/benches/srad_bench.rs +++ b/juno_samples/rodinia/srad/benches/srad_bench.rs @@ -13,8 +13,8 @@ fn srad_bench(c: &mut Criterion) { let mut r = runner!(srad); let niter = 100; let lambda = 0.5; - let nrows = 502; - let ncols = 458; + let nrows = 512; + let ncols = 512; let image = "data/image.pgm".to_string(); let Image { image: image_ori, diff --git a/juno_samples/rodinia/srad/src/gpu.sch b/juno_samples/rodinia/srad/src/gpu.sch index 149d5cd2fd71005ade5cdbb3461e08b3e65ab34f..f7885f9b2e9ed693054be3166a4ca6c285aa8700 100644 --- a/juno_samples/rodinia/srad/src/gpu.sch +++ b/juno_samples/rodinia/srad/src/gpu.sch @@ -1,23 +1,38 @@ -gvn(*); -dce(*); +macro simpl!(X) { + ccp(X); + simplify-cfg(X); + lift-dc-math(X); + gvn(X); + phi-elim(X); + dce(X); + infer-schedules(X); +} + phi-elim(*); -dce(*); +let init_loop = outline(srad@loop1); +let main_loops = outline(srad@loop2 | srad@loop3); +gpu(init_loop, main_loops, extract, compress); +simpl!(*); +const-inline[true](*); crc(*); -dce(*); slf(*); -dce(*); - -let auto = auto-outline(srad); -gpu(auto.srad); - -inline(auto.srad); -inline(auto.srad); -delete-uncalled(*); - -sroa[false](auto.srad); -dce(*); -float-collections(*); -dce(*); +write-predication(*); +simpl!(*); +predication(*); +simpl!(*); +predication(*); +simpl!(*); +fixpoint { + forkify(*); + fork-guard-elim(*); + fork-coalesce(*); +} +simpl!(*); +reduce-slf(*); +simpl!(*); +array-slf(*); +simpl!(*); +slf(*); +simpl!(*); gcm(*); - diff --git a/juno_samples/rodinia/srad/src/srad.jn b/juno_samples/rodinia/srad/src/srad.jn index 3e016a99b574c1dcde982e7277a5cbcdc1743c19..6074bf8cb12ccc2ad29c1086d7620b3ef98bcf59 100644 --- a/juno_samples/rodinia/srad/src/srad.jn +++ b/juno_samples/rodinia/srad/src/srad.jn @@ -50,10 +50,10 @@ fn srad<nrows, ncols: usize>( let varROI = (sum2 / nelems as f32) - meanROI * meanROI; let q0sqr = varROI / (meanROI * meanROI); - let dN : f32[ncols, nrows]; - let dS : f32[ncols, nrows]; - let dE : f32[ncols, nrows]; - let dW : f32[ncols, nrows]; + @dirs let dN : f32[ncols, nrows]; + @dirs let dS : f32[ncols, nrows]; + @dirs let dE : f32[ncols, nrows]; + @dirs let dW : f32[ncols, nrows]; let c : f32[ncols, nrows]; diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index a0db884492120a43d0bb8fff89e689746ef1579e..6aa85fe53689cf015497e56850ef0c197ccbdae0 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -54,14 +54,15 @@ impl Pass { pub fn is_valid_num_args(&self, num: usize) -> bool { match self { Pass::ArrayToProduct => num == 0 || num == 1, + Pass::ConstInline => num == 0 || num == 1, Pass::ForkChunk => num == 4, Pass::ForkExtend => num == 1, Pass::ForkFissionBufferize => num == 2 || num == 1, Pass::ForkInterchange => num == 2, + Pass::InterproceduralSROA => num == 0 || num == 1, Pass::Print => num == 1, Pass::Rename => num == 1, Pass::SROA => num == 0 || num == 1, - Pass::InterproceduralSROA => num == 0 || num == 1, Pass::Xdot => num == 0 || num == 1, _ => num == 0, } @@ -70,14 +71,15 @@ impl Pass { pub fn valid_arg_nums(&self) -> &'static str { match self { Pass::ArrayToProduct => "0 or 1", + Pass::ConstInline => "0 or 1", Pass::ForkChunk => "4", Pass::ForkExtend => "1", Pass::ForkFissionBufferize => "1 or 2", Pass::ForkInterchange => "2", + Pass::InterproceduralSROA => "0 or 1", Pass::Print => "1", Pass::Rename => "1", Pass::SROA => "0 or 1", - Pass::InterproceduralSROA => "0 or 1", Pass::Xdot => "0 or 1", _ => "0", } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index e049f985e0db36ae78368b8d33c01d22744fdcc6..70d8e4278169ebdbe9985e00ede161acbe05c24d 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1837,7 +1837,17 @@ fn run_pass( pm.clear_analyses(); } Pass::ConstInline => { - assert!(args.is_empty()); + let inline_collections = match args.get(0) { + Some(Value::Boolean { val }) => *val, + Some(_) => { + return Err(SchedulerError::PassError { + pass: "constInline".to_string(), + error: "expected boolean argument".to_string(), + }); + } + None => true, + }; + pm.make_callgraph(); let callgraph = pm.callgraph.take().unwrap(); @@ -1845,7 +1855,7 @@ fn run_pass( .into_iter() .map(|editor| editor.unwrap()) .collect(); - const_inline(&mut editors, &callgraph); + const_inline(&mut editors, &callgraph, inline_collections); for func in editors { changed |= func.modified();