From b316e00e56dac70776352a8758178d8b2863e7f2 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Thu, 19 Oct 2023 16:27:09 -0500 Subject: [PATCH] CCP, DCE, GVN, IR Builder API, improve verification, rewrite dot visualization --- Cargo.lock | 66 +++ hercules_ir/src/build.rs | 521 +++++++++++++++++ hercules_ir/src/ccp.rs | 726 ++++++++++++++++++++++++ hercules_ir/src/dataflow.rs | 78 +-- hercules_ir/src/dce.rs | 45 ++ hercules_ir/src/def_use.rs | 73 +++ hercules_ir/src/dom.rs | 12 + hercules_ir/src/dot.rs | 338 ----------- hercules_ir/src/gvn.rs | 93 +++ hercules_ir/src/ir.rs | 253 ++++++++- hercules_ir/src/lib.rs | 10 +- hercules_ir/src/parse.rs | 7 + hercules_ir/src/typecheck.rs | 60 +- hercules_ir/src/verify.rs | 26 +- hercules_tools/Cargo.toml | 1 + hercules_tools/src/hercules_dot/dot.rs | 227 ++++++++ hercules_tools/src/hercules_dot/main.rs | 43 +- samples/ccp_example.hir | 19 + samples/gvn_example.hir | 8 + samples/invalid/bad_phi2.hir | 18 + 20 files changed, 2214 insertions(+), 410 deletions(-) create mode 100644 hercules_ir/src/build.rs create mode 100644 hercules_ir/src/ccp.rs create mode 100644 hercules_ir/src/dce.rs delete mode 100644 hercules_ir/src/dot.rs create mode 100644 hercules_ir/src/gvn.rs create mode 100644 hercules_tools/src/hercules_dot/dot.rs create mode 100644 samples/ccp_example.hir create mode 100644 samples/gvn_example.hir create mode 100644 samples/invalid/bad_phi2.hir diff --git a/Cargo.lock b/Cargo.lock index 13eff2b4..e5253607 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -68,6 +68,12 @@ dependencies = [ "wyz", ] +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + [[package]] name = "clap" version = "4.4.2" @@ -120,6 +126,17 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "heck" version = "0.4.1" @@ -141,8 +158,15 @@ version = "0.1.0" dependencies = [ "clap", "hercules_ir", + "rand", ] +[[package]] +name = "libc" +version = "0.2.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" + [[package]] name = "memchr" version = "2.6.3" @@ -183,6 +207,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.66" @@ -207,6 +237,36 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "strsim" version = "0.10.0" @@ -242,6 +302,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs new file mode 100644 index 00000000..e1e6aa2e --- /dev/null +++ b/hercules_ir/src/build.rs @@ -0,0 +1,521 @@ +use std::collections::HashMap; + +use crate::*; + +/* + * The builder provides a clean API for programatically creating IR modules. + * The main function of the builder is to intern various parts of the IR. + */ +#[derive(Debug, Default)] +pub struct Builder<'a> { + // Intern function names. + function_ids: HashMap<&'a str, FunctionID>, + + // Intern types, constants, and dynamic constants on a per-module basis. + interned_types: HashMap<Type, TypeID>, + interned_constants: HashMap<Constant, ConstantID>, + interned_dynamic_constants: HashMap<DynamicConstant, DynamicConstantID>, + + // For product, summation, and array constant creation, it's useful to know + // the type of each constant. + constant_types: Vec<TypeID>, + + // The module being built. + module: Module, +} + +/* + * Since the builder doesn't provide string names for nodes, we need a different + * mechanism for allowing one to allocate node IDs before actually creating the + * node. This is required since there may be loops in the flow graph. We achieve + * this using NodeBuilders. Allocating a NodeBuilder allocates a Node ID, and + * the NodeBuilder can be later used to actually build an IR node. + */ +#[derive(Debug)] +pub struct NodeBuilder { + id: NodeID, + function_id: FunctionID, + node: Node, +} + +/* + * The IR builder may return errors when used incorrectly. + */ +type BuilderResult<T> = Result<T, String>; + +impl<'a> Builder<'a> { + fn intern_type(&mut self, ty: Type) -> TypeID { + if let Some(id) = self.interned_types.get(&ty) { + *id + } else { + let id = TypeID::new(self.interned_types.len()); + self.interned_types.insert(ty.clone(), id); + self.module.types.push(ty); + id + } + } + + fn intern_constant(&mut self, cons: Constant, ty: TypeID) -> ConstantID { + if let Some(id) = self.interned_constants.get(&cons) { + *id + } else { + let id = ConstantID::new(self.interned_constants.len()); + self.interned_constants.insert(cons.clone(), id); + self.module.constants.push(cons); + self.constant_types.push(ty); + id + } + } + + fn intern_dynamic_constant(&mut self, dyn_cons: DynamicConstant) -> DynamicConstantID { + if let Some(id) = self.interned_dynamic_constants.get(&dyn_cons) { + *id + } else { + let id = DynamicConstantID::new(self.interned_dynamic_constants.len()); + self.interned_dynamic_constants.insert(dyn_cons.clone(), id); + self.module.dynamic_constants.push(dyn_cons); + id + } + } + + pub fn create() -> Self { + Self::default() + } + + pub fn finish(self) -> Module { + self.module + } + + pub fn create_type_bool(&mut self) -> TypeID { + self.intern_type(Type::Boolean) + } + + pub fn create_type_i8(&mut self) -> TypeID { + self.intern_type(Type::Integer8) + } + + pub fn create_type_i16(&mut self) -> TypeID { + self.intern_type(Type::Integer16) + } + + pub fn create_type_i32(&mut self) -> TypeID { + self.intern_type(Type::Integer32) + } + + pub fn create_type_i64(&mut self) -> TypeID { + self.intern_type(Type::Integer64) + } + + pub fn create_type_u8(&mut self) -> TypeID { + self.intern_type(Type::UnsignedInteger8) + } + + pub fn create_type_u16(&mut self) -> TypeID { + self.intern_type(Type::UnsignedInteger16) + } + + pub fn create_type_u32(&mut self) -> TypeID { + self.intern_type(Type::UnsignedInteger32) + } + + pub fn create_type_u64(&mut self) -> TypeID { + self.intern_type(Type::UnsignedInteger64) + } + + pub fn create_type_f32(&mut self) -> TypeID { + self.intern_type(Type::Float32) + } + + pub fn create_type_f64(&mut self) -> TypeID { + self.intern_type(Type::Float64) + } + + pub fn create_type_prod(&mut self, tys: Box<[TypeID]>) -> TypeID { + self.intern_type(Type::Product(tys)) + } + + pub fn create_type_prod2(&mut self, a: TypeID, b: TypeID) -> TypeID { + self.create_type_prod(Box::new([a, b])) + } + + pub fn create_type_prod3(&mut self, a: TypeID, b: TypeID, c: TypeID) -> TypeID { + self.create_type_prod(Box::new([a, b, c])) + } + + pub fn create_type_prod4(&mut self, a: TypeID, b: TypeID, c: TypeID, d: TypeID) -> TypeID { + self.create_type_prod(Box::new([a, b, c, d])) + } + + pub fn create_type_prod5( + &mut self, + a: TypeID, + b: TypeID, + c: TypeID, + d: TypeID, + e: TypeID, + ) -> TypeID { + self.create_type_prod(Box::new([a, b, c, d, e])) + } + + pub fn create_type_prod6( + &mut self, + a: TypeID, + b: TypeID, + c: TypeID, + d: TypeID, + e: TypeID, + f: TypeID, + ) -> TypeID { + self.create_type_prod(Box::new([a, b, c, d, e, f])) + } + + pub fn create_type_prod7( + &mut self, + a: TypeID, + b: TypeID, + c: TypeID, + d: TypeID, + e: TypeID, + f: TypeID, + g: TypeID, + ) -> TypeID { + self.create_type_prod(Box::new([a, b, c, d, e, f, g])) + } + + pub fn create_type_prod8( + &mut self, + a: TypeID, + b: TypeID, + c: TypeID, + d: TypeID, + e: TypeID, + f: TypeID, + g: TypeID, + h: TypeID, + ) -> TypeID { + self.create_type_prod(Box::new([a, b, c, d, e, f, g, h])) + } + + pub fn create_type_sum(&mut self, tys: Box<[TypeID]>) -> TypeID { + self.intern_type(Type::Summation(tys)) + } + + pub fn create_type_sum2(&mut self, a: TypeID, b: TypeID) -> TypeID { + self.create_type_sum(Box::new([a, b])) + } + + pub fn create_type_sum3(&mut self, a: TypeID, b: TypeID, c: TypeID) -> TypeID { + self.create_type_sum(Box::new([a, b, c])) + } + + pub fn create_type_sum4(&mut self, a: TypeID, b: TypeID, c: TypeID, d: TypeID) -> TypeID { + self.create_type_sum(Box::new([a, b, c, d])) + } + + pub fn create_type_sum5( + &mut self, + a: TypeID, + b: TypeID, + c: TypeID, + d: TypeID, + e: TypeID, + ) -> TypeID { + self.create_type_sum(Box::new([a, b, c, d, e])) + } + + pub fn create_type_sum6( + &mut self, + a: TypeID, + b: TypeID, + c: TypeID, + d: TypeID, + e: TypeID, + f: TypeID, + ) -> TypeID { + self.create_type_sum(Box::new([a, b, c, d, e, f])) + } + + pub fn create_type_sum7( + &mut self, + a: TypeID, + b: TypeID, + c: TypeID, + d: TypeID, + e: TypeID, + f: TypeID, + g: TypeID, + ) -> TypeID { + self.create_type_sum(Box::new([a, b, c, d, e, f, g])) + } + + pub fn create_type_sum8( + &mut self, + a: TypeID, + b: TypeID, + c: TypeID, + d: TypeID, + e: TypeID, + f: TypeID, + g: TypeID, + h: TypeID, + ) -> TypeID { + self.create_type_sum(Box::new([a, b, c, d, e, f, g, h])) + } + + pub fn create_type_array(&mut self, elem: TypeID, dc: DynamicConstantID) -> TypeID { + self.intern_type(Type::Array(elem, dc)) + } + + pub fn create_constant_bool(&mut self, val: bool) -> ConstantID { + let ty = self.intern_type(Type::Boolean); + self.intern_constant(Constant::Boolean(val), ty) + } + + pub fn create_constant_i8(&mut self, val: i8) -> ConstantID { + let ty = self.intern_type(Type::Integer8); + self.intern_constant(Constant::Integer8(val), ty) + } + + pub fn create_constant_i16(&mut self, val: i16) -> ConstantID { + let ty = self.intern_type(Type::Integer16); + self.intern_constant(Constant::Integer16(val), ty) + } + + pub fn create_constant_i32(&mut self, val: i32) -> ConstantID { + let ty = self.intern_type(Type::Integer32); + self.intern_constant(Constant::Integer32(val), ty) + } + + pub fn create_constant_i64(&mut self, val: i64) -> ConstantID { + let ty = self.intern_type(Type::Integer64); + self.intern_constant(Constant::Integer64(val), ty) + } + + pub fn create_constant_u8(&mut self, val: u8) -> ConstantID { + let ty = self.intern_type(Type::UnsignedInteger8); + self.intern_constant(Constant::UnsignedInteger8(val), ty) + } + + pub fn create_constant_u16(&mut self, val: u16) -> ConstantID { + let ty = self.intern_type(Type::UnsignedInteger16); + self.intern_constant(Constant::UnsignedInteger16(val), ty) + } + + pub fn create_constant_u32(&mut self, val: u32) -> ConstantID { + let ty = self.intern_type(Type::UnsignedInteger32); + self.intern_constant(Constant::UnsignedInteger32(val), ty) + } + + pub fn create_constant_u64(&mut self, val: u64) -> ConstantID { + let ty = self.intern_type(Type::UnsignedInteger64); + self.intern_constant(Constant::UnsignedInteger64(val), ty) + } + + pub fn create_constant_f32(&mut self, val: f32) -> ConstantID { + let ty = self.intern_type(Type::Float32); + self.intern_constant( + Constant::Float32(ordered_float::OrderedFloat::<f32>(val)), + ty, + ) + } + + pub fn create_constant_f64(&mut self, val: f64) -> ConstantID { + let ty = self.intern_type(Type::Float64); + self.intern_constant( + Constant::Float64(ordered_float::OrderedFloat::<f64>(val)), + ty, + ) + } + + pub fn create_constant_prod(&mut self, cons: Box<[ConstantID]>) -> ConstantID { + let ty = self.create_type_prod(cons.iter().map(|x| self.constant_types[x.idx()]).collect()); + self.intern_constant(Constant::Product(ty, cons), ty) + } + + pub fn create_constant_sum( + &mut self, + ty: TypeID, + variant: u32, + cons: ConstantID, + ) -> BuilderResult<ConstantID> { + if let Type::Summation(variant_tys) = &self.module.types[ty.idx()] { + if variant as usize >= variant_tys.len() { + Err("Variant provided to create_constant_sum is too large for provided summation type.")? + } + if variant_tys[variant as usize] != self.constant_types[cons.idx()] { + Err("Constant provided to create_constant_sum doesn't match the summation type provided.")? + } + Ok(self.intern_constant(Constant::Summation(ty, variant, cons), ty)) + } else { + Err("Type provided to create_constant_sum is not a summation type.".to_owned()) + } + } + + pub fn create_constant_array( + &mut self, + elem_ty: TypeID, + cons: Box<[ConstantID]>, + ) -> BuilderResult<ConstantID> { + for con in cons.iter() { + if self.constant_types[con.idx()] != elem_ty { + Err("Constant provided to create_constant_array has a different type than the provided element type.")? + } + } + let dc = self.create_dynamic_constant_constant(cons.len()); + let ty = self.create_type_array(elem_ty, dc); + Ok(self.intern_constant(Constant::Array(ty, cons), ty)) + } + + pub fn create_dynamic_constant_constant(&mut self, val: usize) -> DynamicConstantID { + self.intern_dynamic_constant(DynamicConstant::Constant(val)) + } + + pub fn create_dynamic_constant_parameter(&mut self, val: usize) -> DynamicConstantID { + self.intern_dynamic_constant(DynamicConstant::Parameter(val)) + } + + pub fn create_function( + &mut self, + name: &'a str, + param_types: Vec<TypeID>, + return_type: TypeID, + num_dynamic_constants: u32, + ) -> BuilderResult<(FunctionID, NodeID)> { + if let Some(_) = self.function_ids.get(name) { + Err(format!("Can't create a function with name \"{}\", because a function with the same name has already been created.", name))? + } + + let id = FunctionID::new(self.module.functions.len()); + self.module.functions.push(Function { + name: name.to_owned(), + param_types, + return_type, + nodes: vec![Node::Start], + num_dynamic_constants, + }); + Ok((id, NodeID::new(0))) + } + + pub fn allocate_node(&mut self, function: FunctionID) -> NodeBuilder { + let id = NodeID::new(self.module.functions[function.idx()].nodes.len()); + self.module.functions[function.idx()] + .nodes + .push(Node::Start); + NodeBuilder { + id, + function_id: function, + node: Node::Start, + } + } + + pub fn add_node(&mut self, builder: NodeBuilder) -> BuilderResult<()> { + if let Node::Start = builder.node { + Err("Can't add node from a NodeBuilder before NodeBuilder has built a node.")? + } + self.module.functions[builder.function_id.idx()].nodes[builder.id.idx()] = builder.node; + Ok(()) + } +} + +impl NodeBuilder { + pub fn id(&self) -> NodeID { + self.id + } + + pub fn build_region(&mut self, preds: Box<[NodeID]>) { + self.node = Node::Region { preds }; + } + + pub fn build_if(&mut self, control: NodeID, cond: NodeID) { + self.node = Node::If { control, cond }; + } + + pub fn build_fork(&mut self, control: NodeID, factor: DynamicConstantID) { + self.node = Node::Fork { control, factor }; + } + + pub fn build_join(&mut self, control: NodeID) { + self.node = Node::Join { control }; + } + + pub fn build_phi(&mut self, control: NodeID, data: Box<[NodeID]>) { + self.node = Node::Phi { control, data }; + } + + pub fn build_threadid(&mut self, control: NodeID) { + self.node = Node::ThreadID { control }; + } + + pub fn build_collect(&mut self, control: NodeID, data: NodeID) { + self.node = Node::Collect { control, data }; + } + + pub fn build_return(&mut self, control: NodeID, data: NodeID) { + self.node = Node::Return { control, data }; + } + + pub fn build_parameter(&mut self, index: usize) { + self.node = Node::Parameter { index }; + } + + pub fn build_constant(&mut self, id: ConstantID) { + self.node = Node::Constant { id }; + } + + pub fn build_dynamicconstant(&mut self, id: DynamicConstantID) { + self.node = Node::DynamicConstant { id }; + } + + pub fn build_unary(&mut self, input: NodeID, op: UnaryOperator) { + self.node = Node::Unary { input, op }; + } + + pub fn build_binary(&mut self, left: NodeID, right: NodeID, op: BinaryOperator) { + self.node = Node::Binary { left, right, op }; + } + + pub fn build_call( + &mut self, + function: FunctionID, + dynamic_constants: Box<[DynamicConstantID]>, + args: Box<[NodeID]>, + ) { + self.node = Node::Call { + function, + dynamic_constants, + args, + }; + } + + pub fn build_readprod(&mut self, prod: NodeID, index: usize) { + self.node = Node::ReadProd { prod, index }; + } + + pub fn build_writeprod(&mut self, prod: NodeID, data: NodeID, index: usize) { + self.node = Node::WriteProd { prod, data, index }; + } + + pub fn build_readarray(&mut self, array: NodeID, index: NodeID) { + self.node = Node::ReadArray { array, index }; + } + + pub fn build_writearray(&mut self, array: NodeID, data: NodeID, index: NodeID) { + self.node = Node::WriteArray { array, data, index }; + } + + pub fn build_match(&mut self, control: NodeID, sum: NodeID) { + self.node = Node::Match { control, sum }; + } + + pub fn build_buildsum(&mut self, data: NodeID, sum_ty: TypeID, variant: usize) { + self.node = Node::BuildSum { + data, + sum_ty, + variant, + }; + } + + pub fn build_extractsum(&mut self, data: NodeID, variant: usize) { + self.node = Node::ExtractSum { data, variant }; + } +} diff --git a/hercules_ir/src/ccp.rs b/hercules_ir/src/ccp.rs new file mode 100644 index 00000000..ded30b1f --- /dev/null +++ b/hercules_ir/src/ccp.rs @@ -0,0 +1,726 @@ +use std::collections::HashMap; +use std::iter::zip; + +use crate::*; + +/* + * The ccp lattice tracks, for each node, the following information: + * 1. Reachability - is it possible for this node to be reached during any + * execution? + * 2. Constant - does this node evaluate to a constant expression? + * The ccp lattice is formulated as a combination of consistuent lattices. The + * flow function for the ccp dataflow analysis "crosses" information across the + * sub lattices - for example, whether a condition is constant may inform + * whether a branch target is reachable. This analysis uses interpreted + * constants, so constant one plus constant one results in constant two. + */ +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct CCPLattice { + reachability: ReachabilityLattice, + constant: ConstantLattice, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ReachabilityLattice { + Unreachable, + Reachable, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ConstantLattice { + Top, + Constant(Constant), + Bottom, +} + +impl CCPLattice { + fn is_reachable(&self) -> bool { + self.reachability == ReachabilityLattice::Reachable + } + + fn get_constant(&self) -> Option<Constant> { + if let ConstantLattice::Constant(cons) = &self.constant { + Some(cons.clone()) + } else { + None + } + } +} + +impl ConstantLattice { + fn is_top(&self) -> bool { + *self == ConstantLattice::Top + } + + fn is_bottom(&self) -> bool { + *self == ConstantLattice::Bottom + } +} + +impl Semilattice for CCPLattice { + fn meet(a: &Self, b: &Self) -> Self { + CCPLattice { + reachability: ReachabilityLattice::meet(&a.reachability, &b.reachability), + constant: ConstantLattice::meet(&a.constant, &b.constant), + } + } + + fn bottom() -> Self { + CCPLattice { + reachability: ReachabilityLattice::bottom(), + constant: ConstantLattice::bottom(), + } + } + + fn top() -> Self { + CCPLattice { + reachability: ReachabilityLattice::top(), + constant: ConstantLattice::top(), + } + } +} + +impl Semilattice for ReachabilityLattice { + fn meet(a: &Self, b: &Self) -> Self { + match (a, b) { + (ReachabilityLattice::Unreachable, ReachabilityLattice::Unreachable) => { + ReachabilityLattice::Unreachable + } + _ => ReachabilityLattice::Reachable, + } + } + + fn bottom() -> Self { + ReachabilityLattice::Reachable + } + + fn top() -> Self { + ReachabilityLattice::Unreachable + } +} + +impl Semilattice for ConstantLattice { + fn meet(a: &Self, b: &Self) -> Self { + match (a, b) { + (ConstantLattice::Top, b) => b.clone(), + (a, ConstantLattice::Top) => a.clone(), + (ConstantLattice::Constant(cons1), ConstantLattice::Constant(cons2)) => { + if cons1 == cons2 { + ConstantLattice::Constant(cons1.clone()) + } else { + ConstantLattice::Bottom + } + } + _ => ConstantLattice::Bottom, + } + } + + fn bottom() -> Self { + ConstantLattice::Bottom + } + + fn top() -> Self { + ConstantLattice::Top + } +} + +/* + * Top level function to run conditional constant propagation. + */ +pub fn ccp( + function: &mut Function, + constants: &mut Vec<Constant>, + def_use: &ImmutableDefUseMap, + reverse_postorder: &Vec<NodeID>, +) { + // Step 1: run ccp analysis to understand the function. + let result = forward_dataflow_global(&function, reverse_postorder, |inputs, node_id| { + ccp_flow_function(inputs, node_id, &function, &constants) + }); + + // Step 2: update uses of constants. Any node that doesn't produce a + // constant value, but does use a newly found constant value, needs to be + // updated to use the newly found constant. + + // Step 2.1: assemble reverse constant map. We created a bunch of constants + // during the analysis, so we need to intern them. + let mut reverse_constant_map: HashMap<Constant, ConstantID> = constants + .iter() + .enumerate() + .map(|(idx, cons)| (cons.clone(), ConstantID::new(idx))) + .collect(); + + // Helper function for interning constants in the lattice. + let mut get_constant_id = |cons| { + if let Some(id) = reverse_constant_map.get(&cons) { + *id + } else { + let id = ConstantID::new(reverse_constant_map.len()); + reverse_constant_map.insert(cons.clone(), id); + id + } + }; + + // Step 2.2: for every node, update uses of now constant nodes. We need to + // separately create constant nodes, since we are mutably looping over the + // function nodes separately. + let mut new_constant_nodes = vec![]; + let base_cons_node_idx = function.nodes.len(); + for node in function.nodes.iter_mut() { + for u in get_uses_mut(node).as_mut() { + let old_id = **u; + if let Some(cons) = result[old_id.idx()].get_constant() { + // Get ConstantID for this constant. + let cons_id = get_constant_id(cons); + + // Search new_constant_nodes for a constant IR node that already + // referenced this ConstantID. + if let Some(new_nodes_idx) = new_constant_nodes + .iter() + .enumerate() + .filter(|(_, id)| **id == cons_id) + .map(|(idx, _)| idx) + .next() + { + // If there is already a constant IR node, calculate what + // the NodeID will be for it, and set the use to that ID. + **u = NodeID::new(base_cons_node_idx + new_nodes_idx); + } else { + // If there is not already a constant IR node for this + // ConstantID, add this ConstantID to the new_constant_nodes + // list. Set the use to the corresponding NodeID for the new + // constant IR node. + let cons_node_id = NodeID::new(base_cons_node_idx + new_constant_nodes.len()); + new_constant_nodes.push(cons_id); + **u = cons_node_id; + } + } + } + } + + // Step 2.3: add new constant nodes into nodes of function. + for node in new_constant_nodes { + function.nodes.push(Node::Constant { id: node }); + } + + // Step 2.4: re-create module's constants vector from interning map. + *constants = vec![Constant::Boolean(false); reverse_constant_map.len()]; + for (cons, id) in reverse_constant_map { + constants[id.idx()] = cons; + } + + // Step 3: delete dead branches. Any nodes that are unreachable should be + // deleted. Any if or match nodes that now are light on users need to be + // removed immediately, since if and match nodes have requirements on the + // number of users. + + // Step 3.1: delete unreachable nodes. Loop over the length of the dataflow + // result instead of the function's node list, since in step 2, constant + // nodes were added that don't have a corresponding lattice result. + for idx in 0..result.len() { + if !result[idx].is_reachable() { + function.nodes[idx] = Node::Start; + } + } + + // Step 3.2: remove uses of data nodes in phi nodes corresponding to + // unreachable uses in corresponding region nodes. + for phi_id in (0..result.len()).map(NodeID::new) { + if let Node::Phi { control, data } = &function.nodes[phi_id.idx()] { + if let Node::Region { preds } = &function.nodes[control.idx()] { + let new_data = zip(preds.iter(), data.iter()) + .filter(|(pred, _)| result[pred.idx()].is_reachable()) + .map(|(_, datum)| *datum) + .collect(); + function.nodes[phi_id.idx()] = Node::Phi { + control: *control, + data: new_data, + }; + } + } + } + + // Step 3.3: remove uses of unreachable nodes in region nodes. + for node in function.nodes.iter_mut() { + if let Node::Region { preds } = node { + *preds = preds + .iter() + .filter(|pred| result[pred.idx()].is_reachable()) + .map(|x| *x) + .collect(); + } + } + + // Step 3.4: remove if and match nodes with one reachable user. + for branch_id in (0..result.len()).map(NodeID::new) { + if let Node::If { control, cond: _ } | Node::Match { control, sum: _ } = + function.nodes[branch_id.idx()].clone() + { + let users = def_use.get_users(branch_id); + let mut reachable_users = users + .iter() + .filter(|user| result[user.idx()].is_reachable()); + let the_reachable_user = reachable_users + .next() + .expect("During CCP, found a branch with no reachable users."); + + // The reachable users iterator will contain one user if we need to + // remove this branch node. + if let None = reachable_users.next() { + // The user is a ReadProd node, which in turn has one user. + let target = def_use.get_users(*the_reachable_user)[0]; + + // For each use in the target of the reachable ReadProd, turn it + // into a use of the node proceeding this branch node. + for u in get_uses_mut(&mut function.nodes[target.idx()]).as_mut() { + if **u == *the_reachable_user { + **u = control; + } + } + + // Remove this branch node, since it is malformed. Also remove + // all successor ReadProd nodes. + function.nodes[branch_id.idx()] = Node::Start; + for user in users { + function.nodes[user.idx()] = Node::Start; + } + } + } + } +} + +fn ccp_flow_function( + inputs: &[CCPLattice], + node_id: NodeID, + function: &Function, + old_constants: &Vec<Constant>, +) -> CCPLattice { + let node = &function.nodes[node_id.idx()]; + match node { + Node::Start => CCPLattice::bottom(), + Node::Region { preds } => preds.iter().fold(CCPLattice::top(), |val, id| { + CCPLattice::meet(&val, &inputs[id.idx()]) + }), + // If node has only one output, so doesn't directly handle crossover of + // reachability and constant propagation. ReadProd handles that. + Node::If { control, cond: _ } => inputs[control.idx()].clone(), + Node::Fork { control, factor: _ } => inputs[control.idx()].clone(), + Node::Join { control } => inputs[control.idx()].clone(), + // Phi nodes must look at the reachability of the inputs to its + // corresponding region node to determine the constant value being + // output. + Node::Phi { control, data } => { + // Get the control predecessors of the corresponding region. + let region_preds = if let Node::Region { preds } = &function.nodes[control.idx()] { + preds + } else { + panic!("A phi's control input must be a region node.") + }; + zip(region_preds.iter(), data.iter()).fold( + CCPLattice { + reachability: inputs[control.idx()].reachability.clone(), + constant: ConstantLattice::top(), + }, + |val, (control_id, data_id)| { + // If a control input to the region node is reachable, then + // and only then do we meet with the data input's constant + // lattice value. + if inputs[control_id.idx()].is_reachable() { + CCPLattice::meet(&val, &inputs[data_id.idx()]) + } else { + val + } + }, + ) + } + // TODO: This should produce a constant zero if the dynamic constant for + // for the corresponding fork is one. + Node::ThreadID { control } => inputs[control.idx()].clone(), + // TODO: At least for now, collect nodes always produce unknown values. + // It may be worthwile to add interpretation of constants for collect + // nodes, but it would involve plumbing dynamic constant and fork join + // pairing information here, and I don't feel like doing that. + Node::Collect { control, data: _ } => inputs[control.idx()].clone(), + Node::Return { control, data } => CCPLattice { + reachability: inputs[control.idx()].reachability.clone(), + constant: inputs[data.idx()].constant.clone(), + }, + Node::Parameter { index: _ } => CCPLattice::bottom(), + // A constant node is the "source" of concrete constant lattice values. + Node::Constant { id } => CCPLattice { + reachability: ReachabilityLattice::bottom(), + constant: ConstantLattice::Constant(old_constants[id.idx()].clone()), + }, + // TODO: This should really be constant interpreted, since dynamic + // constants as values are used frequently. + Node::DynamicConstant { id: _ } => CCPLattice::bottom(), + // Interpret unary op on constant. TODO: avoid UB. + Node::Unary { input, op } => { + let CCPLattice { + ref reachability, + ref constant, + } = inputs[input.idx()]; + + let new_constant = if let ConstantLattice::Constant(cons) = constant { + let new_cons = match (op, cons) { + (UnaryOperator::Not, Constant::Boolean(val)) => Constant::Boolean(!val), + (UnaryOperator::Neg, Constant::Integer8(val)) => Constant::Integer8(-val), + (UnaryOperator::Neg, Constant::Integer16(val)) => Constant::Integer16(-val), + (UnaryOperator::Neg, Constant::Integer32(val)) => Constant::Integer32(-val), + (UnaryOperator::Neg, Constant::Integer64(val)) => Constant::Integer64(-val), + (UnaryOperator::Neg, Constant::Float32(val)) => Constant::Float32(-val), + (UnaryOperator::Neg, Constant::Float64(val)) => Constant::Float64(-val), + (UnaryOperator::Bitflip, Constant::Integer8(val)) => Constant::Integer8(!val), + (UnaryOperator::Bitflip, Constant::Integer16(val)) => Constant::Integer16(!val), + (UnaryOperator::Bitflip, Constant::Integer32(val)) => Constant::Integer32(!val), + (UnaryOperator::Bitflip, Constant::Integer64(val)) => Constant::Integer64(!val), + _ => panic!("Unsupported combination of unary operation and constant value. Did typechecking succeed?") + }; + ConstantLattice::Constant(new_cons) + } else { + constant.clone() + }; + + CCPLattice { + reachability: reachability.clone(), + constant: new_constant, + } + } + // Interpret binary op on constants. TODO: avoid UB. + Node::Binary { left, right, op } => { + let CCPLattice { + reachability: ref left_reachability, + constant: ref left_constant, + } = inputs[left.idx()]; + let CCPLattice { + reachability: ref right_reachability, + constant: ref right_constant, + } = inputs[right.idx()]; + + let new_constant = if let ( + ConstantLattice::Constant(left_cons), + ConstantLattice::Constant(right_cons), + ) = (left_constant, right_constant) + { + let new_cons = match (op, left_cons, right_cons) { + (BinaryOperator::Add, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val + right_val), + (BinaryOperator::Add, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val + right_val), + (BinaryOperator::Add, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val + right_val), + (BinaryOperator::Add, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val + right_val), + (BinaryOperator::Add, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val + right_val), + (BinaryOperator::Add, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val + right_val), + (BinaryOperator::Add, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val + right_val), + (BinaryOperator::Add, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val + right_val), + (BinaryOperator::Add, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val + *right_val), + (BinaryOperator::Add, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val + *right_val), + (BinaryOperator::Sub, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val - right_val), + (BinaryOperator::Sub, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val - right_val), + (BinaryOperator::Sub, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val - right_val), + (BinaryOperator::Sub, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val - right_val), + (BinaryOperator::Sub, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val - right_val), + (BinaryOperator::Sub, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val - right_val), + (BinaryOperator::Sub, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val - right_val), + (BinaryOperator::Sub, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val - right_val), + (BinaryOperator::Sub, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val - *right_val), + (BinaryOperator::Sub, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val - *right_val), + (BinaryOperator::Mul, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val * right_val), + (BinaryOperator::Mul, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val * right_val), + (BinaryOperator::Mul, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val * right_val), + (BinaryOperator::Mul, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val * right_val), + (BinaryOperator::Mul, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val * right_val), + (BinaryOperator::Mul, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val * right_val), + (BinaryOperator::Mul, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val * right_val), + (BinaryOperator::Mul, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val * right_val), + (BinaryOperator::Mul, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val * *right_val), + (BinaryOperator::Mul, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val * *right_val), + (BinaryOperator::Div, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val / right_val), + (BinaryOperator::Div, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val / right_val), + (BinaryOperator::Div, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val / right_val), + (BinaryOperator::Div, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val / right_val), + (BinaryOperator::Div, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val / right_val), + (BinaryOperator::Div, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val / right_val), + (BinaryOperator::Div, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val / right_val), + (BinaryOperator::Div, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val / right_val), + (BinaryOperator::Div, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val / *right_val), + (BinaryOperator::Div, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val / *right_val), + (BinaryOperator::Rem, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val % right_val), + (BinaryOperator::Rem, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val % right_val), + (BinaryOperator::Rem, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val % right_val), + (BinaryOperator::Rem, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val % right_val), + (BinaryOperator::Rem, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val % right_val), + (BinaryOperator::Rem, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val % right_val), + (BinaryOperator::Rem, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val % right_val), + (BinaryOperator::Rem, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val % right_val), + (BinaryOperator::Rem, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val % *right_val), + (BinaryOperator::Rem, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val % *right_val), + (BinaryOperator::LT, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Boolean(left_val < right_val), + (BinaryOperator::LT, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Boolean(left_val < right_val), + (BinaryOperator::LT, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Boolean(left_val < right_val), + (BinaryOperator::LT, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Boolean(left_val < right_val), + (BinaryOperator::LT, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::Boolean(left_val < right_val), + (BinaryOperator::LT, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::Boolean(left_val < right_val), + (BinaryOperator::LT, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::Boolean(left_val < right_val), + (BinaryOperator::LT, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::Boolean(left_val < right_val), + (BinaryOperator::LT, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Boolean(*left_val < *right_val), + (BinaryOperator::LT, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Boolean(*left_val < *right_val), + (BinaryOperator::LTE, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Boolean(left_val <= right_val), + (BinaryOperator::LTE, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Boolean(left_val <= right_val), + (BinaryOperator::LTE, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Boolean(left_val <= right_val), + (BinaryOperator::LTE, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Boolean(left_val <= right_val), + (BinaryOperator::LTE, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::Boolean(left_val <= right_val), + (BinaryOperator::LTE, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::Boolean(left_val <= right_val), + (BinaryOperator::LTE, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::Boolean(left_val <= right_val), + (BinaryOperator::LTE, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::Boolean(left_val <= right_val), + (BinaryOperator::LTE, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Boolean(*left_val <= *right_val), + (BinaryOperator::LTE, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Boolean(*left_val <= *right_val), + (BinaryOperator::GT, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Boolean(left_val > right_val), + (BinaryOperator::GT, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Boolean(left_val > right_val), + (BinaryOperator::GT, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Boolean(left_val > right_val), + (BinaryOperator::GT, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Boolean(left_val > right_val), + (BinaryOperator::GT, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::Boolean(left_val > right_val), + (BinaryOperator::GT, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::Boolean(left_val > right_val), + (BinaryOperator::GT, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::Boolean(left_val > right_val), + (BinaryOperator::GT, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::Boolean(left_val > right_val), + (BinaryOperator::GT, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Boolean(*left_val > *right_val), + (BinaryOperator::GT, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Boolean(*left_val > *right_val), + (BinaryOperator::GTE, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Boolean(left_val >= right_val), + (BinaryOperator::GTE, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Boolean(left_val >= right_val), + (BinaryOperator::GTE, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Boolean(left_val >= right_val), + (BinaryOperator::GTE, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Boolean(left_val >= right_val), + (BinaryOperator::GTE, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::Boolean(left_val >= right_val), + (BinaryOperator::GTE, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::Boolean(left_val >= right_val), + (BinaryOperator::GTE, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::Boolean(left_val >= right_val), + (BinaryOperator::GTE, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::Boolean(left_val >= right_val), + (BinaryOperator::GTE, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Boolean(*left_val >= *right_val), + (BinaryOperator::GTE, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Boolean(*left_val >= *right_val), + // EQ and NE can be implemented more easily, since we don't + // need to unpack the constants. + (BinaryOperator::EQ, left_val, right_val) => Constant::Boolean(left_val == right_val), + (BinaryOperator::NE, left_val, right_val) => Constant::Boolean(left_val != right_val), + (BinaryOperator::Or, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(*left_val || *right_val), + (BinaryOperator::Or, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val | right_val), + (BinaryOperator::Or, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val | right_val), + (BinaryOperator::Or, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val | right_val), + (BinaryOperator::Or, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val | right_val), + (BinaryOperator::Or, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val | right_val), + (BinaryOperator::Or, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val | right_val), + (BinaryOperator::Or, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val | right_val), + (BinaryOperator::Or, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val | right_val), + (BinaryOperator::And, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(*left_val && *right_val), + (BinaryOperator::And, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val & right_val), + (BinaryOperator::And, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val & right_val), + (BinaryOperator::And, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val & right_val), + (BinaryOperator::And, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val & right_val), + (BinaryOperator::And, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val & right_val), + (BinaryOperator::And, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val & right_val), + (BinaryOperator::And, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val & right_val), + (BinaryOperator::And, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val & right_val), + (BinaryOperator::Xor, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(*left_val ^ *right_val), + (BinaryOperator::Xor, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val ^ right_val), + (BinaryOperator::Xor, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val ^ right_val), + (BinaryOperator::Xor, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val ^ right_val), + (BinaryOperator::Xor, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val ^ right_val), + (BinaryOperator::Xor, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val ^ right_val), + (BinaryOperator::Xor, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val ^ right_val), + (BinaryOperator::Xor, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val ^ right_val), + (BinaryOperator::Xor, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val ^ right_val), + (BinaryOperator::LSh, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val << right_val), + (BinaryOperator::LSh, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val << right_val), + (BinaryOperator::LSh, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val << right_val), + (BinaryOperator::LSh, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val << right_val), + (BinaryOperator::LSh, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val << right_val), + (BinaryOperator::LSh, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val << right_val), + (BinaryOperator::LSh, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val << right_val), + (BinaryOperator::LSh, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val << right_val), + (BinaryOperator::RSh, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val >> right_val), + (BinaryOperator::RSh, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val >> right_val), + (BinaryOperator::RSh, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val >> right_val), + (BinaryOperator::RSh, Constant::Integer64(left_val), Constant::Integer64(right_val)) => Constant::Integer64(left_val >> right_val), + (BinaryOperator::RSh, Constant::UnsignedInteger8(left_val), Constant::UnsignedInteger8(right_val)) => Constant::UnsignedInteger8(left_val >> right_val), + (BinaryOperator::RSh, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val >> right_val), + (BinaryOperator::RSh, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val >> right_val), + (BinaryOperator::RSh, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val >> right_val), + _ => panic!("Unsupported combination of binary operation and constant values. Did typechecking succeed?") + }; + ConstantLattice::Constant(new_cons) + } else if (left_constant.is_top() && !right_constant.is_bottom()) + || (!left_constant.is_bottom() && right_constant.is_top()) + { + ConstantLattice::top() + } else { + ConstantLattice::meet(left_constant, right_constant) + }; + + CCPLattice { + reachability: ReachabilityLattice::meet(left_reachability, right_reachability), + constant: new_constant, + } + } + // Call nodes are uninterpretable. + Node::Call { + function: _, + dynamic_constants: _, + args, + } => CCPLattice { + reachability: args.iter().fold(ReachabilityLattice::top(), |val, id| { + ReachabilityLattice::meet(&val, &inputs[id.idx()].reachability) + }), + constant: ConstantLattice::bottom(), + }, + // ReadProd handles reachability when following an if or match. + Node::ReadProd { prod, index } => match &function.nodes[prod.idx()] { + Node::If { control: _, cond } => { + let cond_constant = &inputs[cond.idx()].constant; + let if_reachability = &inputs[prod.idx()].reachability; + let if_constant = &inputs[prod.idx()].constant; + + let new_reachability = if cond_constant.is_top() { + ReachabilityLattice::top() + } else if let ConstantLattice::Constant(cons) = cond_constant { + if let Constant::Boolean(val) = cons { + if *val && *index == 0 { + // If condition is true and this is the false + // branch, then unreachable. + ReachabilityLattice::top() + } else if !val && *index == 1 { + // If condition is true and this is the true branch, + // then unreachable. + ReachabilityLattice::top() + } else { + if_reachability.clone() + } + } else { + panic!("Attempted to interpret ReadProd node, where corresponding if node has a non-boolean constant input. Did typechecking succeed?") + } + } else { + if_reachability.clone() + }; + + CCPLattice { + reachability: new_reachability, + constant: if_constant.clone(), + } + } + Node::Match { control: _, sum } => { + let sum_constant = &inputs[sum.idx()].constant; + let if_reachability = &inputs[prod.idx()].reachability; + let if_constant = &inputs[prod.idx()].constant; + + let new_reachability = if sum_constant.is_top() { + ReachabilityLattice::top() + } else if let ConstantLattice::Constant(cons) = sum_constant { + if let Constant::Summation(_, variant, _) = cons { + if *variant as usize != *index { + // If match variant is not the same as this branch, + // then unreachable. + ReachabilityLattice::top() + } else { + if_reachability.clone() + } + } else { + panic!("Attempted to interpret ReadProd node, where corresponding match node has a non-summation constant input. Did typechecking succeed?") + } + } else { + if_reachability.clone() + }; + + CCPLattice { + reachability: new_reachability, + constant: if_constant.clone(), + } + } + _ => { + let CCPLattice { + ref reachability, + ref constant, + } = inputs[prod.idx()]; + + let new_constant = if let ConstantLattice::Constant(cons) = constant { + let new_cons = if let Constant::Product(_, fields) = cons { + // Index into product constant to get result constant. + old_constants[fields[*index].idx()].clone() + } else { + panic!("Attempted to interpret ReadProd on non-product constant. Did typechecking succeed?") + }; + ConstantLattice::Constant(new_cons) + } else { + constant.clone() + }; + + CCPLattice { + reachability: reachability.clone(), + constant: new_constant, + } + } + }, + // WriteProd is uninterpreted for now. + Node::WriteProd { + prod, + data, + index: _, + } => CCPLattice { + reachability: ReachabilityLattice::meet( + &inputs[prod.idx()].reachability, + &inputs[data.idx()].reachability, + ), + constant: ConstantLattice::bottom(), + }, + Node::ReadArray { array, index } => { + let CCPLattice { + reachability: ref array_reachability, + constant: ref array_constant, + } = inputs[array.idx()]; + let CCPLattice { + reachability: ref index_reachability, + constant: ref index_constant, + } = inputs[index.idx()]; + + let new_constant = if let ( + ConstantLattice::Constant(array_cons), + ConstantLattice::Constant(index_cons), + ) = (array_constant, index_constant) + { + let new_cons = match (array_cons, index_cons) { + (Constant::Array(_, elems), Constant::UnsignedInteger8(idx)) => { + elems[*idx as usize] + } + (Constant::Array(_, elems), Constant::UnsignedInteger16(idx)) => { + elems[*idx as usize] + } + (Constant::Array(_, elems), Constant::UnsignedInteger32(idx)) => { + elems[*idx as usize] + } + (Constant::Array(_, elems), Constant::UnsignedInteger64(idx)) => { + elems[*idx as usize] + } + _ => panic!("Unsupported inputs to ReadArray node. Did typechecking succeed?"), + }; + ConstantLattice::Constant(old_constants[new_cons.idx()].clone()) + } else if (array_constant.is_top() && !index_constant.is_bottom()) + || (!array_constant.is_bottom() && index_constant.is_top()) + { + ConstantLattice::top() + } else { + ConstantLattice::meet(array_constant, index_constant) + }; + + CCPLattice { + reachability: ReachabilityLattice::meet(array_reachability, index_reachability), + constant: new_constant, + } + } + // WriteArray is uninterpreted for now. + Node::WriteArray { array, data, index } => CCPLattice { + reachability: ReachabilityLattice::meet( + &ReachabilityLattice::meet( + &inputs[array.idx()].reachability, + &inputs[data.idx()].reachability, + ), + &inputs[index.idx()].reachability, + ), + constant: ConstantLattice::bottom(), + }, + Node::Match { control, sum: _ } => inputs[control.idx()].clone(), + _ => CCPLattice::bottom(), + } +} diff --git a/hercules_ir/src/dataflow.rs b/hercules_ir/src/dataflow.rs index 92886d3b..0cb95e31 100644 --- a/hercules_ir/src/dataflow.rs +++ b/hercules_ir/src/dataflow.rs @@ -33,45 +33,50 @@ where L: Semilattice, F: FnMut(&[&L], NodeID) -> L, { - // Step 1: compute NodeUses for each node in function. - let uses: Vec<NodeUses> = function.nodes.iter().map(|n| get_uses(n)).collect(); + forward_dataflow_global(function, reverse_postorder, |global_outs, node_id| { + let uses = get_uses(&function.nodes[node_id.idx()]); + let pred_outs: Vec<_> = uses + .as_ref() + .iter() + .map(|id| &global_outs[id.idx()]) + .collect(); + flow_function(&pred_outs, node_id) + }) +} - // Step 2: create initial set of "out" points. - let start_node_output = flow_function(&[&L::bottom()], NodeID::new(0)); +/* + * The previous forward dataflow routine wraps around this dataflow routine, + * where the flow function doesn't just have access to this nodes input lattice + * values, but also all the current lattice values for all the nodes. This is + * useful for some dataflow analyses, such as reachability. The "global" in + * forward_dataflow_global refers to having a global view of the out lattice + * values. + */ +pub fn forward_dataflow_global<L, F>( + function: &Function, + reverse_postorder: &Vec<NodeID>, + mut flow_function: F, +) -> Vec<L> +where + L: Semilattice, + F: FnMut(&[L], NodeID) -> L, +{ + // Step 1: create initial set of "out" points. + let start_node_output = flow_function(&[], NodeID::new(0)); + let mut first_ins = vec![L::top(); function.nodes.len()]; + first_ins[0] = start_node_output; let mut outs: Vec<L> = (0..function.nodes.len()) - .map(|id| { - flow_function( - &vec![ - &(if id == 0 { - start_node_output.clone() - } else { - L::top() - }); - uses[id].as_ref().len() - ], - NodeID::new(id), - ) - }) + .map(|id| flow_function(&first_ins, NodeID::new(id))) .collect(); - // Step 3: peform main dataflow loop. + // Step 2: peform main dataflow loop. loop { let mut change = false; // Iterate nodes in reverse post order. for node_id in reverse_postorder { - // Assemble the "out" values of the predecessors of this node. This - // vector's definition is hopefully LICMed out, so that we don't do - // an allocation per node. This can't be done manually because of - // Rust's ownership rules (in particular, pred_outs holds a - // reference to a value inside outs, which is mutated below). - let mut pred_outs = vec![]; - for u in uses[node_id.idx()].as_ref() { - pred_outs.push(&outs[u.idx()]); - } - // Compute new "out" value from predecessor "out" values. - let new_out = flow_function(&pred_outs[..], *node_id); + let new_out = flow_function(&outs, *node_id); if outs[node_id.idx()] != new_out { change = true; } @@ -87,7 +92,7 @@ where } } - // Step 4: return "out" set. + // Step 3: return "out" set. outs } @@ -166,8 +171,7 @@ impl Semilattice for IntersectNodeSet { ); IntersectNodeSet::Bits(a.clone() & b) } - (IntersectNodeSet::Empty, _) => IntersectNodeSet::Empty, - (_, IntersectNodeSet::Empty) => IntersectNodeSet::Empty, + _ => IntersectNodeSet::Empty, } } @@ -215,8 +219,7 @@ impl Semilattice for UnionNodeSet { ); UnionNodeSet::Bits(a.clone() | b) } - (UnionNodeSet::Full, _) => UnionNodeSet::Full, - (_, UnionNodeSet::Full) => UnionNodeSet::Full, + _ => UnionNodeSet::Full, } } @@ -251,10 +254,9 @@ pub fn control_output_flow( function: &Function, ) -> UnionNodeSet { // Step 1: union inputs. - let mut out = UnionNodeSet::top(); - for input in inputs { - out = UnionNodeSet::meet(&out, input); - } + let mut out = inputs + .into_iter() + .fold(UnionNodeSet::top(), |a, b| UnionNodeSet::meet(&a, b)); let node = &function.nodes[node_id.idx()]; // Step 2: clear all bits, if applicable. diff --git a/hercules_ir/src/dce.rs b/hercules_ir/src/dce.rs new file mode 100644 index 00000000..1f56d864 --- /dev/null +++ b/hercules_ir/src/dce.rs @@ -0,0 +1,45 @@ +use crate::*; + +/* + * Top level function to run dead code elimination. Deletes nodes by setting + * nodes to gravestones. Works with a function already containing gravestones. + */ +pub fn dce(function: &mut Function) { + // Step 1: count number of users for each node. + let mut num_users = vec![0; function.nodes.len()]; + for (idx, node) in function.nodes.iter().enumerate() { + for u in get_uses(node).as_ref() { + num_users[u.idx()] += 1; + } + + // Return nodes shouldn't be considered dead code, so create a "phantom" + // user. + if node.is_return() { + num_users[idx] += 1; + } + } + + // Step 2: worklist over zero user nodes. + + // Worklist starts as list of all nodes with 0 users. + let mut worklist: Vec<_> = num_users + .iter() + .enumerate() + .filter(|(_, num_users)| **num_users == 0) + .map(|(idx, _)| idx) + .collect(); + while let Some(work) = worklist.pop() { + // Use start node as gravestone node value. + let mut gravestone = Node::Start; + std::mem::swap(&mut function.nodes[work], &mut gravestone); + + // Now that we set the gravestone, figure out other nodes that need to + // be added to the worklist. + for u in get_uses(&gravestone).as_ref() { + num_users[u.idx()] -= 1; + if num_users[u.idx()] == 0 { + worklist.push(u.idx()); + } + } + } +} diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index ea9ab07f..fbd75201 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -86,6 +86,19 @@ pub enum NodeUses<'a> { Phi(Box<[NodeID]>), } +/* + * Enum for storing mutable uses of node. Using get_uses_mut, one can easily + * modify the defs that a node uses. + */ +#[derive(Debug)] +pub enum NodeUsesMut<'a> { + Zero, + One([&'a mut NodeID; 1]), + Two([&'a mut NodeID; 2]), + Three([&'a mut NodeID; 3]), + Variable(Box<[&'a mut NodeID]>), +} + impl<'a> AsRef<[NodeID]> for NodeUses<'a> { fn as_ref(&self) -> &[NodeID] { match self { @@ -99,6 +112,18 @@ impl<'a> AsRef<[NodeID]> for NodeUses<'a> { } } +impl<'a> AsMut<[&'a mut NodeID]> for NodeUsesMut<'a> { + fn as_mut(&mut self) -> &mut [&'a mut NodeID] { + match self { + NodeUsesMut::Zero => &mut [], + NodeUsesMut::One(x) => x, + NodeUsesMut::Two(x) => x, + NodeUsesMut::Three(x) => x, + NodeUsesMut::Variable(x) => x, + } + } +} + /* * Construct NodeUses for a Node. */ @@ -144,3 +169,51 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { Node::ExtractSum { data, variant: _ } => NodeUses::One([*data]), } } + +/* + * Construct NodeUsesMut for a node. Note, this is not a one-to-one mutable + * analog of NodeUses. In particular, constant, dynamic constant, and parameter + * nodes all implicitly take as input the start node. However, this is not + * stored (it is an implict use), and thus can't be modified. Thus, those uses + * are not represented in NodeUsesMut, but are in NodeUses. + */ +pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { + match node { + Node::Start => NodeUsesMut::Zero, + Node::Region { preds } => NodeUsesMut::Variable(preds.iter_mut().collect()), + Node::If { control, cond } => NodeUsesMut::Two([control, cond]), + Node::Fork { control, factor: _ } => NodeUsesMut::One([control]), + Node::Join { control } => NodeUsesMut::One([control]), + Node::Phi { control, data } => { + NodeUsesMut::Variable(std::iter::once(control).chain(data.iter_mut()).collect()) + } + Node::ThreadID { control } => NodeUsesMut::One([control]), + Node::Collect { control, data } => NodeUsesMut::Two([control, data]), + Node::Return { control, data } => NodeUsesMut::Two([control, data]), + Node::Parameter { index: _ } => NodeUsesMut::Zero, + Node::Constant { id: _ } => NodeUsesMut::Zero, + Node::DynamicConstant { id: _ } => NodeUsesMut::Zero, + Node::Unary { input, op: _ } => NodeUsesMut::One([input]), + Node::Binary { left, right, op: _ } => NodeUsesMut::Two([left, right]), + Node::Call { + function: _, + dynamic_constants: _, + args, + } => NodeUsesMut::Variable(args.iter_mut().collect()), + Node::ReadProd { prod, index: _ } => NodeUsesMut::One([prod]), + Node::WriteProd { + prod, + data, + index: _, + } => NodeUsesMut::Two([prod, data]), + Node::ReadArray { array, index } => NodeUsesMut::Two([array, index]), + Node::WriteArray { array, data, index } => NodeUsesMut::Three([array, data, index]), + Node::Match { control, sum } => NodeUsesMut::Two([control, sum]), + Node::BuildSum { + data, + sum_ty: _, + variant: _, + } => NodeUsesMut::One([data]), + Node::ExtractSum { data, variant: _ } => NodeUsesMut::One([data]), + } +} diff --git a/hercules_ir/src/dom.rs b/hercules_ir/src/dom.rs index d359db65..fc5fc398 100644 --- a/hercules_ir/src/dom.rs +++ b/hercules_ir/src/dom.rs @@ -46,6 +46,18 @@ impl DomTree { pub fn is_non_root(&self, x: NodeID) -> bool { self.idom.contains_key(&x) } + + /* + * Typically, node ID 0 is the root of the dom tree. Under this assumption, + * this function checks if a node is in the dom tree. + */ + pub fn contains_conventional(&self, x: NodeID) -> bool { + x == NodeID::new(0) || self.idom.contains_key(&x) + } + + pub fn get_underlying_map(&self) -> &HashMap<NodeID, NodeID> { + &self.idom + } } /* diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs deleted file mode 100644 index 559dcfc4..00000000 --- a/hercules_ir/src/dot.rs +++ /dev/null @@ -1,338 +0,0 @@ -use crate::*; - -use std::collections::HashMap; - -pub fn write_dot<W: std::fmt::Write>(module: &Module, w: &mut W) -> std::fmt::Result { - write!(w, "digraph \"Module\" {{\n")?; - write!(w, "compound=true\n")?; - for i in 0..module.functions.len() { - write_function(i, module, w)?; - } - write!(w, "}}\n")?; - Ok(()) -} - -fn write_function<W: std::fmt::Write>(i: usize, module: &Module, w: &mut W) -> std::fmt::Result { - write!(w, "subgraph {} {{\n", module.functions[i].name)?; - if module.functions[i].num_dynamic_constants > 0 { - write!( - w, - "label=\"{}<{}>\"\n", - module.functions[i].name, module.functions[i].num_dynamic_constants - )?; - } else { - write!(w, "label=\"{}\"\n", module.functions[i].name)?; - } - write!(w, "bgcolor=ivory4\n")?; - write!(w, "cluster=true\n")?; - let mut visited = HashMap::default(); - let function = &module.functions[i]; - for j in 0..function.nodes.len() { - visited = write_node(i, j, module, visited, w)?.1; - } - write!(w, "}}\n")?; - Ok(()) -} - -fn write_node<W: std::fmt::Write>( - i: usize, - j: usize, - module: &Module, - mut visited: HashMap<NodeID, String>, - w: &mut W, -) -> Result<(String, HashMap<NodeID, String>), std::fmt::Error> { - let id = NodeID::new(j); - if visited.contains_key(&id) { - Ok((visited.get(&id).unwrap().clone(), visited)) - } else { - let node = &module.functions[i].nodes[j]; - let name = format!("{}_{}_{}", node.lower_case_name(), i, j); - visited.insert(NodeID::new(j), name.clone()); - let visited = match node { - Node::Start => { - write!(w, "{} [xlabel={}, label=\"start\"];\n", name, j)?; - visited - } - Node::Region { preds } => { - write!(w, "{} [xlabel={}, label=\"region\"];\n", name, j)?; - for (idx, pred) in preds.iter().enumerate() { - let (pred_name, tmp_visited) = write_node(i, pred.idx(), module, visited, w)?; - visited = tmp_visited; - write!( - w, - "{} -> {} [label=\"pred {}\", style=\"dashed\"];\n", - pred_name, name, idx - )?; - } - visited - } - Node::If { control, cond } => { - write!(w, "{} [xlabel={}, label=\"if\"];\n", name, j)?; - let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; - let (cond_name, visited) = write_node(i, cond.idx(), module, visited, w)?; - write!( - w, - "{} -> {} [label=\"control\", style=\"dashed\"];\n", - control_name, name - )?; - write!(w, "{} -> {} [label=\"cond\"];\n", cond_name, name)?; - visited - } - Node::Fork { control, factor } => { - write!( - w, - "{} [xlabel={}, label=\"fork<{:?}>\"];\n", - name, - j, - module.dynamic_constants[factor.idx()] - )?; - let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; - write!( - w, - "{} -> {} [label=\"control\", style=\"dashed\"];\n", - control_name, name - )?; - visited - } - Node::Join { control } => { - write!(w, "{} [xlabel={}, label=\"join\"];\n", name, j)?; - let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; - write!( - w, - "{} -> {} [label=\"control\", style=\"dashed\"];\n", - control_name, name - )?; - visited - } - Node::Phi { control, data } => { - write!(w, "{} [xlabel={}, label=\"phi\"];\n", name, j)?; - let (control_name, mut visited) = write_node(i, control.idx(), module, visited, w)?; - write!( - w, - "{} -> {} [label=\"control\", style=\"dashed\"];\n", - control_name, name - )?; - for (idx, data) in data.iter().enumerate() { - let (data_name, tmp_visited) = write_node(i, data.idx(), module, visited, w)?; - visited = tmp_visited; - write!(w, "{} -> {} [label=\"data {}\"];\n", data_name, name, idx)?; - } - visited - } - Node::ThreadID { control } => { - write!(w, "{} [xlabel={}, label=\"thread_id\"];\n", name, j)?; - let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; - write!( - w, - "{} -> {} [label=\"control\", style=\"dashed\"];\n", - control_name, name - )?; - visited - } - Node::Collect { control, data } => { - let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; - let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; - write!(w, "{} [xlabel={}, label=\"collect\"];\n", name, j)?; - write!( - w, - "{} -> {} [label=\"control\", style=\"dashed\"];\n", - control_name, name - )?; - write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; - visited - } - Node::Return { control, data } => { - let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; - let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; - write!(w, "{} [xlabel={}, label=\"return\"];\n", name, j)?; - write!( - w, - "{} -> {} [label=\"control\", style=\"dashed\"];\n", - control_name, name - )?; - write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; - visited - } - Node::Parameter { index } => { - write!( - w, - "{} [xlabel={}, label=\"param #{}\"];\n", - name, - j, - index + 1 - )?; - write!( - w, - "start_{}_0 -> {} [label=\"start\", style=\"dashed\"];\n", - i, name - )?; - visited - } - Node::Constant { id } => { - write!( - w, - "{} [xlabel={}, label=\"{:?}\"];\n", - name, - j, - module.constants[id.idx()] - )?; - write!( - w, - "start_{}_0 -> {} [label=\"start\", style=\"dashed\"];\n", - i, name - )?; - visited - } - Node::DynamicConstant { id } => { - write!( - w, - "{} [xlabel={}, label=\"dynamic_constant({:?})\"];\n", - name, - j, - module.dynamic_constants[id.idx()] - )?; - write!( - w, - "start_{}_0 -> {} [label=\"start\", style=\"dashed\"];\n", - i, name - )?; - visited - } - Node::Unary { input, op } => { - write!( - w, - "{} [xlabel={}, label=\"{}\"];\n", - name, - j, - op.lower_case_name() - )?; - let (input_name, visited) = write_node(i, input.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"input\"];\n", input_name, name)?; - visited - } - Node::Binary { left, right, op } => { - write!( - w, - "{} [xlabel={}, label=\"{}\"];\n", - name, - j, - op.lower_case_name() - )?; - let (left_name, visited) = write_node(i, left.idx(), module, visited, w)?; - let (right_name, visited) = write_node(i, right.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"left\"];\n", left_name, name)?; - write!(w, "{} -> {} [label=\"right\"];\n", right_name, name)?; - visited - } - Node::Call { - function, - dynamic_constants, - args, - } => { - write!(w, "{} [xlabel={}, label=\"call<", name, j)?; - for (idx, id) in dynamic_constants.iter().enumerate() { - let dc = &module.dynamic_constants[id.idx()]; - if idx == 0 { - write!(w, "{:?}", dc)?; - } else { - write!(w, ", {:?}", dc)?; - } - } - write!(w, ">({})\"];\n", module.functions[function.idx()].name)?; - for (idx, arg) in args.iter().enumerate() { - let (arg_name, tmp_visited) = write_node(i, arg.idx(), module, visited, w)?; - visited = tmp_visited; - write!(w, "{} -> {} [label=\"arg {}\"];\n", arg_name, name, idx)?; - } - write!( - w, - "{} -> start_{}_0 [label=\"call\", lhead={}];\n", - name, - function.idx(), - module.functions[function.idx()].name - )?; - visited - } - Node::ReadProd { prod, index } => { - write!( - w, - "{} [xlabel={}, label=\"read_prod({})\"];\n", - name, j, index - )?; - let (prod_name, visited) = write_node(i, prod.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"prod\"];\n", prod_name, name)?; - visited - } - Node::WriteProd { prod, data, index } => { - write!( - w, - "{} [xlabel={}, label=\"write_prod({})\"];\n", - name, j, index - )?; - let (prod_name, visited) = write_node(i, prod.idx(), module, visited, w)?; - let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"prod\"];\n", prod_name, name)?; - write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; - visited - } - Node::ReadArray { array, index } => { - write!(w, "{} [xlabel={}, label=\"read_array\"];\n", name, j)?; - let (array_name, visited) = write_node(i, array.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"array\"];\n", array_name, name)?; - let (index_name, visited) = write_node(i, index.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"index\"];\n", index_name, name)?; - visited - } - Node::WriteArray { array, data, index } => { - write!(w, "{} [xlabel={}, label=\"write_array\"];\n", name, j)?; - let (array_name, visited) = write_node(i, array.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"array\"];\n", array_name, name)?; - let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; - let (index_name, visited) = write_node(i, index.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"index\"];\n", index_name, name)?; - visited - } - Node::Match { control, sum } => { - write!(w, "{} [xlabel={}, label=\"match\"];\n", name, j)?; - let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; - write!( - w, - "{} -> {} [label=\"control\", style=\"dashed\"];\n", - control_name, name - )?; - let (sum_name, visited) = write_node(i, sum.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"sum\"];\n", sum_name, name)?; - visited - } - Node::BuildSum { - data, - sum_ty, - variant, - } => { - write!( - w, - "{} [xlabel={}, label=\"build_sum({:?}, {})\"];\n", - name, - j, - module.types[sum_ty.idx()], - variant - )?; - let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; - visited - } - Node::ExtractSum { data, variant } => { - write!( - w, - "{} [xlabel={}, label=\"extract_sum({})\"];\n", - name, j, variant - )?; - let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; - write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; - visited - } - }; - Ok((visited.get(&id).unwrap().clone(), visited)) - } -} diff --git a/hercules_ir/src/gvn.rs b/hercules_ir/src/gvn.rs new file mode 100644 index 00000000..c8f77244 --- /dev/null +++ b/hercules_ir/src/gvn.rs @@ -0,0 +1,93 @@ +use std::collections::HashMap; + +use crate::*; + +/* + * Top level function to run global value numbering. In the sea of nodes, GVN is + * fairly simple compared to in a normal CFG. Needs access to constants for + * identity function simplification. + */ +pub fn gvn(function: &mut Function, constants: &Vec<Constant>, def_use: &ImmutableDefUseMap) { + // Step 1: create worklist (starts as all nodes) and value number hashmap. + let mut worklist: Vec<_> = (0..function.nodes.len()).rev().map(NodeID::new).collect(); + let mut value_numbers: HashMap<Node, NodeID> = HashMap::new(); + + // Step 2: do worklist. + while let Some(work) = worklist.pop() { + // First, iteratively simplify the work node by unwrapping identity + // functions. + let value = crawl_identities(work, function, constants); + + // Next, check if there is a value number for this simplified value yet. + if let Some(leader) = value_numbers.get(&function.nodes[value.idx()]) { + // Also need to check that leader is not the current work ID. The + // leader should never remove itself. + if *leader != work { + // If there is a value number (a previously found Node ID) for the + // current node, then replace all users' uses of the current work + // node ID with the value number node ID. + for user in def_use.get_users(work) { + for u in get_uses_mut(&mut function.nodes[user.idx()]).as_mut() { + if **u == work { + **u = *leader; + } + } + + // Since we modified user, it may now be congruent to other + // nodes, so add it back into the worklist. + worklist.push(*user); + } + + // Since all ex-users now use the value number node ID, delete this + // node. + function.nodes[work.idx()] = Node::Start; + + // Explicitly continue to branch away from adding current work + // as leader into value_numbers. + continue; + } + } + // If not found, insert the simplified node with its node ID as the + // value number. + value_numbers.insert(function.nodes[value.idx()].clone(), value); + } +} + +/* + * Helper function for unwrapping identity functions. + */ +fn crawl_identities(mut work: NodeID, function: &Function, constants: &Vec<Constant>) -> NodeID { + loop { + // TODO: replace with API for saner pattern matching on IR. Also, + // actually add the rest of the identity functions. + if let Node::Binary { + left, + right, + op: BinaryOperator::Add, + } = function.nodes[work.idx()] + { + if let Node::Constant { id } = function.nodes[left.idx()] { + if constants[id.idx()].is_zero() { + work = right; + continue; + } + } + } + + if let Node::Binary { + left, + right, + op: BinaryOperator::Add, + } = function.nodes[work.idx()] + { + if let Node::Constant { id } = function.nodes[right.idx()] { + if constants[id.idx()].is_zero() { + work = left; + continue; + } + } + } + + return work; + } +} diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index dee8fb9e..e237240f 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1,12 +1,16 @@ extern crate ordered_float; +use std::fmt::Write; + +use crate::*; + /* * A module is a list of functions. Functions contain types, constants, and * dynamic constants, which are interned at the module level. Thus, if one * wants to run an intraprocedural pass in parallel, it is advised to first * destruct the module, then reconstruct it once finished. */ -#[derive(Debug, Clone)] +#[derive(Debug, Default, Clone)] pub struct Module { pub functions: Vec<Function>, pub types: Vec<Type>, @@ -14,6 +18,159 @@ pub struct Module { pub dynamic_constants: Vec<DynamicConstant>, } +impl Module { + /* + * There are many transformations that need to iterate over the functions + * in a module, while having mutable access to the interned types, + * constants, and dynamic constants in a module. This code is really ugly, + * so write it once. + */ + pub fn map<F>(self, mut func: F) -> Self + where + F: FnMut( + (Function, FunctionID), + (Vec<Type>, Vec<Constant>, Vec<DynamicConstant>), + ) -> (Function, (Vec<Type>, Vec<Constant>, Vec<DynamicConstant>)), + { + let Module { + functions, + types, + constants, + dynamic_constants, + } = self; + let mut stuff = (types, constants, dynamic_constants); + let functions = functions + .into_iter() + .enumerate() + .map(|(idx, function)| { + let mut new_stuff = (vec![], vec![], vec![]); + std::mem::swap(&mut stuff, &mut new_stuff); + let (function, mut new_stuff) = func((function, FunctionID::new(idx)), new_stuff); + std::mem::swap(&mut stuff, &mut new_stuff); + function + }) + .collect(); + let (types, constants, dynamic_constants) = stuff; + Module { + functions, + types, + constants, + dynamic_constants, + } + } + + /* + * Printing out types, constants, and dynamic constants fully requires a + * reference to the module, since references to other types, constants, and + * dynamic constants are done using IDs. + */ + pub fn write_type<W: Write>(&self, ty_id: TypeID, w: &mut W) -> std::fmt::Result { + match &self.types[ty_id.idx()] { + Type::Control(_) => write!(w, "Control"), + Type::Boolean => write!(w, "Boolean"), + Type::Integer8 => write!(w, "Integer8"), + Type::Integer16 => write!(w, "Integer16"), + Type::Integer32 => write!(w, "Integer32"), + Type::Integer64 => write!(w, "Integer64"), + Type::UnsignedInteger8 => write!(w, "UnsignedInteger8"), + Type::UnsignedInteger16 => write!(w, "UnsignedInteger16"), + Type::UnsignedInteger32 => write!(w, "UnsignedInteger32"), + Type::UnsignedInteger64 => write!(w, "UnsignedInteger64"), + Type::Float32 => write!(w, "Float32"), + Type::Float64 => write!(w, "Float64"), + Type::Product(fields) => { + write!(w, "Product(")?; + for idx in 0..fields.len() { + let field_ty_id = fields[idx]; + self.write_type(field_ty_id, w)?; + if idx + 1 < fields.len() { + write!(w, ", ")?; + } + } + write!(w, ")") + } + Type::Summation(fields) => { + write!(w, "Summation(")?; + for idx in 0..fields.len() { + let field_ty_id = fields[idx]; + self.write_type(field_ty_id, w)?; + if idx + 1 < fields.len() { + write!(w, ", ")?; + } + } + write!(w, ")") + } + Type::Array(elem, length) => { + write!(w, "Array(")?; + self.write_type(*elem, w)?; + write!(w, ", ")?; + self.write_dynamic_constant(*length, w)?; + write!(w, ")") + } + }?; + + Ok(()) + } + + pub fn write_constant<W: Write>(&self, cons_id: ConstantID, w: &mut W) -> std::fmt::Result { + match &self.constants[cons_id.idx()] { + Constant::Boolean(val) => write!(w, "{}", val), + Constant::Integer8(val) => write!(w, "{}", val), + Constant::Integer16(val) => write!(w, "{}", val), + Constant::Integer32(val) => write!(w, "{}", val), + Constant::Integer64(val) => write!(w, "{}", val), + Constant::UnsignedInteger8(val) => write!(w, "{}", val), + Constant::UnsignedInteger16(val) => write!(w, "{}", val), + Constant::UnsignedInteger32(val) => write!(w, "{}", val), + Constant::UnsignedInteger64(val) => write!(w, "{}", val), + Constant::Float32(val) => write!(w, "{}", val), + Constant::Float64(val) => write!(w, "{}", val), + Constant::Product(_, fields) => { + write!(w, "(")?; + for idx in 0..fields.len() { + let field_cons_id = fields[idx]; + self.write_constant(field_cons_id, w)?; + if idx + 1 < fields.len() { + write!(w, ", ")?; + } + } + write!(w, ")") + } + Constant::Summation(_, variant, field) => { + write!(w, "%{}(", variant)?; + self.write_constant(*field, w)?; + write!(w, ")") + } + Constant::Array(_, elems) => { + write!(w, "[")?; + for idx in 0..elems.len() { + let elem_cons_id = elems[idx]; + self.write_constant(elem_cons_id, w)?; + if idx + 1 < elems.len() { + write!(w, ", ")?; + } + } + write!(w, "]") + } + }?; + + Ok(()) + } + + pub fn write_dynamic_constant<W: Write>( + &self, + dc_id: DynamicConstantID, + w: &mut W, + ) -> std::fmt::Result { + match &self.dynamic_constants[dc_id.idx()] { + DynamicConstant::Constant(cons) => write!(w, "{}", cons), + DynamicConstant::Parameter(param) => write!(w, "#{}", param), + }?; + + Ok(()) + } +} + /* * A function has a name, a list of types for its parameters, a single return * type, a list of nodes in its sea-of-nodes style IR, and a number of dynamic @@ -31,6 +188,61 @@ pub struct Function { pub num_dynamic_constants: u32, } +impl Function { + /* + * Many transformations will delete nodes. There isn't strictly a gravestone + * node value, so use the start node as a gravestone value (for IDs other + * than 0). This function cleans up gravestoned nodes. + */ + pub fn delete_gravestones(&mut self) { + // Step 1: figure out which nodes are gravestones. + let mut gravestones = (0..self.nodes.len()) + .filter(|x| *x != 0 && self.nodes[*x].is_start()) + .map(|x| NodeID::new(x)); + + // Step 2: figure out the mapping between old node IDs and new node IDs. + let mut node_mapping = Vec::with_capacity(self.nodes.len()); + let mut next_gravestone = gravestones.next(); + let mut num_gravestones_passed = 0; + for idx in 0..self.nodes.len() { + if Some(NodeID::new(idx)) == next_gravestone { + node_mapping.push(NodeID::new(0)); + num_gravestones_passed += 1; + next_gravestone = gravestones.next(); + } else { + node_mapping.push(NodeID::new(idx - num_gravestones_passed)); + } + } + + // Step 3: create new nodes vector. Along the way, update all uses. + let mut old_nodes = vec![]; + std::mem::swap(&mut old_nodes, &mut self.nodes); + + let mut new_nodes = Vec::with_capacity(old_nodes.len() - num_gravestones_passed); + for (idx, mut node) in old_nodes.into_iter().enumerate() { + // Skip node if it's dead. + if idx != 0 && node.is_start() { + continue; + } + + // Update uses. + for u in get_uses_mut(&mut node).as_mut() { + let old_id = **u; + let new_id = node_mapping[old_id.idx()]; + if new_id == NodeID::new(0) && old_id != NodeID::new(0) { + panic!("While deleting gravestones, came across a use of a gravestoned node."); + } + **u = new_id; + } + + // Add to new_nodes. + new_nodes.push(node); + } + + std::mem::swap(&mut new_nodes, &mut self.nodes); + } +} + /* * Hercules IR has a fairly standard type system, with the exception of the * control type. Hercules IR is based off of the sea-of-nodes IR, the main @@ -143,6 +355,43 @@ pub enum Constant { Array(TypeID, Box<[ConstantID]>), } +impl Constant { + /* + * Useful for GVN. + */ + pub fn is_zero(&self) -> bool { + match self { + Constant::Integer8(0) => true, + Constant::Integer16(0) => true, + Constant::Integer32(0) => true, + Constant::Integer64(0) => true, + Constant::UnsignedInteger8(0) => true, + Constant::UnsignedInteger16(0) => true, + Constant::UnsignedInteger32(0) => true, + Constant::UnsignedInteger64(0) => true, + Constant::Float32(ord) => *ord == ordered_float::OrderedFloat::<f32>(0.0), + Constant::Float64(ord) => *ord == ordered_float::OrderedFloat::<f64>(0.0), + _ => false, + } + } + + pub fn is_one(&self) -> bool { + match self { + Constant::Integer8(1) => true, + Constant::Integer16(1) => true, + Constant::Integer32(1) => true, + Constant::Integer64(1) => true, + Constant::UnsignedInteger8(1) => true, + Constant::UnsignedInteger16(1) => true, + Constant::UnsignedInteger32(1) => true, + Constant::UnsignedInteger64(1) => true, + Constant::Float32(ord) => *ord == ordered_float::OrderedFloat::<f32>(1.0), + Constant::Float64(ord) => *ord == ordered_float::OrderedFloat::<f64>(1.0), + _ => false, + } + } +} + /* * Dynamic constants are unsigned 64-bit integers passed to a Hercules function * at runtime using the Hercules conductor API. They cannot be the result of @@ -170,7 +419,7 @@ pub enum DynamicConstant { * side effects, so call nodes don't take as input or output control tokens. * There is also no global memory - use arrays. */ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Node { Start, Region { diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index 094873f2..943d0493 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -1,17 +1,23 @@ +pub mod build; +pub mod ccp; pub mod dataflow; +pub mod dce; pub mod def_use; pub mod dom; -pub mod dot; +pub mod gvn; pub mod ir; pub mod parse; pub mod subgraph; pub mod typecheck; pub mod verify; +pub use crate::build::*; +pub use crate::ccp::*; pub use crate::dataflow::*; +pub use crate::dce::*; pub use crate::def_use::*; pub use crate::dom::*; -pub use crate::dot::*; +pub use crate::gvn::*; pub use crate::ir::*; pub use crate::parse::*; pub use crate::subgraph::*; diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 8b666be0..2373332e 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -6,6 +6,13 @@ use std::str::FromStr; use crate::*; +/* + * TODO: This parsing code was written before the generic build API was created. + * As a result, this parsing code duplicates much of the interning logic the + * build API is meant to abstract away. The parsing code should be re-written to + * use the new build API. + */ + /* * Top level parse function. */ diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 74f82fdf..09f4cbb7 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -175,6 +175,21 @@ fn typeflow( } }; + // We need to make sure dynamic constant parameters reference valid dynamic + // constant parameters of the current function. This involves traversing a + // given dynamic constant expression to determine that all referenced + // parameter dynamic constants are valid. + fn check_dynamic_constants( + root: DynamicConstantID, + dynamic_constants: &Vec<DynamicConstant>, + num_parameters: u32, + ) -> bool { + match dynamic_constants[root.idx()] { + DynamicConstant::Parameter(idx) => idx < num_parameters as usize, + _ => true, + } + } + // Each node requires different type logic. This unfortunately results in a // large match statement. Oh well. Each arm returns the lattice value for // the "out" type of the node. @@ -250,14 +265,16 @@ fn typeflow( inputs[0].clone() } - Node::Fork { - control: _, - factor: _, - } => { + Node::Fork { control: _, factor } => { if inputs.len() != 1 { return Error(String::from("Fork node must have exactly one input.")); } + if !check_dynamic_constants(*factor, dynamic_constants, function.num_dynamic_constants) + { + return Error(String::from("Referenced parameter dynamic constant is not a valid dynamic constant parameter for the current function.")); + } + if let Concrete(id) = inputs[0] { if let Type::Control(factors) = &types[id.idx()] { // Fork adds a new factor to the thread spawn factor list. @@ -559,11 +576,15 @@ fn typeflow( } } } - Node::DynamicConstant { id: _ } => { + Node::DynamicConstant { id } => { if inputs.len() != 1 { return Error(String::from("DynamicConstant node must have one input.")); } + if !check_dynamic_constants(*id, dynamic_constants, function.num_dynamic_constants) { + return Error(String::from("Referenced parameter dynamic constant is not a valid dynamic constant parameter for the current function.")); + } + // Dynamic constants are always u64. Concrete(get_type_id( Type::UnsignedInteger64, @@ -586,6 +607,11 @@ fn typeflow( } } UnaryOperator::Neg => { + if types[id.idx()].is_unsigned() { + return Error(String::from( + "Neg unary node input cannot have unsigned type.", + )); + } if !types[id.idx()].is_arithmetic() { return Error(String::from( "Neg unary node input cannot have non-arithmetic type.", @@ -654,11 +680,15 @@ fn typeflow( // Equality operators potentially change the input type. return Concrete(get_type_id(Type::Boolean, types, reverse_type_map)); } - BinaryOperator::Or - | BinaryOperator::And - | BinaryOperator::Xor - | BinaryOperator::LSh - | BinaryOperator::RSh => { + BinaryOperator::Or | BinaryOperator::And | BinaryOperator::Xor => { + if !types[id.idx()].is_fixed() && !types[id.idx()].is_bool() { + return Error(format!( + "{:?} binary node input cannot have non-fixed type and non-boolean type.", + op, + )); + } + } + BinaryOperator::LSh | BinaryOperator::RSh => { if !types[id.idx()].is_fixed() { return Error(format!( "{:?} binary node input cannot have non-fixed type.", @@ -696,6 +726,16 @@ fn typeflow( )); } + for dc_id in dc_args.iter() { + if !check_dynamic_constants( + *dc_id, + dynamic_constants, + function.num_dynamic_constants, + ) { + return Error(String::from("Referenced parameter dynamic constant is not a valid dynamic constant parameter for the current function.")); + } + } + // Check argument types. for (input, param_ty) in zip(inputs.iter(), callee.param_types.iter()) { if let Concrete(input_id) = input { diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs index 70399669..4c9db60b 100644 --- a/hercules_ir/src/verify.rs +++ b/hercules_ir/src/verify.rs @@ -65,7 +65,7 @@ pub fn verify( let fork_join_map = &fork_join_maps[idx]; // Calculate control output dependencies here, since they are not - // returned by verify. Pretty much only useful for verification. + // returned by verify. let control_output_dependencies = forward_dataflow(function, reverse_postorder, |inputs, id| { control_output_flow(inputs, id, function) @@ -352,16 +352,12 @@ fn verify_dominance_relationships( // If this node is a phi node, we need to handle adding dominance checks // completely differently. if let Node::Phi { control, data } = &function.nodes[idx] { - // Get the control predecessors of a region. This weird lambda trick - // is to get around needing to add another nesting level just to - // unpack the predecessor node. - let region_preds = (|| { - if let Node::Region { preds } = &function.nodes[control.idx()] { - preds - } else { - panic!("A phi's control input must be a region node.") - } - })(); + // Get the control predecessors of a region. + let region_preds = if let Node::Region { preds } = &function.nodes[control.idx()] { + preds + } else { + panic!("A phi's control input must be a region node.") + }; // The inputs to a phi node don't need to dominate the phi node. // However, the data inputs to a phi node do need to hold proper @@ -396,7 +392,7 @@ fn verify_dominance_relationships( // If the node to be added to the to_check vector isn't even in the // dominator tree, don't bother. It doesn't need to be checked for // dominance relations. - if !dom.is_non_root(this_id) { + if !dom.contains_conventional(this_id) { continue; } @@ -423,7 +419,7 @@ fn verify_dominance_relationships( // Verify that uses of phis / collect nodes are dominated // by the corresponding region / join nodes, respectively. Node::Phi { control, data: _ } | Node::Collect { control, data: _ } => { - if dom.is_non_root(this_id) && !dom.does_dom(control, this_id) { + if dom.contains_conventional(this_id) && !dom.does_dom(control, this_id) { Err(format!( "{} node (ID {}) doesn't dominate its use (ID {}).", function.nodes[pred_idx].upper_case_name(), @@ -435,7 +431,7 @@ fn verify_dominance_relationships( // Verify that uses of thread ID nodes are dominated by the // corresponding fork nodes. Node::ThreadID { control } => { - if dom.is_non_root(this_id) && !dom.does_dom(control, this_id) { + if dom.contains_conventional(this_id) && !dom.does_dom(control, this_id) { Err(format!( "ThreadID node (ID {}) doesn't dominate its use (ID {}).", pred_idx, @@ -449,7 +445,7 @@ fn verify_dominance_relationships( // flows through the collect node out of the fork-join, // because after the collect, the thread ID is no longer // considered an immediate control output use. - if postdom.is_non_root(this_id) + if postdom.contains_conventional(this_id) && !postdom.does_dom(*fork_join_map.get(&control).unwrap(), this_id) { Err(format!("ThreadID node's (ID {}) fork's join doesn't postdominate its use (ID {}).", pred_idx, this_id.idx()))?; diff --git a/hercules_tools/Cargo.toml b/hercules_tools/Cargo.toml index 260db105..458de0e0 100644 --- a/hercules_tools/Cargo.toml +++ b/hercules_tools/Cargo.toml @@ -10,3 +10,4 @@ path = "src/hercules_dot/main.rs" [dependencies] clap = { version = "*", features = ["derive"] } hercules_ir = { path = "../hercules_ir" } +rand = "*" diff --git a/hercules_tools/src/hercules_dot/dot.rs b/hercules_tools/src/hercules_dot/dot.rs new file mode 100644 index 00000000..6f41f85b --- /dev/null +++ b/hercules_tools/src/hercules_dot/dot.rs @@ -0,0 +1,227 @@ +extern crate hercules_ir; + +use std::collections::HashMap; +use std::fmt::Write; + +use self::hercules_ir::*; + +/* + * Top level function to write a module out as a dot graph. Takes references to + * many analysis results to generate a more informative dot graph. + */ +pub fn write_dot<W: Write>( + module: &ir::Module, + typing: &ModuleTyping, + doms: &Vec<DomTree>, + fork_join_maps: &Vec<HashMap<NodeID, NodeID>>, + w: &mut W, +) -> std::fmt::Result { + write_digraph_header(w)?; + + for function_id in (0..module.functions.len()).map(FunctionID::new) { + let function = &module.functions[function_id.idx()]; + write_subgraph_header(function_id, module, w)?; + + // Step 1: draw IR graph itself. This includes all IR nodes and all edges + // between IR nodes. + for node_id in (0..function.nodes.len()).map(NodeID::new) { + let node = &function.nodes[node_id.idx()]; + let dst_ty = &module.types[typing[function_id.idx()][node_id.idx()].idx()]; + let dst_strictly_control = node.is_strictly_control(); + let dst_control = dst_ty.is_control() || dst_strictly_control; + + // Control nodes are dark red, data nodes are dark blue. + let color = if dst_control { "darkred" } else { "darkblue" }; + + write_node(node_id, function_id, color, module, w)?; + + for u in def_use::get_uses(&node).as_ref() { + let src_ty = &module.types[typing[function_id.idx()][u.idx()].idx()]; + let src_strictly_control = function.nodes[u.idx()].is_strictly_control(); + let src_control = src_ty.is_control() || src_strictly_control; + + // An edge between control nodes is dashed. An edge between data + // nodes is filled. An edge between a control node and a data + // node is dotted. + let style = if dst_control && src_control { + "dashed" + } else if !dst_control && !src_control { + "" + } else { + "dotted" + }; + + write_edge( + node_id, + function_id, + *u, + function_id, + "black", + style, + module, + w, + )?; + } + } + + // Step 2: draw dominance edges in dark green. Don't draw post dominance + // edges because then xdot lays out the graph strangely. + let dom = &doms[function_id.idx()]; + for (child_id, parent_id) in dom.get_underlying_map() { + write_edge( + *child_id, + function_id, + *parent_id, + function_id, + "darkgreen", + "dotted", + &module, + w, + )?; + } + + // Step 3: draw fork join edges in dark magenta. + let fork_join_map = &fork_join_maps[function_id.idx()]; + for (fork_id, join_id) in fork_join_map { + write_edge( + *join_id, + function_id, + *fork_id, + function_id, + "darkmagenta", + "dotted", + &module, + w, + )?; + } + + write_graph_footer(w)?; + } + + write_graph_footer(w)?; + Ok(()) +} + +fn write_digraph_header<W: Write>(w: &mut W) -> std::fmt::Result { + write!(w, "digraph \"Module\" {{\n")?; + write!(w, "compound=true\n")?; + Ok(()) +} + +fn write_subgraph_header<W: Write>( + function_id: FunctionID, + module: &Module, + w: &mut W, +) -> std::fmt::Result { + let function = &module.functions[function_id.idx()]; + write!(w, "subgraph {} {{\n", function.name)?; + + // Write number of dynamic constants in brackets. + if function.num_dynamic_constants > 0 { + write!( + w, + "label=\"{}<{}>\"\n", + function.name, function.num_dynamic_constants + )?; + } else { + write!(w, "label=\"{}\"\n", function.name)?; + } + write!(w, "bgcolor=ivory4\n")?; + write!(w, "cluster=true\n")?; + Ok(()) +} + +fn write_graph_footer<W: Write>(w: &mut W) -> std::fmt::Result { + write!(w, "}}\n")?; + Ok(()) +} + +fn write_node<W: Write>( + node_id: NodeID, + function_id: FunctionID, + color: &str, + module: &Module, + w: &mut W, +) -> std::fmt::Result { + let node = &module.functions[function_id.idx()].nodes[node_id.idx()]; + + // Some nodes have additional information that need to get written after the + // node label. + let mut suffix = String::new(); + match node { + Node::Fork { control: _, factor } => module.write_dynamic_constant(*factor, &mut suffix)?, + Node::Parameter { index } => write!(&mut suffix, "#{}", index)?, + Node::Constant { id } => module.write_constant(*id, &mut suffix)?, + Node::DynamicConstant { id } => module.write_dynamic_constant(*id, &mut suffix)?, + Node::Call { + function, + dynamic_constants, + args: _, + } => { + write!(&mut suffix, "{}", module.functions[function.idx()].name)?; + for dc_id in dynamic_constants.iter() { + write!(&mut suffix, ", ")?; + module.write_dynamic_constant(*dc_id, &mut suffix)?; + } + } + Node::ReadProd { prod: _, index } => write!(&mut suffix, "{}", index)?, + Node::WriteProd { + prod: _, + data: _, + index, + } => write!(&mut suffix, "{}", index)?, + Node::BuildSum { + data: _, + sum_ty: _, + variant, + } => write!(&mut suffix, "{}", variant)?, + Node::ExtractSum { data: _, variant } => write!(&mut suffix, "{}", variant)?, + _ => {} + }; + + // If this is a node with additional information, add that to the node + // label. + let label = if suffix.is_empty() { + node.lower_case_name().to_owned() + } else { + format!("{} ({})", node.lower_case_name(), suffix) + }; + write!( + w, + "{}_{}_{} [xlabel={}, label=\"{}\", color={}];\n", + node.lower_case_name(), + function_id.idx(), + node_id.idx(), + node_id.idx(), + label, + color + )?; + Ok(()) +} + +fn write_edge<W: Write>( + dst_node_id: NodeID, + dst_function_id: FunctionID, + src_node_id: NodeID, + src_function_id: FunctionID, + color: &str, + style: &str, + module: &Module, + w: &mut W, +) -> std::fmt::Result { + let dst_node = &module.functions[dst_function_id.idx()].nodes[dst_node_id.idx()]; + let src_node = &module.functions[src_function_id.idx()].nodes[src_node_id.idx()]; + write!( + w, + "{}_{}_{} -> {}_{}_{} [color={}, style=\"{}\"];\n", + src_node.lower_case_name(), + src_function_id.idx(), + src_node_id.idx(), + dst_node.lower_case_name(), + dst_function_id.idx(), + dst_node_id.idx(), + color, + style, + )?; + Ok(()) +} diff --git a/hercules_tools/src/hercules_dot/main.rs b/hercules_tools/src/hercules_dot/main.rs index 16fc20f2..bb3efe57 100644 --- a/hercules_tools/src/hercules_dot/main.rs +++ b/hercules_tools/src/hercules_dot/main.rs @@ -1,4 +1,5 @@ extern crate clap; +extern crate rand; use std::env::temp_dir; use std::fs::File; @@ -7,6 +8,11 @@ use std::process::Command; use clap::Parser; +use rand::Rng; + +pub mod dot; +use dot::*; + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -28,14 +34,41 @@ fn main() { .expect("PANIC: Unable to read input file contents."); let mut module = hercules_ir::parse::parse(&contents).expect("PANIC: Failed to parse Hercules IR file."); - let _ = hercules_ir::verify::verify(&mut module) - .expect("PANIC: Failed to verify Hercules IR module."); + let (def_uses, reverse_postorders, _typing, _doms, _postdoms, _fork_join_maps) = + hercules_ir::verify::verify(&mut module) + .expect("PANIC: Failed to verify Hercules IR module."); + + let mut module = module.map( + |(mut function, id), (types, mut constants, dynamic_constants)| { + hercules_ir::ccp::ccp( + &mut function, + &mut constants, + &def_uses[id.idx()], + &reverse_postorders[id.idx()], + ); + hercules_ir::dce::dce(&mut function); + function.delete_gravestones(); + + let def_use = hercules_ir::def_use::def_use(&function); + hercules_ir::gvn::gvn(&mut function, &constants, &def_use); + hercules_ir::dce::dce(&mut function); + function.delete_gravestones(); + + (function, (types, constants, dynamic_constants)) + }, + ); + let (_def_use, _reverse_postorders, typing, doms, _postdoms, fork_join_maps) = + hercules_ir::verify::verify(&mut module) + .expect("PANIC: Failed to verify Hercules IR module."); + if args.output.is_empty() { let mut tmp_path = temp_dir(); - tmp_path.push("hercules_dot.dot"); + let mut rng = rand::thread_rng(); + let num: u64 = rng.gen(); + tmp_path.push(format!("hercules_dot_{}.dot", num)); let mut file = File::create(tmp_path.clone()).expect("PANIC: Unable to open output file."); let mut contents = String::new(); - hercules_ir::dot::write_dot(&module, &mut contents) + write_dot(&module, &typing, &doms, &fork_join_maps, &mut contents) .expect("PANIC: Unable to generate output file contents."); file.write_all(contents.as_bytes()) .expect("PANIC: Unable to write output file contents."); @@ -46,7 +79,7 @@ fn main() { } else { let mut file = File::create(args.output).expect("PANIC: Unable to open output file."); let mut contents = String::new(); - hercules_ir::dot::write_dot(&module, &mut contents) + write_dot(&module, &typing, &doms, &fork_join_maps, &mut contents) .expect("PANIC: Unable to generate output file contents."); file.write_all(contents.as_bytes()) .expect("PANIC: Unable to write output file contents."); diff --git a/samples/ccp_example.hir b/samples/ccp_example.hir new file mode 100644 index 00000000..618a7573 --- /dev/null +++ b/samples/ccp_example.hir @@ -0,0 +1,19 @@ +fn tricky(x: i32) -> i32 + one = constant(i32, 1) + two = constant(i32, 2) + loop = region(start, if2_true) + idx = phi(loop, x, idx_dec) + val = phi(loop, one, later_val) + b = ne(one, val) + if1 = if(loop, b) + if1_false = read_prod(if1, 0) + if1_true = read_prod(if1, 1) + middle = region(if1_false, if1_true) + inter_val = sub(two, val) + later_val = phi(middle, inter_val, two) + idx_dec = sub(idx, one) + cond = gte(idx_dec, one) + if2 = if(middle, cond) + if2_false = read_prod(if2, 0) + if2_true = read_prod(if2, 1) + r = return(if2_false, later_val) \ No newline at end of file diff --git a/samples/gvn_example.hir b/samples/gvn_example.hir new file mode 100644 index 00000000..b5e3c8aa --- /dev/null +++ b/samples/gvn_example.hir @@ -0,0 +1,8 @@ +fn tricky(x: i32, y: i32) -> i32 + zero = constant(i32, 0) + xx = add(x, zero) + a = add(xx, y) + b = add(x, y) + bb = add(zero, b) + c = add(a, bb) + r = return(start, c) diff --git a/samples/invalid/bad_phi2.hir b/samples/invalid/bad_phi2.hir new file mode 100644 index 00000000..c03628cb --- /dev/null +++ b/samples/invalid/bad_phi2.hir @@ -0,0 +1,18 @@ +fn tricky(x: i32) -> i32 + one = constant(i32, 1) + two = constant(i32, 2) + loop = region(start, if2_true) + idx = phi(loop, x, idx_dec) + val = phi(loop, later_val, one) + b = ne(one, val) + if1 = if(loop, b) + if1_false = read_prod(if1, 0) + if1_true = read_prod(if1, 1) + middle = region(if1_false, if1_true) + later_val = phi(middle, val, two) + idx_dec = sub(idx, one) + cond = gte(idx_dec, one) + if2 = if(middle, cond) + if2_false = read_prod(if2, 0) + if2_true = read_prod(if2, 1) + r = return(if2_false, later_val) -- GitLab