diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index c584a3fd01da993f95ecc07f8e4a251834053faf..bba6ac42a4a2479f8309b575fe1fb1030f5b5a21 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -85,6 +85,7 @@ pub fn compute_fork_join_nesting( */ pub fn reduce_cycles( function: &Function, + def_use: &ImmutableDefUseMap, fork_join_map: &HashMap<NodeID, NodeID>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> HashMap<NodeID, HashSet<NodeID>> { @@ -101,60 +102,50 @@ pub fn reduce_cycles( let (join, _, reduct) = function.nodes[reduce.idx()].try_reduce().unwrap(); let fork = join_fork_map[&join]; - // DFS the uses of `reduct` until finding the reduce itself. - let mut current_visited = HashSet::new(); - let mut in_cycle = HashSet::new(); - reduce_cycle_dfs_helper( - function, - reduct, - fork, - reduce, - &mut current_visited, - &mut in_cycle, - nodes_in_fork_joins, - ); - result.insert(reduce, in_cycle); - } + // Find nodes in the fork-join that the reduce can reach through uses. + let mut reachable_uses = HashSet::new(); + let mut workset = vec![]; + reachable_uses.insert(reduct); + workset.push(reduct); + while let Some(pop) = workset.pop() { + for u in get_uses(&function.nodes[pop.idx()]).as_ref() { + if !reachable_uses.contains(u) + && nodes_in_fork_joins[&fork].contains(u) + && *u != reduce + { + reachable_uses.insert(*u); + workset.push(*u); + } + } + } - result -} + // Find nodes in the fork-join that the reduce can reach through users. + let mut reachable_users = HashSet::new(); + workset.clear(); + reachable_users.insert(reduce); + workset.push(reduce); + while let Some(pop) = workset.pop() { + for u in def_use.get_users(pop) { + if !reachable_users.contains(u) + && nodes_in_fork_joins[&fork].contains(u) + && *u != reduce + { + reachable_users.insert(*u); + workset.push(*u); + } + } + } -fn reduce_cycle_dfs_helper( - function: &Function, - iter: NodeID, - fork: NodeID, - reduce: NodeID, - current_visited: &mut HashSet<NodeID>, - in_cycle: &mut HashSet<NodeID>, - nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, -) -> bool { - if iter == reduce || in_cycle.contains(&iter) { - return true; + // The reduce cycle is the insersection of nodes reachable through uses + // and users. + let intersection = reachable_uses + .intersection(&reachable_users) + .map(|id| *id) + .collect(); + result.insert(reduce, intersection); } - current_visited.insert(iter); - let mut found_reduce = false; - - // This doesn't short circuit on purpose. - for u in get_uses(&function.nodes[iter.idx()]).as_ref() { - found_reduce |= !current_visited.contains(u) - && !function.nodes[u.idx()].is_control() - && nodes_in_fork_joins[&fork].contains(u) - && reduce_cycle_dfs_helper( - function, - *u, - fork, - reduce, - current_visited, - in_cycle, - nodes_in_fork_joins, - ) - } - if found_reduce { - in_cycle.insert(iter); - } - current_visited.remove(&iter); - found_reduce + result } /* diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index abc2fde0fc3888d033dd2d2dc082aa8a9680f76a..7c416e904ad5d43a5297496b6de40037f5b9b553 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -14,6 +14,8 @@ cpu(auto.test5); cpu(auto.test7); cpu(auto.test8); +let test1_cpu = auto.test1; +rename["test1_cpu"](test1_cpu); ip-sroa(*); sroa(*); @@ -39,7 +41,10 @@ dce(*); fixpoint panic after 20 { infer-schedules(*); } -fork-split(auto.test1); + +let out = fork-split(test1_cpu); +let first_fork = out.test1_cpu.fj1; + fixpoint panic after 20 { unroll(auto.test1); } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index b0a85f593ca8721bd2fa541ec17818a27888d616..912cc91f7f5fd968374a50d304dfb4e5cce1654f 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -134,6 +134,7 @@ impl FromStr for Appliable { "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), "predication" => Ok(Appliable::Pass(ir::Pass::Predication)), "reduce-slf" => Ok(Appliable::Pass(ir::Pass::ReduceSLF)), + "rename" => Ok(Appliable::Pass(ir::Pass::Rename)), "reuse-products" => Ok(Appliable::Pass(ir::Pass::ReuseProducts)), "simplify-cfg" => Ok(Appliable::Pass(ir::Pass::SimplifyCFG)), "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), @@ -435,6 +436,11 @@ fn compile_expr( parser::Expr::Boolean { span: _, val } => { Ok(ExprResult::Expr(ir::ScheduleExp::Boolean { val })) } + parser::Expr::String { span } => { + let string = lexer.span_str(span); + let val = string[1..string.len() - 1].to_string(); + Ok(ExprResult::Expr(ir::ScheduleExp::String { val })) + } parser::Expr::Field { span: _, lhs, diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index b490ef71b595d6a612797dcda007361f3d5dc608..8ad923242a3b73f8b9379f1f9b51070245d2030c 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -29,6 +29,7 @@ pub enum Pass { Predication, Print, ReduceSLF, + Rename, ReuseProducts, SLF, SROA, @@ -48,6 +49,7 @@ impl Pass { Pass::ForkFissionBufferize => num == 2, Pass::ForkInterchange => num == 2, Pass::Print => num == 1, + Pass::Rename => num == 1, Pass::Xdot => num == 0 || num == 1, _ => num == 0, } @@ -60,6 +62,7 @@ impl Pass { Pass::ForkFissionBufferize => "2", Pass::ForkInterchange => "2", Pass::Print => "1", + Pass::Rename => "1", Pass::Xdot => "0 or 1", _ => "0", } @@ -83,6 +86,9 @@ pub enum ScheduleExp { Boolean { val: bool, }, + String { + val: String, + }, Field { collect: Box<ScheduleExp>, field: String, diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l index 2f34f01f89d88e4863e829a9384a48006fc28d53..ca75276e326e79eeb60aa1f3ea8b1169ddc96e0a 100644 --- a/juno_scheduler/src/lang.l +++ b/juno_scheduler/src/lang.l @@ -46,5 +46,6 @@ stop[\t \n\r]+after "stop_after" [a-zA-Z_][a-zA-Z0-9_\-]*! "MACRO" [a-zA-Z_][a-zA-Z0-9_\-]* "ID" [0-9]+ "INT" +\"[a-zA-Z0-9_\-\s\.]*\" "STRING" . "UNMATCHED" diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 9cb728428e08b0f77b77be526e2f3e0aa9c86a37..584bf2a4ef1a476669f10de115a5dda38213a695 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -1,6 +1,6 @@ %start Schedule -%avoid_insert "ID" "INT" +%avoid_insert "ID" "INT" "STRING" %expect-unused Unmatched 'UNMATCHED' %% @@ -47,6 +47,8 @@ Expr -> Expr { Expr::Boolean { span: $span, val: true } } | 'false' { Expr::Boolean { span: $span, val: false } } + | 'STRING' + { Expr::String { span: $span } } | Expr '.' 'ID' { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } | Expr '@' 'ID' @@ -155,6 +157,7 @@ pub enum Expr { Variable { span: Span }, Integer { span: Span }, Boolean { span: Span, val: bool }, + String { span: Span }, Field { span: Span, lhs: Box<Expr>, field: Span }, BlockExpr { span: Span, body: Box<OperationList> }, Record { span: Span, fields: Vec<(Span, Expr)> }, diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index d4ab432aa9bbd75f67e8e3d07d74a43f03f0967f..03ad0111069673bec3f5e081e6d6e1e15eb5d749 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -1,6 +1,6 @@ #![feature(exact_size_is_empty)] +#![feature(let_chains)] -use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::Read; diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index b1d4c93feb54facd63eba6a7825dbab60c405fb5..b26d1720fc192378a214e3951d15e957c317b77f 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -29,6 +29,7 @@ pub enum Value { Selection { selection: Vec<Value> }, Integer { val: usize }, Boolean { val: bool }, + String { val: String }, } #[derive(Debug, Copy, Clone)] @@ -70,6 +71,9 @@ impl Value { Value::Boolean { .. } => Err(SchedulerError::SemanticError( "Expected labels, found boolean".to_string(), )), + Value::String { .. } => Err(SchedulerError::SemanticError( + "Expected labels, found string".to_string(), + )), } } @@ -99,6 +103,9 @@ impl Value { Value::Boolean { .. } => Err(SchedulerError::SemanticError( "Expected functions, found boolean".to_string(), )), + Value::String { .. } => Err(SchedulerError::SemanticError( + "Expected functions, found string".to_string(), + )), } } @@ -130,6 +137,9 @@ impl Value { Value::Boolean { .. } => Err(SchedulerError::SemanticError( "Expected code locations, found boolean".to_string(), )), + Value::String { .. } => Err(SchedulerError::SemanticError( + "Expected code locations, found string".to_string(), + )), } } } @@ -389,8 +399,10 @@ impl PassManager { pub fn make_reduce_cycles(&mut self) { if self.reduce_cycles.is_none() { + self.make_def_uses(); self.make_fork_join_maps(); self.make_nodes_in_fork_joins(); + let def_uses = self.def_uses.as_ref().unwrap().iter(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); let nodes_in_fork_joins = self.nodes_in_fork_joins.as_ref().unwrap().iter(); self.reduce_cycles = Some( @@ -398,9 +410,12 @@ impl PassManager { .iter() .zip(fork_join_maps) .zip(nodes_in_fork_joins) - .map(|((function, fork_join_map), nodes_in_fork_joins)| { - reduce_cycles(function, fork_join_map, nodes_in_fork_joins) - }) + .zip(def_uses) + .map( + |(((function, fork_join_map), nodes_in_fork_joins), def_use)| { + reduce_cycles(function, def_use, fork_join_map, nodes_in_fork_joins) + }, + ) .collect(), ); } @@ -998,6 +1013,7 @@ fn interp_expr( } ScheduleExp::Integer { val } => Ok((Value::Integer { val: *val }, false)), ScheduleExp::Boolean { val } => Ok((Value::Boolean { val: *val }, false)), + ScheduleExp::String { val } => Ok((Value::String { val: val.clone() }, false)), ScheduleExp::Field { collect, field } => { let (lhs, changed) = interp_expr(pm, collect, stringtab, env, functions)?; match lhs { @@ -1005,7 +1021,8 @@ fn interp_expr( | Value::Selection { .. } | Value::Everything { .. } | Value::Integer { .. } - | Value::Boolean { .. } => Err(SchedulerError::UndefinedField(field.clone())), + | Value::Boolean { .. } + | Value::String { .. } => Err(SchedulerError::UndefinedField(field.clone())), Value::JunoFunction { func } => { match pm.labels.borrow().iter().position(|s| s == field) { None => Err(SchedulerError::UndefinedLabel(field.clone())), @@ -1260,6 +1277,7 @@ fn update_value( Value::Everything {} => Some(Value::Everything {}), Value::Integer { val } => Some(Value::Integer { val }), Value::Boolean { val } => Some(Value::Boolean { val }), + Value::String { val } => Some(Value::String { val }), } } @@ -2191,6 +2209,36 @@ fn run_pass( changed |= func.modified(); } } + Pass::Rename => { + assert!(args.len() == 1); + let new_name = match args[0] { + Value::String { ref val } => val.clone(), + _ => { + return Err(SchedulerError::PassError { + pass: "rename".to_string(), + error: "expected string argument".to_string(), + }); + } + }; + if pm.functions.iter().any(|f| f.name == new_name) { + return Err(SchedulerError::PassError { + pass: "rename".to_string(), + error: format!("function with name {} already exists", new_name), + }); + } + + if let Some(funcs) = selection_of_functions(pm, selection) + && funcs.len() == 1 + { + let func = funcs[0]; + pm.functions[func.idx()].name = new_name; + } else { + return Err(SchedulerError::PassError { + pass: "rename".to_string(), + error: "must be applied to the entirety of a single function".to_string(), + }); + }; + } Pass::ReuseProducts => { assert!(args.is_empty()); pm.make_reverse_postorders();