rewrite_math_expressions.rs 5.43 KiB
use std::collections::{HashMap, HashSet};
use std::fmt::{Error, Write};
use hercules_ir::*;
use egg::*;
use crate::*;
define_language! {
enum MathLanguage {
"zero" = Zero,
"one" = One,
ForkDim(i64),
"tid" = ThreadID(Id),
"sum" = SumReduction(Box<[Id]>),
"array" = Comprehension(Box<[Id]>),
"read" = Read(Box<[Id]>),
"+" = Add([Id; 2]),
"*" = Mul([Id; 2]),
"library_gemm" = LibraryGemm([Id; 2]),
Opaque(Symbol),
}
}
fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
vec![
rewrite!("add-zero"; "(+ zero ?a)" => "?a"),
rewrite!("mul-zero"; "(* zero ?a)" => "zero"),
rewrite!("mul-one"; "(* one ?a)" => "?a"),
rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"),
]
}
pub fn rewrite_math_expressions(
editor: &mut FunctionEditor,
device: Device,
typing: &Vec<TypeID>,
fork_join_map: &HashMap<NodeID, NodeID>,
nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
) {
let join_fork_map: HashMap<_, _> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
// Step 1: figure out how many fork-joins each reduce is in. We want to
// rewrite outer reductions before inner ones.
let mut depth: HashMap<NodeID, u32> = HashMap::new();
for (_, inside) in nodes_in_fork_joins {
for id in inside {
if editor.func().nodes[id.idx()].is_reduce() {
*depth.entry(*id).or_default() += 1;
}
}
}
let mut reduces: Vec<(u32, NodeID)> =
depth.into_iter().map(|(id, depth)| (depth, id)).collect();
reduces.sort();
for (_, reduce) in reduces {
let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap();
let fork = join_fork_map[&join];