Skip to content
Snippets Groups Projects
Commit 0b1a9eb3 authored by Russel Arbore's avatar Russel Arbore
Browse files

Add intrinsics to math exprs, array slf gets rid of first ctrl pts

look in gamut
parent 0958995e
No related branches found
No related tags found
1 merge request!178Some Cava optimization
Pipeline #201655 passed
......@@ -34,6 +34,7 @@ pub enum MathExpr {
Unary(UnaryOperator, MathID),
Binary(BinaryOperator, MathID, MathID),
Ternary(TernaryOperator, MathID, MathID, MathID),
IntrinsicFunc(Intrinsic, Box<[MathID]>),
}
pub type MathEnv = Vec<MathExpr>;
......@@ -224,6 +225,16 @@ impl<'a> EinsumContext<'a> {
let third = self.compute_math_expr(third);
MathExpr::Ternary(op, first, second, third)
}
Node::IntrinsicCall {
intrinsic,
ref args,
} => {
let args = args
.into_iter()
.map(|id| self.compute_math_expr(*id))
.collect();
MathExpr::IntrinsicFunc(intrinsic, args)
}
Node::Read {
collect,
ref indices,
......@@ -322,6 +333,14 @@ impl<'a> EinsumContext<'a> {
let third = self.substitute_new_dims(third);
self.intern_math_expr(MathExpr::Ternary(op, first, second, third))
}
MathExpr::IntrinsicFunc(intrinsic, ref args) => {
let args = args
.clone()
.iter()
.map(|id| self.substitute_new_dims(*id))
.collect();
self.intern_math_expr(MathExpr::IntrinsicFunc(intrinsic, args))
}
_ => id,
}
}
......@@ -355,6 +374,9 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
stack.push(second);
stack.push(third);
}
MathExpr::IntrinsicFunc(_, ref args) => {
stack.extend(args);
}
}
}
set
......@@ -412,5 +434,14 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) {
debug_print_math_expr(third, env);
print!(")");
}
MathExpr::IntrinsicFunc(intrinsic, ref args) => {
print!("{}(", intrinsic.lower_case_name());
debug_print_math_expr(id, env);
for arg in args {
print!(", ");
debug_print_math_expr(*arg, env);
}
print!(")");
}
}
}
......@@ -474,6 +474,13 @@ pub fn materialize_simple_einsum_expr(
third,
})
}
MathExpr::IntrinsicFunc(intrinsic, ref args) => {
let args = args
.into_iter()
.map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs))
.collect();
edit.add_node(Node::IntrinsicCall { intrinsic, args })
}
_ => panic!(),
}
}
......
......@@ -102,6 +102,9 @@ simpl!(fuse4);
fork-unroll(fuse4@channel_loop);
simpl!(fuse4);
//fork-fusion(fuse4@channel_loop);
simpl!(fuse4);
array-slf(fuse4);
simpl!(fuse4);
fork-split(fuse4);
unforkify(fuse4);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment