diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index 5dd0fe5d6f7a6fe9743554ed357044b13845dd23..b222e1bc709cc7a95b790b0780faa32224063f77 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -34,6 +34,7 @@ pub enum MathExpr { Unary(UnaryOperator, MathID), Binary(BinaryOperator, MathID, MathID), Ternary(TernaryOperator, MathID, MathID, MathID), + IntrinsicFunc(Intrinsic, Box<[MathID]>), } pub type MathEnv = Vec<MathExpr>; @@ -224,6 +225,16 @@ impl<'a> EinsumContext<'a> { let third = self.compute_math_expr(third); MathExpr::Ternary(op, first, second, third) } + Node::IntrinsicCall { + intrinsic, + ref args, + } => { + let args = args + .into_iter() + .map(|id| self.compute_math_expr(*id)) + .collect(); + MathExpr::IntrinsicFunc(intrinsic, args) + } Node::Read { collect, ref indices, @@ -322,6 +333,14 @@ impl<'a> EinsumContext<'a> { let third = self.substitute_new_dims(third); self.intern_math_expr(MathExpr::Ternary(op, first, second, third)) } + MathExpr::IntrinsicFunc(intrinsic, ref args) => { + let args = args + .clone() + .iter() + .map(|id| self.substitute_new_dims(*id)) + .collect(); + self.intern_math_expr(MathExpr::IntrinsicFunc(intrinsic, args)) + } _ => id, } } @@ -355,6 +374,9 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> { stack.push(second); stack.push(third); } + MathExpr::IntrinsicFunc(_, ref args) => { + stack.extend(args); + } } } set @@ -412,5 +434,14 @@ pub fn debug_print_math_expr(id: MathID, env: &MathEnv) { debug_print_math_expr(third, env); print!(")"); } + MathExpr::IntrinsicFunc(intrinsic, ref args) => { + print!("{}(", intrinsic.lower_case_name()); + debug_print_math_expr(id, env); + for arg in args { + print!(", "); + debug_print_math_expr(*arg, env); + } + print!(")"); + } } } diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index c5e0d934fb2b4673890e3e7f1d209db2e028d513..3f12ad7c5a9e6b36e0a9e43fa35c5117e44b09eb 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -474,6 +474,13 @@ pub fn materialize_simple_einsum_expr( third, }) } + MathExpr::IntrinsicFunc(intrinsic, ref args) => { + let args = args + .into_iter() + .map(|id| materialize_simple_einsum_expr(edit, *id, env, dim_substs)) + .collect(); + edit.add_node(Node::IntrinsicCall { intrinsic, args }) + } _ => panic!(), } } diff --git a/juno_samples/cava/src/cpu.sch b/juno_samples/cava/src/cpu.sch index 3021f6a0fc367e2656293460c3392efeccad909a..1ae1dc132942ff868a25080d297a25f06c8395a2 100644 --- a/juno_samples/cava/src/cpu.sch +++ b/juno_samples/cava/src/cpu.sch @@ -102,6 +102,9 @@ simpl!(fuse4); fork-unroll(fuse4@channel_loop); simpl!(fuse4); //fork-fusion(fuse4@channel_loop); +simpl!(fuse4); +array-slf(fuse4); +simpl!(fuse4); fork-split(fuse4); unforkify(fuse4);