Skip to content
Snippets Groups Projects

Some Cava optimization

Merged rarbore2 requested to merge cava_opt into main
Files
4
+ 31
0
@@ -34,6 +34,7 @@ pub enum MathExpr {
Unary(UnaryOperator, MathID),
Binary(BinaryOperator, MathID, MathID),
Ternary(TernaryOperator, MathID, MathID, MathID),
IntrinsicFunc(Intrinsic, Box<[MathID]>),
}
pub type MathEnv = Vec<MathExpr>;
@@ -224,6 +225,16 @@ impl<'a> EinsumContext<'a> {
let third = self.compute_math_expr(third);
MathExpr::Ternary(op, first, second, third)
}
Node::IntrinsicCall {
intrinsic,
ref args,
} => {
let args = args
.into_iter()
.map(|id| self.compute_math_expr(*id))
.collect();
MathExpr::IntrinsicFunc(intrinsic, args)
}
Node::Read {
collect,
ref indices,
@@ -322,6 +333,14 @@ impl<'a> EinsumContext<'a> {
let third = self.substitute_new_dims(third);
self.intern_math_expr(MathExpr::Ternary(op, first, second, third))
}
MathExpr::IntrinsicFunc(intrinsic, ref args) => {
let args = args
.clone()
.iter()
.map(|id| self.substitute_new_dims(*id))
.collect();
self.intern_math_expr(MathExpr::IntrinsicFunc(intrinsic, args))
}
_ => id,
}
}
@@ -355,6 +374,9 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> {
stack.push(second);
stack.push(third);
}
MathExpr::IntrinsicFunc(_, ref args) => {
stack.extend(args);
}
}
}
set
@@ -412,5 +434,14 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) {
debug_print_math_expr(third, env);
print!(")");
}
MathExpr::IntrinsicFunc(intrinsic, ref args) => {
print!("{}(", intrinsic.lower_case_name());
debug_print_math_expr(id, env);
for arg in args {
print!(", ");
debug_print_math_expr(*arg, env);
}
print!(")");
}
}
}
Loading