From 1cdace83ebf1b518d58aec9d04d1c1dcb67586b3 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Thu, 30 Jan 2025 17:23:39 -0600 Subject: [PATCH] Einsum analysis --- hercules_ir/src/einsum.rs | 413 ++++++++++++++++++++++++++++ hercules_ir/src/lib.rs | 3 + hercules_samples/matmul/src/cpu.sch | 2 + juno_samples/matmul/build.rs | 2 - juno_samples/matmul/src/sched.sch | 76 ----- juno_scheduler/src/pm.rs | 53 +++- 6 files changed, 467 insertions(+), 82 deletions(-) create mode 100644 hercules_ir/src/einsum.rs delete mode 100644 juno_samples/matmul/src/sched.sch diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs new file mode 100644 index 00000000..c2521c22 --- /dev/null +++ b/hercules_ir/src/einsum.rs @@ -0,0 +1,413 @@ +use std::collections::{HashMap, HashSet}; +use std::iter::zip; + +use crate::*; + +/* + * Math expressions are stored as a simple tree. + */ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ForkDimension(pub usize, pub DynamicConstantID); + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum MathExpr { + // Zero constant of a specific type. + Zero(TypeID), + // One constant of a specific type. + One(TypeID), + // Opaque value corresponding to a particular node in the original IR. + OpaqueNode(NodeID), + + // Thread ID from the fork corresponding to the reduce being expressed. + // Thread IDs from outside forks are considered OpaqueNodes. + ThreadID(ForkDimension), + + // Sum reduction over a dimension of a fork. + SumReduction(MathID, Box<[ForkDimension]>), + // Comprehend a scalar expression into an array over fork dimensions. + Comprehension(MathID, Box<[ForkDimension]>), + + // Read from an array. + Read(MathID, Box<[MathID]>), + + // Math ops. + Add(MathID, MathID), + Sub(MathID, MathID), + Mul(MathID, MathID), + Div(MathID, MathID), + Rem(MathID, MathID), +} + +pub type MathEnv = Vec<MathExpr>; + +define_id_type!(MathID); + +/* + * Top level function to run "einsum" analysis on fork-joins. This is a terrible + * name for this analysis, since it's actually more general than identifying + * einsums, but einsum syntax has a passing resemblance to the idea of this + * analysis and it's what we keep calling it, so we're doomed to this bad name. + * The idea of this analysis is to convert some fork-joins into pure math + * expressions that we can rewrite into intrinsic functions for higher level + * operators like matmul. Specifically, this function returns a map from each + * reduce node to a math expression. + */ +pub fn einsum( + function: &Function, + types: &Vec<Type>, + constants: &Vec<Constant>, + def_use: &ImmutableDefUseMap, + typing: &Vec<TypeID>, + fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, + data_nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) -> (MathEnv, HashMap<NodeID, MathID>) { + let mut env = vec![]; + let mut rev_env = HashMap::new(); + let mut result = HashMap::new(); + + // Iterate fork-joins bottom-up, since we need to compute the math + // expressions of inner reduces before getting to outer reduces. Since fork- + // joins are strictly nested, we can literally iterate entries of + // `fork_join_nest` in decreasing order of nesting size to accomplish this. + let mut nests: Vec<_> = fork_join_nest + .into_iter() + .filter(|(id, _)| function.nodes[id.idx()].is_fork()) + .collect(); + nests.sort_by(|a, b| b.1.len().cmp(&a.1.len())); + for fork in nests.into_iter().map(|(id, _)| *id) { + let Node::Fork { + control: _, + ref factors, + } = function.nodes[fork.idx()] + else { + panic!() + }; + let join = fork_join_map[&fork]; + let thread_ids: Vec<_> = def_use + .get_users(fork) + .into_iter() + .filter_map(|id| { + function.nodes[id.idx()] + .try_thread_id() + .map(|(_, dim)| (*id, dim)) + }) + .collect(); + let reduces = def_use + .get_users(join) + .into_iter() + .filter_map(|id| function.nodes[id.idx()].try_reduce().map(|v| (*id, v))); + let mut ctx = EinsumContext { + function, + typing, + constants, + data_nodes_in_fork_joins, + fork, + factors, + thread_ids: &thread_ids, + so_far: &mut result, + env: &mut env, + rev_env: &mut rev_env, + }; + + // Compute a math expression for each reduce node in the fork-join with + // appropriate schedules. + for (reduce, (_, init, reduct)) in reduces { + // The reduce defines an array where each fork dimension corresponds + // to one array dimension. + if function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce) + && let Node::Write { + collect, + data, + ref indices, + } = function.nodes[reduct.idx()] + && collect == reduce + && indices.len() == 1 + && let Some(indices) = indices[0].try_position() + && let Some(dimension_bounds) = indices + .into_iter() + .map(|id| { + function.nodes[id.idx()] + .try_thread_id() + .filter(|(tid_fork, _)| *tid_fork == fork) + .map(|(_, dim)| dim) + }) + .collect::<Option<Vec<usize>>>() + && let Type::Array(_, ref array_bounds) = types[typing[reduce.idx()].idx()] + && zip(array_bounds.into_iter(), dimension_bounds.iter()) + .all(|(array, fork)| *array == factors[*fork]) + { + let data_expr = ctx.compute_math_expr(data); + let reduce_expr = MathExpr::Comprehension( + data_expr, + dimension_bounds + .into_iter() + .map(|dim| ForkDimension(dim, factors[dim])) + .collect(), + ); + // We don't need to consider the initializer, since the writes + // cover the whole array. + let total_id = ctx.intern_math_expr(reduce_expr); + ctx.result_insert(reduce, total_id); + } + // The reduce defines a sum reduction over a set of fork dimensions. + else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) + && let Node::Binary { + op: BinaryOperator::Add, + left, + right, + } = function.nodes[reduct.idx()] + && (left == reduce || right == reduce) + { + let data_expr = ctx.compute_math_expr(if left == reduce { right } else { left }); + let reduce_expr = MathExpr::SumReduction( + data_expr, + factors + .into_iter() + .enumerate() + .map(|(dim, factor)| ForkDimension(dim, *factor)) + .collect(), + ); + // Add the initializer. + let reduce_expr_id = ctx.intern_math_expr(reduce_expr); + let init_expr_id = ctx.compute_math_expr(init); + let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id); + let total_id = ctx.intern_math_expr(add_expr); + ctx.result_insert(reduce, total_id); + } + } + } + + (env, result) +} + +struct EinsumContext<'a> { + function: &'a Function, + typing: &'a Vec<TypeID>, + constants: &'a Vec<Constant>, + data_nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, + fork: NodeID, + factors: &'a [DynamicConstantID], + thread_ids: &'a Vec<(NodeID, usize)>, + so_far: &'a mut HashMap<NodeID, MathID>, + env: &'a mut MathEnv, + rev_env: &'a mut HashMap<MathExpr, MathID>, +} + +impl<'a> EinsumContext<'a> { + fn compute_math_expr(&mut self, id: NodeID) -> MathID { + let math_expr = match self.function.nodes[id.idx()] { + Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_zero() => { + MathExpr::Zero(self.typing[id.idx()]) + } + Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_one() => { + MathExpr::One(self.typing[id.idx()]) + } + Node::ThreadID { control, dimension } if control == self.fork => { + MathExpr::ThreadID(ForkDimension(dimension, self.factors[dimension])) + } + Node::Binary { op, left, right } if representable(op) => { + let left = self.compute_math_expr(left); + let right = self.compute_math_expr(right); + match op { + BinaryOperator::Add => MathExpr::Add(left, right), + BinaryOperator::Sub => MathExpr::Sub(left, right), + BinaryOperator::Mul => MathExpr::Mul(left, right), + BinaryOperator::Div => MathExpr::Div(left, right), + BinaryOperator::Rem => MathExpr::Rem(left, right), + _ => unreachable!(), + } + } + Node::Read { + collect, + ref indices, + } if indices.len() == 1 + && let Some(indices) = indices[0].try_position() => + { + let collect = self.compute_math_expr(collect); + let indices = indices + .into_iter() + .map(|id| self.compute_math_expr(*id)) + .collect(); + MathExpr::Read(collect, indices) + } + Node::Reduce { + control: _, + init: _, + reduct: _, + } => { + let reduce = self.so_far[&id]; + // Substitute opaque uses of thread ID nodes in inner expression + // with thread ID math expression, and increment inner-fork + // dimensions (alpha renaming). + return self.substitute_new_dims(reduce); + } + _ => MathExpr::OpaqueNode(id), + }; + self.intern_math_expr(math_expr) + } + + fn intern_math_expr(&mut self, expr: MathExpr) -> MathID { + if let Some(id) = self.rev_env.get(&expr) { + *id + } else { + let id = MathID::new(self.env.len()); + self.env.push(expr.clone()); + self.rev_env.insert(expr, id); + id + } + } + + fn result_insert(&mut self, node: NodeID, math: MathID) { + self.so_far.insert(node, math); + } + + fn substitute_new_dims(&mut self, id: MathID) -> MathID { + match self.env[id.idx()] { + MathExpr::OpaqueNode(opaque) + if let Some((_, dim)) = self + .thread_ids + .into_iter() + .filter(|(node, _)| *node == opaque) + .next() => + { + self.intern_math_expr(MathExpr::ThreadID(ForkDimension(*dim, self.factors[*dim]))) + } + MathExpr::ThreadID(dim) => self.intern_math_expr(MathExpr::ThreadID(ForkDimension( + dim.0 + self.factors.len(), + dim.1, + ))), + MathExpr::SumReduction(id, ref dims) => { + let dims = dims + .into_iter() + .map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1)) + .collect(); + let id = self.substitute_new_dims(id); + self.intern_math_expr(MathExpr::SumReduction(id, dims)) + } + MathExpr::Comprehension(id, ref dims) => { + let dims = dims + .into_iter() + .map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1)) + .collect(); + let id = self.substitute_new_dims(id); + self.intern_math_expr(MathExpr::Comprehension(id, dims)) + } + MathExpr::Read(array, ref indices) => { + let indices = indices + .clone() + .iter() + .map(|id| self.substitute_new_dims(*id)) + .collect(); + let array = self.substitute_new_dims(array); + self.intern_math_expr(MathExpr::Read(array, indices)) + } + MathExpr::Add(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Add(left, right)) + } + MathExpr::Sub(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Sub(left, right)) + } + MathExpr::Mul(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Mul(left, right)) + } + MathExpr::Div(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Div(left, right)) + } + MathExpr::Rem(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Rem(left, right)) + } + _ => id, + } + } + + fn debug_print_expr(&self, id: MathID) { + match self.env[id.idx()] { + MathExpr::Zero(_) => print!("0"), + MathExpr::One(_) => print!("1"), + MathExpr::OpaqueNode(id) => print!("{:?}", id), + MathExpr::ThreadID(dim) => print!("#{}", dim.0), + MathExpr::SumReduction(id, ref dims) => { + print!("Sum ("); + for dim in dims { + print!("#{}/{:?},", dim.0, dim.1); + } + print!(") "); + self.debug_print_expr(id); + } + MathExpr::Comprehension(id, ref dims) => { + print!("["); + for dim in dims { + print!("#{}/{:?},", dim.0, dim.1); + } + print!("] "); + self.debug_print_expr(id); + } + MathExpr::Read(id, ref pos) => { + print!("read("); + self.debug_print_expr(id); + for pos in pos { + print!(", "); + self.debug_print_expr(*pos); + } + print!(")"); + } + MathExpr::Add(left, right) => { + print!("+("); + self.debug_print_expr(left); + print!(", "); + self.debug_print_expr(right); + print!(")"); + } + MathExpr::Sub(left, right) => { + print!("-("); + self.debug_print_expr(left); + print!(", "); + self.debug_print_expr(right); + print!(")"); + } + MathExpr::Mul(left, right) => { + print!("*("); + self.debug_print_expr(left); + print!(", "); + self.debug_print_expr(right); + print!(")"); + } + MathExpr::Div(left, right) => { + print!("/("); + self.debug_print_expr(left); + print!(", "); + self.debug_print_expr(right); + print!(")"); + } + MathExpr::Rem(left, right) => { + print!("%("); + self.debug_print_expr(left); + print!(", "); + self.debug_print_expr(right); + print!(")"); + } + } + } +} + +fn representable(op: BinaryOperator) -> bool { + match op { + BinaryOperator::Add + | BinaryOperator::Sub + | BinaryOperator::Mul + | BinaryOperator::Div + | BinaryOperator::Rem => true, + _ => false, + } +} diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index 85dc277f..185a2888 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -2,6 +2,7 @@ coroutines, coroutine_trait, let_chains, + if_let_guard, stmt_expr_attributes, iter_intersperse )] @@ -14,6 +15,7 @@ pub mod def_use; pub mod device; pub mod dom; pub mod dot; +pub mod einsum; pub mod fork_join_analysis; pub mod ir; pub mod loops; @@ -30,6 +32,7 @@ pub use crate::def_use::*; pub use crate::device::*; pub use crate::dom::*; pub use crate::dot::*; +pub use crate::einsum::*; pub use crate::fork_join_analysis::*; pub use crate::ir::*; pub use crate::loops::*; diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch index f7891b9b..c00a3314 100644 --- a/hercules_samples/matmul/src/cpu.sch +++ b/hercules_samples/matmul/src/cpu.sch @@ -6,6 +6,8 @@ auto-outline(*); ip-sroa(*); sroa(*); +dce(*); +infer-schedules(*); fork-split(*); unforkify(*); dce(*); diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 511bf483..926fbc33 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -4,8 +4,6 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() - //.schedule_in_src("sched.sch") - //.unwrap() .build() .unwrap(); } diff --git a/juno_samples/matmul/src/sched.sch b/juno_samples/matmul/src/sched.sch deleted file mode 100644 index 3999f923..00000000 --- a/juno_samples/matmul/src/sched.sch +++ /dev/null @@ -1,76 +0,0 @@ -macro juno-setup!(X) { - gvn(X); - dce(X); - phi-elim(X); -} - -macro default!(X) { - dce(X); - crc(X); - dce(X); - slf(X); - dce(X); - inline(X); - ip-sroa(X); - sroa(X); - phi-elim(X); - dce(X); - ccp(X); - dce(X); - gvn(X); - dce(X); - write-predication(X); - phi-elim(X); - dce(X); - crc(X); - dce(X); - slf(X); - dce(X); - predication(X); - dce(X); - ccp(X); - dce(X); - gvn(X); - dce(X); - lift-dc-math(X); - dce(X); - gvn(X); - dce(X); -} - -macro codegen-prep!(X) { - verify(*); - ip-sroa(*); - sroa(*); - infer-schedules(X); - dce(X); - gcm(X); - dce(X); - phi-elim(X); - float-collections(X); - gcm(X); -} - -juno-setup!(*); -default!(*); -// your stuff here. - -fixpoint stop after 13 { - forkify(*); - fork-guard-elim(*); - fork-coalesce(*); - phi-elim(*); - dce(*); -} - -xdot[true](*); -// serialize(*); - -fork-split(*); -unforkify(*); - -gvn(*); -dce(*); - -auto-outline(*); -codegen-prep!(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 9888f3d2..81c46656 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -181,6 +181,7 @@ pub struct PassManager { pub loops: Option<Vec<LoopTree>>, pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, + pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>, pub collection_objects: Option<CollectionObjects>, pub callgraph: Option<CallGraph>, pub devices: Option<Vec<Device>>, @@ -216,6 +217,7 @@ impl PassManager { loops: None, reduce_cycles: None, data_nodes_in_fork_joins: None, + reduce_einsums: None, collection_objects: None, callgraph: None, devices: None, @@ -392,6 +394,48 @@ impl PassManager { } } + pub fn make_reduce_einsums(&mut self) { + if self.reduce_einsums.is_none() { + self.make_def_uses(); + self.make_typing(); + self.make_fork_join_maps(); + self.make_fork_join_nests(); + self.make_data_nodes_in_fork_joins(); + let def_uses = self.def_uses.as_ref().unwrap().iter(); + let typing = self.typing.as_ref().unwrap().iter(); + let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); + let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); + let data_nodes_in_fork_joins = self.data_nodes_in_fork_joins.as_ref().unwrap().iter(); + self.reduce_einsums = Some( + self.functions + .iter() + .zip(def_uses) + .zip(typing) + .zip(fork_join_maps) + .zip(fork_join_nests) + .zip(data_nodes_in_fork_joins) + .map( + |( + ((((function, def_use), typing), fork_join_map), fork_join_nest), + data_nodes_in_fork_joins, + )| { + einsum( + function, + &self.types.borrow(), + &self.constants.borrow(), + def_use, + typing, + fork_join_map, + fork_join_nest, + data_nodes_in_fork_joins, + ) + }, + ) + .collect(), + ); + } + } + pub fn make_collection_objects(&mut self) { if self.collection_objects.is_none() { self.make_reverse_postorders(); @@ -463,6 +507,7 @@ impl PassManager { self.loops = None; self.reduce_cycles = None; self.data_nodes_in_fork_joins = None; + self.reduce_einsums = None; self.collection_objects = None; self.callgraph = None; self.devices = None; @@ -643,7 +688,7 @@ impl PassManager { } pub fn schedule_codegen( - mut module: Module, + module: Module, schedule: ScheduleStmt, mut stringtab: StringTable, mut env: Env<usize, Value>, @@ -913,8 +958,8 @@ fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo> .labels .iter() .enumerate() - .filter(|(i, ls)| ls.contains(&label)) - .map(|(i, ls)| i) + .filter(|(_, ls)| ls.contains(&label)) + .map(|(i, _)| i) .collect::<Vec<_>>(); for node in nodes { pm.functions[func.idx()].schedules[node].push(sched.clone()); @@ -1350,7 +1395,7 @@ fn run_pass( let Some(mut func) = func else { continue; }; - // TODO: uses direct return from forkify for now instead of + // TODO: uses direct return from forkify for now instead of // func.modified, see comment on top of `forkify` for why. Fix // this eventually. changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest); -- GitLab