diff --git a/DESIGN.md b/DESIGN.md index cd51e630fe4316741a31778e104c747529a6f294..c34c3023305a1625e524d65f5e5369a12c23d1be 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -8,6 +8,8 @@ Hercules' is a compiler targeting heterogenous devices. The key goals of Hercule - Design an intermediate representation that allows for fine-grained control of what code is executed on what device in a system. - Develop a runtime system capable of dynamically scheduling generated code fragments on a heterogenous machine. +The following sections contain information on how Hercules is designed to meet these goals. + ## Front-end Language Design TODO: @aaronjc4 @@ -24,15 +26,21 @@ The Hercules' compiler is split into the following components: The IR of the Hercules compiler is similar to the sea of nodes IR presented in "A Simple Graph-Based Intermediate Representation", with a few differences. -- There are dynamic constants, which are constants provided dynamically to the runtime system - these can be used to specify array type sizes, unlike normal runtime values. +- There are dynamic constants, which are constants provided dynamically to the conductor (this is the runtime system, [see the section describing the conductor](#the-conductor)) - these can be used to specify array type sizes, unlike normal runtime values. - There is no single global store. The closest analog are individual values with an array type, which support dynamically indexed read and write operations. - There is no I/O, or other side effects. - There is no recursion. -- The implementation of Hercules IR does not follow the original object oriented design. +- The implementation of Hercules IR does not follow the original object oriented design of sea-of-nodes. A key design consideration of Hercules IR is the absence of a concept of memory. A downside of this approach is that any language targetting Hecules IR must also be very restrictive regarding memory - in practice, this means tightly controlling or eliminating first-class references. The upside is that the compiler has complete freedom to layout data however it likes in memory when performing code generation. This includes deciding which data resides in which address spaces, which is a necessary ability for a compiler striving to have fine-grained control over what operations are computed on what devices. -In addition to not having a generalized memory, Hercules IR has no functionality for calling functions with side-effects, or doing IO. In other words, Hercules is a pure IR (it's not functional, as functions aren't first class values). This may be changed in the future - we could support effectful programs by giving call operators a control input and output edge. However, at least for now, we need to work with the simplest IR possible, so the IR is pure. +In addition to not having a generalized memory, Hercules IR has no functionality for calling functions with side-effects, or doing IO. In other words, Hercules is a pure IR (it's not functional, as functions aren't first class values). This may be changed in the future - we could support effectful programs by giving call operators a control input and output edge. However, at least for now, we'd like to work with the simplest IR possible, so the IR is pure. + +The key idea behind the sea of nodes IR is that control flow and data flow are represented in the same graph. The entire program thus can be represented by one large flow graph. This has several nice properties, the primary of which being that instructions are unordered except by true dependencies. This alleviates most code motion concerns, and also makes peephole optimizations more practical. Additionally, loop invariant code is neither "inside" nor "outside" a loop in the sea of nodes. Thus, any optimizations benefitting from a particular assumption about the position of loop invariant code works without needing to do code motion. Deciding whether code lives inside a loop or not becomes a scheduling concern. + +We chose to use a sea of nodes based IR because we believe it will be easier to partition than a CFG + basic block style IR. A CFG + basic block IR is inherently two-level - there is the control flow level in the CFG, and the data flow in the basic blocks. Partitioning a function across these two levels is a challenging task. As shown by previous work (HPVM), introducing more graph levels into the IR makes partitioning harder, not easier. We want Hercules to have fine-grained control over which code executes where. This requires Hercules' compiler IR to have as few graph levels as reasonable. + +See [IR.md](IR.md) for a more specific description of Hercules IR. ### Optimizations @@ -42,7 +50,7 @@ TODO: @rarbore2 ### Partitioning -Partitioning is responsible for deciding which operations in the IR graph are executed on which devices. Additionally, operations are broken up into shards - every node in a shard executes on the same device, and the runtime system schedules execution at the shard level. Partitioning is conceptually very similar to instruction selection. Each shard can be thought of as a single instruction, and the device the shard is executed on can be thought of as the particular instruction being selected. In instruction selection, there is not only the choice of which instructions to use, but also how to partition the potentially many operations in the IR into a smaller number of target instructions. Similarly, partitioning Hercules IR must decide which operations are grouped together into the same shard, and for each shard, which device it should execute on. The set of operations each potential target device is capable of executing is crucial information when forming the shard boundaries, so this cannot be performed optimally as a sequential two step process. +Partitioning is responsible for deciding which operations in the IR graph are executed on which devices. Additionally, operations are broken up into shards - every node in a shard executes on the same device, and the runtime system schedules execution at the shard level. Partitioning is conceptually very similar to instruction selection. Each shard can be thought of as a single instruction, and the device the shard is executed on can be thought of as the particular instruction being selected. In instruction selection, there is not only the choice of which instructions to use, but also how to partition the potentially many operations in the IR into a smaller number of target instructions. Similarly, the Hercules IR partitioning process must decide which operations are grouped together into the same shard, and for each shard, which device it should execute on. The set of operations each potential target device is capable of executing is crucial information when forming the shard boundaries, so this cannot be performed optimally as a sequential two step process. TODO: @rarbore2 @@ -52,8 +60,8 @@ Hercules uses LLVM for generating CPU and GPU code. Memory is "introduced" into TODO: @rarbore2 -## Runtime System +## The Conductor -The runtime system is responsible for dynamically executing code generated by Hercules. It exposes a Rust API for executing Hercules code. It takes care of memory allocation, synchronization, and scheduling. +The conductor is responsible for dynamically executing code generated by Hercules. It exposes a Rust API for executing Hercules code. It takes care of memory allocation, synchronization, and scheduling. It is what is called the "runtime" in other systems - we chose a different name as there are events that happen distinctly as "conductor time" (such as providing dynamic constants), rather than at "runtime" (where the generated code is actually executed). TODO: @rarbore2 diff --git a/IR.md b/IR.md new file mode 100644 index 0000000000000000000000000000000000000000..57c6332075f9c8b366c6b794f8cd47849fb08836 --- /dev/null +++ b/IR.md @@ -0,0 +1,99 @@ +# Hercules IR + +Hercules IR is structured as following: +- One entire program lives in one "Module". +- Each module contains a set of functions, as well as interned types, constants, and dynamic constants. The most important element of a module is its resident functions. +- Each function consists of a name, a set of types for its parameters, a return type, a list of nodes, and the number of dynamic constants it takes as argument. Types are not needed for dynamic constants, since all dynamic constants have type u64. The most important element of a function is its node list. +- There are control and data types. The control type is parameterized by a list of thread replication factors. The primitive data types are boolean, signed integers, unsigned integers, and floating point numbers. The integer types can hold 8, 16, 32, or 64 bits. The floating point types can hold 32 or 64 bits. The compound types are product, summation, and arrays. A product type is a tuple, containing some number of children data types. A summation type is a union, containing exactly one of some number of children data types at runtime. An array is a dynamically indexable collection of elements, where each element is the same type. The size of the array is part of the type, and is represented with a dynamic constant. +- Dynamic constants are constants provided to the conductor when a Hercules IR program is started. Through this mechanism, Hercules IR can represent programs operating on a variable number of array elements, while forbidding runtime dynamic memory allocation (all dynamic memory allocation happens in the conductor). +- The nodes in a function are structured as a flow graph, which an explicit start node. Although control and data flow from definitions to uses, def-use edges are stored implicitly in the IR. Each node stores its predecessor nodes, so use-def edges are stored explicitly. To query the def-use edges in an IR graph, use the `def_use` function. + +Below, all of the nodes in Hercules IR are described. + +## Start + +The start node of the IR flow graph. This node is implicitly defined in the text format. It takes no inputs. Its output type is the empty control type (control with no thread replication factors). + +## Region + +Region nodes are the mechanism for merging multiple branches inside Hercules IR. A region node takes at least one input - each input must have a control type, and all of the inputs must have the same control type. The output type of the region node is the same control type as all of its inputs. The main purpose of a region node is to drive some number of [phi](#phi) nodes. + +## If + +The branch mechanism in Hercules IR. An if node takes two inputs - a control predecessor, and a condition. The control predecessor must have control type, and the condition must have boolean type. The output type is a product of two control types, which are the same as the control input's type. Every if node must be followed directly by two [read\_prod](#readprod) nodes, each of which reads differing elements of the if node's output product. This is the mechanism by which the output edges from the if node (and also the [match](#match) node) are labelled, even though nodes only explicitly store their input edges. + +## Fork + +Fork (and [join](#join)) nodes are the mechanism for representing data-parallelism inside Hercules IR. A fork node takes one input - a control predecessor. A fork node also stores a thread replication factor (TRF), represented as a dynamic constant. The output type of a fork node is a control type, which is the same as the type of the control predecessor, with the TRF pushed to the end of the control type's factor list. Conceptually, for every thread that comes in to a fork node, TRF threads come out. A fork node can drive any number of children [thread\_id](#threadid) nodes. Each fork must have a single corresponding [join](#join) node - the fork must dominate the join node, and the join node must post-dominate the fork node (in the control flow subgraph). + +## Join + +Join (and [fork](#fork)) nodes are the mechanism for synchronizing data-parallel threads inside Hercules IR. A join nodes takes one input - a control predecessor. The output type of a join node is a control type, which is the same as the type of the control predecessor, with the last factor in the control type's list removed. Conceptually, after all threads created by the corresponding fork reach the join, then and only then does the join output a single thread. A join node can drive any number of children [collect](#collect) nodes. Each join must have a single corresponding [fork](#fork) node - the join must post-dominate the fork node, and the fork node must dominate the join node (in the control flow subgraph). + +## Phi + +Phi nodes merge potentially many data sources into one data output, driven by a corresponding region node. Phi nodes in Hercules IR perform the same function as phi nodes in other SSA-based IRs. Phi nodes take at least one input - a control predecessor, and some number of data inputs. The control predecessor of a phi node must be a region node. The data inputs must all have the same type. The output of the phi node has that data type. In the sea of nodes execution model, a phi node can be thought of as "latching" when its corresponding region node is reached. The phi node will latch to output the value of the input corresponding to the input that control traversed to reach the region node. After latching, the phi node's output won't change until the region node is reached again. + +## ThreadID + +The thread\_id node provides the thread ID as a datum to children nodes after a [fork](#fork) has been performed. A thread\_id node takes one input - a control predecessor. The control predecessor must be a [fork](#fork) node. The output type is a 64-bit unsigned integer. The output thread IDs generated by a thread\_id node range from 0 to TRF - 1, inclusive, where TRF is the thread replication factor of the input [fork](#fork) node. + +## Collect + +The collect node collects data from multiple executing threads, and puts them all into an array. A collect node takes two inputs - a control predecessor, and a data input. The control predecessor must be a [join](#join) node. The data input must have a non-control type. The output type will be an array, where the element type will be the type of the data input. The extent of the array will be equal to the thread replication factor of the [fork](#fork) node corresponding to the input [join](#join) node. For each datum input, the thread ID corresponding to that datum will be the index the datum is inserted into the array. + +## Return + +The return node returns some data from the current function. A return node has two inputs - a control predecessor, and a data input. The control predecessor must have a control type with an empty factor list - just as only one thread starts the execution of a function, only one thread can return from a function. The data input must have the same type as the function's return type. No node should use a return node as input (technically, the output type of a return node is an empty product type). + +## Parameter + +The parameter node represents a parameter of the function. A parameter node takes no inputs. A parameter node stores the parameter index of the function it corresponds to. Its value at runtime is the index-th argument to the function. Its output type is the type of the index-th parameter of the function. + +## Constant + +The constant node represents a constant value. A constant node takes no inputs. A constant node stores the constant ID of the constant it corresponds to. Its value at runtime is the constant it references. Its output type is the type of the constant it references. + +## DynamicConstant + +The dynamic\_constant node represents a dynamic constant, used as a runtime value. A dynamic\_constant node takes no inputs. A dynamic\_constant node stores the dynamic constant ID of the dynamic constant it corresponds to. Its value at runtime is the value of the dynamic constant it references, which is calculated at conductor time. Its output type is a 64-bit unsigned integer. + +## Unary + +The unary node represents a basic unary operation. A unary node takes one input - a data input. The data input must have a non-control type. A unary node additionally stores which unary operation it performs. The output type of the unary node is the same as its input type. The acceptable input data type depends on the unary operation. + +## Binary + +The binary node represents a basic binary operation. A binary node takes two inputs - a left data input, and a right data input. The left and right data inputs must be the same non-control type. A binary node additionally stores the binary operation it performs. The output type of the binary node is the same as its input type. The acceptable input data type depends on the binary operation. + +## Call + +The call node passes its inputs to a function, and outputs the result of the function call. A call node takes some number of data inputs. A call node also stores a reference to the function it calls. The number and types of the data inputs must match the referenced function. A call node also stores references to dynamic constants it uses as inputs to the function. The number of dynamic constants references must match the number of dynamic constant inputs of the referenced function. The output type of a call node is the return type of the referenced function. A call node notably does not take as input or output a control type. This is because all operations in Hercules IR are pure, including arbitrary function calls. Thus, the only things affecting a function call are the data inputs, and (conceptually) the function may be called an arbitrary amount of times. + +## ReadProd + +The read\_prod node reads an element from an product typed value. A read\_prod node takes one data input. A read\_prod node also stores the index into the product it reads. The type of the data input must be a product type. The index must be a valid index into the product type. The output type of a read\_prod node is the type of the index-th element in the product (0-indexed). + +## WriteProd + +The write\_prod node modifies an input product with an input datum, and outputs the new product. A write\_prod node takes two inputs - one product input, and one data input. A write\_prod node also stores the index into the product it writes. The type of the product input must be a product type. The type of the data input must be the same as the index-th element in the product (0-indexed). The output type of a write\_prod node is the same as the product input type. + +## ReadArray + +The read\_array node reads an element from an array typed value. A read\_array node takes two inputs - one array input, and one index input. The type of the array input must be an array type. The type of the index input must be an integer type. The output type of a read\_array node is the element type of the array input's array type. At runtime, if an out-of-bounds array access occurs, the conductor will eventually notify the host. + +## WriteArray + +The write\_array node modifies an input array with an input datum. A write\_array node takes three inputs - one array input, one data input, and one index input. The type of the array input must be an array type. The type of the data input must be the same as the element type of the array input's array type. The type of the index input must be an integer type. The output type of a write\_array node is the same as the array input's array type. At runtime, if an out-of-bounds array access occurs, the conductor will eventually notify the host. + +## Match + +The match node branches based on the variant of a sum typed value. A match node takes two inputs - a control predecessor, and a sum input. The control predecessor must have control type, and the sum input must have a sum type. The output type is a product of N control types, where N is the number of possible variants in the sum input's sum type. The control types in the product are the same as the control input's type. Every match node must be followed directly by N [read\_prod](#readprod) nodes, each of which reads differing elements of the match node's output product. This is the mechanism by which the output edges from the match node (and also the [if](#if) node) are labelled, even though nodes only explicitly store their input edges. + +## BuildSum + +The build\_sum node creates a sum typed value from a datum. A build\_sum node takes one input - a data input. A build\_sum node additionally stores the sum type it builds, as well as which variant of the aforementioned sum type it builds. The stored variant must be a valid variant inside the stored sum type. The type of the data input must match the type of the variant of the sum type. The output type of a build\_sum node is the aforementioned sum type. + +## ExtractSum + +The extract\_sum node extracts the concrete value inside a sum value, given a particular variant to extract. An extract\_sum node takes one input - a data input. The data input must have a sum type. An extract\_sum node also stored the variant it extracts. The stored variant must be a valid variant of the data input's sum type. The output type of an extract\_sum node is the type of the specified variant of the data input's sum type. At runtime, if the input sum value holds the stored variant, the output of an extract\_sum node is the value inside that variant in the sum value. If the input sum value holds a different variant, the output of an extract\_sum node is defined as the bit-pattern of all zeros for the output type of the extract\_sum node. diff --git a/hercules_ir/src/dataflow.rs b/hercules_ir/src/dataflow.rs index 9dfa68670dad8460a2dc8b2dfe1508ee28d66277..69a00b65118aa357df93cfc833e8f738f28df500 100644 --- a/hercules_ir/src/dataflow.rs +++ b/hercules_ir/src/dataflow.rs @@ -31,7 +31,7 @@ pub fn forward_dataflow<L, F>( ) -> Vec<L> where L: Semilattice, - F: FnMut(&[&L], &Node) -> L, + F: FnMut(&[&L], NodeID) -> L, { // Step 1: compute NodeUses for each node in function. let uses: Vec<NodeUses> = function.nodes.iter().map(|n| get_uses(n)).collect(); @@ -41,7 +41,7 @@ where .map(|id| { flow_function( &vec![&(if id == 0 { L::bottom() } else { L::top() }); uses[id].as_ref().len()], - &function.nodes[id], + NodeID::new(id), ) }) .collect(); @@ -63,7 +63,7 @@ where } // Compute new "out" value from predecessor "out" values. - let new_out = flow_function(&pred_outs[..], &function.nodes[node_id.idx()]); + let new_out = flow_function(&pred_outs[..], *node_id); if outs[node_id.idx()] != new_out { change = true; } diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index 0e750ca59e7e0baa237c9e6c0b0eece6d3b5a811..1eb578fa067be176bb0a9acabd80dd782e1ff5a4 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -108,13 +108,15 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { Node::Region { preds } => NodeUses::Variable(preds), Node::If { control, cond } => NodeUses::Two([*control, *cond]), Node::Fork { control, factor: _ } => NodeUses::One([*control]), - Node::Join { control, data } => NodeUses::Two([*control, *data]), + Node::Join { control } => NodeUses::One([*control]), Node::Phi { control, data } => { let mut uses: Vec<NodeID> = Vec::from(&data[..]); uses.push(*control); NodeUses::Phi(uses.into_boxed_slice()) } - Node::Return { control, value } => NodeUses::Two([*control, *value]), + Node::ThreadID { control } => NodeUses::One([*control]), + Node::Collect { control, data } => NodeUses::Two([*control, *data]), + Node::Return { control, data } => NodeUses::Two([*control, *data]), Node::Parameter { index: _ } => NodeUses::Zero, Node::Constant { id: _ } => NodeUses::Zero, Node::DynamicConstant { id: _ } => NodeUses::Zero, diff --git a/hercules_ir/src/dom.rs b/hercules_ir/src/dom.rs new file mode 100644 index 0000000000000000000000000000000000000000..002cbd04a8b1d16d714bb72fe17a5e719e87aa79 --- /dev/null +++ b/hercules_ir/src/dom.rs @@ -0,0 +1,166 @@ +extern crate bitvec; + +use crate::*; + +use std::collections::HashMap; + +/* + * Custom type for storing a dominator tree. For each control node, store its + * immediate dominator. + */ +#[derive(Debug, Clone)] +pub struct DomTree { + idom: HashMap<NodeID, NodeID>, +} + +impl DomTree { + pub fn imm_dom(&self, x: NodeID) -> Option<NodeID> { + self.idom.get(&x).map(|x| x.clone()) + } + + pub fn does_imm_dom(&self, a: NodeID, b: NodeID) -> bool { + self.imm_dom(b) == Some(a) + } + + pub fn does_dom(&self, a: NodeID, b: NodeID) -> bool { + let mut iter = Some(b); + + // Go up dominator tree until finding a, or root of tree. + while let Some(b) = iter { + if b == a { + return true; + } + iter = self.imm_dom(b); + } + false + } + + pub fn does_prop_dom(&self, a: NodeID, b: NodeID) -> bool { + a != b && self.does_dom(a, b) + } +} + +/* + * Top level function for calculating dominator trees. Uses the semi-NCA + * algorithm, as described in "Finding Dominators in Practice". + */ +pub fn dominator(subgraph: &Subgraph, root: NodeID) -> DomTree { + // Step 1: compute pre-order DFS of subgraph. + let (preorder, mut parents) = preorder(&subgraph, root); + let mut node_numbers = HashMap::new(); + for (number, node) in preorder.iter().enumerate() { + node_numbers.insert(node, number); + } + parents.insert(root, root); + let mut idom = HashMap::new(); + for w in preorder[1..].iter() { + // Each idom starts as the parent node. + idom.insert(*w, parents[w]); + } + + // Step 2: define snca_compress, which will be used to compute semi- + // dominators, and initialize various variables. + let mut semi = vec![0; preorder.len()]; + let mut labels: Vec<_> = (0..preorder.len()).collect(); + let mut ancestors = vec![0; preorder.len()]; + fn snca_compress( + v_n: usize, + mut ancestors: Vec<usize>, + mut labels: Vec<usize>, + ) -> (Vec<usize>, Vec<usize>) { + let u_n = ancestors[v_n]; + + if u_n != 0 { + (ancestors, labels) = snca_compress(u_n, ancestors, labels); + if labels[u_n] < labels[v_n] { + labels[v_n] = labels[u_n]; + } + ancestors[v_n] = ancestors[u_n]; + } + + (ancestors, labels) + } + + // Step 3: compute semi-dominators. + for w_n in (1..preorder.len()).rev() { + semi[w_n] = w_n; + for v in subgraph.preds(preorder[w_n]) { + let v_n = node_numbers[&v]; + (ancestors, labels) = snca_compress(v_n, ancestors, labels); + semi[w_n] = std::cmp::min(semi[w_n], labels[v_n]); + } + labels[w_n] = semi[w_n]; + ancestors[w_n] = node_numbers[&parents[&preorder[w_n]]]; + } + + // Step 4: compute idom. + for v_n in 1..preorder.len() { + let v = preorder[v_n]; + while node_numbers[&idom[&v]] > semi[v_n] { + *idom.get_mut(&v).unwrap() = idom[&idom[&v]]; + } + } + + DomTree { idom } +} + +fn preorder(subgraph: &Subgraph, root: NodeID) -> (Vec<NodeID>, HashMap<NodeID, NodeID>) { + // Initialize order vector and visited hashmap for tracking which nodes have + // been visited. + let order = Vec::with_capacity(subgraph.num_nodes() as usize); + + // Explicitly keep track of parents in DFS tree. Doubles as a visited set. + let parents = HashMap::new(); + + // Order and parents are threaded through arguments / return pair of + // reverse_postorder_helper for ownership reasons. + preorder_helper(root, None, subgraph, order, parents) +} + +fn preorder_helper( + node: NodeID, + parent: Option<NodeID>, + subgraph: &Subgraph, + mut order: Vec<NodeID>, + mut parents: HashMap<NodeID, NodeID>, +) -> (Vec<NodeID>, HashMap<NodeID, NodeID>) { + assert!(subgraph.contains_node(node)); + if parents.contains_key(&node) { + // If already visited, return early. + (order, parents) + } else { + // Keep track of DFS parent for region nodes. + if let Some(parent) = parent { + // Only node where the above isn't true is the start node, which + // has no incoming edge. Thus, there's no need to insert the start + // node into the parents map for tracking visitation. + parents.insert(node, parent); + } + + // Before iterating users, push this node. + order.push(node); + + // Iterate over users. + for user in subgraph.succs(node) { + (order, parents) = preorder_helper(user, Some(node), subgraph, order, parents); + } + + (order, parents) + } +} + +/* + * Top level function for calculating post-dominator trees. Reverses the edges + * in the subgraph, and then runs normal dominator analysis. Takes an owned + * subgraph, since we need to reverse it. Also take a fake root node ID to + * insert in the reversed subgraph. This will be the root of the resulting + * dominator tree. + */ +pub fn postdominator(subgraph: Subgraph, fake_root: NodeID) -> DomTree { + // Step 1: reverse the subgraph. + let reversed_subgraph = subgraph.reverse(fake_root); + + // Step 2: run dominator analysis on the reversed subgraph. Use the fake + // root as the root of the dominator analysis. + dominator(&reversed_subgraph, fake_root) +} diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index dd85cc9972f9217bcd8dbcafa13afcdd48f534ad..f149a50d482f407ee8cea8073275222bb20d1662 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -50,11 +50,11 @@ fn write_node<W: std::fmt::Write>( visited.insert(NodeID::new(j), name.clone()); let visited = match node { Node::Start => { - write!(w, "{} [label=\"start\"];\n", name)?; + write!(w, "{} [xlabel={}, label=\"start\"];\n", name, j)?; visited } Node::Region { preds } => { - write!(w, "{} [label=\"region\"];\n", name)?; + write!(w, "{} [xlabel={}, label=\"region\"];\n", name, j)?; for (idx, pred) in preds.iter().enumerate() { let (pred_name, tmp_visited) = write_node(i, pred.idx(), module, visited, w)?; visited = tmp_visited; @@ -67,7 +67,7 @@ fn write_node<W: std::fmt::Write>( visited } Node::If { control, cond } => { - write!(w, "{} [label=\"if\"];\n", name)?; + write!(w, "{} [xlabel={}, label=\"if\"];\n", name, j)?; let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; let (cond_name, visited) = write_node(i, cond.idx(), module, visited, w)?; write!( @@ -81,8 +81,9 @@ fn write_node<W: std::fmt::Write>( Node::Fork { control, factor } => { write!( w, - "{} [label=\"fork<{:?}>\"];\n", + "{} [xlabel={}, label=\"fork<{:?}>\"];\n", name, + j, module.dynamic_constants[factor.idx()] )?; let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; @@ -93,20 +94,18 @@ fn write_node<W: std::fmt::Write>( )?; visited } - Node::Join { control, data } => { - write!(w, "{} [label=\"join\"];\n", name,)?; + Node::Join { control } => { + write!(w, "{} [xlabel={}, label=\"join\"];\n", name, j)?; let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; - let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; write!( w, "{} -> {} [label=\"control\", style=\"dashed\"];\n", control_name, name )?; - write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; visited } Node::Phi { control, data } => { - write!(w, "{} [label=\"phi\"];\n", name)?; + write!(w, "{} [xlabel={}, label=\"phi\"];\n", name, j)?; let (control_name, mut visited) = write_node(i, control.idx(), module, visited, w)?; write!( w, @@ -120,27 +119,56 @@ fn write_node<W: std::fmt::Write>( } visited } - Node::Return { control, value } => { + Node::ThreadID { control } => { + write!(w, "{} [xlabel={}, label=\"thread_id\"];\n", name, j)?; + let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; + write!( + w, + "{} -> {} [label=\"control\", style=\"dashed\"];\n", + control_name, name + )?; + visited + } + Node::Collect { control, data } => { + let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; + let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; + write!(w, "{} [xlabel={}, label=\"collect\"];\n", name, j)?; + write!( + w, + "{} -> {} [label=\"control\", style=\"dashed\"];\n", + control_name, name + )?; + write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; + visited + } + Node::Return { control, data } => { let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; - let (value_name, visited) = write_node(i, value.idx(), module, visited, w)?; - write!(w, "{} [label=\"return\"];\n", name)?; + let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; + write!(w, "{} [xlabel={}, label=\"return\"];\n", name, j)?; write!( w, "{} -> {} [label=\"control\", style=\"dashed\"];\n", control_name, name )?; - write!(w, "{} -> {} [label=\"value\"];\n", value_name, name)?; + write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; visited } Node::Parameter { index } => { - write!(w, "{} [label=\"param #{}\"];\n", name, index + 1)?; + write!( + w, + "{} [xlabel={}, label=\"param #{}\"];\n", + name, + j, + index + 1 + )?; visited } Node::Constant { id } => { write!( w, - "{} [label=\"{:?}\"];\n", + "{} [xlabel={}, label=\"{:?}\"];\n", name, + j, module.constants[id.idx()] )?; visited @@ -148,20 +176,33 @@ fn write_node<W: std::fmt::Write>( Node::DynamicConstant { id } => { write!( w, - "{} [label=\"dynamic_constant({:?})\"];\n", + "{} [xlabel={}, label=\"dynamic_constant({:?})\"];\n", name, + j, module.dynamic_constants[id.idx()] )?; visited } Node::Unary { input, op } => { - write!(w, "{} [label=\"{}\"];\n", name, op.lower_case_name())?; + write!( + w, + "{} [xlabel={}, label=\"{}\"];\n", + name, + j, + op.lower_case_name() + )?; let (input_name, visited) = write_node(i, input.idx(), module, visited, w)?; write!(w, "{} -> {} [label=\"input\"];\n", input_name, name)?; visited } Node::Binary { left, right, op } => { - write!(w, "{} [label=\"{}\"];\n", name, op.lower_case_name())?; + write!( + w, + "{} [xlabel={}, label=\"{}\"];\n", + name, + j, + op.lower_case_name() + )?; let (left_name, visited) = write_node(i, left.idx(), module, visited, w)?; let (right_name, visited) = write_node(i, right.idx(), module, visited, w)?; write!(w, "{} -> {} [label=\"left\"];\n", left_name, name)?; @@ -173,7 +214,7 @@ fn write_node<W: std::fmt::Write>( dynamic_constants, args, } => { - write!(w, "{} [label=\"call<", name,)?; + write!(w, "{} [xlabel={}, label=\"call<", name, j)?; for (idx, id) in dynamic_constants.iter().enumerate() { let dc = &module.dynamic_constants[id.idx()]; if idx == 0 { @@ -198,13 +239,21 @@ fn write_node<W: std::fmt::Write>( visited } Node::ReadProd { prod, index } => { - write!(w, "{} [label=\"read_prod({})\"];\n", name, index)?; + write!( + w, + "{} [xlabel={}, label=\"read_prod({})\"];\n", + name, j, index + )?; let (prod_name, visited) = write_node(i, prod.idx(), module, visited, w)?; write!(w, "{} -> {} [label=\"prod\"];\n", prod_name, name)?; visited } Node::WriteProd { prod, data, index } => { - write!(w, "{} [label=\"write_prod({})\"];\n", name, index)?; + write!( + w, + "{} [xlabel={}, label=\"write_prod({})\"];\n", + name, j, index + )?; let (prod_name, visited) = write_node(i, prod.idx(), module, visited, w)?; let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; write!(w, "{} -> {} [label=\"prod\"];\n", prod_name, name)?; @@ -212,7 +261,7 @@ fn write_node<W: std::fmt::Write>( visited } Node::ReadArray { array, index } => { - write!(w, "{} [label=\"read_array\"];\n", name)?; + write!(w, "{} [xlabel={}, label=\"read_array\"];\n", name, j)?; let (array_name, visited) = write_node(i, array.idx(), module, visited, w)?; write!(w, "{} -> {} [label=\"array\"];\n", array_name, name)?; let (index_name, visited) = write_node(i, index.idx(), module, visited, w)?; @@ -220,7 +269,7 @@ fn write_node<W: std::fmt::Write>( visited } Node::WriteArray { array, data, index } => { - write!(w, "{} [label=\"write_array\"];\n", name)?; + write!(w, "{} [xlabel={}, label=\"write_array\"];\n", name, j)?; let (array_name, visited) = write_node(i, array.idx(), module, visited, w)?; write!(w, "{} -> {} [label=\"array\"];\n", array_name, name)?; let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; @@ -230,7 +279,7 @@ fn write_node<W: std::fmt::Write>( visited } Node::Match { control, sum } => { - write!(w, "{} [label=\"match\"];\n", name)?; + write!(w, "{} [xlabel={}, label=\"match\"];\n", name, j)?; let (control_name, visited) = write_node(i, control.idx(), module, visited, w)?; write!( w, @@ -248,8 +297,9 @@ fn write_node<W: std::fmt::Write>( } => { write!( w, - "{} [label=\"build_sum({:?}, {})\"];\n", + "{} [xlabel={}, label=\"build_sum({:?}, {})\"];\n", name, + j, module.types[sum_ty.idx()], variant )?; @@ -258,7 +308,11 @@ fn write_node<W: std::fmt::Write>( visited } Node::ExtractSum { data, variant } => { - write!(w, "{} [label=\"extract_sum({})\"];\n", name, variant)?; + write!( + w, + "{} [xlabel={}, label=\"extract_sum({})\"];\n", + name, j, variant + )?; let (data_name, visited) = write_node(i, data.idx(), module, visited, w)?; write!(w, "{} -> {} [label=\"data\"];\n", data_name, name)?; visited diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index d07484a5337c11db35d9db48b27a09470273807e..9400f0265f7937ec253c3d2279180042cffb1365 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -145,13 +145,13 @@ pub enum Constant { /* * Dynamic constants are unsigned 64-bit integers passed to a Hercules function - * at runtime using the Hercules runtime API. They cannot be the result of + * at runtime using the Hercules conductor API. They cannot be the result of * computations in Hercules IR. For a single execution of a Hercules function, * dynamic constants are constant throughout execution. This provides a * mechanism by which Hercules functions can operate on arrays with variable * length, while not needing Hercules functions to perform dynamic memory - * allocation - by providing dynamic constants to the runtime API, the runtime - * can allocate memory as necessary. + * allocation - by providing dynamic constants to the conductor API, the + * conductor can allocate memory as necessary. */ #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum DynamicConstant { @@ -161,15 +161,14 @@ pub enum DynamicConstant { /* * Hercules IR is a combination of a possibly cylic control flow graph, and - * many acyclic data flow graphs. Each node represents some operation on input - * values (including control), and produces some output value. Operations that - * conceptually produce multiple outputs (such as an if node) produce a product - * type instead. For example, the if node produces prod(control(N), + * many possibly cyclic data flow graphs. Each node represents some operation on + * input values (including control), and produces some output value. Operations + * that conceptually produce multiple outputs (such as an if node) produce a + * product type instead. For example, the if node produces prod(control(N), * control(N)), where the first control token represents the false branch, and - * the second control token represents the true branch. Another example is the - * fork node, which produces prod(control(N, K), u64), where the u64 is the - * thread ID. Functions are devoid of side effects, so call nodes don't take as - * input or output control tokens. There is also no global memory - use arrays. + * the second control token represents the true branch. Functions are devoid of + * side effects, so call nodes don't take as input or output control tokens. + * There is also no global memory - use arrays. */ #[derive(Debug, Clone)] pub enum Node { @@ -187,15 +186,21 @@ pub enum Node { }, Join { control: NodeID, - data: NodeID, }, Phi { control: NodeID, data: Box<[NodeID]>, }, + ThreadID { + control: NodeID, + }, + Collect { + control: NodeID, + data: NodeID, + }, Return { control: NodeID, - value: NodeID, + data: NodeID, }, Parameter { index: usize, @@ -284,7 +289,7 @@ impl Node { pub fn is_return(&self) -> bool { if let Node::Return { control: _, - value: _, + data: _, } = self { true @@ -305,17 +310,19 @@ impl Node { control: _, factor: _, } => "Fork", - Node::Join { - control: _, - data: _, - } => "Join", + Node::Join { control: _ } => "Join", Node::Phi { control: _, data: _, } => "Phi", + Node::ThreadID { control: _ } => "ThreadID", + Node::Collect { + control: _, + data: _, + } => "Collect", Node::Return { control: _, - value: _, + data: _, } => "Return", Node::Parameter { index: _ } => "Parameter", Node::DynamicConstant { id: _ } => "DynamicConstant", @@ -368,17 +375,19 @@ impl Node { control: _, factor: _, } => "fork", - Node::Join { - control: _, - data: _, - } => "join", + Node::Join { control: _ } => "join", Node::Phi { control: _, data: _, } => "phi", + Node::ThreadID { control: _ } => "thread_id", + Node::Collect { + control: _, + data: _, + } => "collect", Node::Return { control: _, - value: _, + data: _, } => "return", Node::Parameter { index: _ } => "parameter", Node::DynamicConstant { id: _ } => "dynamic_constant", @@ -485,7 +494,7 @@ impl BinaryOperator { /* * Rust things to make newtyped IDs usable. */ -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct FunctionID(u32); impl FunctionID { @@ -498,7 +507,7 @@ impl FunctionID { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct NodeID(u32); impl NodeID { @@ -511,7 +520,7 @@ impl NodeID { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ConstantID(u32); impl ConstantID { @@ -524,7 +533,7 @@ impl ConstantID { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct TypeID(u32); impl TypeID { @@ -537,7 +546,7 @@ impl TypeID { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct DynamicConstantID(u32); impl DynamicConstantID { diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index cd66d5fb5ea57a37aa6b85f09f5b1683f02146f0..094873f27dd31583fb72c838ac79279c9afaef1d 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -1,15 +1,19 @@ pub mod dataflow; pub mod def_use; +pub mod dom; pub mod dot; pub mod ir; pub mod parse; +pub mod subgraph; pub mod typecheck; pub mod verify; pub use crate::dataflow::*; pub use crate::def_use::*; +pub use crate::dom::*; pub use crate::dot::*; pub use crate::ir::*; pub use crate::parse::*; +pub use crate::subgraph::*; pub use crate::typecheck::*; pub use crate::verify::*; diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index a54f81096e42e98618342f10dde30f20b7ee85f9..fcabe771c6e42e10e7d0784785ea6a6b5a344ddb 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -275,6 +275,8 @@ fn parse_node<'a>( "fork" => parse_fork(ir_text, context)?, "join" => parse_join(ir_text, context)?, "phi" => parse_phi(ir_text, context)?, + "thread_id" => parse_thread_id(ir_text, context)?, + "collect" => parse_collect(ir_text, context)?, "return" => parse_return(ir_text, context)?, "constant" => parse_constant_node(ir_text, context)?, "dynamic_constant" => parse_dynamic_constant_node(ir_text, context)?, @@ -365,15 +367,14 @@ fn parse_fork<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IRes } fn parse_join<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { - let (ir_text, (control, data)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?; + let (ir_text, (control,)) = parse_tuple1(parse_identifier)(ir_text)?; let control = context.borrow_mut().get_node_id(control); - let data = context.borrow_mut().get_node_id(data); // A join node doesn't need to explicitly store a join factor. The join // factor is implicitly stored at the tail of the control token's type // level list of thread spawn factors. Intuitively, fork pushes to the end // of this list, while join just pops from the end of this list. - Ok((ir_text, Node::Join { control, data })) + Ok((ir_text, Node::Join { control })) } fn parse_phi<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { @@ -396,14 +397,33 @@ fn parse_phi<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResu Ok((ir_text, Node::Phi { control, data })) } +fn parse_thread_id<'a>( + ir_text: &'a str, + context: &RefCell<Context<'a>>, +) -> nom::IResult<&'a str, Node> { + let (ir_text, (control,)) = parse_tuple1(parse_identifier)(ir_text)?; + let control = context.borrow_mut().get_node_id(control); + Ok((ir_text, Node::ThreadID { control })) +} + +fn parse_collect<'a>( + ir_text: &'a str, + context: &RefCell<Context<'a>>, +) -> nom::IResult<&'a str, Node> { + let (ir_text, (control, data)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?; + let control = context.borrow_mut().get_node_id(control); + let data = context.borrow_mut().get_node_id(data); + Ok((ir_text, Node::Collect { control, data })) +} + fn parse_return<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Node> { - let (ir_text, (control, value)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?; + let (ir_text, (control, data)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?; let control = context.borrow_mut().get_node_id(control); - let value = context.borrow_mut().get_node_id(value); - Ok((ir_text, Node::Return { control, value })) + let data = context.borrow_mut().get_node_id(data); + Ok((ir_text, Node::Return { control, data })) } fn parse_constant_node<'a>( diff --git a/hercules_ir/src/subgraph.rs b/hercules_ir/src/subgraph.rs new file mode 100644 index 0000000000000000000000000000000000000000..a75bc637ff66f75e7f4fa2307dd74a3441b785d6 --- /dev/null +++ b/hercules_ir/src/subgraph.rs @@ -0,0 +1,248 @@ +use crate::*; + +use std::collections::HashMap; + +/* + * In various parts of the compiler, we want to consider a subset of a complete + * function graph. For example, for dominators, we often only want to find the + * dominator tree of only the control subgraph. + */ +#[derive(Debug, Clone)] +pub struct Subgraph { + nodes: Vec<NodeID>, + node_numbers: HashMap<NodeID, u32>, + first_forward_edges: Vec<u32>, + forward_edges: Vec<u32>, + first_backward_edges: Vec<u32>, + backward_edges: Vec<u32>, +} + +pub struct SubgraphIterator<'a> { + nodes: &'a Vec<NodeID>, + edges: &'a [u32], +} + +impl<'a> Iterator for SubgraphIterator<'a> { + type Item = NodeID; + + fn next(&mut self) -> Option<Self::Item> { + if self.edges.len() == 0 { + None + } else { + let id = self.edges[0]; + self.edges = &self.edges[1..]; + Some(self.nodes[id as usize]) + } + } +} + +impl Subgraph { + pub fn num_nodes(&self) -> u32 { + self.nodes.len() as u32 + } + + pub fn contains_node(&self, id: NodeID) -> bool { + self.node_numbers.contains_key(&id) + } + + pub fn preds(&self, id: NodeID) -> SubgraphIterator { + let number = self.node_numbers[&id]; + if ((number + 1) as usize) < self.first_backward_edges.len() { + SubgraphIterator { + nodes: &self.nodes, + edges: &self.backward_edges[(self.first_backward_edges[number as usize] as usize) + ..(self.first_backward_edges[number as usize + 1] as usize)], + } + } else { + SubgraphIterator { + nodes: &self.nodes, + edges: &self.backward_edges + [(self.first_backward_edges[number as usize] as usize)..], + } + } + } + + pub fn succs(&self, id: NodeID) -> SubgraphIterator { + let number = self.node_numbers[&id]; + if ((number + 1) as usize) < self.first_forward_edges.len() { + SubgraphIterator { + nodes: &self.nodes, + edges: &self.forward_edges[(self.first_forward_edges[number as usize] as usize) + ..(self.first_forward_edges[number as usize + 1] as usize)], + } + } else { + SubgraphIterator { + nodes: &self.nodes, + edges: &self.forward_edges[(self.first_forward_edges[number as usize] as usize)..], + } + } + } + + pub fn reverse(self, new_root: NodeID) -> Self { + let Subgraph { + mut nodes, + mut node_numbers, + first_forward_edges, + forward_edges, + mut first_backward_edges, + mut backward_edges, + } = self; + + // Since we need to add a "new" root to the subgraph, we first need to + // identify all the nodes with no forward edges. We're going to + // simultaneously add the new backward edges from the old leaves to the + // new root. + let mut leaf_numbers = vec![]; + let mut new_first_forward_edges = vec![]; + let mut new_forward_edges = vec![]; + let mut old_forward_edges_idx = 0; + for number in 0..nodes.len() as u32 { + new_first_forward_edges.push(new_forward_edges.len() as u32); + let num_edges = if ((number + 1) as usize) < first_forward_edges.len() { + first_forward_edges[number as usize + 1] - first_forward_edges[number as usize] + } else { + forward_edges.len() as u32 - first_forward_edges[number as usize] + }; + if num_edges == 0 { + // Node number of new root will be largest in subgraph. + new_forward_edges.push(nodes.len() as u32); + leaf_numbers.push(number); + } else { + for _ in 0..num_edges { + new_forward_edges.push(forward_edges[old_forward_edges_idx]); + old_forward_edges_idx += 1; + } + } + } + + // There are no backward edges from the root node. + new_first_forward_edges.push(new_forward_edges.len() as u32); + + // To reverse the edges in the graph, just swap the forward and backward + // edge vectors. Thus, we add the forward edges from the new root to + // the old leaves in the backward edge arrays. + node_numbers.insert(new_root, nodes.len() as u32); + nodes.push(new_root); + first_backward_edges.push(backward_edges.len() as u32); + for leaf in leaf_numbers.iter() { + backward_edges.push(*leaf); + } + + // Swap forward and backward edges. + assert!(nodes.len() == first_backward_edges.len()); + assert!(nodes.len() == new_first_forward_edges.len()); + Subgraph { + nodes, + node_numbers, + first_forward_edges: first_backward_edges, + forward_edges: backward_edges, + first_backward_edges: new_first_forward_edges, + backward_edges: new_forward_edges, + } + } +} + +/* + * Top level subgraph construction routine. Takes a function reference and a + * predicate - the predicate selects which nodes from the function will be + * included in the subgraph. An edge is added to the subgraph if it's between + * two nodes that each pass the predicate. + */ +pub fn subgraph<F>(function: &Function, def_use: &ImmutableDefUseMap, predicate: F) -> Subgraph +where + F: Fn(&Node) -> bool, +{ + let mut subgraph = Subgraph { + nodes: vec![], + node_numbers: HashMap::new(), + first_forward_edges: vec![], + forward_edges: vec![], + first_backward_edges: vec![], + backward_edges: vec![], + }; + + // Step 1: collect predicated nodes. + for (idx, node) in function.nodes.iter().enumerate() { + if predicate(node) { + subgraph + .node_numbers + .insert(NodeID::new(idx), subgraph.nodes.len() as u32); + subgraph.nodes.push(NodeID::new(idx)); + } + } + + // Step 2: collect backwards edges. This is fairly easy, since use-def + // edges are explicitly stored. + for id in subgraph.nodes.iter() { + subgraph + .first_backward_edges + .push(subgraph.backward_edges.len() as u32); + let uses = get_uses(&function.nodes[id.idx()]); + for use_id in uses.as_ref() { + // Any predecessor node that satisfies the predicate already got + // added to node numbers. We need to get the node number anyway, + // so we don't have to do a redundant predicate check. + if let Some(number) = subgraph.node_numbers.get(use_id) { + subgraph.backward_edges.push(*number); + } + } + } + + // Step 3: collect forwards edges. This is also easy, since we already have + // the def-use edges of this function. + for id in subgraph.nodes.iter() { + subgraph + .first_forward_edges + .push(subgraph.forward_edges.len() as u32); + + // Only difference is that we iterate over users, not uses. + let users = def_use.get_users(*id); + for user_id in users.as_ref() { + // Any successor node that satisfies the predicate already got + // added to node numbers. We need to get the node number anyway, + // so we don't have to do a redundant predicate check. + if let Some(number) = subgraph.node_numbers.get(user_id) { + subgraph.forward_edges.push(*number); + } + } + } + + subgraph +} + +/* + * Get the control subgraph of a function. + */ +pub fn control_subgraph(function: &Function, def_use: &ImmutableDefUseMap) -> Subgraph { + use Node::*; + + subgraph(function, def_use, |node| match node { + Start + | Region { preds: _ } + | If { + control: _, + cond: _, + } + | Fork { + control: _, + factor: _, + } + | Join { control: _ } + | Return { + control: _, + data: _, + } + | Match { control: _, sum: _ } => true, + ReadProd { prod, index } => match function.nodes[prod.idx()] { + // ReadProd nodes are control nodes if their predecessor is a + // legal control node, and if it's the right index. + Match { control: _, sum: _ } + | If { + control: _, + cond: _, + } => true, + _ => false, + }, + _ => false, + }) +} diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 4bfa9faecb57ce29137b0d892765e837b1bc4f9d..6722a66dbaa0aff09ad8d5dcf254ff4ac11c24bb 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -98,6 +98,10 @@ pub fn typecheck( .map(|(idx, ty)| (ty.clone(), TypeID::new(idx))) .collect(); + // Also create a join replication factor map. This is needed to typecheck + // collect node. + let mut join_factor_map: HashMap<NodeID, DynamicConstantID> = HashMap::new(); + // Step 2: run dataflow. This is an occurrence of dataflow where the flow // function performs a non-associative operation on the predecessor "out" // values. @@ -113,6 +117,7 @@ pub fn typecheck( constants, dynamic_constants, &mut reverse_type_map, + &mut join_factor_map, ) }) }) @@ -143,13 +148,14 @@ pub fn typecheck( */ fn typeflow( inputs: &[&TypeSemilattice], - node: &Node, + node_id: NodeID, function: &Function, functions: &Vec<Function>, types: &mut Vec<Type>, constants: &Vec<Constant>, dynamic_constants: &Vec<DynamicConstant>, reverse_type_map: &mut HashMap<Type, TypeID>, + join_factor_map: &mut HashMap<NodeID, DynamicConstantID>, ) -> TypeSemilattice { // Whenever we want to reference a specific type (for example, for the // start node), we need to get its type ID. This helper function gets the @@ -170,7 +176,7 @@ fn typeflow( // Each node requires different type logic. This unfortunately results in a // large match statement. Oh well. Each arm returns the lattice value for // the "out" type of the node. - match node { + match &function.nodes[node_id.idx()] { Node::Start => { if inputs.len() != 0 { return Error(String::from("Start node must have zero inputs.")); @@ -253,18 +259,14 @@ fn typeflow( let mut new_factors = factors.clone().into_vec(); new_factors.push(*factor); - // Out type is a pair - first element is the control type, - // second is the index type (u64). Each thread gets a - // different thread ID at runtime. + // Out type is control type, with the new thread spawn + // factor. let control_out_id = get_type_id( Type::Control(new_factors.into_boxed_slice()), types, reverse_type_map, ); - let index_out_id = - get_type_id(Type::UnsignedInteger64, types, reverse_type_map); - let out_ty = Type::Product(Box::new([control_out_id, index_out_id])); - return Concrete(get_type_id(out_ty, types, reverse_type_map)); + return Concrete(control_out_id); } else { return Error(String::from( "Fork node's input cannot have non-control type.", @@ -274,59 +276,38 @@ fn typeflow( inputs[0].clone() } - Node::Join { - control: _, - data: _, - } => { - if inputs.len() != 2 { + Node::Join { control: _ } => { + if inputs.len() != 1 { return Error(String::from("Join node must have exactly two inputs.")); } - // If the data input isn't concrete, we can't assemble a concrete - // output type yet, so just return data input's type (either - // unconstrained or error) instead. - if let Concrete(data_id) = inputs[1] { - if types[data_id.idx()].is_control() { - return Error(String::from( - "Join node's second input must not have a control type.", - )); - } - - // Similarly, if the control input isn't concrete yet, we can't - // assemble a concrete output type, so just return the control - // input non-concrete type. - if let Concrete(control_id) = inputs[0] { - if let Type::Control(factors) = &types[control_id.idx()] { - // Join removes a factor from the factor list. - if factors.len() == 0 { - return Error(String::from("Join node's first input must have a control type with at least one thread replication factor.")); - } - let mut new_factors = factors.clone().into_vec(); - let dc_id = new_factors.pop().unwrap(); - - // Out type is a pair - first element is the control - // type, second is the result array from the parallel - // computation. - let control_out_id = get_type_id( - Type::Control(new_factors.into_boxed_slice()), - types, - reverse_type_map, - ); - let array_out_id = - get_type_id(Type::Array(*data_id, dc_id), types, reverse_type_map); - let out_ty = Type::Product(Box::new([control_out_id, array_out_id])); - return Concrete(get_type_id(out_ty, types, reverse_type_map)); - } else { - return Error(String::from( - "Join node's first input cannot have non-control type.", - )); + // If the control input isn't concrete yet, we can't assemble a + // concrete output type, so just return the control input non- + // concrete type. + if let Concrete(control_id) = inputs[0] { + if let Type::Control(factors) = &types[control_id.idx()] { + // Join removes a factor from the factor list. + if factors.len() == 0 { + return Error(String::from("Join node's first input must have a control type with at least one thread replication factor.")); } + let mut new_factors = factors.clone().into_vec(); + join_factor_map.insert(node_id, new_factors.pop().unwrap()); + + // Out type is the new control type. + let control_out_id = get_type_id( + Type::Control(new_factors.into_boxed_slice()), + types, + reverse_type_map, + ); + return Concrete(control_out_id); } else { - return inputs[0].clone(); + return Error(String::from( + "Join node's first input cannot have non-control type.", + )); } } - inputs[1].clone() + inputs[0].clone() } Node::Phi { control: _, @@ -365,9 +346,67 @@ fn typeflow( meet } + Node::ThreadID { control: _ } => { + if inputs.len() != 1 { + return Error(String::from("ThreadID node must have exactly one input.")); + } + + // If type of control input is an error, we must propagate it. + if inputs[0].is_error() { + return inputs[0].clone(); + } + + // Type of thread ID is always u64. + Concrete(get_type_id( + Type::UnsignedInteger64, + types, + reverse_type_map, + )) + } + Node::Collect { control, data: _ } => { + if inputs.len() != 2 { + return Error(String::from("Collect node must have exactly two inputs.")); + } + + if let (Concrete(control_id), Concrete(data_id)) = (inputs[0], inputs[1]) { + // Check control input is control. + if let Type::Control(_) = types[control_id.idx()] { + } else { + return Error(String::from( + "Collect node's control input must have control type.", + )); + } + + // Check data input isn't control. + if let Type::Control(_) = types[data_id.idx()] { + return Error(String::from( + "Collect node's data input must not have control type.", + )); + } + + // Unfortunately, the type of the control input doesn't contain + // the thread replication factor this collect node is operating + // with. We use the join replication factor map side data + // structure to store the replication factor each join reduces + // over to make this easier. + if let Some(factor) = join_factor_map.get(control) { + let array_out_id = + get_type_id(Type::Array(*data_id, *factor), types, reverse_type_map); + Concrete(array_out_id) + } else { + // If the join factor map doesn't contain the control + // input, stay optimistic. + Unconstrained + } + } else if inputs[0].is_error() { + inputs[0].clone() + } else { + inputs[1].clone() + } + } Node::Return { control: _, - value: _, + data: _, } => { if inputs.len() != 2 { return Error(String::from("Return node must have exactly two inputs.")); @@ -702,6 +741,10 @@ fn typeflow( } else if let Concrete(data_id) = inputs[1] { if elem_tys[*index] != *data_id { return Error(format!("WriteProd node's data input doesn't match the type of the element at index {} inside the product type.", index)); + } else if let Type::Control(_) = &types[data_id.idx()] { + return Error(String::from( + "WriteProd node's data input cannot have a control type.", + )); } } else if inputs[1].is_error() { // If an input lattice value is an error, we must @@ -773,6 +816,10 @@ fn typeflow( if let Concrete(data_id) = inputs[1] { if elem_id != *data_id { return Error(String::from("WriteArray node's array and data inputs must have compatible types (type of data input must be the same as the array input's element type).")); + } else if let Type::Control(_) = &types[data_id.idx()] { + return Error(String::from( + "WriteArray node's data input cannot have a control type.", + )); } } } else { @@ -831,6 +878,12 @@ fn typeflow( } if let Concrete(id) = inputs[0] { + if let Type::Control(_) = &types[id.idx()] { + return Error(String::from( + "BuildSum node's data input cannot have a control type.", + )); + } + // BuildSum node stores its own result type. if let Type::Summation(variants) = &types[sum_ty.idx()] { // Must reference an existing variant. diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs index a737aafef6dc776be2b13e8f824455a8f9fd2e68..89192c01a48cc58d14180bb88f70faeab3f787c4 100644 --- a/hercules_ir/src/verify.rs +++ b/hercules_ir/src/verify.rs @@ -31,6 +31,16 @@ pub fn verify(module: &mut Module) -> Result<ModuleTyping, String> { { verify_structure(function, def_use, typing, &module.types)?; } + + // Check SSA, fork, and join dominance relations. + for (function, def_use) in zip(module.functions.iter(), def_uses) { + let subgraph = control_subgraph(function, &def_use); + let dom = dominator(&subgraph, NodeID::new(0)); + let postdom = postdominator(subgraph, NodeID::new(function.nodes.len())); + println!("{:?}", dom); + println!("{:?}", postdom); + } + Ok(typing) } @@ -48,20 +58,11 @@ fn verify_structure( for (idx, node) in function.nodes.iter().enumerate() { let users = def_use.get_users(NodeID::new(idx)); match node { - // If, fork, and join nodes all have the same structural - // constraints - each must have exactly two ReadProd users, which + // Each if node must have exactly two ReadProd users, which // reference differing elements of the node's output product. Node::If { control: _, cond: _, - } - | Node::Fork { - control: _, - factor: _, - } - | Node::Join { - control: _, - data: _, } => { if users.len() != 2 { Err(format!( @@ -100,10 +101,28 @@ fn verify_structure( Err("Phi node's control input must be a region node.")?; } } + // ThreadID nodes must depend on a fork node. + Node::ThreadID { control } => { + if let Node::Fork { + control: _, + factor: _, + } = function.nodes[control.idx()] + { + } else { + Err("ThreadID node's control input must be a fork node.")?; + } + } + // Collect nodes must depend on a join node. + Node::Collect { control, data: _ } => { + if let Node::Join { control: _ } = function.nodes[control.idx()] { + } else { + Err("Collect node's control input must be a join node.")?; + } + } // Return nodes must have no users. Node::Return { control: _, - value: _, + data: _, } => { if users.len() != 0 { Err(format!( diff --git a/samples/matmul.hir b/samples/matmul.hir index 511bdfa8118194d31e0c21b12e83c62eba4318ed..af13ce95489cd5675f61c9023d78f0866a9d4e68 100644 --- a/samples/matmul.hir +++ b/samples/matmul.hir @@ -1,10 +1,8 @@ fn matmul<3>(a: array(array(f32, #1), #0), b: array(array(f32, #2), #1)) -> array(array(f32, #2), #0) - i = fork(start, #0) - i_ctrl = read_prod(i, 0) - i_idx = read_prod(i, 1) - k = fork(i_ctrl, #2) - k_ctrl = read_prod(k, 0) - k_idx = read_prod(k, 1) + i_ctrl = fork(start, #0) + i_idx = thread_id(i_ctrl) + k_ctrl = fork(i_ctrl, #2) + k_idx = thread_id(k_ctrl) zero_idx = constant(u64, 0) one_idx = constant(u64, 1) zero_val = constant(f32, 0) @@ -23,10 +21,8 @@ fn matmul<3>(a: array(array(f32, #1), #0), b: array(array(f32, #2), #1)) -> arra if = if(loop, less) if_false = read_prod(if, 0) if_true = read_prod(if, 1) - k_join = join(if_false, sum_inc) - k_join_ctrl = read_prod(k_join, 0) - k_join_data = read_prod(k_join, 1) - i_join = join(k_join_ctrl, k_join_data) - i_join_ctrl = read_prod(i_join, 0) - i_join_data = read_prod(i_join, 1) + k_join_ctrl = join(if_false) + k_join_data = collect(k_join_ctrl, sum_inc) + i_join_ctrl = join(k_join_ctrl) + i_join_data = collect(i_join_ctrl, k_join_data) r = return(i_join_ctrl, i_join_data) diff --git a/samples/simple1.hir b/samples/simple1.hir index 415b2bc3f9a710fceccd222d4ac47c348669a0d5..92c1435b14047a0e4293ee5b814a70abc59a4a91 100644 --- a/samples/simple1.hir +++ b/samples/simple1.hir @@ -1,11 +1,5 @@ fn myfunc(x: i32) -> i32 - y = call<5>(add, x, x) - r = return(start, y) - -fn add<1>(x: i32, y: i32) -> i32 - c = constant(i8, 5) - dc = dynamic_constant(#0) - r = return(start, s) - w = add(z, c) - s = add(w, dc) - z = add(x, y) \ No newline at end of file + a = region(start) + b = region(start) + c = region(a, b) + d = return(c, x)