Skip to content
Snippets Groups Projects
rewrite_math_expressions.rs 5.43 KiB
use std::collections::{HashMap, HashSet};
use std::fmt::{Error, Write};

use hercules_ir::*;

use egg::*;

use crate::*;

define_language! {
    enum MathLanguage {
        "zero" = Zero,
        "one" = One,

        ForkDim(i64),
        "tid" = ThreadID(Id),
        "sum" = SumReduction(Box<[Id]>),
        "array" = Comprehension(Box<[Id]>),

        "read" = Read(Box<[Id]>),

        "+" = Add([Id; 2]),
        "*" = Mul([Id; 2]),

        "library_gemm" = LibraryGemm([Id; 2]),

        Opaque(Symbol),
    }
}

fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> {
    vec![
        rewrite!("add-zero"; "(+ zero ?a)" => "?a"),
        rewrite!("mul-zero"; "(* zero ?a)" => "zero"),
        rewrite!("mul-one"; "(* one ?a)" => "?a"),
        rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"),
    ]
}

pub fn rewrite_math_expressions(
    editor: &mut FunctionEditor,
    device: Device,
    typing: &Vec<TypeID>,
    fork_join_map: &HashMap<NodeID, NodeID>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>),
) {
    let join_fork_map: HashMap<_, _> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();

    // Step 1: figure out how many fork-joins each reduce is in. We want to
    // rewrite outer reductions before inner ones.
    let mut depth: HashMap<NodeID, u32> = HashMap::new();
    for (_, inside) in nodes_in_fork_joins {
        for id in inside {
            if editor.func().nodes[id.idx()].is_reduce() {
                *depth.entry(*id).or_default() += 1;
            }
        }
    }
    let mut reduces: Vec<(u32, NodeID)> =
        depth.into_iter().map(|(id, depth)| (depth, id)).collect();
    reduces.sort();

    for (_, reduce) in reduces {
        let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap();
        let fork = join_fork_map[&join];