diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index c134a972f6d7440c25f763a843379e1724fee6ca..9f2b7ef40adc2933de4d2715c654fecceb087903 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -108,8 +108,8 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { } } - // Constructs an editor but only makes the nodes with at least one of the set of labels as - // mutable + // Constructs an editor but only makes the nodes with at least one of the + // set of labels as mutable. pub fn new_labeled( function: &'a mut Function, function_id: FunctionID, @@ -151,6 +151,41 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { } } + // Constructs an editor but makes every node immutable. + pub fn new_immutable( + function: &'a mut Function, + function_id: FunctionID, + constants: &'a RefCell<Vec<Constant>>, + dynamic_constants: &'a RefCell<Vec<DynamicConstant>>, + types: &'a RefCell<Vec<Type>>, + labels: &'a RefCell<Vec<String>>, + def_use: &ImmutableDefUseMap, + ) -> Self { + let mut_def_use = (0..function.nodes.len()) + .map(|idx| { + def_use + .get_users(NodeID::new(idx)) + .into_iter() + .map(|x| *x) + .collect() + }) + .collect(); + + let mutable_nodes = bitvec![u8, Lsb0; 0; function.nodes.len()]; + + FunctionEditor { + function, + function_id, + constants, + dynamic_constants, + types, + labels, + mut_def_use, + mutable_nodes, + modified: false, + } + } + pub fn modified(&self) -> bool { self.modified } diff --git a/juno_samples/cava/src/gpu.sch b/juno_samples/cava/src/gpu.sch index 594cbfa35602de515a98d9e58d9a67c99454c993..f440dacde5d4dc0f9d42599e19923ccba91d82ab 100644 --- a/juno_samples/cava/src/gpu.sch +++ b/juno_samples/cava/src/gpu.sch @@ -2,9 +2,8 @@ gvn(*); phi-elim(*); dce(*); -inline(*); -let out = auto-outline(*); -gpu(out.cava); +inline(denoise); +gpu(scale, demosaic, denoise, transform, gamut, tone_map, descale); ip-sroa(*); sroa(*); diff --git a/juno_samples/cava/src/main.rs b/juno_samples/cava/src/main.rs index e8a7e4e94393c8684c2a10a1e040e3be3f2600cb..8368a74f42d2a795d9febfb42ec5cf1707958772 100644 --- a/juno_samples/cava/src/main.rs +++ b/juno_samples/cava/src/main.rs @@ -16,8 +16,6 @@ use image::ImageError; use clap::Parser; -use std::mem; - juno_build::juno!("cava"); fn run_cava( diff --git a/juno_samples/edge_detection/src/gpu.sch b/juno_samples/edge_detection/src/gpu.sch index a1bf06a4f6cbc8b7a2de0193461b3d78aefea408..1e51efb9d84da1d20496d04c6aa939ffc2bc4123 100644 --- a/juno_samples/edge_detection/src/gpu.sch +++ b/juno_samples/edge_detection/src/gpu.sch @@ -2,9 +2,7 @@ gvn(*); phi-elim(*); dce(*); -inline(*); -let out = auto-outline(*); -gpu(out.edge_detection); +gpu(gaussian_smoothing, laplacian_estimate, zero_crossings, gradient, max_gradient, reject_zero_crossings); ip-sroa(*); sroa(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index f7dd102d3cef24d06411fa9b0652209e2c39e7f7..5343005e5cc8df5e4ce08100239467072bb872dd 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1270,6 +1270,7 @@ fn selection_as_set( fn build_selection<'a>( pm: &'a mut PassManager, selection: Option<Vec<CodeLocation>>, + create_editors_for_nothing_functions: bool, ) -> Vec<Option<FunctionEditor<'a>>> { // Build def uses, which are needed for the editors we'll construct pm.make_def_uses(); @@ -1284,7 +1285,21 @@ fn build_selection<'a>( .zip(def_uses.iter()) .enumerate() .map(|(idx, ((func, selected), def_use))| match selected { - FunctionSelection::Nothing() => None, + FunctionSelection::Nothing() => { + if create_editors_for_nothing_functions { + Some(FunctionEditor::new_immutable( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + )) + } else { + None + } + } FunctionSelection::Everything() => Some(FunctionEditor::new( func, FunctionID::new(idx), @@ -1336,7 +1351,7 @@ fn run_pass( let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); for (((func, fork_join_map), reduce_einsum), nodes_in_fork_joins) in - build_selection(pm, selection) + build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(reduce_einsums.iter()) @@ -1437,7 +1452,7 @@ fn run_pass( assert!(args.is_empty()); pm.make_reverse_postorders(); let reverse_postorders = pm.reverse_postorders.take().unwrap(); - for (func, reverse_postorder) in build_selection(pm, selection) + for (func, reverse_postorder) in build_selection(pm, selection, false) .into_iter() .zip(reverse_postorders.iter()) { @@ -1452,7 +1467,7 @@ fn run_pass( } Pass::CRC => { assert!(args.is_empty()); - for func in build_selection(pm, selection) { + for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; @@ -1464,7 +1479,7 @@ fn run_pass( } Pass::DCE => { assert!(args.is_empty()); - for func in build_selection(pm, selection) { + for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; @@ -1487,7 +1502,7 @@ fn run_pass( let devices = pm.devices.take().unwrap(); // Modify the selection to include callers of selected functions. - let mut editors = build_selection(pm, selection) + let mut editors = build_selection(pm, selection, false) .into_iter() .filter_map(|editor| editor.map(|editor| (editor.func_id(), editor))) .collect(); @@ -1504,7 +1519,7 @@ fn run_pass( assert!(args.is_empty()); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection) + for (func, fork_join_map) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) { @@ -1536,10 +1551,11 @@ fn run_pass( pm.make_reduce_cycles(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); - for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection.clone()) - .into_iter() - .zip(fork_join_maps.iter()) - .zip(reduce_cycles.iter()) + for ((func, fork_join_map), reduce_cycles) in + build_selection(pm, selection.clone(), false) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) { let Some(mut func) = func else { continue; @@ -1630,7 +1646,7 @@ fn run_pass( let loops = pm.loops.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); for (((func, fork_join_map), loop_nest), control_subgraph) in - build_selection(pm, selection.clone()) + build_selection(pm, selection.clone(), false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) @@ -1735,7 +1751,7 @@ fn run_pass( } Pass::GVN => { assert!(args.is_empty()); - for func in build_selection(pm, selection) { + for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; @@ -1754,7 +1770,7 @@ fn run_pass( let reduce_cycles = pm.reduce_cycles.take().unwrap(); let no_reset_constants = pm.no_reset_constants.take().unwrap(); for (((func, fork_join_map), reduce_cycles), no_reset_constants) in - build_selection(pm, selection) + build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(reduce_cycles.iter()) @@ -1775,17 +1791,13 @@ fn run_pass( } Pass::Inline => { assert!(args.is_empty()); - if let Some(_) = selection { - return Err(SchedulerError::PassError { - pass: "inline".to_string(), - error: "must be applied to the entire module (currently)".to_string(), - }); - } - pm.make_callgraph(); let callgraph = pm.callgraph.take().unwrap(); - let mut editors = build_editors(pm); + let mut editors: Vec<_> = build_selection(pm, selection, true) + .into_iter() + .map(|editor| editor.unwrap()) + .collect(); inline(&mut editors, &callgraph); for func in editors { @@ -1816,7 +1828,7 @@ fn run_pass( } Pass::LiftDCMath => { assert!(args.is_empty()); - for func in build_selection(pm, selection) { + for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; @@ -1896,7 +1908,7 @@ fn run_pass( } Pass::PhiElim => { assert!(args.is_empty()); - for func in build_selection(pm, selection) { + for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; @@ -1911,7 +1923,7 @@ fn run_pass( pm.make_typing(); let typing = pm.typing.take().unwrap(); - for (func, types) in build_selection(pm, selection) + for (func, types) in build_selection(pm, selection, false) .into_iter() .zip(typing.iter()) { @@ -1931,7 +1943,7 @@ fn run_pass( let reverse_postorders = pm.reverse_postorders.take().unwrap(); let typing = pm.typing.take().unwrap(); - for ((func, reverse_postorder), types) in build_selection(pm, selection) + for ((func, reverse_postorder), types) in build_selection(pm, selection, false) .into_iter() .zip(reverse_postorders.iter()) .zip(typing.iter()) @@ -1952,7 +1964,7 @@ fn run_pass( let fork_join_maps = pm.fork_join_maps.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); - for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection) + for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(reduce_cycles.iter()) @@ -1973,7 +1985,7 @@ fn run_pass( let reverse_postorders = pm.reverse_postorders.take().unwrap(); let typing = pm.typing.take().unwrap(); - for ((func, reverse_postorder), types) in build_selection(pm, selection) + for ((func, reverse_postorder), types) in build_selection(pm, selection, false) .into_iter() .zip(reverse_postorders.iter()) .zip(typing.iter()) @@ -1995,7 +2007,7 @@ fn run_pass( let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); - for ((func, fork_join_map), loop_tree) in build_selection(pm, selection) + for ((func, fork_join_map), loop_tree) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) @@ -2033,7 +2045,7 @@ fn run_pass( assert_eq!(*guarded_flag, true); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection) + for (func, fork_join_map) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) { @@ -2050,7 +2062,7 @@ fn run_pass( assert!(args.is_empty()); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection) + for (func, fork_join_map) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) { @@ -2073,7 +2085,7 @@ fn run_pass( let loops = pm.loops.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); for (((func, fork_join_map), loop_nest), control_subgraph) in - build_selection(pm, selection) + build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) @@ -2090,7 +2102,7 @@ fn run_pass( } Pass::WritePredication => { assert!(args.is_empty()); - for func in build_selection(pm, selection) { + for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; };