Skip to content
Snippets Groups Projects
Commit 145a016c authored by rarbore2's avatar rarbore2
Browse files

Merge branch 'array-to-prod' into 'main'

Array to prod

See merge request !174
parents d5985d59 c84c8c7a
No related branches found
No related tags found
1 merge request!174Array to prod
Pipeline #201639 passed
......@@ -1231,6 +1231,16 @@ dependencies = [
"with_builtin_macros",
]
[[package]]
name = "juno_median_window"
version = "0.1.0"
dependencies = [
"async-std",
"hercules_rt",
"juno_build",
"with_builtin_macros",
]
[[package]]
name = "juno_multi_device"
version = "0.1.0"
......
......@@ -34,4 +34,5 @@ members = [
"juno_samples/fork_join_tests",
"juno_samples/multi_device",
"juno_samples/products",
"juno_samples/median_window",
]
use hercules_ir::define_id_type;
use hercules_ir::ir::*;
use bitvec::prelude::*;
use crate::*;
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
/*
* Top level function for array to product which will convert constant
* sized arrays into products if the array is only accessed at indices which
* are constants.
*
* To identify the collections we can convert we look at each constant-sized
* array constant and compute the set which includes the constant node and is
* closed under the following properties:
* - For each collection in the set, its uses are in the set
* - For each node that uses a collection, all collections it uses are in the
* set
* From this set, we then determine whether this whole set can be converted to
* operating on products, rather than arrays, as follows
* - Each read and write node must be to a constant index
* - It may not contain any arguments (we could generate code to read a an
* array argument into a product, but do not do so for now)
* - There are call or return nodes in the set (this would mean that the
* collections are consumed by a call or return, again we could reconstruct
* the array where needed but do not do so for now and so have this
* restriction)
* - All nodes in the set are editable (if we cannot modify some node then the
* conversion will fail)
*
* The max_size argument allows the user to specify a limit on the size of arrays
* that should be converted to products. If the number of elements in the array
* is larger than the max size the array will not be converted.
*/
pub fn array_to_product(editor: &mut FunctionEditor, types: &[TypeID], max_size: Option<usize>) {
let replace_nodes = array_usage_analysis(editor, types, max_size);
let num_nodes = editor.func().nodes.len();
// Replace nodes
for node_idx in 0..num_nodes {
if !replace_nodes[node_idx] {
continue;
}
let node = NodeID::new(node_idx);
// We can replace the array(s) this node uses with a product. What we have to do depends on
// the type of the node
match &editor.func().nodes[node_idx] {
// Phi, Reduce, and Ternary just use the whole collection, they do not need to change,
// except as they will be modified by replace_all_uses_of
Node::Phi { .. }
| Node::Reduce { .. }
| Node::Ternary {
op: TernaryOperator::Select,
..
} => {}
Node::Constant { id } => {
assert!(editor.get_constant(*id).is_array());
let element: TypeID = editor.get_type(types[node_idx]).try_element_type().unwrap();
let dims: Vec<usize> = editor
.get_type(types[node_idx])
.try_extents()
.unwrap()
.iter()
.map(|dc| editor.get_dynamic_constant(*dc).try_constant().unwrap())
.collect();
// Replace the constant by a product that is a product (for each dimension) and the
// elements are zero'd
editor.edit(|mut edit| {
let element_zero = edit.add_zero_constant(element);
let (constant, _) = dims.into_iter().rfold(
(element_zero, element),
|(cur_const, cur_type), dim| {
let new_type = edit.add_type(Type::Product(vec![cur_type; dim].into()));
let new_const = edit.add_constant(Constant::Product(
new_type,
vec![cur_const; dim].into(),
));
(new_const, new_type)
},
);
let new_val = edit.add_node(Node::Constant { id: constant });
let edit = edit.replace_all_uses(node, new_val)?;
edit.delete_node(node)
});
}
Node::Read { collect, indices } => {
let collect = *collect;
let new_indices = convert_indices_to_prod(editor, indices);
editor.edit(|mut edit| {
let new_val = edit.add_node(Node::Read {
collect,
indices: new_indices,
});
let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?;
edit.delete_node(node)
});
}
Node::Write {
collect,
data,
indices,
} => {
let collect = *collect;
let data = *data;
let new_indices = convert_indices_to_prod(editor, indices);
editor.edit(|mut edit| {
let new_val = edit.add_node(Node::Write {
collect,
data,
indices: new_indices,
});
let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?;
edit.delete_node(node)
});
}
node => panic!("Node cannot be replaced: {:?}", node),
}
}
}
fn convert_indices_to_prod(editor: &FunctionEditor, indices: &[Index]) -> Box<[Index]> {
let mut result = vec![];
for index in indices {
match index {
Index::Position(positions) => {
for pos in positions {
let const_id = editor.func().nodes[pos.idx()]
.try_constant()
.expect("Array position must be constant");
match *editor.get_constant(const_id) {
Constant::UnsignedInteger64(idx) => result.push(Index::Field(idx as usize)),
ref val => panic!("Position should be u64 constant: {:?}", val),
}
}
}
index => panic!("Index cannot be replaced: {:?}", index),
}
}
result.into()
}
// Given the editor, while compute a mask of which nodes are to be converted
// from using a constant sized array into using a product
fn array_usage_analysis(
editor: &FunctionEditor,
types: &[TypeID],
max_size: Option<usize>,
) -> BitVec<u8, Lsb0> {
let num_nodes = editor.func().nodes.len();
// Step 1: identify the constant nodes that are constant sized arrays no larger than the
// max_size, these are what we are interested in converting into products
let sources = editor
.func()
.nodes
.iter()
.enumerate()
.filter_map(|(idx, node)| {
let Node::Constant { id } = node else {
return None;
};
let Constant::Array(array_type) = *editor.get_constant(*id) else {
return None;
};
let typ = editor.get_type(array_type);
let Some(dims) = typ.try_extents() else {
return None;
};
// Compute the total number of elements, the result is None if some dimension is not a
// constant and otherwise is Some(num_elements) which we can then compare with max_size
if let Some(elements) = dims.iter().fold(Some(1), |prod, dc| {
prod.and_then(|prod| {
editor
.get_dynamic_constant(*dc)
.try_constant()
.map(|dim| prod * dim)
})
}) {
if let Some(max_size) = max_size
&& elements > max_size
{
// Too many elements, don't convert
None
} else {
Some(NodeID::new(idx))
}
} else {
None
}
})
.collect::<Vec<_>>();
// Step 2: collect the collection information we need for the (whole) function. For each node
// that returns a collection (that in reference semantics returns the same reference as some of
// its inputs) union with all of its users. The nodes that matter in this are arguments,
// constants, writes, phis, selects, and reduces with array types.
let mut analysis = UnionFind::new();
for node_idx in 0..num_nodes {
let node_id = NodeID::new(node_idx);
if editor.get_type(types[node_idx]).is_array() {
match editor.func().nodes[node_idx] {
Node::Phi { .. }
| Node::Reduce { .. }
| Node::Parameter { .. }
| Node::Constant { .. }
| Node::Ternary {
op: TernaryOperator::Select,
..
}
| Node::Write { .. } => {
for user in editor.get_users(node_id) {
analysis.union(node_id, user);
}
}
_ => {}
}
}
}
let sets = analysis.sets(&sources);
// Step 3: determine which sets can be converted and mark the nodes in those sets
let mut result = bitvec![u8, Lsb0; 0; num_nodes];
for nodes in sets {
if nodes
.iter()
.all(|node_id| editor.is_mutable(*node_id) && can_replace(editor, *node_id))
{
for node_id in nodes {
result.set(node_id.idx(), true);
}
}
}
result
}
fn can_replace(editor: &FunctionEditor, node: NodeID) -> bool {
match &editor.func().nodes[node.idx()] {
// Reads and writes must be at constant indices
Node::Read { indices, .. } | Node::Write { indices, .. } => {
indices.iter().all(|idx| match idx {
Index::Position(pos) => pos
.iter()
.all(|node| editor.func().nodes[node.idx()].is_constant()),
_ => false,
})
}
// phi, reduce, constants, and select can always be replaced if their users and uses allow
// it, which is handled by the construction of the set
Node::Phi { .. }
| Node::Reduce { .. }
| Node::Constant { .. }
| Node::Ternary {
op: TernaryOperator::Select,
..
} => true,
// No other nodes allow replacement
_ => false,
}
}
define_id_type!(SetID);
#[derive(Clone, Debug)]
struct UnionFindNode {
parent: SetID,
rank: usize,
}
#[derive(Clone, Debug)]
struct UnionFind<T> {
sets: Vec<UnionFindNode>,
_phantom: PhantomData<T>,
}
impl<T: ID> UnionFind<T> {
pub fn new() -> Self {
UnionFind {
sets: vec![],
_phantom: PhantomData,
}
}
fn extend_past(&mut self, size: usize) {
for i in self.sets.len()..=size {
// The new nodes we add are in their own sets and have rank 0
self.sets.push(UnionFindNode {
parent: SetID::new(i),
rank: 0,
});
}
}
pub fn find(&mut self, x: T) -> SetID {
self.extend_past(x.idx());
self.find_set(x.idx())
}
fn find_set(&mut self, x: usize) -> SetID {
let mut parent = self.sets[x].parent;
if parent.idx() != x {
parent = self.find_set(parent.idx());
self.sets[x].parent = parent;
}
parent
}
pub fn union(&mut self, x: T, y: T) {
let x = self.find(x);
let y = self.find(y);
self.link(x, y);
}
fn link(&mut self, x: SetID, y: SetID) {
if self.sets[x.idx()].rank > self.sets[y.idx()].rank {
self.sets[y.idx()].parent = x;
} else {
self.sets[x.idx()].parent = y;
if self.sets[x.idx()].rank == self.sets[y.idx()].rank {
self.sets[y.idx()].rank += 1;
}
}
}
pub fn sets(&mut self, keys: &[T]) -> Vec<Vec<T>> {
let key_index = keys
.iter()
.enumerate()
.map(|(i, k)| (self.find(*k), i))
.collect::<HashMap<SetID, usize>>();
let mut result = vec![vec![]; keys.len()];
let num_elements = self.sets.len();
for i in 0..num_elements {
let set = self.find_set(i);
let Some(idx) = key_index.get(&set) else {
continue;
};
result[*idx].push(T::new(i));
}
result
}
}
#![feature(let_chains)]
pub mod array_to_prod;
pub mod ccp;
pub mod crc;
pub mod dce;
......@@ -26,6 +27,7 @@ pub mod sroa;
pub mod unforkify;
pub mod utils;
pub use crate::array_to_prod::*;
pub use crate::ccp::*;
pub use crate::crc::*;
pub use crate::dce::*;
......
......@@ -23,13 +23,6 @@ pub fn lift_dc_math(editor: &mut FunctionEditor) {
};
DynamicConstant::Constant(cons as usize)
}
Node::DynamicConstant { id } => {
let Some(cons) = evaluate_dynamic_constant(id, &*editor.get_dynamic_constants())
else {
continue;
};
DynamicConstant::Constant(cons)
}
Node::Binary { op, left, right } => {
let (left, right) = if let (
Node::DynamicConstant { id: left },
......
[package]
name = "juno_median_window"
version = "0.1.0"
authors = ["Aaron Councilman <aaronjc4@illinois.edu>"]
edition = "2021"
[[bin]]
name = "juno_median_window"
path = "src/main.rs"
[features]
cuda = ["juno_build/cuda", "hercules_rt/cuda"]
[build-dependencies]
juno_build = { path = "../../juno_build" }
[dependencies]
juno_build = { path = "../../juno_build" }
hercules_rt = { path = "../../hercules_rt" }
with_builtin_macros = "0.1.0"
async-std = "*"
use juno_build::JunoCompiler;
fn main() {
JunoCompiler::new()
.file_in_src("median.jn")
.unwrap()
.schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
.unwrap()
.build()
.unwrap();
}
gvn(*);
phi-elim(*);
dce(*);
inline(*);
delete-uncalled(*);
let out = auto-outline(*);
cpu(out.median_window);
ip-sroa(*);
sroa(*);
dce(*);
gvn(*);
phi-elim(*);
dce(*);
forkify(*);
fork-guard-elim(*);
lift-dc-math(*);
forkify(*);
fork-guard-elim(*);
fork-unroll(out.median_window@outer);
lift-dc-math(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-unroll(*);
}
ccp(*);
gvn(*);
dce(*);
array-to-product(*);
sroa(*);
phi-elim(*);
predication(*);
simplify-cfg(*);
dce(*);
gvn(*);
gcm(*);
gvn(*);
phi-elim(*);
dce(*);
inline(*);
delete-uncalled(*);
let out = auto-outline(*);
gpu(out.median_window);
ip-sroa(*);
sroa(*);
dce(*);
gvn(*);
phi-elim(*);
dce(*);
forkify(*);
fork-guard-elim(*);
lift-dc-math(*);
forkify(*);
fork-guard-elim(*);
fork-unroll(out.median_window@outer);
lift-dc-math(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-unroll(*);
}
ccp(*);
gvn(*);
dce(*);
array-to-product(*);
sroa(*);
phi-elim(*);
predication(*);
simplify-cfg(*);
dce(*);
gvn(*);
gcm(*);
#![feature(concat_idents)]
juno_build::juno!("median");
use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo};
fn main() {
let m = vec![86, 72, 14, 5, 55,
25, 98, 89, 3, 66,
44, 81, 27, 3, 40,
18, 4, 57, 93, 34,
70, 50, 50, 18, 34];
let m = HerculesImmBox::from(m.as_slice());
let mut r = runner!(median_window);
let res =
async_std::task::block_on(async {
r.run(m.to()).await
});
assert_eq!(res, 57);
}
#[test]
fn test_median_window() {
main()
}
fn median_matrix<t: number, n, m: usize>(x: t[n, m]) -> t {
let tmp : t[n * m];
for i = 0 to n {
for j = 0 to m {
tmp[i * m + j] = x[i, j];
}
}
const cnt = n * m;
@outer for i = 0 to cnt - 1 {
for j = 0 to cnt - i - 1 {
if tmp[j] > tmp[j + 1] {
let t = tmp[j];
tmp[j] = tmp[j + 1];
tmp[j + 1] = t;
}
}
}
return tmp[cnt / 2];
}
#[entry]
fn median_window(x: i32[5, 5]) -> i32 {
let window: i32[3, 3];
for i = 0 to 3 {
for j = 0 to 3 {
window[i, j] = x[i + 1, j + 1];
}
}
return median_matrix::<_, 3, 3>(window);
}
......@@ -18,7 +18,7 @@ pub enum ScheduleCompilerError {
UndefinedMacro(String, Location),
NoSuchPass(String, Location),
IncorrectArguments {
expected: usize,
expected: String,
actual: usize,
loc: Location,
},
......@@ -81,11 +81,20 @@ enum Appliable {
}
impl Appliable {
fn num_args(&self) -> usize {
// Tests whether a given number of arguments is a valid number of arguments for this
fn is_valid_num_args(&self, num: usize) -> bool {
match self {
Appliable::Pass(pass) => pass.num_args(),
Appliable::Pass(pass) => pass.is_valid_num_args(num),
// Delete uncalled, Schedules, and devices do not take arguments
_ => 0,
_ => num == 0,
}
}
// Returns a description of the number of arguments this requires
fn valid_arg_nums(&self) -> &'static str {
match self {
Appliable::Pass(pass) => pass.valid_arg_nums(),
_ => "0",
}
}
}
......@@ -96,6 +105,9 @@ impl FromStr for Appliable {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"array-slf" => Ok(Appliable::Pass(ir::Pass::ArraySLF)),
"array-to-product" | "array-to-prod" | "a2p" => {
Ok(Appliable::Pass(ir::Pass::ArrayToProduct))
}
"auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)),
"ccp" => Ok(Appliable::Pass(ir::Pass::CCP)),
"crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)),
......@@ -306,9 +318,9 @@ fn compile_expr(
.parse()
.map_err(|s| ScheduleCompilerError::NoSuchPass(s, lexer.line_col(name)))?;
if args.len() != func.num_args() {
if !func.is_valid_num_args(args.len()) {
return Err(ScheduleCompilerError::IncorrectArguments {
expected: func.num_args(),
expected: func.valid_arg_nums().to_string(),
actual: args.len(),
loc: lexer.line_col(span),
});
......@@ -365,7 +377,7 @@ fn compile_expr(
if args.len() != params.len() {
return Err(ScheduleCompilerError::IncorrectArguments {
expected: params.len(),
expected: params.len().to_string(),
actual: args.len(),
loc: lexer.line_col(span),
});
......
......@@ -3,6 +3,7 @@ use hercules_ir::ir::{Device, Schedule};
#[derive(Debug, Copy, Clone)]
pub enum Pass {
ArraySLF,
ArrayToProduct,
AutoOutline,
CCP,
CRC,
......@@ -39,13 +40,25 @@ pub enum Pass {
}
impl Pass {
pub fn num_args(&self) -> usize {
pub fn is_valid_num_args(&self, num: usize) -> bool {
match self {
Pass::Xdot => 1,
Pass::ForkChunk => 4,
Pass::ForkFissionBufferize => 2,
Pass::ForkInterchange => 2,
_ => 0,
Pass::ArrayToProduct => num == 0 || num == 1,
Pass::Xdot => num == 0 || num == 1,
Pass::ForkChunk => num == 4,
Pass::ForkFissionBufferize => num == 2,
Pass::ForkInterchange => num == 2,
_ => num == 0,
}
}
pub fn valid_arg_nums(&self) -> &'static str {
match self {
Pass::ArrayToProduct => "0 or 1",
Pass::Xdot => "0 or 1",
Pass::ForkChunk => "4",
Pass::ForkFissionBufferize => "2",
Pass::ForkInterchange => "2",
_ => "0",
}
}
}
......
......@@ -1521,6 +1521,34 @@ fn run_pass(
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::ArrayToProduct => {
assert!(args.len() <= 1);
let max_size = match args.get(0) {
Some(Value::Integer { val }) => Some(*val),
Some(_) => {
return Err(SchedulerError::PassError {
pass: "array-to-product".to_string(),
error: "expected integer argument".to_string(),
});
}
None => None,
};
pm.make_typing();
let typing = pm.typing.take().unwrap();
for (func, types) in build_selection(pm, selection, false)
.into_iter()
.zip(typing.iter())
{
let Some(mut func) = func else {
continue;
};
array_to_product(&mut func, types, max_size);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::AutoOutline => {
let Some(funcs) = selection_of_functions(pm, selection) else {
return Err(SchedulerError::PassError {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment