Skip to content
Snippets Groups Projects
Commit c84c8c7a authored by Aaron Councilman's avatar Aaron Councilman Committed by rarbore2
Browse files

Array to prod

parent d5985d59
No related branches found
No related tags found
1 merge request!174Array to prod
......@@ -1231,6 +1231,16 @@ dependencies = [
"with_builtin_macros",
]
[[package]]
name = "juno_median_window"
version = "0.1.0"
dependencies = [
"async-std",
"hercules_rt",
"juno_build",
"with_builtin_macros",
]
[[package]]
name = "juno_multi_device"
version = "0.1.0"
......
......@@ -34,4 +34,5 @@ members = [
"juno_samples/fork_join_tests",
"juno_samples/multi_device",
"juno_samples/products",
"juno_samples/median_window",
]
use hercules_ir::define_id_type;
use hercules_ir::ir::*;
use bitvec::prelude::*;
use crate::*;
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
/*
* Top level function for array to product which will convert constant
* sized arrays into products if the array is only accessed at indices which
* are constants.
*
* To identify the collections we can convert we look at each constant-sized
* array constant and compute the set which includes the constant node and is
* closed under the following properties:
* - For each collection in the set, its uses are in the set
* - For each node that uses a collection, all collections it uses are in the
* set
* From this set, we then determine whether this whole set can be converted to
* operating on products, rather than arrays, as follows
* - Each read and write node must be to a constant index
* - It may not contain any arguments (we could generate code to read a an
* array argument into a product, but do not do so for now)
* - There are call or return nodes in the set (this would mean that the
* collections are consumed by a call or return, again we could reconstruct
* the array where needed but do not do so for now and so have this
* restriction)
* - All nodes in the set are editable (if we cannot modify some node then the
* conversion will fail)
*
* The max_size argument allows the user to specify a limit on the size of arrays
* that should be converted to products. If the number of elements in the array
* is larger than the max size the array will not be converted.
*/
pub fn array_to_product(editor: &mut FunctionEditor, types: &[TypeID], max_size: Option<usize>) {
let replace_nodes = array_usage_analysis(editor, types, max_size);
let num_nodes = editor.func().nodes.len();
// Replace nodes
for node_idx in 0..num_nodes {
if !replace_nodes[node_idx] {
continue;
}
let node = NodeID::new(node_idx);
// We can replace the array(s) this node uses with a product. What we have to do depends on
// the type of the node
match &editor.func().nodes[node_idx] {
// Phi, Reduce, and Ternary just use the whole collection, they do not need to change,
// except as they will be modified by replace_all_uses_of
Node::Phi { .. }
| Node::Reduce { .. }
| Node::Ternary {
op: TernaryOperator::Select,
..
} => {}
Node::Constant { id } => {
assert!(editor.get_constant(*id).is_array());
let element: TypeID = editor.get_type(types[node_idx]).try_element_type().unwrap();
let dims: Vec<usize> = editor
.get_type(types[node_idx])
.try_extents()
.unwrap()
.iter()
.map(|dc| editor.get_dynamic_constant(*dc).try_constant().unwrap())
.collect();
// Replace the constant by a product that is a product (for each dimension) and the
// elements are zero'd
editor.edit(|mut edit| {
let element_zero = edit.add_zero_constant(element);
let (constant, _) = dims.into_iter().rfold(
(element_zero, element),
|(cur_const, cur_type), dim| {
let new_type = edit.add_type(Type::Product(vec![cur_type; dim].into()));
let new_const = edit.add_constant(Constant::Product(
new_type,
vec![cur_const; dim].into(),
));
(new_const, new_type)
},
);
let new_val = edit.add_node(Node::Constant { id: constant });
let edit = edit.replace_all_uses(node, new_val)?;
edit.delete_node(node)
});
}
Node::Read { collect, indices } => {
let collect = *collect;
let new_indices = convert_indices_to_prod(editor, indices);
editor.edit(|mut edit| {
let new_val = edit.add_node(Node::Read {
collect,
indices: new_indices,
});
let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?;
edit.delete_node(node)
});
}
Node::Write {
collect,
data,
indices,
} => {
let collect = *collect;
let data = *data;
let new_indices = convert_indices_to_prod(editor, indices);
editor.edit(|mut edit| {
let new_val = edit.add_node(Node::Write {
collect,
data,
indices: new_indices,
});
let edit = edit.replace_all_uses(NodeID::new(node_idx), new_val)?;
edit.delete_node(node)
});
}
node => panic!("Node cannot be replaced: {:?}", node),
}
}
}
fn convert_indices_to_prod(editor: &FunctionEditor, indices: &[Index]) -> Box<[Index]> {
let mut result = vec![];
for index in indices {
match index {
Index::Position(positions) => {
for pos in positions {
let const_id = editor.func().nodes[pos.idx()]
.try_constant()
.expect("Array position must be constant");
match *editor.get_constant(const_id) {
Constant::UnsignedInteger64(idx) => result.push(Index::Field(idx as usize)),
ref val => panic!("Position should be u64 constant: {:?}", val),
}
}
}
index => panic!("Index cannot be replaced: {:?}", index),
}
}
result.into()
}
// Given the editor, while compute a mask of which nodes are to be converted
// from using a constant sized array into using a product
fn array_usage_analysis(
editor: &FunctionEditor,
types: &[TypeID],
max_size: Option<usize>,
) -> BitVec<u8, Lsb0> {
let num_nodes = editor.func().nodes.len();
// Step 1: identify the constant nodes that are constant sized arrays no larger than the
// max_size, these are what we are interested in converting into products
let sources = editor
.func()
.nodes
.iter()
.enumerate()
.filter_map(|(idx, node)| {
let Node::Constant { id } = node else {
return None;
};
let Constant::Array(array_type) = *editor.get_constant(*id) else {
return None;
};
let typ = editor.get_type(array_type);
let Some(dims) = typ.try_extents() else {
return None;
};
// Compute the total number of elements, the result is None if some dimension is not a
// constant and otherwise is Some(num_elements) which we can then compare with max_size
if let Some(elements) = dims.iter().fold(Some(1), |prod, dc| {
prod.and_then(|prod| {
editor
.get_dynamic_constant(*dc)
.try_constant()
.map(|dim| prod * dim)
})
}) {
if let Some(max_size) = max_size
&& elements > max_size
{
// Too many elements, don't convert
None
} else {
Some(NodeID::new(idx))
}
} else {
None
}
})
.collect::<Vec<_>>();
// Step 2: collect the collection information we need for the (whole) function. For each node
// that returns a collection (that in reference semantics returns the same reference as some of
// its inputs) union with all of its users. The nodes that matter in this are arguments,
// constants, writes, phis, selects, and reduces with array types.
let mut analysis = UnionFind::new();
for node_idx in 0..num_nodes {
let node_id = NodeID::new(node_idx);
if editor.get_type(types[node_idx]).is_array() {
match editor.func().nodes[node_idx] {
Node::Phi { .. }
| Node::Reduce { .. }
| Node::Parameter { .. }
| Node::Constant { .. }
| Node::Ternary {
op: TernaryOperator::Select,
..
}
| Node::Write { .. } => {
for user in editor.get_users(node_id) {
analysis.union(node_id, user);
}
}
_ => {}
}
}
}
let sets = analysis.sets(&sources);
// Step 3: determine which sets can be converted and mark the nodes in those sets
let mut result = bitvec![u8, Lsb0; 0; num_nodes];
for nodes in sets {
if nodes
.iter()
.all(|node_id| editor.is_mutable(*node_id) && can_replace(editor, *node_id))
{
for node_id in nodes {
result.set(node_id.idx(), true);
}
}
}
result
}
fn can_replace(editor: &FunctionEditor, node: NodeID) -> bool {
match &editor.func().nodes[node.idx()] {
// Reads and writes must be at constant indices
Node::Read { indices, .. } | Node::Write { indices, .. } => {
indices.iter().all(|idx| match idx {
Index::Position(pos) => pos
.iter()
.all(|node| editor.func().nodes[node.idx()].is_constant()),
_ => false,
})
}
// phi, reduce, constants, and select can always be replaced if their users and uses allow
// it, which is handled by the construction of the set
Node::Phi { .. }
| Node::Reduce { .. }
| Node::Constant { .. }
| Node::Ternary {
op: TernaryOperator::Select,
..
} => true,
// No other nodes allow replacement
_ => false,
}
}
define_id_type!(SetID);
#[derive(Clone, Debug)]
struct UnionFindNode {
parent: SetID,
rank: usize,
}
#[derive(Clone, Debug)]
struct UnionFind<T> {
sets: Vec<UnionFindNode>,
_phantom: PhantomData<T>,
}
impl<T: ID> UnionFind<T> {
pub fn new() -> Self {
UnionFind {
sets: vec![],
_phantom: PhantomData,
}
}
fn extend_past(&mut self, size: usize) {
for i in self.sets.len()..=size {
// The new nodes we add are in their own sets and have rank 0
self.sets.push(UnionFindNode {
parent: SetID::new(i),
rank: 0,
});
}
}
pub fn find(&mut self, x: T) -> SetID {
self.extend_past(x.idx());
self.find_set(x.idx())
}
fn find_set(&mut self, x: usize) -> SetID {
let mut parent = self.sets[x].parent;
if parent.idx() != x {
parent = self.find_set(parent.idx());
self.sets[x].parent = parent;
}
parent
}
pub fn union(&mut self, x: T, y: T) {
let x = self.find(x);
let y = self.find(y);
self.link(x, y);
}
fn link(&mut self, x: SetID, y: SetID) {
if self.sets[x.idx()].rank > self.sets[y.idx()].rank {
self.sets[y.idx()].parent = x;
} else {
self.sets[x.idx()].parent = y;
if self.sets[x.idx()].rank == self.sets[y.idx()].rank {
self.sets[y.idx()].rank += 1;
}
}
}
pub fn sets(&mut self, keys: &[T]) -> Vec<Vec<T>> {
let key_index = keys
.iter()
.enumerate()
.map(|(i, k)| (self.find(*k), i))
.collect::<HashMap<SetID, usize>>();
let mut result = vec![vec![]; keys.len()];
let num_elements = self.sets.len();
for i in 0..num_elements {
let set = self.find_set(i);
let Some(idx) = key_index.get(&set) else {
continue;
};
result[*idx].push(T::new(i));
}
result
}
}
#![feature(let_chains)]
pub mod array_to_prod;
pub mod ccp;
pub mod crc;
pub mod dce;
......@@ -26,6 +27,7 @@ pub mod sroa;
pub mod unforkify;
pub mod utils;
pub use crate::array_to_prod::*;
pub use crate::ccp::*;
pub use crate::crc::*;
pub use crate::dce::*;
......
......@@ -23,13 +23,6 @@ pub fn lift_dc_math(editor: &mut FunctionEditor) {
};
DynamicConstant::Constant(cons as usize)
}
Node::DynamicConstant { id } => {
let Some(cons) = evaluate_dynamic_constant(id, &*editor.get_dynamic_constants())
else {
continue;
};
DynamicConstant::Constant(cons)
}
Node::Binary { op, left, right } => {
let (left, right) = if let (
Node::DynamicConstant { id: left },
......
[package]
name = "juno_median_window"
version = "0.1.0"
authors = ["Aaron Councilman <aaronjc4@illinois.edu>"]
edition = "2021"
[[bin]]
name = "juno_median_window"
path = "src/main.rs"
[features]
cuda = ["juno_build/cuda", "hercules_rt/cuda"]
[build-dependencies]
juno_build = { path = "../../juno_build" }
[dependencies]
juno_build = { path = "../../juno_build" }
hercules_rt = { path = "../../hercules_rt" }
with_builtin_macros = "0.1.0"
async-std = "*"
use juno_build::JunoCompiler;
fn main() {
JunoCompiler::new()
.file_in_src("median.jn")
.unwrap()
.schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" })
.unwrap()
.build()
.unwrap();
}
gvn(*);
phi-elim(*);
dce(*);
inline(*);
delete-uncalled(*);
let out = auto-outline(*);
cpu(out.median_window);
ip-sroa(*);
sroa(*);
dce(*);
gvn(*);
phi-elim(*);
dce(*);
forkify(*);
fork-guard-elim(*);
lift-dc-math(*);
forkify(*);
fork-guard-elim(*);
fork-unroll(out.median_window@outer);
lift-dc-math(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-unroll(*);
}
ccp(*);
gvn(*);
dce(*);
array-to-product(*);
sroa(*);
phi-elim(*);
predication(*);
simplify-cfg(*);
dce(*);
gvn(*);
gcm(*);
gvn(*);
phi-elim(*);
dce(*);
inline(*);
delete-uncalled(*);
let out = auto-outline(*);
gpu(out.median_window);
ip-sroa(*);
sroa(*);
dce(*);
gvn(*);
phi-elim(*);
dce(*);
forkify(*);
fork-guard-elim(*);
lift-dc-math(*);
forkify(*);
fork-guard-elim(*);
fork-unroll(out.median_window@outer);
lift-dc-math(*);
fixpoint {
forkify(*);
fork-guard-elim(*);
fork-unroll(*);
}
ccp(*);
gvn(*);
dce(*);
array-to-product(*);
sroa(*);
phi-elim(*);
predication(*);
simplify-cfg(*);
dce(*);
gvn(*);
gcm(*);
#![feature(concat_idents)]
juno_build::juno!("median");
use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo};
fn main() {
let m = vec![86, 72, 14, 5, 55,
25, 98, 89, 3, 66,
44, 81, 27, 3, 40,
18, 4, 57, 93, 34,
70, 50, 50, 18, 34];
let m = HerculesImmBox::from(m.as_slice());
let mut r = runner!(median_window);
let res =
async_std::task::block_on(async {
r.run(m.to()).await
});
assert_eq!(res, 57);
}
#[test]
fn test_median_window() {
main()
}
fn median_matrix<t: number, n, m: usize>(x: t[n, m]) -> t {
let tmp : t[n * m];
for i = 0 to n {
for j = 0 to m {
tmp[i * m + j] = x[i, j];
}
}
const cnt = n * m;
@outer for i = 0 to cnt - 1 {
for j = 0 to cnt - i - 1 {
if tmp[j] > tmp[j + 1] {
let t = tmp[j];
tmp[j] = tmp[j + 1];
tmp[j + 1] = t;
}
}
}
return tmp[cnt / 2];
}
#[entry]
fn median_window(x: i32[5, 5]) -> i32 {
let window: i32[3, 3];
for i = 0 to 3 {
for j = 0 to 3 {
window[i, j] = x[i + 1, j + 1];
}
}
return median_matrix::<_, 3, 3>(window);
}
......@@ -18,7 +18,7 @@ pub enum ScheduleCompilerError {
UndefinedMacro(String, Location),
NoSuchPass(String, Location),
IncorrectArguments {
expected: usize,
expected: String,
actual: usize,
loc: Location,
},
......@@ -81,11 +81,20 @@ enum Appliable {
}
impl Appliable {
fn num_args(&self) -> usize {
// Tests whether a given number of arguments is a valid number of arguments for this
fn is_valid_num_args(&self, num: usize) -> bool {
match self {
Appliable::Pass(pass) => pass.num_args(),
Appliable::Pass(pass) => pass.is_valid_num_args(num),
// Delete uncalled, Schedules, and devices do not take arguments
_ => 0,
_ => num == 0,
}
}
// Returns a description of the number of arguments this requires
fn valid_arg_nums(&self) -> &'static str {
match self {
Appliable::Pass(pass) => pass.valid_arg_nums(),
_ => "0",
}
}
}
......@@ -96,6 +105,9 @@ impl FromStr for Appliable {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"array-slf" => Ok(Appliable::Pass(ir::Pass::ArraySLF)),
"array-to-product" | "array-to-prod" | "a2p" => {
Ok(Appliable::Pass(ir::Pass::ArrayToProduct))
}
"auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)),
"ccp" => Ok(Appliable::Pass(ir::Pass::CCP)),
"crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)),
......@@ -306,9 +318,9 @@ fn compile_expr(
.parse()
.map_err(|s| ScheduleCompilerError::NoSuchPass(s, lexer.line_col(name)))?;
if args.len() != func.num_args() {
if !func.is_valid_num_args(args.len()) {
return Err(ScheduleCompilerError::IncorrectArguments {
expected: func.num_args(),
expected: func.valid_arg_nums().to_string(),
actual: args.len(),
loc: lexer.line_col(span),
});
......@@ -365,7 +377,7 @@ fn compile_expr(
if args.len() != params.len() {
return Err(ScheduleCompilerError::IncorrectArguments {
expected: params.len(),
expected: params.len().to_string(),
actual: args.len(),
loc: lexer.line_col(span),
});
......
......@@ -3,6 +3,7 @@ use hercules_ir::ir::{Device, Schedule};
#[derive(Debug, Copy, Clone)]
pub enum Pass {
ArraySLF,
ArrayToProduct,
AutoOutline,
CCP,
CRC,
......@@ -39,13 +40,25 @@ pub enum Pass {
}
impl Pass {
pub fn num_args(&self) -> usize {
pub fn is_valid_num_args(&self, num: usize) -> bool {
match self {
Pass::Xdot => 1,
Pass::ForkChunk => 4,
Pass::ForkFissionBufferize => 2,
Pass::ForkInterchange => 2,
_ => 0,
Pass::ArrayToProduct => num == 0 || num == 1,
Pass::Xdot => num == 0 || num == 1,
Pass::ForkChunk => num == 4,
Pass::ForkFissionBufferize => num == 2,
Pass::ForkInterchange => num == 2,
_ => num == 0,
}
}
pub fn valid_arg_nums(&self) -> &'static str {
match self {
Pass::ArrayToProduct => "0 or 1",
Pass::Xdot => "0 or 1",
Pass::ForkChunk => "4",
Pass::ForkFissionBufferize => "2",
Pass::ForkInterchange => "2",
_ => "0",
}
}
}
......
......@@ -1521,6 +1521,34 @@ fn run_pass(
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::ArrayToProduct => {
assert!(args.len() <= 1);
let max_size = match args.get(0) {
Some(Value::Integer { val }) => Some(*val),
Some(_) => {
return Err(SchedulerError::PassError {
pass: "array-to-product".to_string(),
error: "expected integer argument".to_string(),
});
}
None => None,
};
pm.make_typing();
let typing = pm.typing.take().unwrap();
for (func, types) in build_selection(pm, selection, false)
.into_iter()
.zip(typing.iter())
{
let Some(mut func) = func else {
continue;
};
array_to_product(&mut func, types, max_size);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::AutoOutline => {
let Some(funcs) = selection_of_functions(pm, selection) else {
return Err(SchedulerError::PassError {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment