diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index 25e15d6354d8803a70783b3073c5d0c816e6771d..5dd0fe5d6f7a6fe9743554ed357044b13845dd23 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -31,11 +31,9 @@ pub enum MathExpr { Read(MathID, Box<[MathID]>), // Math ops. - Add(MathID, MathID), - Sub(MathID, MathID), - Mul(MathID, MathID), - Div(MathID, MathID), - Rem(MathID, MathID), + Unary(UnaryOperator, MathID), + Binary(BinaryOperator, MathID, MathID), + Ternary(TernaryOperator, MathID, MathID, MathID), } pub type MathEnv = Vec<MathExpr>; @@ -171,7 +169,7 @@ pub fn einsum( // Add the initializer. let reduce_expr_id = ctx.intern_math_expr(reduce_expr); let init_expr_id = ctx.compute_math_expr(init); - let add_expr = MathExpr::Add(init_expr_id, reduce_expr_id); + let add_expr = MathExpr::Binary(BinaryOperator::Add, init_expr_id, reduce_expr_id); let total_id = ctx.intern_math_expr(add_expr); ctx.result_insert(reduce, total_id); } @@ -206,17 +204,25 @@ impl<'a> EinsumContext<'a> { Node::ThreadID { control, dimension } if control == self.fork => { MathExpr::ThreadID(ForkDimension(dimension, self.factors[dimension])) } - Node::Binary { op, left, right } if representable(op) => { + Node::Unary { op, input } => { + let input = self.compute_math_expr(input); + MathExpr::Unary(op, input) + } + Node::Binary { op, left, right } => { let left = self.compute_math_expr(left); let right = self.compute_math_expr(right); - match op { - BinaryOperator::Add => MathExpr::Add(left, right), - BinaryOperator::Sub => MathExpr::Sub(left, right), - BinaryOperator::Mul => MathExpr::Mul(left, right), - BinaryOperator::Div => MathExpr::Div(left, right), - BinaryOperator::Rem => MathExpr::Rem(left, right), - _ => unreachable!(), - } + MathExpr::Binary(op, left, right) + } + Node::Ternary { + op, + first, + second, + third, + } => { + let first = self.compute_math_expr(first); + let second = self.compute_math_expr(second); + let third = self.compute_math_expr(third); + MathExpr::Ternary(op, first, second, third) } Node::Read { collect, @@ -301,103 +307,24 @@ impl<'a> EinsumContext<'a> { let array = self.substitute_new_dims(array); self.intern_math_expr(MathExpr::Read(array, indices)) } - MathExpr::Add(left, right) => { - let left = self.substitute_new_dims(left); - let right = self.substitute_new_dims(right); - self.intern_math_expr(MathExpr::Add(left, right)) + MathExpr::Unary(op, input) => { + let input = self.substitute_new_dims(input); + self.intern_math_expr(MathExpr::Unary(op, input)) } - MathExpr::Sub(left, right) => { + MathExpr::Binary(op, left, right) => { let left = self.substitute_new_dims(left); let right = self.substitute_new_dims(right); - self.intern_math_expr(MathExpr::Sub(left, right)) + self.intern_math_expr(MathExpr::Binary(op, left, right)) } - MathExpr::Mul(left, right) => { - let left = self.substitute_new_dims(left); - let right = self.substitute_new_dims(right); - self.intern_math_expr(MathExpr::Mul(left, right)) - } - MathExpr::Div(left, right) => { - let left = self.substitute_new_dims(left); - let right = self.substitute_new_dims(right); - self.intern_math_expr(MathExpr::Div(left, right)) - } - MathExpr::Rem(left, right) => { - let left = self.substitute_new_dims(left); - let right = self.substitute_new_dims(right); - self.intern_math_expr(MathExpr::Rem(left, right)) + MathExpr::Ternary(op, first, second, third) => { + let first = self.substitute_new_dims(first); + let second = self.substitute_new_dims(second); + let third = self.substitute_new_dims(third); + self.intern_math_expr(MathExpr::Ternary(op, first, second, third)) } _ => id, } } - - fn debug_print_expr(&self, id: MathID) { - match self.env[id.idx()] { - MathExpr::Zero(_) => print!("0"), - MathExpr::One(_) => print!("1"), - MathExpr::OpaqueNode(id) => print!("{:?}", id), - MathExpr::ThreadID(dim) => print!("#{}", dim.0), - MathExpr::SumReduction(id, ref dims) => { - print!("Sum ("); - for dim in dims { - print!("#{}/{:?},", dim.0, dim.1); - } - print!(") "); - self.debug_print_expr(id); - } - MathExpr::Comprehension(id, ref dims) => { - print!("["); - for dim in dims { - print!("#{}/{:?},", dim.0, dim.1); - } - print!("] "); - self.debug_print_expr(id); - } - MathExpr::Read(id, ref pos) => { - print!("read("); - self.debug_print_expr(id); - for pos in pos { - print!(", "); - self.debug_print_expr(*pos); - } - print!(")"); - } - MathExpr::Add(left, right) => { - print!("+("); - self.debug_print_expr(left); - print!(", "); - self.debug_print_expr(right); - print!(")"); - } - MathExpr::Sub(left, right) => { - print!("-("); - self.debug_print_expr(left); - print!(", "); - self.debug_print_expr(right); - print!(")"); - } - MathExpr::Mul(left, right) => { - print!("*("); - self.debug_print_expr(left); - print!(", "); - self.debug_print_expr(right); - print!(")"); - } - MathExpr::Div(left, right) => { - print!("/("); - self.debug_print_expr(left); - print!(", "); - self.debug_print_expr(right); - print!(")"); - } - MathExpr::Rem(left, right) => { - print!("%("); - self.debug_print_expr(left); - print!(", "); - self.debug_print_expr(right); - print!(")"); - } - } - } } pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> { @@ -416,26 +343,74 @@ pub fn opaque_nodes_in_expr(env: &MathEnv, id: MathID) -> HashSet<NodeID> { stack.push(id); stack.extend(ids); } - MathExpr::Add(left, right) - | MathExpr::Sub(left, right) - | MathExpr::Mul(left, right) - | MathExpr::Div(left, right) - | MathExpr::Rem(left, right) => { + MathExpr::Unary(_, input) => { + stack.push(input); + } + MathExpr::Binary(_, left, right) => { stack.push(left); stack.push(right); } + MathExpr::Ternary(_, first, second, third) => { + stack.push(first); + stack.push(second); + stack.push(third); + } } } set } -fn representable(op: BinaryOperator) -> bool { - match op { - BinaryOperator::Add - | BinaryOperator::Sub - | BinaryOperator::Mul - | BinaryOperator::Div - | BinaryOperator::Rem => true, - _ => false, +pub fn debug_print_math_expr(id: MathID, env: &MathEnv) { + match env[id.idx()] { + MathExpr::Zero(_) => print!("0"), + MathExpr::One(_) => print!("1"), + MathExpr::OpaqueNode(id) => print!("{:?}", id), + MathExpr::ThreadID(dim) => print!("#{}", dim.0), + MathExpr::SumReduction(id, ref dims) => { + print!("Sum ("); + for dim in dims { + print!("#{}/{:?},", dim.0, dim.1); + } + print!(") "); + debug_print_math_expr(id, env); + } + MathExpr::Comprehension(id, ref dims) => { + print!("["); + for dim in dims { + print!("#{}/{:?},", dim.0, dim.1); + } + print!("] "); + debug_print_math_expr(id, env); + } + MathExpr::Read(id, ref pos) => { + print!("read("); + debug_print_math_expr(id, env); + for pos in pos { + print!(", "); + debug_print_math_expr(*pos, env); + } + print!(")"); + } + MathExpr::Unary(op, input) => { + print!("{}(", op.lower_case_name()); + debug_print_math_expr(input, env); + print!(")"); + } + MathExpr::Binary(op, left, right) => { + print!("{}(", op.lower_case_name()); + debug_print_math_expr(left, env); + print!(", "); + debug_print_math_expr(right, env); + print!(")"); + } + MathExpr::Ternary(op, first, second, third) => { + print!("{}(", op.lower_case_name()); + debug_print_math_expr(first, env); + print!(", "); + debug_print_math_expr(second, env); + print!(", "); + debug_print_math_expr(third, env); + print!(")"); + } } } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 5c575ea1dded3ec314f6d5d9aa8deed025e6d532..c2ebd86ca07f7ada3d2cb582a1154644bc3f978d 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1383,6 +1383,14 @@ impl Node { } } + pub fn try_read(&self) -> Option<(NodeID, &[Index])> { + if let Node::Read { collect, indices } = self { + Some((*collect, indices)) + } else { + None + } + } + pub fn try_write(&self) -> Option<(NodeID, NodeID, &[Index])> { if let Node::Write { collect, diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 43a88b7c5a4e617c22a911d3761a39df7b9a7443..c134a972f6d7440c25f763a843379e1724fee6ca 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -720,7 +720,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { Type::UnsignedInteger64 => Constant::UnsignedInteger64(0), Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(0.0)), Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(0.0)), - Type::Control => panic!("Tried to get zero control element"), + Type::Control => panic!("PANIC: Can't create zero constant for the control type."), Type::Product(tys) => { let dummy_elems: Vec<_> = tys.iter().map(|ty| self.add_zero_constant(*ty)).collect(); @@ -732,6 +732,28 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { self.add_constant(constant_to_construct) } + pub fn add_one_constant(&mut self, id: TypeID) -> ConstantID { + let ty = self.get_type(id).clone(); + let constant_to_construct = match ty { + Type::Boolean => Constant::Boolean(true), + Type::Integer8 => Constant::Integer8(1), + Type::Integer16 => Constant::Integer16(1), + Type::Integer32 => Constant::Integer32(1), + Type::Integer64 => Constant::Integer64(1), + Type::UnsignedInteger8 => Constant::UnsignedInteger8(1), + Type::UnsignedInteger16 => Constant::UnsignedInteger16(1), + Type::UnsignedInteger32 => Constant::UnsignedInteger32(1), + Type::UnsignedInteger64 => Constant::UnsignedInteger64(1), + Type::Float32 => Constant::Float32(ordered_float::OrderedFloat(1.0)), + Type::Float64 => Constant::Float64(ordered_float::OrderedFloat(1.0)), + Type::Control => panic!("PANIC: Can't create one constant for the control type."), + Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { + panic!("PANIC: Can't create one constant of a collection type.") + } + }; + self.add_constant(constant_to_construct) + } + pub fn get_constant(&self, id: ConstantID) -> impl Deref<Target = Constant> + '_ { if id.idx() < self.editor.constants.borrow().len() { Either::Left(Ref::map(self.editor.constants.borrow(), |constants| { diff --git a/hercules_opt/src/simplify_cfg.rs b/hercules_opt/src/simplify_cfg.rs index 2e19e6c0b148607e2a4460b0d3cfa9b5311349b1..d579012e3fd0c49709b62c22a24a95b9d491f6df 100644 --- a/hercules_opt/src/simplify_cfg.rs +++ b/hercules_opt/src/simplify_cfg.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use hercules_ir::*; @@ -7,12 +7,16 @@ use crate::*; /* * Top level function to simplify control flow in a Hercules function. */ -pub fn simplify_cfg(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { +pub fn simplify_cfg( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, +) { // Collapse region chains. collapse_region_chains(editor); // Get rid of unnecessary fork-joins. - remove_useless_fork_joins(editor, fork_join_map); + remove_useless_fork_joins(editor, fork_join_map, reduce_cycles); } /* @@ -78,9 +82,52 @@ fn collapse_region_chains(editor: &mut FunctionEditor) { /* * Function to remove unused fork-joins. A fork-join is unused if there are no * reduce users of the join node. In such situations, it is asserted there are - * no thread ID users of the fork as well. + * no thread ID users of the fork as well. Also look for reduces that weren't + * eliminated by DCE because they have a user, but the only users are in the + * corresponding reduce cycle, so the reduce has no user outside the fork-join. */ -fn remove_useless_fork_joins(editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>) { +fn remove_useless_fork_joins( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, +) { + // First, try to get rid of reduces where possible. We can only delete all + // the reduces or none of the reduces in a particular fork-join, since even + // if one reduce may have no users outside the reduction cycle, it may be + // used by a reduce that is used outside the cycle, so it shouldn't be + // deleted. The reduction cycle may contain every reduce in a fork-join. + for (_, join) in fork_join_map { + let nodes = &editor.func().nodes; + let reduces: Vec<_> = editor + .get_users(*join) + .filter(|id| nodes[id.idx()].is_reduce()) + .collect(); + + // If every reduce has users only in the reduce cycle, then all the + // reduces can be deleted, along with every node in the reduce cycles. + if reduces.iter().all(|reduce| { + editor + .get_users(*reduce) + .all(|user| reduce_cycles[reduce].contains(&user)) + }) { + let mut all_the_nodes = HashSet::new(); + for reduce in reduces { + all_the_nodes.insert(reduce); + all_the_nodes.extend(&reduce_cycles[&reduce]); + } + editor.edit(|mut edit| { + for id in all_the_nodes { + edit = edit.delete_node(id)?; + } + Ok(edit) + }); + } + } + + // Second, run DCE to get rid of thread IDs. + dce(editor); + + // Third, get rid of fork-joins. for (fork, join) in fork_join_map { if editor.get_users(*join).len() == 1 { assert_eq!(editor.get_users(*fork).len(), 1); diff --git a/hercules_opt/src/slf.rs b/hercules_opt/src/slf.rs index 92acb2a85fca916803849725edc2b0eabd9b793a..024a442e006d23a9b9bc445a702734b828998177 100644 --- a/hercules_opt/src/slf.rs +++ b/hercules_opt/src/slf.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap, HashSet}; use hercules_ir::*; @@ -158,3 +158,74 @@ pub fn slf(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, typing: } } } + +/* + * Top level functiion to run fork-join level store-to-load forwarding on a + * function. Looks for reduce nodes holding arrays that have nice einsum + * expressions, and replaces reads of that array with the sub-expression of the + * einsum array comprehension. + */ +pub fn array_slf( + editor: &mut FunctionEditor, + fork_join_map: &HashMap<NodeID, NodeID>, + reduce_einsum: &(MathEnv, HashMap<NodeID, MathID>), + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, +) { + let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); + + let (env, einsums) = reduce_einsum; + for (reduce, einsum) in einsums { + // Check that the expression is an array comprehension. + let MathExpr::Comprehension(elem, _) = env[einsum.idx()] else { + continue; + }; + + // If any of the opaque nodes are "in" the fork-join of the reduce, then + // they depend on the thread IDs of the fork-join in a way that's not + // modeled in the einsum expression, and therefore those thread IDs + // can't be substituted with read position indices. We need to skip + // applying SLF to these arrays. + let nodes = &editor.func().nodes; + let join = nodes[reduce.idx()].try_reduce().unwrap().0; + let fork = join_fork_map[&join]; + let nodes_in_this_fork_join = &nodes_in_fork_joins[&fork]; + let opaque_nodes = opaque_nodes_in_expr(env, *einsum); + if opaque_nodes + .into_iter() + .any(|id| nodes_in_this_fork_join.contains(&id)) + { + continue; + } + + // Look for read users of the reduce. They can be replaced with + // substituting the read indices of the array into the einsum expression + // to compute, rather than read, the needed value. + let reads: Vec<(NodeID, Box<[NodeID]>)> = editor + .get_users(*reduce) + .filter_map(|id| { + nodes[id.idx()].try_read().map(|(_, indices)| { + // The indices list should just be a single position index, since + // einsum expressions are only derived for arrays of primitives. + assert_eq!(indices.len(), 1); + let Index::Position(indices) = &indices[0] else { + panic!() + }; + (id, indices.clone()) + }) + }) + .collect(); + for (read, indices) in reads { + editor.edit(|mut edit| { + // Create the expression equivalent to the read. + let id = materialize_simple_einsum_expr(&mut edit, elem, env, &indices); + + // Replace and delete the read. + edit = edit.replace_all_uses(read, id)?; + edit.delete_node(read) + }); + } + } +} diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 0c6c2fac7e96e712f0c41a8ffa56c68fb339319c..022da629cf7267316af1c4a44007f8ed424c5e31 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -1,10 +1,9 @@ use std::collections::{HashMap, HashSet}; -use std::iter::zip; -use hercules_ir::def_use::*; -use hercules_ir::ir::*; use nestify::nest; +use hercules_ir::*; + use crate::*; /* @@ -423,3 +422,58 @@ impl<'a> Iterator for NodeIterator<'a> { None } } + +/* + * Materializes an einsum expression into an IR node tree. Replaces thread IDs + * with provides node IDs. Doesn't materialize reductions or comprehensions. + */ +pub fn materialize_simple_einsum_expr( + edit: &mut FunctionEdit, + id: MathID, + env: &MathEnv, + dim_substs: &[NodeID], +) -> NodeID { + match env[id.idx()] { + MathExpr::Zero(ty) => { + let cons_id = edit.add_zero_constant(ty); + edit.add_node(Node::Constant { id: cons_id }) + } + MathExpr::One(ty) => { + let cons_id = edit.add_one_constant(ty); + edit.add_node(Node::Constant { id: cons_id }) + } + MathExpr::OpaqueNode(id) => id, + MathExpr::ThreadID(dim) => dim_substs[dim.0], + MathExpr::Read(collect, ref indices) => { + let collect = materialize_simple_einsum_expr(edit, collect, env, dim_substs); + let indices = Box::new([Index::Position( + indices + .into_iter() + .map(|idx| materialize_simple_einsum_expr(edit, *idx, env, dim_substs)) + .collect(), + )]); + edit.add_node(Node::Read { collect, indices }) + } + MathExpr::Unary(op, input) => { + let input = materialize_simple_einsum_expr(edit, input, env, dim_substs); + edit.add_node(Node::Unary { op, input }) + } + MathExpr::Binary(op, left, right) => { + let left = materialize_simple_einsum_expr(edit, left, env, dim_substs); + let right = materialize_simple_einsum_expr(edit, right, env, dim_substs); + edit.add_node(Node::Binary { op, left, right }) + } + MathExpr::Ternary(op, first, second, third) => { + let first = materialize_simple_einsum_expr(edit, first, env, dim_substs); + let second = materialize_simple_einsum_expr(edit, second, env, dim_substs); + let third = materialize_simple_einsum_expr(edit, third, env, dim_substs); + edit.add_node(Node::Ternary { + op, + first, + second, + third, + }) + } + _ => panic!(), + } +} diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 5665e1faef4f55f49a26545be73e0a8ba81f89a8..8543d23d7ff9da8f6f0eecc844dc6bb0a77bcc2d 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -5,12 +5,13 @@ gvn(*); phi-elim(*); dce(*); -let out = auto-outline(test1, test2, test3, test4, test5); -cpu(out.test1); -cpu(out.test2); -cpu(out.test3); -cpu(out.test4); -cpu(out.test5); +let auto = auto-outline(test1, test2, test3, test4, test5, test7); +cpu(auto.test1); +cpu(auto.test2); +cpu(auto.test3); +cpu(auto.test4); +cpu(auto.test5); +cpu(auto.test7); ip-sroa(*); sroa(*); @@ -32,11 +33,12 @@ dce(*); fixpoint panic after 20 { infer-schedules(*); } -fork-split(out.test1, out.test2, out.test3, out.test4, out.test5); + +fork-split(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5); gvn(*); phi-elim(*); dce(*); -unforkify(out.test1, out.test2, out.test3, out.test4, out.test5); +unforkify(auto.test1, auto.test2, auto.test3, auto.test4, auto.test5); ccp(*); gvn(*); phi-elim(*); @@ -54,4 +56,12 @@ ccp(*); gvn(*); phi-elim(*); dce(*); + +unforkify(auto.test7@loop2); +array-slf(auto.test7); +ccp(auto.test7); +dce(auto.test7); +simplify-cfg(auto.test7); +dce(auto.test7); + gcm(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 806cb0f122f768783bde9f373b569f4a2c4eebed..3f63f8209742a3c5487366783170ba9f38c4a147 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -81,3 +81,16 @@ fn test6(input: i32) -> i32[1024] { } return arr; } + +#[entry] +fn test7(input: i32) -> i32 { + let arr : i32[32]; + for i = 0 to 32 { + arr[i] = input + i as i32; + } + let sum : i32; + @loop2 for i = 0 to 32 { + sum += arr[i]; + } + return sum; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index f108e2c1e0d150f12a7889d9966faf57d57d33ec..bd1fd1d8465ba27778556987429757b0c47705ec 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -10,12 +10,13 @@ gvn(*); phi-elim(*); dce(*); -let auto = auto-outline(test1, test2, test3, test4, test5); +let auto = auto-outline(test1, test2, test3, test4, test5, test7); gpu(auto.test1); gpu(auto.test2); gpu(auto.test3); gpu(auto.test4); gpu(auto.test5); +gpu(auto.test7); ip-sroa(*); sroa(*); @@ -42,6 +43,13 @@ fork-tile[32, 0, true](test6@loop); let out = fork-split(test6@loop); let out = auto-outline(test6); gpu(out.test6); + +array-slf(auto.test7); +ccp(auto.test7); +dce(auto.test7); +simplify-cfg(auto.test7); +dce(auto.test7); + ip-sroa(*); sroa(*); dce(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 19838fd7741d7bf366c23c36038a66ab636d96c9..e08a8b1b405ca367c7c40242679e7950430db532 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -47,6 +47,11 @@ fn main() { let output = r.run(73).await; let correct = (73i32..73i32+1024i32).collect(); assert(&correct, output); + + let mut r = runner!(test7); + let output = r.run(42).await; + let correct: i32 = (42i32..42i32+32i32).sum(); + assert_eq!(correct, output); }); } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 123111b874c97110350b6e45ab9151f59a3aa1f0..1d5044d86479f68e86c74d864a6f976604c4c6fe 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -93,6 +93,7 @@ impl FromStr for Appliable { fn from_str(s: &str) -> Result<Self, Self::Err> { match s { + "array-slf" => Ok(Appliable::Pass(ir::Pass::ArraySLF)), "auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)), "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)), "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)), diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs index 2cd2c122aaaa7c06622bc30c3e3c662f673af9a1..3f4af107c45d87e6a89b198483b25abae6156a78 100644 --- a/juno_scheduler/src/default.rs +++ b/juno_scheduler/src/default.rs @@ -74,6 +74,10 @@ pub fn default_schedule() -> ScheduleStmt { DCE, SimplifyCFG, DCE, + ArraySLF, + DCE, + SLF, + DCE, ForkSplit, Unforkify, CCP, diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index e8cc2d39ce43bed831acc0bf7ef4231a4ebd0149..88aed007d2106e7f77d2425b326278557175de01 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -2,6 +2,7 @@ use hercules_ir::ir::{Device, Schedule}; #[derive(Debug, Copy, Clone)] pub enum Pass { + ArraySLF, AutoOutline, CCP, CRC, diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 7d9687f644c47a54ad657f322e00c527973b49af..5f4452177f4ffd03c1bd1154314b768ebd76206c 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1350,6 +1350,31 @@ fn run_pass( let mut changed = false; match pass { + Pass::ArraySLF => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_reduce_einsums(); + pm.make_nodes_in_fork_joins(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let reduce_einsums = pm.reduce_einsums.take().unwrap(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + + for (((func, fork_join_map), reduce_einsum), nodes_in_fork_joins) in + build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_einsums.iter()) + .zip(nodes_in_fork_joins.iter()) + { + let Some(mut func) = func else { + continue; + }; + array_slf(&mut func, fork_join_map, reduce_einsum, nodes_in_fork_joins); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } Pass::AutoOutline => { let Some(funcs) = selection_of_functions(pm, selection) else { return Err(SchedulerError::PassError { @@ -1949,16 +1974,19 @@ fn run_pass( Pass::SimplifyCFG => { assert!(args.is_empty()); pm.make_fork_join_maps(); + pm.make_reduce_cycles(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); - for (func, fork_join_map) in build_selection(pm, selection) + for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection) .into_iter() .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) { let Some(mut func) = func else { continue; }; - simplify_cfg(&mut func, fork_join_map); + simplify_cfg(&mut func, fork_join_map, reduce_cycles); changed |= func.modified(); } pm.delete_gravestones();