Skip to content
Snippets Groups Projects
Commit 1cdace83 authored by rarbore2's avatar rarbore2
Browse files

Einsum analysis

parent 31ed43d8
No related branches found
No related tags found
1 merge request!135Einsum analysis
use std::collections::{HashMap, HashSet};
use std::iter::zip;
use crate::*;
/*
* Math expressions are stored as a simple tree.
*/
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ForkDimension(pub usize, pub DynamicConstantID);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MathExpr {
// Zero constant of a specific type.
Zero(TypeID),
// One constant of a specific type.
One(TypeID),
// Opaque value corresponding to a particular node in the original IR.
OpaqueNode(NodeID),
// Thread ID from the fork corresponding to the reduce being expressed.
// Thread IDs from outside forks are considered OpaqueNodes.
ThreadID(ForkDimension),
// Sum reduction over a dimension of a fork.
SumReduction(MathID, Box<[ForkDimension]>),
// Comprehend a scalar expression into an array over fork dimensions.
Comprehension(MathID, Box<[ForkDimension]>),
// Read from an array.
Read(MathID, Box<[MathID]>),
// Math ops.
Add(MathID, MathID),
Sub(MathID, MathID),
Mul(MathID, MathID),
Div(MathID, MathID),
Rem(MathID, MathID),
}
pub type MathEnv = Vec<MathExpr>;
define_id_type!(MathID);
/*
* Top level function to run "einsum" analysis on fork-joins. This is a terrible
* name for this analysis, since it's actually more general than identifying
* einsums, but einsum syntax has a passing resemblance to the idea of this
* analysis and it's what we keep calling it, so we're doomed to this bad name.
* The idea of this analysis is to convert some fork-joins into pure math
* expressions that we can rewrite into intrinsic functions for higher level
* operators like matmul. Specifically, this function returns a map from each
* reduce node to a math expression.
*/
pub fn einsum(
function: &Function,
types: &Vec<Type>,
constants: &Vec<Constant>,
def_use: &ImmutableDefUseMap,
typing: &Vec<TypeID>,
fork_join_map: &HashMap<NodeID, NodeID>,
fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
data_nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
) -> (MathEnv, HashMap<NodeID, MathID>) {
let mut env = vec![];
let mut rev_env = HashMap::new();
let mut result = HashMap::new();
// Iterate fork-joins bottom-up, since we need to compute the math
// expressions of inner reduces before getting to outer reduces. Since fork-
// joins are strictly nested, we can literally iterate entries of
// `fork_join_nest` in decreasing order of nesting size to accomplish this.
let mut nests: Vec<_> = fork_join_nest
.into_iter()
.filter(|(id, _)| function.nodes[id.idx()].is_fork())
.collect();
nests.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
for fork in nests.into_iter().map(|(id, _)| *id) {
let Node::Fork {
control: _,
ref factors,
} = function.nodes[fork.idx()]
else {
panic!()
};
let join = fork_join_map[&fork];
let thread_ids: Vec<_> = def_use
.get_users(fork)
.into_iter()
.filter_map(|id| {
function.nodes[id.idx()]
.try_thread_id()
.map(|(_, dim)| (*id, dim))
})
.collect();
let reduces = def_use
.get_users(join)
.into_iter()
.filter_map(|id| function.nodes[id.idx()].try_reduce().map(|v| (*id, v)));
let mut ctx = EinsumContext {
function,
typing,
constants,
data_nodes_in_fork_joins,
fork,
factors,
thread_ids: &thread_ids,
so_far: &mut result,
env: &mut env,
rev_env: &mut rev_env,
};
// Compute a math expression for each reduce node in the fork-join with
// appropriate schedules.
for (reduce, (_, init, reduct)) in reduces {
// The reduce defines an array where each fork dimension corresponds
// to one array dimension.
if function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
&& let Node::Write {
collect,
data,
ref indices,
} = function.nodes[reduct.idx()]
&& collect == reduce
&& indices.len() == 1
&& let Some(indices) = indices[0].try_position()
&& let Some(dimension_bounds) = indices
.into_iter()
.map(|id| {
function.nodes[id.idx()]
.try_thread_id()
.filter(|(tid_fork, _)| *tid_fork == fork)
.map(|(_, dim)| dim)
})
.collect::<Option<Vec<usize>>>()
&& let Type::Array(_, ref array_bounds) = types[typing[reduce.idx()].idx()]
&& zip(array_bounds.into_iter(), dimension_bounds.iter())
.all(|(array, fork)| *array == factors[*fork])
{
let data_expr = ctx.compute_math_expr(data);
let reduce_expr = MathExpr::Comprehension(
data_expr,
dimension_bounds
.into_iter()
.map(|dim| ForkDimension(dim, factors[dim]))
.collect(),
);
// We don't need to consider the initializer, since the writes
// cover the whole array.
let total_id = ctx.intern_math_expr(reduce_expr);
ctx.result_insert(reduce, total_id);
}
// The reduce defines a sum reduction over a set of fork dimensions.
else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
&& let Node::Binary {
op: BinaryOperator::Add,
left,
right,
} = function.nodes[reduct.idx()]
&& (left == reduce || right == reduce)
{
let data_expr = ctx.compute_math_expr(if left == reduce { right } else { left });
let reduce_expr = MathExpr::SumReduction(
data_expr,
factors
.into_iter()
.enumerate()
.map(|(dim, factor)| ForkDimension(dim, *factor))
.collect(),
);
// Add the initializer.
let reduce_expr_id = ctx.intern_math_expr(reduce_expr);
let init_expr_id = ctx.compute_math_expr(init);
let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id);
let total_id = ctx.intern_math_expr(add_expr);
ctx.result_insert(reduce, total_id);
}
}
}
(env, result)
}
struct EinsumContext<'a> {
function: &'a Function,
typing: &'a Vec<TypeID>,
constants: &'a Vec<Constant>,
data_nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
fork: NodeID,
factors: &'a [DynamicConstantID],
thread_ids: &'a Vec<(NodeID, usize)>,
so_far: &'a mut HashMap<NodeID, MathID>,
env: &'a mut MathEnv,
rev_env: &'a mut HashMap<MathExpr, MathID>,
}
impl<'a> EinsumContext<'a> {
fn compute_math_expr(&mut self, id: NodeID) -> MathID {
let math_expr = match self.function.nodes[id.idx()] {
Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_zero() => {
MathExpr::Zero(self.typing[id.idx()])
}
Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_one() => {
MathExpr::One(self.typing[id.idx()])
}
Node::ThreadID { control, dimension } if control == self.fork => {
MathExpr::ThreadID(ForkDimension(dimension, self.factors[dimension]))
}
Node::Binary { op, left, right } if representable(op) => {
let left = self.compute_math_expr(left);
let right = self.compute_math_expr(right);
match op {
BinaryOperator::Add => MathExpr::Add(left, right),
BinaryOperator::Sub => MathExpr::Sub(left, right),
BinaryOperator::Mul => MathExpr::Mul(left, right),
BinaryOperator::Div => MathExpr::Div(left, right),
BinaryOperator::Rem => MathExpr::Rem(left, right),
_ => unreachable!(),
}
}
Node::Read {
collect,
ref indices,
} if indices.len() == 1
&& let Some(indices) = indices[0].try_position() =>
{
let collect = self.compute_math_expr(collect);
let indices = indices
.into_iter()
.map(|id| self.compute_math_expr(*id))
.collect();
MathExpr::Read(collect, indices)
}
Node::Reduce {
control: _,
init: _,
reduct: _,
} => {
let reduce = self.so_far[&id];
// Substitute opaque uses of thread ID nodes in inner expression
// with thread ID math expression, and increment inner-fork
// dimensions (alpha renaming).
return self.substitute_new_dims(reduce);
}
_ => MathExpr::OpaqueNode(id),
};
self.intern_math_expr(math_expr)
}
fn intern_math_expr(&mut self, expr: MathExpr) -> MathID {
if let Some(id) = self.rev_env.get(&expr) {
*id
} else {
let id = MathID::new(self.env.len());
self.env.push(expr.clone());
self.rev_env.insert(expr, id);
id
}
}
fn result_insert(&mut self, node: NodeID, math: MathID) {
self.so_far.insert(node, math);
}
fn substitute_new_dims(&mut self, id: MathID) -> MathID {
match self.env[id.idx()] {
MathExpr::OpaqueNode(opaque)
if let Some((_, dim)) = self
.thread_ids
.into_iter()
.filter(|(node, _)| *node == opaque)
.next() =>
{
self.intern_math_expr(MathExpr::ThreadID(ForkDimension(*dim, self.factors[*dim])))
}
MathExpr::ThreadID(dim) => self.intern_math_expr(MathExpr::ThreadID(ForkDimension(
dim.0 + self.factors.len(),
dim.1,
))),
MathExpr::SumReduction(id, ref dims) => {
let dims = dims
.into_iter()
.map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1))
.collect();
let id = self.substitute_new_dims(id);
self.intern_math_expr(MathExpr::SumReduction(id, dims))
}
MathExpr::Comprehension(id, ref dims) => {
let dims = dims
.into_iter()
.map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1))
.collect();
let id = self.substitute_new_dims(id);
self.intern_math_expr(MathExpr::Comprehension(id, dims))
}
MathExpr::Read(array, ref indices) => {
let indices = indices
.clone()
.iter()
.map(|id| self.substitute_new_dims(*id))
.collect();
let array = self.substitute_new_dims(array);
self.intern_math_expr(MathExpr::Read(array, indices))
}
MathExpr::Add(left, right) => {
let left = self.substitute_new_dims(left);
let right = self.substitute_new_dims(right);
self.intern_math_expr(MathExpr::Add(left, right))
}
MathExpr::Sub(left, right) => {
let left = self.substitute_new_dims(left);
let right = self.substitute_new_dims(right);
self.intern_math_expr(MathExpr::Sub(left, right))
}
MathExpr::Mul(left, right) => {
let left = self.substitute_new_dims(left);
let right = self.substitute_new_dims(right);
self.intern_math_expr(MathExpr::Mul(left, right))
}
MathExpr::Div(left, right) => {
let left = self.substitute_new_dims(left);
let right = self.substitute_new_dims(right);
self.intern_math_expr(MathExpr::Div(left, right))
}
MathExpr::Rem(left, right) => {
let left = self.substitute_new_dims(left);
let right = self.substitute_new_dims(right);
self.intern_math_expr(MathExpr::Rem(left, right))
}
_ => id,
}
}
fn debug_print_expr(&self, id: MathID) {
match self.env[id.idx()] {
MathExpr::Zero(_) => print!("0"),
MathExpr::One(_) => print!("1"),
MathExpr::OpaqueNode(id) => print!("{:?}", id),
MathExpr::ThreadID(dim) => print!("#{}", dim.0),
MathExpr::SumReduction(id, ref dims) => {
print!("Sum (");
for dim in dims {
print!("#{}/{:?},", dim.0, dim.1);
}
print!(") ");
self.debug_print_expr(id);
}
MathExpr::Comprehension(id, ref dims) => {
print!("[");
for dim in dims {
print!("#{}/{:?},", dim.0, dim.1);
}
print!("] ");
self.debug_print_expr(id);
}
MathExpr::Read(id, ref pos) => {
print!("read(");
self.debug_print_expr(id);
for pos in pos {
print!(", ");
self.debug_print_expr(*pos);
}
print!(")");
}
MathExpr::Add(left, right) => {
print!("+(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
MathExpr::Sub(left, right) => {
print!("-(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
MathExpr::Mul(left, right) => {
print!("*(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
MathExpr::Div(left, right) => {
print!("/(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
MathExpr::Rem(left, right) => {
print!("%(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
}
}
}
fn representable(op: BinaryOperator) -> bool {
match op {
BinaryOperator::Add
| BinaryOperator::Sub
| BinaryOperator::Mul
| BinaryOperator::Div
| BinaryOperator::Rem => true,
_ => false,
}
}
......@@ -2,6 +2,7 @@
coroutines,
coroutine_trait,
let_chains,
if_let_guard,
stmt_expr_attributes,
iter_intersperse
)]
......@@ -14,6 +15,7 @@ pub mod def_use;
pub mod device;
pub mod dom;
pub mod dot;
pub mod einsum;
pub mod fork_join_analysis;
pub mod ir;
pub mod loops;
......@@ -30,6 +32,7 @@ pub use crate::def_use::*;
pub use crate::device::*;
pub use crate::dom::*;
pub use crate::dot::*;
pub use crate::einsum::*;
pub use crate::fork_join_analysis::*;
pub use crate::ir::*;
pub use crate::loops::*;
......
......@@ -6,6 +6,8 @@ auto-outline(*);
ip-sroa(*);
sroa(*);
dce(*);
infer-schedules(*);
fork-split(*);
unforkify(*);
dce(*);
......
......@@ -4,8 +4,6 @@ fn main() {
JunoCompiler::new()
.file_in_src("matmul.jn")
.unwrap()
//.schedule_in_src("sched.sch")
//.unwrap()
.build()
.unwrap();
}
macro juno-setup!(X) {
gvn(X);
dce(X);
phi-elim(X);
}
macro default!(X) {
dce(X);
crc(X);
dce(X);
slf(X);
dce(X);
inline(X);
ip-sroa(X);
sroa(X);
phi-elim(X);
dce(X);
ccp(X);
dce(X);
gvn(X);
dce(X);
write-predication(X);
phi-elim(X);
dce(X);
crc(X);
dce(X);
slf(X);
dce(X);
predication(X);
dce(X);
ccp(X);
dce(X);
gvn(X);
dce(X);
lift-dc-math(X);
dce(X);
gvn(X);
dce(X);
}
macro codegen-prep!(X) {
verify(*);
ip-sroa(*);
sroa(*);
infer-schedules(X);
dce(X);
gcm(X);
dce(X);
phi-elim(X);
float-collections(X);
gcm(X);
}
juno-setup!(*);
default!(*);
// your stuff here.
fixpoint stop after 13 {
forkify(*);
fork-guard-elim(*);
fork-coalesce(*);
phi-elim(*);
dce(*);
}
xdot[true](*);
// serialize(*);
fork-split(*);
unforkify(*);
gvn(*);
dce(*);
auto-outline(*);
codegen-prep!(*);
......@@ -181,6 +181,7 @@ pub struct PassManager {
pub loops: Option<Vec<LoopTree>>,
pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>,
pub collection_objects: Option<CollectionObjects>,
pub callgraph: Option<CallGraph>,
pub devices: Option<Vec<Device>>,
......@@ -216,6 +217,7 @@ impl PassManager {
loops: None,
reduce_cycles: None,
data_nodes_in_fork_joins: None,
reduce_einsums: None,
collection_objects: None,
callgraph: None,
devices: None,
......@@ -392,6 +394,48 @@ impl PassManager {
}
}
pub fn make_reduce_einsums(&mut self) {
if self.reduce_einsums.is_none() {
self.make_def_uses();
self.make_typing();
self.make_fork_join_maps();
self.make_fork_join_nests();
self.make_data_nodes_in_fork_joins();
let def_uses = self.def_uses.as_ref().unwrap().iter();
let typing = self.typing.as_ref().unwrap().iter();
let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter();
let data_nodes_in_fork_joins = self.data_nodes_in_fork_joins.as_ref().unwrap().iter();
self.reduce_einsums = Some(
self.functions
.iter()
.zip(def_uses)
.zip(typing)
.zip(fork_join_maps)
.zip(fork_join_nests)
.zip(data_nodes_in_fork_joins)
.map(
|(
((((function, def_use), typing), fork_join_map), fork_join_nest),
data_nodes_in_fork_joins,
)| {
einsum(
function,
&self.types.borrow(),
&self.constants.borrow(),
def_use,
typing,
fork_join_map,
fork_join_nest,
data_nodes_in_fork_joins,
)
},
)
.collect(),
);
}
}
pub fn make_collection_objects(&mut self) {
if self.collection_objects.is_none() {
self.make_reverse_postorders();
......@@ -463,6 +507,7 @@ impl PassManager {
self.loops = None;
self.reduce_cycles = None;
self.data_nodes_in_fork_joins = None;
self.reduce_einsums = None;
self.collection_objects = None;
self.callgraph = None;
self.devices = None;
......@@ -643,7 +688,7 @@ impl PassManager {
}
pub fn schedule_codegen(
mut module: Module,
module: Module,
schedule: ScheduleStmt,
mut stringtab: StringTable,
mut env: Env<usize, Value>,
......@@ -913,8 +958,8 @@ fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo>
.labels
.iter()
.enumerate()
.filter(|(i, ls)| ls.contains(&label))
.map(|(i, ls)| i)
.filter(|(_, ls)| ls.contains(&label))
.map(|(i, _)| i)
.collect::<Vec<_>>();
for node in nodes {
pm.functions[func.idx()].schedules[node].push(sched.clone());
......@@ -1350,7 +1395,7 @@ fn run_pass(
let Some(mut func) = func else {
continue;
};
// TODO: uses direct return from forkify for now instead of
// TODO: uses direct return from forkify for now instead of
// func.modified, see comment on top of `forkify` for why. Fix
// this eventually.
changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment