diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 3f53ea406a4b371ecb5eccbd90d75a8ec867b4b6..28215366d8960ded5426f05e01eec2e4f936b19f 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -5,6 +5,7 @@ use std::iter::zip; use self::hercules_ir::dataflow::*; use self::hercules_ir::ir::*; +use self::hercules_ir::def_use::get_uses; use crate::*; @@ -180,6 +181,39 @@ pub fn ccp(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>) { ccp_flow_function(inputs, node_id, editor) }); + // Check the results of CCP. In particular we assert that for every reachable node (except for + // region and phi nodes) all of its uses are also reachable. For phi nodes we check this + // property only on live branches and for regions we check that at least one of its + // predecessors is reachable if it is reachable + for node_idx in 0..result.len() { + if !result[node_idx].is_reachable() { + continue; + } + match &editor.func().nodes[node_idx] { + Node::Region { preds } => { + assert!(preds.iter().any(|n| result[n.idx()].is_reachable())) + } + Node::Phi { control, data } => { + assert!(result[control.idx()].is_reachable()); + let region_preds = + if let Node::Region { preds } = &editor.func().nodes[control.idx()] { + preds + } else { + panic!("A phi's control input must be a region node.") + }; + assert!(zip(region_preds.iter(), data.iter()).all(|(region, data)| { + !result[region.idx()].is_reachable() || result[data.idx()].is_reachable() + })); + } + _ => { + assert!(get_uses(&editor.func().nodes[node_idx]) + .as_ref() + .iter() + .all(|n| result[n.idx()].is_reachable())); + } + } + } + // Step 2: propagate constants. For each node that was found to have a constant value, we // create a node for that constant value, replace uses of the original node with the constant, // and finally delete the original node @@ -380,14 +414,8 @@ fn ccp_flow_function( }), // If node has only one output, if doesn't directly handle crossover of // reachability and constant propagation. Read handles that. - Node::If { control, cond } => { - assert!(!inputs[control.idx()].is_reachable() || inputs[cond.idx()].is_reachable()); - inputs[control.idx()].clone() - } - Node::Match { control, sum } => { - assert!(!inputs[control.idx()].is_reachable() || inputs[sum.idx()].is_reachable()); - inputs[control.idx()].clone() - } + Node::If { control, cond } => inputs[control.idx()].clone(), + Node::Match { control, sum } => inputs[control.idx()].clone(), Node::Fork { control, factors: _, @@ -426,7 +454,6 @@ fn ccp_flow_function( control, dimension: _, } => inputs[control.idx()].clone(), - // TODO: At least for now, reduce nodes always produce unknown values. Node::Reduce { control, init, @@ -434,7 +461,6 @@ fn ccp_flow_function( } => { let reachability = inputs[control.idx()].reachability.clone(); if reachability == ReachabilityLattice::Reachable { - assert!(inputs[init.idx()].is_reachable()); let mut constant = inputs[init.idx()].constant.clone(); if inputs[reduct.idx()].is_reachable() { constant = ConstantLattice::meet(&constant, &inputs[reduct.idx()].constant); diff --git a/juno_samples/nested_ccp/src/main.rs b/juno_samples/nested_ccp/src/main.rs index 80f92c0b9f9e600a157fa22784843438a7445aae..9b38476ec4ce2b606f1505463ed737eaa418639f 100644 --- a/juno_samples/nested_ccp/src/main.rs +++ b/juno_samples/nested_ccp/src/main.rs @@ -10,17 +10,27 @@ juno_build::juno!("nested_ccp"); fn main() { async_std::task::block_on(async { let a: Box<[f32]> = Box::new([17.0, 18.0, 19.0]); + let b: Box<[i32]> = Box::new([12, 16, 4, 18, 23, 56, 93, 22, 14]); let mut a_bytes: Box<[u8]> = Box::new([0; 12]); + let mut b_bytes: Box<[u8]> = Box::new([0; 36]); unsafe { copy_nonoverlapping( Box::as_ptr(&a) as *const u8, Box::as_mut_ptr(&mut a_bytes) as *mut u8, 12, ); + copy_nonoverlapping( + Box::as_ptr(&b) as *const u8, + Box::as_mut_ptr(&mut b_bytes) as *mut u8, + 36, + ); }; - let output = ccp_example(a_bytes).await; - println!("{}", output); - assert_eq!(output, 1.0); + let output_example = ccp_example(a_bytes).await; + let output_median = median_array(9, b_bytes).await; + println!("{}", output_example); + println!("{}", output_median); + assert_eq!(output_example, 1.0); + assert_eq!(output_median, 18); }); } diff --git a/juno_samples/nested_ccp/src/nested_ccp.jn b/juno_samples/nested_ccp/src/nested_ccp.jn index 7a9575854ee5220c67ffa3c4a784458646677119..ffa2d4f1dfedefe2286e982df0d8289205698862 100644 --- a/juno_samples/nested_ccp/src/nested_ccp.jn +++ b/juno_samples/nested_ccp/src/nested_ccp.jn @@ -13,3 +13,18 @@ fn ccp_example(arg : f32[3]) -> f32 { if false { for i = 0 to 3 by 1 { x = arg[i]; arg[i] = 0; } } return x; } + +#[entry] +fn median_array<n : usize>(arr : i32[n]) -> i32 { + for i = 0 to n - 1 { + for j = 0 to n - i - 1 { + if arr[j] > arr[j+1] { + let t = arr[j]; + arr[j] = arr[j+1]; + arr[j+1] = t; + } + } + } + + return arr[n / 2]; +}