use std::collections::{HashMap, HashSet}; use std::fmt::{Error, Write}; use hercules_ir::*; //use egg::*; use crate::*; pub fn rewrite_math_expressions( editor: &mut FunctionEditor, device: Device, typing: &Vec<TypeID>, fork_join_map: &HashMap<NodeID, NodeID>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>), ) { panic!("PANIC: The rewrite math expressions pass is currently disabled, as including egg increases compile times and we're not using it currently."); } /* define_language! { enum MathLanguage { "zero" = Zero, "one" = One, ForkDim(i64), "tid" = ThreadID(Id), "sum" = SumReduction(Box<[Id]>), "array" = Comprehension(Box<[Id]>), "read" = Read(Box<[Id]>), "+" = Add([Id; 2]), "*" = Mul([Id; 2]), "library_gemm" = LibraryGemm([Id; 2]), Opaque(Symbol), } } fn make_rules() -> Vec<Rewrite<MathLanguage, ()>> { vec![ rewrite!("add-zero"; "(+ zero ?a)" => "?a"), rewrite!("mul-zero"; "(* zero ?a)" => "zero"), rewrite!("mul-one"; "(* one ?a)" => "?a"), rewrite!("library-gemm"; "(array ?i ?k (sum ?j (* (read ?i ?j ?A) (read ?j ?k ?B))))" => "(library_gemm ?A ?B)"), ] } pub fn rewrite_math_expressions( editor: &mut FunctionEditor, device: Device, typing: &Vec<TypeID>, fork_join_map: &HashMap<NodeID, NodeID>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, reduce_einsums: &(MathEnv, HashMap<NodeID, MathID>), ) { let join_fork_map: HashMap<_, _> = fork_join_map .into_iter() .map(|(fork, join)| (*join, *fork)) .collect(); // Step 1: figure out how many fork-joins each reduce is in. We want to // rewrite outer reductions before inner ones. let mut depth: HashMap<NodeID, u32> = HashMap::new(); for (_, inside) in nodes_in_fork_joins { for id in inside { if editor.func().nodes[id.idx()].is_reduce() { *depth.entry(*id).or_default() += 1; } } } let mut reduces: Vec<(u32, NodeID)> = depth.into_iter().map(|(id, depth)| (depth, id)).collect(); reduces.sort(); for (_, reduce) in reduces { let (join, init, _) = editor.func().nodes[reduce.idx()].try_reduce().unwrap(); let fork = join_fork_map[&join]; // Step 2: convert the reduce to an expression in egg and rewrite it. let mut s = String::new(); egg_print_math_expr(reduce_einsums.1[&reduce], &reduce_einsums.0, &mut s).unwrap(); let expr: RecExpr<MathLanguage> = s.parse().unwrap(); let egraph = Runner::default().with_expr(&expr).run(&make_rules()).egraph; // Step 3: match the smallest expression against patterns for known // library functions. let gemm_pattern: Pattern<MathLanguage> = "(library_gemm ?A ?B)".parse().unwrap(); let mut matches = gemm_pattern.search(&egraph); if matches.len() > 0 && let m = matches.remove(0) && m.substs.len() > 0 { let left_id = get_node_ids_from_subst("?A", &m.substs[0], &egraph); let right_id = get_node_ids_from_subst("?B", &m.substs[0], &egraph); let call = Node::LibraryCall { library_function: LibraryFunction::GEMM, args: vec![init, left_id, right_id].into_boxed_slice(), ty: typing[reduce.idx()], device, }; let success = editor.edit(|mut edit| { let call = edit.add_node(call); edit.replace_all_uses_where(reduce, call, |id| { !nodes_in_fork_joins[&fork].contains(id) }) }); if success { return; } } } } fn get_node_ids_from_subst(var: &str, subst: &Subst, egraph: &EGraph<MathLanguage, ()>) -> NodeID { let id = *subst.get(var.parse::<Var>().unwrap()).unwrap(); let expr = egraph.id_to_expr(id); let MathLanguage::Opaque(sym) = expr.last().unwrap() else { todo!(); }; let sym = sym.as_str(); assert!(sym.chars().nth(0).unwrap() == 'n'); NodeID::new(sym[1..].parse().unwrap()) } fn egg_print_math_expr<W: Write>(id: MathID, env: &MathEnv, w: &mut W) -> Result<(), Error> { match env[id.idx()] { MathExpr::Zero(_) => write!(w, "zero"), MathExpr::One(_) => write!(w, "one"), MathExpr::OpaqueNode(id) => write!(w, "n{}", id.idx()), MathExpr::ThreadID(dim) => write!(w, "{}", dim.0), MathExpr::SumReduction(id, ref dims) => { write!(w, "(sum")?; for dim in dims { write!(w, " {}", dim.0)?; } write!(w, " ")?; egg_print_math_expr(id, env, w)?; write!(w, ")") } MathExpr::Comprehension(id, ref dims) => { write!(w, "(array")?; for dim in dims { write!(w, " {}", dim.0)?; } write!(w, " ")?; egg_print_math_expr(id, env, w)?; write!(w, ")") } MathExpr::Read(id, ref pos) => { write!(w, "(read")?; for pos in pos { write!(w, " ")?; egg_print_math_expr(*pos, env, w)?; } write!(w, " ")?; egg_print_math_expr(id, env, w)?; write!(w, ")") } MathExpr::Binary(op, left, right) => { write!(w, "(")?; match op { BinaryOperator::Add => write!(w, "+ "), BinaryOperator::Mul => write!(w, "* "), _ => Err(Error::default()), }?; egg_print_math_expr(left, env, w)?; write!(w, " ")?; egg_print_math_expr(right, env, w)?; write!(w, ")") } _ => Err(Error::default()), } } */