From d3763dc2de4d0d9d1f12059391374544a49a4e95 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 29 Jan 2025 17:09:03 -0600 Subject: [PATCH 1/8] Skeleton --- hercules_ir/src/einsum.rs | 86 +++++++++++++++++++++++++++++++++++++++ hercules_ir/src/lib.rs | 2 + 2 files changed, 88 insertions(+) create mode 100644 hercules_ir/src/einsum.rs diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs new file mode 100644 index 00000000..7c7d1b58 --- /dev/null +++ b/hercules_ir/src/einsum.rs @@ -0,0 +1,86 @@ +use std::collections::HashMap; + +use crate::*; + +/* + * Math expressions are stored as a simple tree. + */ +#[derive(Debug, Clone)] +pub struct ForkDimension(pub usize, pub DynamicConstantID); + +#[derive(Debug, Clone)] +pub enum MathExpr { + // Zero constant of a specific type. + Zero(TypeID), + // One constant of a specific type. + One(TypeID), + // Opaque value corresponding to a particular node in the original IR. + OpaqueNode(NodeID), + + // Thread ID from the fork corresponding to the reduce being expressed. + // Thread IDs from outside forks are considered OpaqueNodes. + ThreadID(ForkDimension), + + // Sum reduction over a dimension of a fork. + SumReduction(MathID, ForkDimension), + // Comprehend a scalar expression into an array over fork dimensions. + Comprehension(MathID, Box<[ForkDimension]>), + + // Read from an array. + Read(MathID, Box<[MathID]>), + + // Math ops. + Add(MathID, MathID), + Sub(MathID, MathID), + Mul(MathID, MathID), + Div(MathID, MathID), +} + +pub type MathEnv = Vec<MathExpr>; + +define_id_type!(MathID); + +/* + * Top level function to run "einsum" analysis on fork-joins. This is a terrible + * name for this analysis, since it's actually more general than identifying + * einsums, but einsum syntax has a passing resemblance to the idea of this + * analysis and it's what we keep calling it, so we're doomed to this bad name. + * The idea of this analysis is to convert some fork-joins into pure math + * expressions that we can rewrite into intrinsic functions for higher level + * operators like matmul. Specifically, this function returns a map from each + * reduce node to a math expression. + */ +pub fn einsum( + function: &Function, + types: &Vec<Type>, + def_use: &ImmutableDefUseMap, + typing: &Vec<TypeID>, + fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, +) -> (MathEnv, HashMap<NodeID, MathID>) { + let mut env = vec![]; + let mut result = HashMap::new(); + + // Iterate fork-joins bottom-up, since we need to compute the math + // expressions of inner reduces before getting to outer reduces. Since fork- + // joins are strictly nested, we can literally iterate entries of + // `fork_join_nest` in decreasing order of nesting size to accomplish this. + let mut nests: Vec<_> = fork_join_nest.into_iter().collect(); + nests.sort_by(|a, b| b.1.len().cmp(&a.1.len())); + for fork in nests.into_iter().map(|(id, _)| *id) { + // Check that the fork-join has no control flow. + let join = fork_join_map[&fork]; + if function.nodes[join.idx()].try_join().unwrap() != fork { + continue; + } + + // Compute a math expression for each reduce node in the fork-join. + let reduces = def_use + .get_users(join) + .into_iter() + .filter(|id| function.nodes[id.idx()].is_reduce()); + for reduce in reduces {} + } + + (env, result) +} diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index 85dc277f..adfe030c 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -14,6 +14,7 @@ pub mod def_use; pub mod device; pub mod dom; pub mod dot; +pub mod einsum; pub mod fork_join_analysis; pub mod ir; pub mod loops; @@ -30,6 +31,7 @@ pub use crate::def_use::*; pub use crate::device::*; pub use crate::dom::*; pub use crate::dot::*; +pub use crate::einsum::*; pub use crate::fork_join_analysis::*; pub use crate::ir::*; pub use crate::loops::*; -- GitLab From 869dfa14ba1db4b7f506fcf76430eaa2758d2d41 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 29 Jan 2025 21:48:12 -0600 Subject: [PATCH 2/8] Setup of accepted pattersn --- hercules_ir/src/einsum.rs | 118 ++++++++++++++++++++++++++-- hercules_samples/matmul/src/cpu.sch | 3 + 2 files changed, 114 insertions(+), 7 deletions(-) diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index 7c7d1b58..58bb73fb 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -1,14 +1,14 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use crate::*; /* * Math expressions are stored as a simple tree. */ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ForkDimension(pub usize, pub DynamicConstantID); -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum MathExpr { // Zero constant of a specific type. Zero(TypeID), @@ -22,7 +22,7 @@ pub enum MathExpr { ThreadID(ForkDimension), // Sum reduction over a dimension of a fork. - SumReduction(MathID, ForkDimension), + SumReduction(MathID, Box<[ForkDimension]>), // Comprehend a scalar expression into an array over fork dimensions. Comprehension(MathID, Box<[ForkDimension]>), @@ -34,6 +34,7 @@ pub enum MathExpr { Sub(MathID, MathID), Mul(MathID, MathID), Div(MathID, MathID), + Rem(MathID, MathID), } pub type MathEnv = Vec<MathExpr>; @@ -53,12 +54,15 @@ define_id_type!(MathID); pub fn einsum( function: &Function, types: &Vec<Type>, + constants: &Vec<Constant>, def_use: &ImmutableDefUseMap, typing: &Vec<TypeID>, fork_join_map: &HashMap<NodeID, NodeID>, fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, + data_nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> (MathEnv, HashMap<NodeID, MathID>) { let mut env = vec![]; + let mut rev_env = HashMap::new(); let mut result = HashMap::new(); // Iterate fork-joins bottom-up, since we need to compute the math @@ -73,14 +77,114 @@ pub fn einsum( if function.nodes[join.idx()].try_join().unwrap() != fork { continue; } + let Node::Fork { + control: _, + ref factors, + } = function.nodes[fork.idx()] + else { + panic!() + }; - // Compute a math expression for each reduce node in the fork-join. + // Compute a math expression for each reduce node in the fork-join with + // appropriate schedules. let reduces = def_use .get_users(join) .into_iter() - .filter(|id| function.nodes[id.idx()].is_reduce()); - for reduce in reduces {} + .filter_map(|id| function.nodes[id.idx()].try_reduce().map(|v| (*id, v))); + for (reduce, (_, init, reduct)) in reduces { + // The reduce defines an array where each fork dimension corresponds + // to one array dimension. + if function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce) + && let Node::Write { + collect, + data, + ref indices, + } = function.nodes[reduct.idx()] + && collect == reduce + && indices.len() == 1 + && let Some(indices) = indices[0].try_position() + && let Some(dimension_bounds) = indices + .into_iter() + .map(|id| { + function.nodes[id.idx()] + .try_thread_id() + .filter(|(tid_fork, _)| *tid_fork == fork) + .map(|(_, dim)| dim) + }) + .collect::<Option<Vec<usize>>>() + { + let data_expr = compute_einsum_expression(data); + let reduce_expr = MathExpr::Comprehension( + data_expr, + dimension_bounds + .into_iter() + .map(|dim| ForkDimension(dim, factors[dim])) + .collect(), + ); + result.insert( + reduce, + intern_math_expr(reduce_expr, &mut env, &mut rev_env), + ); + } + // The reduce defines a sum reduction over a set of fork dimensions. + else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) + && let Node::Binary { + op: BinaryOperator::Add, + left, + right, + } = function.nodes[reduct.idx()] + && (left == reduce || right == reduce) + { + let data_expr = + compute_einsum_expression(if left == reduce { right } else { left }); + let reduce_expr = MathExpr::SumReduction( + data_expr, + factors + .into_iter() + .enumerate() + .map(|(dim, factor)| ForkDimension(dim, *factor)) + .collect(), + ); + if function.nodes[init.idx()] + .try_constant() + .map(|id| constants[id.idx()].is_zero()) + .unwrap_or(false) + { + // If the initializer is zero, the expression is just the + // sum reduction. + result.insert( + reduce, + intern_math_expr(reduce_expr, &mut env, &mut rev_env), + ); + } else { + // If the initializer is not zero, we need to add it. + let reduce_expr_id = intern_math_expr(reduce_expr, &mut env, &mut rev_env); + let init_expr = MathExpr::OpaqueNode(init); + let init_expr_id = intern_math_expr(init_expr, &mut env, &mut rev_env); + let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id); + result.insert(reduce, intern_math_expr(add_expr, &mut env, &mut rev_env)); + } + } + } } (env, result) } + +fn compute_einsum_expression(id: NodeID) -> MathID { + todo!() +} + +fn intern_math_expr( + expr: MathExpr, + env: &mut MathEnv, + rev_env: &mut HashMap<MathExpr, MathID>, +) -> MathID { + if let Some(id) = rev_env.get(&expr) { + *id + } else { + let id = MathID::new(env.len()); + rev_env.insert(expr, id); + id + } +} diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch index f7891b9b..b58ac566 100644 --- a/hercules_samples/matmul/src/cpu.sch +++ b/hercules_samples/matmul/src/cpu.sch @@ -6,6 +6,9 @@ auto-outline(*); ip-sroa(*); sroa(*); +dce(*); +infer-schedules(*); +xdot[true](*); fork-split(*); unforkify(*); dce(*); -- GitLab From 9748467ed31aa3c5af849f4bec6a0959bb8489a6 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 30 Jan 2025 12:07:56 -0600 Subject: [PATCH 3/8] basic einsum analysis --- hercules_ir/src/einsum.rs | 208 ++++++++++++++++++++++++++++++-------- hercules_ir/src/lib.rs | 1 + 2 files changed, 169 insertions(+), 40 deletions(-) diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index 58bb73fb..ea1f5775 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -1,11 +1,12 @@ use std::collections::{HashMap, HashSet}; +use std::iter::zip; use crate::*; /* * Math expressions are stored as a simple tree. */ -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct ForkDimension(pub usize, pub DynamicConstantID); #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -84,6 +85,16 @@ pub fn einsum( else { panic!() }; + let mut ctx = EinsumContext { + function, + typing, + constants, + data_nodes_in_fork_joins, + fork, + factors, + env: &mut env, + rev_env: &mut rev_env, + }; // Compute a math expression for each reduce node in the fork-join with // appropriate schedules. @@ -112,8 +123,11 @@ pub fn einsum( .map(|(_, dim)| dim) }) .collect::<Option<Vec<usize>>>() + && let Type::Array(_, ref array_bounds) = types[typing[reduce.idx()].idx()] + && zip(array_bounds.into_iter(), dimension_bounds.iter()) + .all(|(array, fork)| *array == factors[*fork]) { - let data_expr = compute_einsum_expression(data); + let data_expr = ctx.compute_math_expr(data); let reduce_expr = MathExpr::Comprehension( data_expr, dimension_bounds @@ -121,10 +135,11 @@ pub fn einsum( .map(|dim| ForkDimension(dim, factors[dim])) .collect(), ); - result.insert( - reduce, - intern_math_expr(reduce_expr, &mut env, &mut rev_env), - ); + // We don't need to consider the initializer, since the writes + // cover the whole array. + let total_id = ctx.intern_math_expr(reduce_expr); + ctx.debug_print_expr(total_id); + result.insert(reduce, total_id); } // The reduce defines a sum reduction over a set of fork dimensions. else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) @@ -135,8 +150,7 @@ pub fn einsum( } = function.nodes[reduct.idx()] && (left == reduce || right == reduce) { - let data_expr = - compute_einsum_expression(if left == reduce { right } else { left }); + let data_expr = ctx.compute_math_expr(if left == reduce { right } else { left }); let reduce_expr = MathExpr::SumReduction( data_expr, factors @@ -145,25 +159,13 @@ pub fn einsum( .map(|(dim, factor)| ForkDimension(dim, *factor)) .collect(), ); - if function.nodes[init.idx()] - .try_constant() - .map(|id| constants[id.idx()].is_zero()) - .unwrap_or(false) - { - // If the initializer is zero, the expression is just the - // sum reduction. - result.insert( - reduce, - intern_math_expr(reduce_expr, &mut env, &mut rev_env), - ); - } else { - // If the initializer is not zero, we need to add it. - let reduce_expr_id = intern_math_expr(reduce_expr, &mut env, &mut rev_env); - let init_expr = MathExpr::OpaqueNode(init); - let init_expr_id = intern_math_expr(init_expr, &mut env, &mut rev_env); - let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id); - result.insert(reduce, intern_math_expr(add_expr, &mut env, &mut rev_env)); - } + // Add the initializer. + let reduce_expr_id = ctx.intern_math_expr(reduce_expr); + let init_expr_id = ctx.compute_math_expr(init); + let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id); + let total_id = ctx.intern_math_expr(add_expr); + ctx.debug_print_expr(total_id); + result.insert(reduce, total_id); } } } @@ -171,20 +173,146 @@ pub fn einsum( (env, result) } -fn compute_einsum_expression(id: NodeID) -> MathID { - todo!() +struct EinsumContext<'a> { + function: &'a Function, + typing: &'a Vec<TypeID>, + constants: &'a Vec<Constant>, + data_nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, + fork: NodeID, + factors: &'a [DynamicConstantID], + env: &'a mut MathEnv, + rev_env: &'a mut HashMap<MathExpr, MathID>, +} + +impl<'a> EinsumContext<'a> { + fn compute_math_expr(&mut self, id: NodeID) -> MathID { + let math_expr = match self.function.nodes[id.idx()] { + Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_zero() => { + MathExpr::Zero(self.typing[id.idx()]) + } + Node::Constant { id: cons_id } if self.constants[cons_id.idx()].is_one() => { + MathExpr::One(self.typing[id.idx()]) + } + Node::ThreadID { control, dimension } if control == self.fork => { + MathExpr::ThreadID(ForkDimension(dimension, self.factors[dimension])) + } + Node::Binary { op, left, right } if representable(op) => { + let left = self.compute_math_expr(left); + let right = self.compute_math_expr(right); + match op { + BinaryOperator::Add => MathExpr::Add(left, right), + BinaryOperator::Sub => MathExpr::Sub(left, right), + BinaryOperator::Mul => MathExpr::Mul(left, right), + BinaryOperator::Div => MathExpr::Div(left, right), + BinaryOperator::Rem => MathExpr::Rem(left, right), + _ => unreachable!(), + } + } + Node::Read { + collect, + ref indices, + } if indices.len() == 1 + && let Some(indices) = indices[0].try_position() => + { + let collect = self.compute_math_expr(collect); + let indices = indices + .into_iter() + .map(|id| self.compute_math_expr(*id)) + .collect(); + MathExpr::Read(collect, indices) + } + _ => MathExpr::OpaqueNode(id), + }; + self.intern_math_expr(math_expr) + } + + fn intern_math_expr(&mut self, expr: MathExpr) -> MathID { + if let Some(id) = self.rev_env.get(&expr) { + *id + } else { + let id = MathID::new(self.env.len()); + self.rev_env.insert(expr, id); + id + } + } + + fn debug_print_expr(&self, id: MathID) { + match self.env[id.idx()] { + MathExpr::Zero(_) => println!("0"), + MathExpr::One(_) => println!("1"), + MathExpr::OpaqueNode(id) => println!("{:?}", id), + MathExpr::ThreadID(dim) => println!("#{}", dim.0), + MathExpr::SumReduction(id, ref dims) => { + println!("Sum ("); + for dim in dims { + println!("#{}/{:?},", dim.0, dim.1); + } + println!(") "); + self.debug_print_expr(id); + } + MathExpr::Comprehension(id, ref dims) => { + println!("["); + for dim in dims { + println!("#{}/{:?},", dim.0, dim.1); + } + println!("] "); + self.debug_print_expr(id); + } + MathExpr::Read(id, ref pos) => { + println!("read("); + self.debug_print_expr(id); + for pos in pos { + println!(", "); + self.debug_print_expr(*pos); + } + println!(")"); + } + MathExpr::Add(left, right) => { + println!("+("); + self.debug_print_expr(left); + println!(", "); + self.debug_print_expr(right); + println!(")"); + } + MathExpr::Sub(left, right) => { + println!("-("); + self.debug_print_expr(left); + println!(", "); + self.debug_print_expr(right); + println!(")"); + } + MathExpr::Mul(left, right) => { + println!("*("); + self.debug_print_expr(left); + println!(", "); + self.debug_print_expr(right); + println!(")"); + } + MathExpr::Div(left, right) => { + println!("/("); + self.debug_print_expr(left); + println!(", "); + self.debug_print_expr(right); + println!(")"); + } + MathExpr::Rem(left, right) => { + println!("%("); + self.debug_print_expr(left); + println!(", "); + self.debug_print_expr(right); + println!(")"); + } + } + } } -fn intern_math_expr( - expr: MathExpr, - env: &mut MathEnv, - rev_env: &mut HashMap<MathExpr, MathID>, -) -> MathID { - if let Some(id) = rev_env.get(&expr) { - *id - } else { - let id = MathID::new(env.len()); - rev_env.insert(expr, id); - id +fn representable(op: BinaryOperator) -> bool { + match op { + BinaryOperator::Add + | BinaryOperator::Sub + | BinaryOperator::Mul + | BinaryOperator::Div + | BinaryOperator::Rem => true, + _ => false, } } diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index adfe030c..185a2888 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -2,6 +2,7 @@ coroutines, coroutine_trait, let_chains, + if_let_guard, stmt_expr_attributes, iter_intersperse )] -- GitLab From 072cbe7a6a583edced53b088333873e6545b44eb Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 30 Jan 2025 15:18:39 -0600 Subject: [PATCH 4/8] Einsum works w/ one reduce --- hercules_ir/src/einsum.rs | 70 ++++++++++++++++++++------------------- juno_scheduler/src/pm.rs | 51 ++++++++++++++++++++++++++-- 2 files changed, 84 insertions(+), 37 deletions(-) diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index ea1f5775..90b2f07d 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -70,14 +70,12 @@ pub fn einsum( // expressions of inner reduces before getting to outer reduces. Since fork- // joins are strictly nested, we can literally iterate entries of // `fork_join_nest` in decreasing order of nesting size to accomplish this. - let mut nests: Vec<_> = fork_join_nest.into_iter().collect(); + let mut nests: Vec<_> = fork_join_nest + .into_iter() + .filter(|(id, _)| function.nodes[id.idx()].is_fork()) + .collect(); nests.sort_by(|a, b| b.1.len().cmp(&a.1.len())); for fork in nests.into_iter().map(|(id, _)| *id) { - // Check that the fork-join has no control flow. - let join = fork_join_map[&fork]; - if function.nodes[join.idx()].try_join().unwrap() != fork { - continue; - } let Node::Fork { control: _, ref factors, @@ -85,6 +83,7 @@ pub fn einsum( else { panic!() }; + let join = fork_join_map[&fork]; let mut ctx = EinsumContext { function, typing, @@ -139,6 +138,7 @@ pub fn einsum( // cover the whole array. let total_id = ctx.intern_math_expr(reduce_expr); ctx.debug_print_expr(total_id); + println!(""); result.insert(reduce, total_id); } // The reduce defines a sum reduction over a set of fork dimensions. @@ -165,6 +165,7 @@ pub fn einsum( let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id); let total_id = ctx.intern_math_expr(add_expr); ctx.debug_print_expr(total_id); + println!(""); result.insert(reduce, total_id); } } @@ -231,6 +232,7 @@ impl<'a> EinsumContext<'a> { *id } else { let id = MathID::new(self.env.len()); + self.env.push(expr.clone()); self.rev_env.insert(expr, id); id } @@ -238,69 +240,69 @@ impl<'a> EinsumContext<'a> { fn debug_print_expr(&self, id: MathID) { match self.env[id.idx()] { - MathExpr::Zero(_) => println!("0"), - MathExpr::One(_) => println!("1"), - MathExpr::OpaqueNode(id) => println!("{:?}", id), - MathExpr::ThreadID(dim) => println!("#{}", dim.0), + MathExpr::Zero(_) => print!("0"), + MathExpr::One(_) => print!("1"), + MathExpr::OpaqueNode(id) => print!("{:?}", id), + MathExpr::ThreadID(dim) => print!("#{}", dim.0), MathExpr::SumReduction(id, ref dims) => { - println!("Sum ("); + print!("Sum ("); for dim in dims { - println!("#{}/{:?},", dim.0, dim.1); + print!("#{}/{:?},", dim.0, dim.1); } - println!(") "); + print!(") "); self.debug_print_expr(id); } MathExpr::Comprehension(id, ref dims) => { - println!("["); + print!("["); for dim in dims { - println!("#{}/{:?},", dim.0, dim.1); + print!("#{}/{:?},", dim.0, dim.1); } - println!("] "); + print!("] "); self.debug_print_expr(id); } MathExpr::Read(id, ref pos) => { - println!("read("); + print!("read("); self.debug_print_expr(id); for pos in pos { - println!(", "); + print!(", "); self.debug_print_expr(*pos); } - println!(")"); + print!(")"); } MathExpr::Add(left, right) => { - println!("+("); + print!("+("); self.debug_print_expr(left); - println!(", "); + print!(", "); self.debug_print_expr(right); - println!(")"); + print!(")"); } MathExpr::Sub(left, right) => { - println!("-("); + print!("-("); self.debug_print_expr(left); - println!(", "); + print!(", "); self.debug_print_expr(right); - println!(")"); + print!(")"); } MathExpr::Mul(left, right) => { - println!("*("); + print!("*("); self.debug_print_expr(left); - println!(", "); + print!(", "); self.debug_print_expr(right); - println!(")"); + print!(")"); } MathExpr::Div(left, right) => { - println!("/("); + print!("/("); self.debug_print_expr(left); - println!(", "); + print!(", "); self.debug_print_expr(right); - println!(")"); + print!(")"); } MathExpr::Rem(left, right) => { - println!("%("); + print!("%("); self.debug_print_expr(left); - println!(", "); + print!(", "); self.debug_print_expr(right); - println!(")"); + print!(")"); } } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 9888f3d2..910d3226 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -181,6 +181,7 @@ pub struct PassManager { pub loops: Option<Vec<LoopTree>>, pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, + pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>, pub collection_objects: Option<CollectionObjects>, pub callgraph: Option<CallGraph>, pub devices: Option<Vec<Device>>, @@ -216,6 +217,7 @@ impl PassManager { loops: None, reduce_cycles: None, data_nodes_in_fork_joins: None, + reduce_einsums: None, collection_objects: None, callgraph: None, devices: None, @@ -392,6 +394,48 @@ impl PassManager { } } + pub fn make_reduce_einsums(&mut self) { + if self.reduce_einsums.is_none() { + self.make_def_uses(); + self.make_typing(); + self.make_fork_join_maps(); + self.make_fork_join_nests(); + self.make_data_nodes_in_fork_joins(); + let def_uses = self.def_uses.as_ref().unwrap().iter(); + let typing = self.typing.as_ref().unwrap().iter(); + let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); + let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); + let data_nodes_in_fork_joins = self.data_nodes_in_fork_joins.as_ref().unwrap().iter(); + self.reduce_einsums = Some( + self.functions + .iter() + .zip(def_uses) + .zip(typing) + .zip(fork_join_maps) + .zip(fork_join_nests) + .zip(data_nodes_in_fork_joins) + .map( + |( + ((((function, def_use), typing), fork_join_map), fork_join_nest), + data_nodes_in_fork_joins, + )| { + einsum( + function, + &self.types.borrow(), + &self.constants.borrow(), + def_use, + typing, + fork_join_map, + fork_join_nest, + data_nodes_in_fork_joins, + ) + }, + ) + .collect(), + ); + } + } + pub fn make_collection_objects(&mut self) { if self.collection_objects.is_none() { self.make_reverse_postorders(); @@ -463,6 +507,7 @@ impl PassManager { self.loops = None; self.reduce_cycles = None; self.data_nodes_in_fork_joins = None; + self.reduce_einsums = None; self.collection_objects = None; self.callgraph = None; self.devices = None; @@ -643,7 +688,7 @@ impl PassManager { } pub fn schedule_codegen( - mut module: Module, + module: Module, schedule: ScheduleStmt, mut stringtab: StringTable, mut env: Env<usize, Value>, @@ -913,8 +958,8 @@ fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo> .labels .iter() .enumerate() - .filter(|(i, ls)| ls.contains(&label)) - .map(|(i, ls)| i) + .filter(|(_, ls)| ls.contains(&label)) + .map(|(i, _)| i) .collect::<Vec<_>>(); for node in nodes { pm.functions[func.idx()].schedules[node].push(sched.clone()); -- GitLab From 84c2ff659747e21301e67394084ab94be39cd775 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 30 Jan 2025 17:14:32 -0600 Subject: [PATCH 5/8] einsum across reduces of inner fork-joins --- hercules_ir/src/einsum.rs | 109 +++++++++++++++++++++++++++++++++++--- juno_scheduler/src/pm.rs | 3 +- 2 files changed, 105 insertions(+), 7 deletions(-) diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index 90b2f07d..bf91dc6d 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -84,6 +84,19 @@ pub fn einsum( panic!() }; let join = fork_join_map[&fork]; + let thread_ids: Vec<_> = def_use + .get_users(fork) + .into_iter() + .filter_map(|id| { + function.nodes[id.idx()] + .try_thread_id() + .map(|(_, dim)| (*id, dim)) + }) + .collect(); + let reduces = def_use + .get_users(join) + .into_iter() + .filter_map(|id| function.nodes[id.idx()].try_reduce().map(|v| (*id, v))); let mut ctx = EinsumContext { function, typing, @@ -91,16 +104,14 @@ pub fn einsum( data_nodes_in_fork_joins, fork, factors, + thread_ids: &thread_ids, + so_far: &mut result, env: &mut env, rev_env: &mut rev_env, }; // Compute a math expression for each reduce node in the fork-join with // appropriate schedules. - let reduces = def_use - .get_users(join) - .into_iter() - .filter_map(|id| function.nodes[id.idx()].try_reduce().map(|v| (*id, v))); for (reduce, (_, init, reduct)) in reduces { // The reduce defines an array where each fork dimension corresponds // to one array dimension. @@ -139,7 +150,7 @@ pub fn einsum( let total_id = ctx.intern_math_expr(reduce_expr); ctx.debug_print_expr(total_id); println!(""); - result.insert(reduce, total_id); + ctx.result_insert(reduce, total_id); } // The reduce defines a sum reduction over a set of fork dimensions. else if function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) @@ -166,7 +177,7 @@ pub fn einsum( let total_id = ctx.intern_math_expr(add_expr); ctx.debug_print_expr(total_id); println!(""); - result.insert(reduce, total_id); + ctx.result_insert(reduce, total_id); } } } @@ -181,6 +192,8 @@ struct EinsumContext<'a> { data_nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, fork: NodeID, factors: &'a [DynamicConstantID], + thread_ids: &'a Vec<(NodeID, usize)>, + so_far: &'a mut HashMap<NodeID, MathID>, env: &'a mut MathEnv, rev_env: &'a mut HashMap<MathExpr, MathID>, } @@ -222,6 +235,17 @@ impl<'a> EinsumContext<'a> { .collect(); MathExpr::Read(collect, indices) } + Node::Reduce { + control: _, + init: _, + reduct: _, + } => { + let reduce = self.so_far[&id]; + // Substitute opaque uses of thread ID nodes in inner expression + // with thread ID math expression, and increment inner-fork + // dimensions (alpha renaming). + return self.substitute_new_dims(reduce); + } _ => MathExpr::OpaqueNode(id), }; self.intern_math_expr(math_expr) @@ -238,6 +262,79 @@ impl<'a> EinsumContext<'a> { } } + fn result_insert(&mut self, node: NodeID, math: MathID) { + self.so_far.insert(node, math); + } + + fn substitute_new_dims(&mut self, id: MathID) -> MathID { + match self.env[id.idx()] { + MathExpr::OpaqueNode(opaque) + if let Some((_, dim)) = self + .thread_ids + .into_iter() + .filter(|(node, _)| *node == opaque) + .next() => + { + self.intern_math_expr(MathExpr::ThreadID(ForkDimension(*dim, self.factors[*dim]))) + } + MathExpr::ThreadID(dim) => self.intern_math_expr(MathExpr::ThreadID(ForkDimension( + dim.0 + self.factors.len(), + dim.1, + ))), + MathExpr::SumReduction(id, ref dims) => { + let dims = dims + .into_iter() + .map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1)) + .collect(); + let id = self.substitute_new_dims(id); + self.intern_math_expr(MathExpr::SumReduction(id, dims)) + } + MathExpr::Comprehension(id, ref dims) => { + let dims = dims + .into_iter() + .map(|dim| ForkDimension(dim.0 + self.factors.len(), dim.1)) + .collect(); + let id = self.substitute_new_dims(id); + self.intern_math_expr(MathExpr::Comprehension(id, dims)) + } + MathExpr::Read(array, ref indices) => { + let indices = indices + .clone() + .iter() + .map(|id| self.substitute_new_dims(*id)) + .collect(); + let array = self.substitute_new_dims(array); + self.intern_math_expr(MathExpr::Read(array, indices)) + } + MathExpr::Add(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Add(left, right)) + } + MathExpr::Sub(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Sub(left, right)) + } + MathExpr::Mul(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Mul(left, right)) + } + MathExpr::Div(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Div(left, right)) + } + MathExpr::Rem(left, right) => { + let left = self.substitute_new_dims(left); + let right = self.substitute_new_dims(right); + self.intern_math_expr(MathExpr::Rem(left, right)) + } + _ => id, + } + } + fn debug_print_expr(&self, id: MathID) { match self.env[id.idx()] { MathExpr::Zero(_) => print!("0"), diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 910d3226..3f43aca8 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1351,6 +1351,7 @@ fn run_pass( } Pass::ForkSplit => { assert!(args.is_empty()); + pm.make_reduce_einsums(); loop { let mut inner_changed = false; pm.make_fork_join_maps(); @@ -1395,7 +1396,7 @@ fn run_pass( let Some(mut func) = func else { continue; }; - // TODO: uses direct return from forkify for now instead of + // TODO: uses direct return from forkify for now instead of // func.modified, see comment on top of `forkify` for why. Fix // this eventually. changed |= forkify(&mut func, control_subgraph, fork_join_map, loop_nest); -- GitLab From fbc200a853cb35fe3265dfda37f682c28dd9c6cb Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 30 Jan 2025 17:17:58 -0600 Subject: [PATCH 6/8] remove debug stuff --- hercules_ir/src/einsum.rs | 4 ---- juno_scheduler/src/pm.rs | 1 - 2 files changed, 5 deletions(-) diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index bf91dc6d..c2521c22 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -148,8 +148,6 @@ pub fn einsum( // We don't need to consider the initializer, since the writes // cover the whole array. let total_id = ctx.intern_math_expr(reduce_expr); - ctx.debug_print_expr(total_id); - println!(""); ctx.result_insert(reduce, total_id); } // The reduce defines a sum reduction over a set of fork dimensions. @@ -175,8 +173,6 @@ pub fn einsum( let init_expr_id = ctx.compute_math_expr(init); let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id); let total_id = ctx.intern_math_expr(add_expr); - ctx.debug_print_expr(total_id); - println!(""); ctx.result_insert(reduce, total_id); } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 3f43aca8..81c46656 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1351,7 +1351,6 @@ fn run_pass( } Pass::ForkSplit => { assert!(args.is_empty()); - pm.make_reduce_einsums(); loop { let mut inner_changed = false; pm.make_fork_join_maps(); -- GitLab From a92c8613365dff8ce7763767088cc9465191777d Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 30 Jan 2025 17:20:44 -0600 Subject: [PATCH 7/8] whoops --- hercules_samples/matmul/src/cpu.sch | 1 - 1 file changed, 1 deletion(-) diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch index b58ac566..c00a3314 100644 --- a/hercules_samples/matmul/src/cpu.sch +++ b/hercules_samples/matmul/src/cpu.sch @@ -8,7 +8,6 @@ ip-sroa(*); sroa(*); dce(*); infer-schedules(*); -xdot[true](*); fork-split(*); unforkify(*); dce(*); -- GitLab From 1acde31d7535c76d70c4cc8025d386d03bdccfa0 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Thu, 30 Jan 2025 17:21:25 -0600 Subject: [PATCH 8/8] clean --- juno_samples/matmul/build.rs | 2 - juno_samples/matmul/src/sched.sch | 76 ------------------------------- 2 files changed, 78 deletions(-) delete mode 100644 juno_samples/matmul/src/sched.sch diff --git a/juno_samples/matmul/build.rs b/juno_samples/matmul/build.rs index 511bf483..926fbc33 100644 --- a/juno_samples/matmul/build.rs +++ b/juno_samples/matmul/build.rs @@ -4,8 +4,6 @@ fn main() { JunoCompiler::new() .file_in_src("matmul.jn") .unwrap() - //.schedule_in_src("sched.sch") - //.unwrap() .build() .unwrap(); } diff --git a/juno_samples/matmul/src/sched.sch b/juno_samples/matmul/src/sched.sch deleted file mode 100644 index 3999f923..00000000 --- a/juno_samples/matmul/src/sched.sch +++ /dev/null @@ -1,76 +0,0 @@ -macro juno-setup!(X) { - gvn(X); - dce(X); - phi-elim(X); -} - -macro default!(X) { - dce(X); - crc(X); - dce(X); - slf(X); - dce(X); - inline(X); - ip-sroa(X); - sroa(X); - phi-elim(X); - dce(X); - ccp(X); - dce(X); - gvn(X); - dce(X); - write-predication(X); - phi-elim(X); - dce(X); - crc(X); - dce(X); - slf(X); - dce(X); - predication(X); - dce(X); - ccp(X); - dce(X); - gvn(X); - dce(X); - lift-dc-math(X); - dce(X); - gvn(X); - dce(X); -} - -macro codegen-prep!(X) { - verify(*); - ip-sroa(*); - sroa(*); - infer-schedules(X); - dce(X); - gcm(X); - dce(X); - phi-elim(X); - float-collections(X); - gcm(X); -} - -juno-setup!(*); -default!(*); -// your stuff here. - -fixpoint stop after 13 { - forkify(*); - fork-guard-elim(*); - fork-coalesce(*); - phi-elim(*); - dce(*); -} - -xdot[true](*); -// serialize(*); - -fork-split(*); -unforkify(*); - -gvn(*); -dce(*); - -auto-outline(*); -codegen-prep!(*); -- GitLab