Skip to content
Snippets Groups Projects

Array store-to-load forwarding

Merged rarbore2 requested to merge fork_slf into main
Files
14
+ 92
117
@@ -31,11 +31,9 @@ pub enum MathExpr {
@@ -31,11 +31,9 @@ pub enum MathExpr {
Read(MathID, Box<[MathID]>),
Read(MathID, Box<[MathID]>),
// Math ops.
// Math ops.
Add(MathID, MathID),
Unary(UnaryOperator, MathID),
Sub(MathID, MathID),
Binary(BinaryOperator, MathID, MathID),
Mul(MathID, MathID),
Ternary(TernaryOperator, MathID, MathID, MathID),
Div(MathID, MathID),
Rem(MathID, MathID),
}
}
pub type MathEnv = Vec<MathExpr>;
pub type MathEnv = Vec<MathExpr>;
@@ -171,7 +169,7 @@ pub fn einsum(
@@ -171,7 +169,7 @@ pub fn einsum(
// Add the initializer.
// Add the initializer.
let reduce_expr_id = ctx.intern_math_expr(reduce_expr);
let reduce_expr_id = ctx.intern_math_expr(reduce_expr);
let init_expr_id = ctx.compute_math_expr(init);
let init_expr_id = ctx.compute_math_expr(init);
let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id);
let add_expr = MathExpr::Binary(BinaryOperator::Add, init_expr_id, reduce_expr_id);
let total_id = ctx.intern_math_expr(add_expr);
let total_id = ctx.intern_math_expr(add_expr);
ctx.result_insert(reduce, total_id);
ctx.result_insert(reduce, total_id);
}
}
@@ -206,17 +204,25 @@ impl<'a> EinsumContext<'a> {
@@ -206,17 +204,25 @@ impl<'a> EinsumContext<'a> {
Node::ThreadID { control, dimension } if control == self.fork => {
Node::ThreadID { control, dimension } if control == self.fork => {
MathExpr::ThreadID(ForkDimension(dimension, self.factors[dimension]))
MathExpr::ThreadID(ForkDimension(dimension, self.factors[dimension]))
}
}
Node::Binary { op, left, right } if representable(op) => {
Node::Unary { op, input } => {
 
let input = self.compute_math_expr(input);
 
MathExpr::Unary(op, input)
 
}
 
Node::Binary { op, left, right } => {
let left = self.compute_math_expr(left);
let left = self.compute_math_expr(left);
let right = self.compute_math_expr(right);
let right = self.compute_math_expr(right);
match op {
MathExpr::Binary(op, left, right)
BinaryOperator::Add => MathExpr::Add(left, right),
}
BinaryOperator::Sub => MathExpr::Sub(left, right),
Node::Ternary {
BinaryOperator::Mul => MathExpr::Mul(left, right),
op,
BinaryOperator::Div => MathExpr::Div(left, right),
first,
BinaryOperator::Rem => MathExpr::Rem(left, right),
second,
_ => unreachable!(),
third,
}
} => {
 
let first = self.compute_math_expr(first);
 
let second = self.compute_math_expr(second);
 
let third = self.compute_math_expr(third);
 
MathExpr::Ternary(op, first, second, third)
}
}
Node::Read {
Node::Read {
collect,
collect,
@@ -301,103 +307,24 @@ impl<'a> EinsumContext<'a> {
@@ -301,103 +307,24 @@ impl<'a> EinsumContext<'a> {
let array = self.substitute_new_dims(array);
let array = self.substitute_new_dims(array);
self.intern_math_expr(MathExpr::Read(array, indices))
self.intern_math_expr(MathExpr::Read(array, indices))
}
}
MathExpr::Add(left, right) => {
MathExpr::Unary(op, input) => {
let left = self.substitute_new_dims(left);
let input = self.substitute_new_dims(input);
let right = self.substitute_new_dims(right);
self.intern_math_expr(MathExpr::Unary(op, input))
self.intern_math_expr(MathExpr::Add(left, right))
}
}
MathExpr::Sub(left, right) => {
MathExpr::Binary(op, left, right) => {
let left = self.substitute_new_dims(left);
let left = self.substitute_new_dims(left);
let right = self.substitute_new_dims(right);
let right = self.substitute_new_dims(right);
self.intern_math_expr(MathExpr::Sub(left, right))
self.intern_math_expr(MathExpr::Binary(op, left, right))
}
}
MathExpr::Mul(left, right) => {
MathExpr::Ternary(op, first, second, third) => {
let left = self.substitute_new_dims(left);
let first = self.substitute_new_dims(first);
let right = self.substitute_new_dims(right);
let second = self.substitute_new_dims(second);
self.intern_math_expr(MathExpr::Mul(left, right))
let third = self.substitute_new_dims(third);
}
self.intern_math_expr(MathExpr::Ternary(op, first, second, third))
MathExpr::Div(left, right) => {
let left = self.substitute_new_dims(left);
let right = self.substitute_new_dims(right);
self.intern_math_expr(MathExpr::Div(left, right))
}
MathExpr::Rem(left, right) => {
let left = self.substitute_new_dims(left);
let right = self.substitute_new_dims(right);
self.intern_math_expr(MathExpr::Rem(left, right))
}
}
_ => id,
_ => id,
}
}
}
}
fn debug_print_expr(&self, id: MathID) {
match self.env[id.idx()] {
MathExpr::Zero(_) => print!("0"),
MathExpr::One(_) => print!("1"),
MathExpr::OpaqueNode(id) => print!("{:?}", id),
MathExpr::ThreadID(dim) => print!("#{}", dim.0),
MathExpr::SumReduction(id, ref dims) => {
print!("Sum (");
for dim in dims {
print!("#{}/{:?},", dim.0, dim.1);
}
print!(") ");
self.debug_print_expr(id);
}
MathExpr::Comprehension(id, ref dims) => {
print!("[");
for dim in dims {
print!("#{}/{:?},", dim.0, dim.1);
}
print!("] ");
self.debug_print_expr(id);
}
MathExpr::Read(id, ref pos) => {
print!("read(");
self.debug_print_expr(id);
for pos in pos {
print!(", ");
self.debug_print_expr(*pos);
}
print!(")");
}
MathExpr::Add(left, right) => {
print!("+(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
MathExpr::Sub(left, right) => {
print!("-(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
MathExpr::Mul(left, right) => {
print!("*(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
MathExpr::Div(left, right) => {
print!("/(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
MathExpr::Rem(left, right) => {
print!("%(");
self.debug_print_expr(left);
print!(", ");
self.debug_print_expr(right);
print!(")");
}
}
}
}
}
pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
@@ -416,26 +343,74 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
@@ -416,26 +343,74 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
stack.push(id);
stack.push(id);
stack.extend(ids);
stack.extend(ids);
}
}
MathExpr::Add(left, right)
MathExpr::Unary(_, input) => {
| MathExpr::Sub(left, right)
stack.push(input);
| MathExpr::Mul(left, right)
}
| MathExpr::Div(left, right)
MathExpr::Binary(_, left, right) => {
| MathExpr::Rem(left, right) => {
stack.push(left);
stack.push(left);
stack.push(right);
stack.push(right);
}
}
 
MathExpr::Ternary(_, first, second, third) => {
 
stack.push(first);
 
stack.push(second);
 
stack.push(third);
 
}
}
}
}
}
set
set
}
}
fn representable(op: BinaryOperator) -> bool {
pub fn debug_print_math_expr(id: MathID, env: &MathEnv) {
match op {
match env[id.idx()] {
BinaryOperator::Add
MathExpr::Zero(_) => print!("0"),
| BinaryOperator::Sub
MathExpr::One(_) => print!("1"),
| BinaryOperator::Mul
MathExpr::OpaqueNode(id) => print!("{:?}", id),
| BinaryOperator::Div
MathExpr::ThreadID(dim) => print!("#{}", dim.0),
| BinaryOperator::Rem => true,
MathExpr::SumReduction(id, ref dims) => {
_ => false,
print!("Sum (");
 
for dim in dims {
 
print!("#{}/{:?},", dim.0, dim.1);
 
}
 
print!(") ");
 
debug_print_math_expr(id, env);
 
}
 
MathExpr::Comprehension(id, ref dims) => {
 
print!("[");
 
for dim in dims {
 
print!("#{}/{:?},", dim.0, dim.1);
 
}
 
print!("] ");
 
debug_print_math_expr(id, env);
 
}
 
MathExpr::Read(id, ref pos) => {
 
print!("read(");
 
debug_print_math_expr(id, env);
 
for pos in pos {
 
print!(", ");
 
debug_print_math_expr(*pos, env);
 
}
 
print!(")");
 
}
 
MathExpr::Unary(op, input) => {
 
print!("{}(", op.lower_case_name());
 
debug_print_math_expr(input, env);
 
print!(")");
 
}
 
MathExpr::Binary(op, left, right) => {
 
print!("{}(", op.lower_case_name());
 
debug_print_math_expr(left, env);
 
print!(", ");
 
debug_print_math_expr(right, env);
 
print!(")");
 
}
 
MathExpr::Ternary(op, first, second, third) => {
 
print!("{}(", op.lower_case_name());
 
debug_print_math_expr(first, env);
 
print!(", ");
 
debug_print_math_expr(second, env);
 
print!(", ");
 
debug_print_math_expr(third, env);
 
print!(")");
 
}
}
}
}
}
Loading