From e936110be361a33ecd4e6f3756ca94a44dd9a341 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Fri, 14 Feb 2025 19:12:18 -0600 Subject: [PATCH] Some Cava optimization --- hercules_cg/src/rt.rs | 2 +- hercules_ir/src/dot.rs | 13 ++- hercules_ir/src/einsum.rs | 31 +++++++ hercules_opt/src/fork_transforms.rs | 120 ++++++++++++++++++++++--- hercules_opt/src/utils.rs | 7 ++ juno_samples/cava/build.rs | 2 + juno_samples/cava/src/cava.jn | 97 ++++++++++---------- juno_samples/cava/src/cpu.sch | 134 ++++++++++++++++++++++++++++ juno_samples/cava/src/image_proc.rs | 37 +++++--- juno_samples/cava/src/main.rs | 83 ++++++++++------- juno_scheduler/src/compile.rs | 3 + juno_scheduler/src/ir.rs | 12 ++- juno_scheduler/src/lang.l | 4 +- juno_scheduler/src/pm.rs | 24 +++++ 14 files changed, 457 insertions(+), 112 deletions(-) create mode 100644 juno_samples/cava/src/cpu.sch diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 090253d4..f758fed3 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -533,7 +533,7 @@ impl<'a> RTContext<'a> { { write!(block, "backing_{}.byte_add(", device.name())?; self.codegen_dynamic_constant(offset, block)?; - write!(block, "), ")? + write!(block, " as usize), ")? } for dc in dynamic_constants { self.codegen_dynamic_constant(*dc, block)?; diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 9c6c5f17..0e084085 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -24,8 +24,8 @@ pub fn xdot_module( bbs: Option<&Vec<BasicBlocks>>, ) { let mut tmp_path = temp_dir(); - let mut rng = rand::thread_rng(); - let num: u64 = rng.gen(); + let mut rng = rand::rng(); + let num: u64 = rng.random(); tmp_path.push(format!("hercules_dot_{}.dot", num)); let mut file = File::create(&tmp_path).expect("PANIC: Unable to open output file."); let mut contents = String::new(); @@ -43,11 +43,11 @@ pub fn xdot_module( .expect("PANIC: Unable to generate output file contents."); file.write_all(contents.as_bytes()) .expect("PANIC: Unable to write output file contents."); + println!("Graphviz written to: {}", tmp_path.display()); Command::new("xdot") .args([&tmp_path]) .output() .expect("PANIC: Couldn't execute xdot. Is xdot installed?"); - println!("Graphviz written to: {}", tmp_path.display()); } /* @@ -108,6 +108,13 @@ pub fn write_dot<W: Write>( for node_id in (0..function.nodes.len()).map(NodeID::new) { let node = &function.nodes[node_id.idx()]; + let skip = node.is_constant() + || node.is_dynamic_constant() + || node.is_undef() + || node.is_parameter(); + if skip { + continue; + } let dst_control = node.is_control(); for u in def_use::get_uses(&node).as_ref() { let src_control = function.nodes[u.idx()].is_control(); diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index 5dd0fe5d..b222e1bc 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -34,6 +34,7 @@ pub enum MathExpr { Unary(UnaryOperator, MathID), Binary(BinaryOperator, MathID, MathID), Ternary(TernaryOperator, MathID, MathID, MathID), + IntrinsicFunc(Intrinsic, Box<[MathID]>), } pub type MathEnv = Vec<MathExpr>; @@ -224,6 +225,16 @@ impl<'a> EinsumContext<'a> { let third = self.compute_math_expr(third); MathExpr::Ternary(op, first, second, third) } + Node::IntrinsicCall { + intrinsic, + ref args, + } => { + let args = args + .into_iter() + .map(|id| self.compute_math_expr(*id)) + .collect(); + MathExpr::IntrinsicFunc(intrinsic, args) + } Node::Read { collect, ref indices, @@ -322,6 +333,14 @@ impl<'a> EinsumContext<'a> { let third = self.substitute_new_dims(third); self.intern_math_expr(MathExpr::Ternary(op, first, second, third)) } + MathExpr::IntrinsicFunc(intrinsic, ref args) => { + let args = args + .clone() + .iter() + .map(|id| self.substitute_new_dims(*id)) + .collect(); + self.intern_math_expr(MathExpr::IntrinsicFunc(intrinsic, args)) + } _ => id, } } @@ -355,6 +374,9 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> { stack.push(second); stack.push(third); } + MathExpr::IntrinsicFunc(_, ref args) => { + stack.extend(args); + } } } set @@ -412,5 +434,14 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) { debug_print_math_expr(third, env); print!(")"); } + MathExpr::IntrinsicFunc(intrinsic, ref args) => { + print!("{}(", intrinsic.lower_case_name()); + debug_print_math_expr(id, env); + for arg in args { + print!(", "); + debug_print_math_expr(*arg, env); + } + print!(")"); + } } } diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index 72b5716f..c32a517e 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -912,16 +912,16 @@ pub fn chunk_fork_unguarded( let outer = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx + 1, tile_size); new_factors[dim_idx] = edit.add_dynamic_constant(outer); - + let new_fork = Node::Fork { control: old_control, factors: new_factors.into(), }; let new_fork = edit.add_node(new_fork); - + edit = edit.replace_all_uses(fork, new_fork)?; edit.sub_edit(fork, new_fork); - + for (tid, node) in fork_users { let Node::ThreadID { control: _, @@ -945,7 +945,7 @@ pub fn chunk_fork_unguarded( dimension: tid_dim + 1, }; let tile_tid = edit.add_node(tile_tid); - + let tile_size = edit.add_node(Node::DynamicConstant { id: tile_size }); let mul = edit.add_node(Node::Binary { left: tid, @@ -965,23 +965,23 @@ pub fn chunk_fork_unguarded( edit = edit.delete_node(fork)?; Ok(edit) }); - }, + } TileOrder::TileOuter => { editor.edit(|mut edit| { let inner = DynamicConstant::div(new_factors[dim_idx], tile_size); new_factors.insert(dim_idx, tile_size); - let inner_dc_id = edit.add_dynamic_constant(inner); + let inner_dc_id = edit.add_dynamic_constant(inner); new_factors[dim_idx + 1] = inner_dc_id; - + let new_fork = Node::Fork { control: old_control, factors: new_factors.into(), }; let new_fork = edit.add_node(new_fork); - + edit = edit.replace_all_uses(fork, new_fork)?; edit.sub_edit(fork, new_fork); - + for (tid, node) in fork_users { let Node::ThreadID { control: _, @@ -1000,13 +1000,12 @@ pub fn chunk_fork_unguarded( edit.sub_edit(tid, new_tid); edit = edit.delete_node(tid)?; } else if tid_dim == dim_idx { - let tile_tid = Node::ThreadID { control: new_fork, dimension: tid_dim, }; let tile_tid = edit.add_node(tile_tid); - let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id } ); + let inner_dc = edit.add_node(Node::DynamicConstant { id: inner_dc_id }); let mul = edit.add_node(Node::Binary { left: tid, right: inner_dc, @@ -1027,7 +1026,6 @@ pub fn chunk_fork_unguarded( }); } } - } pub fn merge_all_fork_dims(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { @@ -1350,3 +1348,101 @@ pub fn fork_unroll( Ok(edit) }) } + +/* + * Looks for fork-joins that are next to each other, not inter-dependent, and + * have the same bounds. These fork-joins can be fused, pooling together all + * their reductions. + */ +pub fn fork_fusion_all_forks( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) { + for (fork, join) in fork_join_map { + if editor.is_mutable(*fork) + && fork_fusion(editor, *fork, *join, fork_join_map, nodes_in_fork_joins) + { + break; + } + } +} + +/* + * Tries to fuse a given fork join with the immediately following fork-join, if + * it exists. + */ +fn fork_fusion( + editor: &mut FunctionEditor, + top_fork: NodeID, + top_join: NodeID, + fork_join_map: &HashMap<NodeID, NodeID>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) -> bool { + let nodes = &editor.func().nodes; + // Rust operator precedence is not such that these can be put in one big + // let-else statement. Sad! + let Some(bottom_fork) = editor + .get_users(top_join) + .filter(|id| nodes[id.idx()].is_control()) + .next() + else { + return false; + }; + let Some(bottom_join) = fork_join_map.get(&bottom_fork) else { + return false; + }; + let (_, top_factors) = nodes[top_fork.idx()].try_fork().unwrap(); + let (bottom_fork_pred, bottom_factors) = nodes[bottom_fork.idx()].try_fork().unwrap(); + assert_eq!(bottom_fork_pred, top_join); + let top_join_pred = nodes[top_join.idx()].try_join().unwrap(); + let bottom_join_pred = nodes[bottom_join.idx()].try_join().unwrap(); + + // The fork factors must be identical. + if top_factors != bottom_factors { + return false; + } + + // Check that no iterated users of the top's reduces are in the bottom fork- + // join (iteration stops at a phi or reduce outside the bottom fork-join). + for reduce in editor + .get_users(top_join) + .filter(|id| nodes[id.idx()].is_reduce()) + { + let mut visited = HashSet::new(); + visited.insert(reduce); + let mut workset = vec![reduce]; + while let Some(pop) = workset.pop() { + for u in editor.get_users(pop) { + if nodes_in_fork_joins[&bottom_fork].contains(&u) { + return false; + } else if (nodes[u.idx()].is_phi() || nodes[u.idx()].is_reduce()) + && !nodes_in_fork_joins[&top_fork].contains(&u) + { + } else if !visited.contains(&u) && !nodes_in_fork_joins[&top_fork].contains(&u) { + visited.insert(u); + workset.push(u); + } + } + } + } + + // Perform the fusion. + editor.edit(|mut edit| { + if bottom_join_pred != bottom_fork { + // If there is control flow in the bottom fork-join, stitch it into + // the top fork-join. + edit = edit.replace_all_uses_where(bottom_fork, top_join_pred, |id| { + nodes_in_fork_joins[&bottom_fork].contains(id) + })?; + edit = + edit.replace_all_uses_where(top_join_pred, bottom_join_pred, |id| *id == top_join)?; + } + // Replace the bottom fork and join with the top fork and join. + edit = edit.replace_all_uses(bottom_fork, top_fork)?; + edit = edit.replace_all_uses(*bottom_join, top_join)?; + edit = edit.delete_node(bottom_fork)?; + edit = edit.delete_node(*bottom_join)?; + Ok(edit) + }) +} diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index c5e0d934..3f12ad7c 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -474,6 +474,13 @@ pub fn materialize_simple_einsum_expr( third, }) } + MathExpr::IntrinsicFunc(intrinsic, ref args) => { + let args = args + .into_iter() + .map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs)) + .collect(); + edit.add_node(Node::IntrinsicCall { intrinsic, args }) + } _ => panic!(), } } diff --git a/juno_samples/cava/build.rs b/juno_samples/cava/build.rs index 1b6dddf4..ff7e2b6b 100644 --- a/juno_samples/cava/build.rs +++ b/juno_samples/cava/build.rs @@ -13,6 +13,8 @@ fn main() { JunoCompiler::new() .file_in_src("cava.jn") .unwrap() + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/juno_samples/cava/src/cava.jn b/juno_samples/cava/src/cava.jn index 95a73f5b..8158bf0a 100644 --- a/juno_samples/cava/src/cava.jn +++ b/juno_samples/cava/src/cava.jn @@ -1,47 +1,35 @@ fn medianMatrix<a : number, rows, cols : usize>(m : a[rows, cols]) -> a { - const n : usize = rows * cols; - - let tmp : a[rows * cols]; - for i = 0 to rows * cols { - tmp[i] = m[i / cols, i % cols]; - } - - for i = 0 to n - 1 { - for j = 0 to n - i - 1 { - if tmp[j] > tmp[j+1] { - let t : a = tmp[j]; - tmp[j] = tmp[j+1]; - tmp[j+1] = t; + @median { + const n : usize = rows * cols; + + @tmp let tmp : a[rows * cols]; + for i = 0 to rows * cols { + tmp[i] = m[i / cols, i % cols]; + } + + @medianOuter for i = 0 to n - 1 { + for j = 0 to n - i - 1 { + if tmp[j] > tmp[j+1] { + let t : a = tmp[j]; + tmp[j] = tmp[j+1]; + tmp[j+1] = t; + } } } + + return tmp[n / 2]; } - - return tmp[n / 2]; } const CHAN : u64 = 3; fn scale<row : usize, col : usize>(input : u8[CHAN, row, col]) -> f32[CHAN, row, col] { - let res : f32[CHAN, row, col]; + @res1 let res : f32[CHAN, row, col]; for chan = 0 to CHAN { for r = 0 to row { for c = 0 to col { - res[chan, r, c] = input[chan, r, c] as f32 * 1.0 / 255; - } - } - } - - return res; -} - -fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] { - let res : u8[CHAN, row, col]; - - for chan = 0 to CHAN { - for r = 0 to row { - for c = 0 to col { - res[chan, r, c] = min!::<f32>(max!::<f32>(input[chan, r, c] * 255, 0), 255) as u8; + res[chan, r, c] = input[chan, r, c] as f32 / 255.0; } } } @@ -50,7 +38,7 @@ fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, ro } fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] { - let res : f32[CHAN, row, col]; + @res2 let res : f32[CHAN, row, col]; for r = 1 to row-1 { for c = 1 to col-1 { @@ -102,13 +90,13 @@ fn demosaic<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, } fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, row, col] { - let res : f32[CHAN, row, col]; + @res let res : f32[CHAN, row, col]; for chan = 0 to CHAN { for r = 0 to row { for c = 0 to col { if r >= 1 && r < row - 1 && c >= 1 && c < col - 1 { - let filter : f32[3][3]; // same as [3, 3] + @filter let filter : f32[3][3]; // same as [3, 3] for i = 0 to 3 by 1 { for j = 0 to 3 by 1 { filter[i, j] = input[chan, r + i - 1, c + j - 1]; @@ -129,7 +117,7 @@ fn denoise<row : usize, col : usize>(input : f32[CHAN, row, col]) -> f32[CHAN, r fn transform<row : usize, col : usize> (input : f32[CHAN, row, col], tstw_trans : f32[CHAN, CHAN]) -> f32[CHAN, row, col] { - let result : f32[CHAN, row, col]; + @res let result : f32[CHAN, row, col]; for chan = 0 to CHAN { for r = 0 to row { @@ -152,11 +140,11 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( weights : f32[num_ctrl_pts, CHAN], coefs : f32[4, CHAN] ) -> f32[CHAN, row, col] { - let result : f32[CHAN, row, col]; - let l2_dist : f32[num_ctrl_pts]; + @res let result : f32[CHAN, row, col]; for r = 0 to row { for c = 0 to col { + @l2 let l2_dist : f32[num_ctrl_pts]; for cp = 0 to num_ctrl_pts { let v1 = input[0, r, c] - ctrl_pts[cp, 0]; let v2 = input[1, r, c] - ctrl_pts[cp, 1]; @@ -164,8 +152,8 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( let v = v1 * v1 + v2 * v2 + v3 * v3; l2_dist[cp] = sqrt!::<f32>(v); } - - for chan = 0 to CHAN { + + @channel_loop for chan = 0 to CHAN { let chan_val : f32 = 0.0; for cp = 0 to num_ctrl_pts { chan_val += l2_dist[cp] * weights[cp, chan]; @@ -184,7 +172,7 @@ fn gamut<row : usize, col : usize, num_ctrl_pts : usize>( fn tone_map<row : usize, col:usize> (input : f32[CHAN, row, col], tone_map : f32[256, CHAN]) -> f32[CHAN, row, col] { - let result : f32[CHAN, row, col]; + @res1 let result : f32[CHAN, row, col]; for chan = 0 to CHAN { for r = 0 to row { @@ -198,6 +186,20 @@ fn tone_map<row : usize, col:usize> return result; } +fn descale<row : usize, col : usize>(input : f32[CHAN, row, col]) -> u8[CHAN, row, col] { + @res2 let res : u8[CHAN, row, col]; + + for chan = 0 to CHAN { + for r = 0 to row { + for c = 0 to col { + res[chan, r, c] = min!::<f32>(max!::<f32>(input[chan, r, c] * 255, 0), 255) as u8; + } + } + } + + return res; +} + #[entry] fn cava<r, c, num_ctrl_pts : usize>( input : u8[CHAN, r, c], @@ -207,11 +209,12 @@ fn cava<r, c, num_ctrl_pts : usize>( coefs : f32[4, CHAN], tonemap : f32[256, CHAN], ) -> u8[CHAN, r, c] { - let scaled = scale::<r, c>(input); - let demosc = demosaic::<r, c>(scaled); - let denosd = denoise::<r, c>(demosc); - let transf = transform::<r, c>(denosd, TsTw); - let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs); - let tonemd = tone_map::<r, c>(gamutd, tonemap); - return descale::<r, c>(tonemd); + @fuse1 let scaled = scale::<r, c>(input); + @fuse1 let demosc = demosaic::<r, c>(scaled); + @fuse2 let denosd = denoise::<r, c>(demosc); + @fuse3 let transf = transform::<r, c>(denosd, TsTw); + @fuse4 let gamutd = gamut::<r, c, num_ctrl_pts>(transf, ctrl_pts, weights, coefs); + @fuse5 let tonemd = tone_map::<r, c>(gamutd, tonemap); + @fuse5 let dscald = descale::<r, c>(tonemd); + return dscald; } diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch new file mode 100644 index 00000000..623dcfbc --- /dev/null +++ b/juno_samples/cava/src/cpu.sch @@ -0,0 +1,134 @@ +macro simpl!(X) { + ccp(X); + simplify-cfg(X); + lift-dc-math(X); + gvn(X); + phi-elim(X); + dce(X); + infer-schedules(X); +} + +simpl!(*); + +let fuse1 = outline(cava@fuse1); +inline(fuse1); + +let fuse2 = outline(cava@fuse2); +inline(fuse2); + +let fuse3 = outline(cava@fuse3); +inline(fuse3); + +let fuse4 = outline(cava@fuse4); +inline(fuse4); + +let fuse5 = outline(cava@fuse5); +inline(fuse5); + +ip-sroa(*); +sroa(*); +simpl!(*); + +no-memset(fuse1@res1); +no-memset(fuse1@res2); +fixpoint { + forkify(fuse1); + fork-guard-elim(fuse1); + fork-coalesce(fuse1); +} +simpl!(fuse1); +array-slf(fuse1); + +inline(fuse2); +no-memset(fuse2@res); +no-memset(fuse2@filter); +no-memset(fuse2@tmp); +fixpoint { + forkify(fuse2); + fork-guard-elim(fuse2); + fork-coalesce(fuse2); +} +simpl!(fuse2); +predication(fuse2); +simpl!(fuse2); + +let median = outline(fuse2@median); +fork-unroll(median@medianOuter); +simpl!(median); +fixpoint { + forkify(median); + fork-guard-elim(median); +} +simpl!(median); +fixpoint { + fork-unroll(median); +} +ccp(median); +array-to-product(median); +sroa(median); +phi-elim(median); +predication(median); +simpl!(median); + +inline(fuse2); +ip-sroa(*); +sroa(*); +array-slf(fuse2); +write-predication(fuse2); +simpl!(fuse2); + +no-memset(fuse3@res); +fixpoint { + forkify(fuse3); + fork-guard-elim(fuse3); + fork-coalesce(fuse3); +} +simpl!(fuse3); + +no-memset(fuse4@res); +no-memset(fuse4@l2); +fixpoint { + forkify(fuse4); + fork-guard-elim(fuse4); + fork-coalesce(fuse4); +} +simpl!(fuse4); +fork-unroll(fuse4@channel_loop); +simpl!(fuse4); +fixpoint { + fork-fusion(fuse4@channel_loop); +} +simpl!(fuse4); +array-slf(fuse4); +simpl!(fuse4); + +no-memset(fuse5@res1); +no-memset(fuse5@res2); +fixpoint { + forkify(fuse5); + fork-guard-elim(fuse5); + fork-coalesce(fuse5); +} +simpl!(fuse5); +array-slf(fuse5); +simpl!(fuse5); + +delete-uncalled(*); +simpl!(*); +xdot[true](*); + +simpl!(fuse1); +unforkify(fuse1); +fork-split(fuse2); +unforkify(fuse2); +fork-split(fuse3); +unforkify(fuse3); +fork-split(fuse4); +unforkify(fuse4); +fork-split(fuse5); +unforkify(fuse5); + +simpl!(*); + +delete-uncalled(*); +gcm(*); diff --git a/juno_samples/cava/src/image_proc.rs b/juno_samples/cava/src/image_proc.rs index 94a7c14f..7115bac8 100644 --- a/juno_samples/cava/src/image_proc.rs +++ b/juno_samples/cava/src/image_proc.rs @@ -17,41 +17,54 @@ fn read_bin_u32(f: &mut File) -> Result<u32, Error> { Ok(u32::from_le_bytes(buffer)) } -pub fn read_raw<P: AsRef<Path>>(image: P) -> Result<RawImage, Error> { +pub fn read_raw<P: AsRef<Path>>( + image: P, + crop_rows: Option<usize>, + crop_cols: Option<usize>, +) -> Result<RawImage, Error> { let mut file = File::open(image)?; - let rows = read_bin_u32(&mut file)? as usize; - let cols = read_bin_u32(&mut file)? as usize; + let in_rows = read_bin_u32(&mut file)? as usize; + let in_cols = read_bin_u32(&mut file)? as usize; + let out_rows = crop_rows.unwrap_or(in_rows); + let out_cols = crop_cols.unwrap_or(in_cols); let chans = read_bin_u32(&mut file)? as usize; assert!( chans == CHAN, "Channel size read from the binary file doesn't match the default value" ); + assert!(in_rows >= out_rows); + assert!(in_cols >= out_cols); - let size: usize = rows * cols * CHAN; - let mut buffer: Vec<u8> = vec![0; size]; + let in_size: usize = in_rows * in_cols * CHAN; + let mut buffer: Vec<u8> = vec![0; in_size]; // The file has pixels is a 3D array with dimensions height, width, channel file.read_exact(&mut buffer)?; let chunked_channels = buffer.chunks(CHAN).collect::<Vec<_>>(); - let chunked = chunked_channels.chunks(cols).collect::<Vec<_>>(); + let chunked = chunked_channels.chunks(in_cols).collect::<Vec<_>>(); let input: &[&[&[u8]]] = chunked.as_slice(); // We need the image in a 3D array with dimensions channel, height, width - let mut pixels: Vec<u8> = vec![0; size]; - let mut pixels_columns = pixels.chunks_mut(cols).collect::<Vec<_>>(); - let mut pixels_chunked = pixels_columns.chunks_mut(rows).collect::<Vec<_>>(); + let out_size: usize = out_rows * out_cols * CHAN; + let mut pixels: Vec<u8> = vec![0; out_size]; + let mut pixels_columns = pixels.chunks_mut(out_cols).collect::<Vec<_>>(); + let mut pixels_chunked = pixels_columns.chunks_mut(out_rows).collect::<Vec<_>>(); let result: &mut [&mut [&mut [u8]]] = pixels_chunked.as_mut_slice(); - for row in 0..rows { - for col in 0..cols { + for row in 0..out_rows { + for col in 0..out_cols { for chan in 0..CHAN { result[chan][row][col] = input[row][col][chan]; } } } - Ok(RawImage { rows, cols, pixels }) + Ok(RawImage { + rows: out_rows, + cols: out_cols, + pixels, + }) } pub fn extern_image(rows: usize, cols: usize, image: &[u8]) -> RgbImage { diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index b4a0f6fd..142ed703 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -19,24 +19,31 @@ juno_build::juno!("cava"); // Individual lifetimes are not needed in this example but should probably be generated for // flexibility async fn safe_run<'a, 'b: 'a, 'c: 'a, 'd: 'a, 'e: 'a, 'f: 'a, 'g: 'a>( - runner: &'a mut HerculesRunner_cava, r: u64, c: u64, num_ctrl_pts: u64, - input: &'b HerculesImmBox<'b, u8>, tstw: &'c HerculesImmBox<'c, f32>, - ctrl_pts: &'d HerculesImmBox<'d, f32>, weights: &'e HerculesImmBox<'e, f32>, - coefs: &'f HerculesImmBox<'f, f32>, tonemap: &'g HerculesImmBox<'g, f32>, + runner: &'a mut HerculesRunner_cava, + r: u64, + c: u64, + num_ctrl_pts: u64, + input: &'b HerculesImmBox<'b, u8>, + tstw: &'c HerculesImmBox<'c, f32>, + ctrl_pts: &'d HerculesImmBox<'d, f32>, + weights: &'e HerculesImmBox<'e, f32>, + coefs: &'f HerculesImmBox<'f, f32>, + tonemap: &'g HerculesImmBox<'g, f32>, ) -> HerculesMutBox<'a, u8> { HerculesMutBox::from( - runner.run( - r, - c, - num_ctrl_pts, - input.to(), - tstw.to(), - ctrl_pts.to(), - weights.to(), - coefs.to(), - tonemap.to() - ) - .await + runner + .run( + r, + c, + num_ctrl_pts, + input.to(), + tstw.to(), + ctrl_pts.to(), + weights.to(), + coefs.to(), + tonemap.to(), + ) + .await, ) } @@ -68,16 +75,17 @@ fn run_cava( let mut r = runner!(cava); async_std::task::block_on(async { - safe_run(&mut r, - rows as u64, - cols as u64, - num_ctrl_pts as u64, - &image, - &tstw, - &ctrl_pts, - &weights, - &coefs, - &tonemap, + safe_run( + &mut r, + rows as u64, + cols as u64, + num_ctrl_pts as u64, + &image, + &tstw, + &ctrl_pts, + &weights, + &coefs, + &tonemap, ) .await }) @@ -139,6 +147,10 @@ struct CavaInputs { #[clap(long = "output-verify", value_name = "PATH")] output_verify: Option<String>, cam_model: String, + #[clap(short, long)] + crop_rows: Option<usize>, + #[clap(short, long)] + crop_cols: Option<usize>, } fn cava_harness(args: CavaInputs) { @@ -148,8 +160,11 @@ fn cava_harness(args: CavaInputs) { verify, output_verify, cam_model, + crop_rows, + crop_cols, } = args; - let RawImage { rows, cols, pixels } = read_raw(input).expect("Error loading image"); + let RawImage { rows, cols, pixels } = + read_raw(input, crop_rows, crop_cols).expect("Error loading image"); let CamModel { tstw, num_ctrl_pts, @@ -159,6 +174,10 @@ fn cava_harness(args: CavaInputs) { tonemap, } = load_cam_model(cam_model, CHAN).expect("Error loading camera model"); + println!( + "Running cava with {} rows, {} columns, and {} control points.", + rows, cols, num_ctrl_pts + ); let result = run_cava( rows, cols, @@ -224,19 +243,21 @@ fn cava_test_small() { verify: true, output_verify: None, cam_model: "cam_models/NikonD7000".to_string(), + crop_rows: Some(144), + crop_cols: Some(192), }); } -// Disabling the larger test because of how long it takes -/* #[test] -fn cava_test() { +#[ignore] +fn cava_test_full() { cava_harness(CavaInputs { input: "examples/raw_tulips.bin".to_string(), output: None, verify: true, output_verify: None, cam_model: "cam_models/NikonD7000".to_string(), + crop_rows: None, + crop_cols: None, }); } -*/ diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 0652f8f2..7c92e00d 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -129,6 +129,7 @@ impl FromStr for Appliable { "fork-interchange" => Ok(Appliable::Pass(ir::Pass::ForkInterchange)), "fork-chunk" | "fork-tile" => Ok(Appliable::Pass(ir::Pass::ForkChunk)), "fork-unroll" | "unroll" => Ok(Appliable::Pass(ir::Pass::ForkUnroll)), + "fork-fusion" | "fusion" => Ok(Appliable::Pass(ir::Pass::ForkFusion)), "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), @@ -146,6 +147,8 @@ impl FromStr for Appliable { "serialize" => Ok(Appliable::Pass(ir::Pass::Serialize)), "write-predication" => Ok(Appliable::Pass(ir::Pass::WritePredication)), + "print" => Ok(Appliable::Pass(ir::Pass::Print)), + "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)), "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)), "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)), diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index e92f1d37..205cd70b 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -13,6 +13,7 @@ pub enum Pass { ForkCoalesce, ForkDimMerge, ForkFissionBufferize, + ForkFusion, ForkGuardElim, ForkInterchange, ForkSplit, @@ -27,6 +28,7 @@ pub enum Pass { Outline, PhiElim, Predication, + Print, ReduceSLF, Rename, ReuseProducts, @@ -44,11 +46,12 @@ impl Pass { pub fn is_valid_num_args(&self, num: usize) -> bool { match self { Pass::ArrayToProduct => num == 0 || num == 1, - Pass::Rename => num == 1, - Pass::Xdot => num == 0 || num == 1, Pass::ForkChunk => num == 4, Pass::ForkFissionBufferize => num == 2, Pass::ForkInterchange => num == 2, + Pass::Print => num == 1, + Pass::Rename => num == 1, + Pass::Xdot => num == 0 || num == 1, _ => num == 0, } } @@ -56,11 +59,12 @@ impl Pass { pub fn valid_arg_nums(&self) -> &'static str { match self { Pass::ArrayToProduct => "0 or 1", - Pass::Rename => "1", - Pass::Xdot => "0 or 1", Pass::ForkChunk => "4", Pass::ForkFissionBufferize => "2", Pass::ForkInterchange => "2", + Pass::Print => "1", + Pass::Rename => "1", + Pass::Xdot => "0 or 1", _ => "0", } } diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l index afe596b2..ca75276e 100644 --- a/juno_scheduler/src/lang.l +++ b/juno_scheduler/src/lang.l @@ -43,8 +43,8 @@ panic[\t \n\r]+after "panic_after" print[\t \n\r]+iter "print_iter" stop[\t \n\r]+after "stop_after" -[a-zA-Z][a-zA-Z0-9_\-]*! "MACRO" -[a-zA-Z][a-zA-Z0-9_\-]* "ID" +[a-zA-Z_][a-zA-Z0-9_\-]*! "MACRO" +[a-zA-Z_][a-zA-Z0-9_\-]* "ID" [0-9]+ "INT" \"[a-zA-Z0-9_\-\s\.]*\" "STRING" diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 8c2ecb19..a4783a93 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2526,6 +2526,27 @@ fn run_pass( fields: new_fork_joins, }; } + Pass::ForkFusion => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + for ((func, fork_join_map), nodes_in_fork_joins) in + build_selection(pm, selection, false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + fork_fusion_all_forks(&mut func, fork_join_map, nodes_in_fork_joins); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps(); @@ -2657,6 +2678,9 @@ fn run_pass( // Put BasicBlocks back, since it's needed for Codegen. pm.bbs = bbs; } + Pass::Print => { + println!("{:?}", args.get(0)); + } } println!("Ran Pass: {:?}", pass); -- GitLab