From 0792bc64deb3de181a5b0611c2bf961e618ecff4 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Sat, 22 Feb 2025 16:53:15 -0600 Subject: [PATCH 1/2] Add set ops to scheduler --- Cargo.lock | 1 + hercules_opt/src/editor.rs | 17 +- juno_scheduler/Cargo.toml | 3 +- juno_scheduler/src/compile.rs | 5 + juno_scheduler/src/ir.rs | 7 + juno_scheduler/src/lang.l | 4 + juno_scheduler/src/lang.y | 19 ++ juno_scheduler/src/pm.rs | 626 +++++++++++++++++++--------------- 8 files changed, 388 insertions(+), 294 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fdcbaf84..c872be3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1459,6 +1459,7 @@ dependencies = [ name = "juno_scheduler" version = "0.0.1" dependencies = [ + "bitvec", "cfgrammar", "hercules_cg", "hercules_ir", diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 9cf5af72..17cea325 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -108,9 +108,8 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { } } - // Constructs an editor but only makes the nodes with at least one of the - // set of labels as mutable. - pub fn new_labeled( + // Constructs an editor with a specified mask determining which nodes are mutable + pub fn new_mask( function: &'a mut Function, function_id: FunctionID, constants: &'a RefCell<Vec<Constant>>, @@ -118,7 +117,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { types: &'a RefCell<Vec<Type>>, labels: &'a RefCell<Vec<String>>, def_use: &ImmutableDefUseMap, - with_labels: &HashSet<LabelID>, + mask: BitVec<u8, Lsb0>, ) -> Self { let mut_def_use = (0..function.nodes.len()) .map(|idx| { @@ -130,14 +129,6 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { }) .collect(); - let mut mutable_nodes = bitvec![u8, Lsb0; 0; function.nodes.len()]; - // Add all nodes which have some label which is in the with_labels set - for (idx, labels) in function.labels.iter().enumerate() { - if !labels.is_disjoint(with_labels) { - mutable_nodes.set(idx, true); - } - } - FunctionEditor { function, function_id, @@ -146,7 +137,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { types, labels, mut_def_use, - mutable_nodes, + mutable_nodes: mask, modified: false, } } diff --git a/juno_scheduler/Cargo.toml b/juno_scheduler/Cargo.toml index 03a18c83..174de05b 100644 --- a/juno_scheduler/Cargo.toml +++ b/juno_scheduler/Cargo.toml @@ -24,4 +24,5 @@ hercules_ir = { path = "../hercules_ir" } hercules_opt = { path = "../hercules_opt" } juno_utils = { path = "../juno_utils" } postcard = { version = "*", features = ["alloc"] } -serde = { version = "*", features = ["derive"] } \ No newline at end of file +serde = { version = "*", features = ["derive"] } +bitvec = "*" diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index e9132fd2..86377241 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -473,6 +473,11 @@ fn compile_expr( } Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result })) } + parser::Expr::SetOp { span: _, op, lhs, rhs } => { + let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; + let rhs = compile_exp_as_expr(*rhs, lexer, macrostab, macros)?; + Ok(ExprResult::Expr(ir::ScheduleExp::SetOp { op, lhs: Box::new(lhs), rhs: Box::new(rhs), })) + } } } diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index a888cf74..bbecc6ff 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -1,5 +1,7 @@ use hercules_ir::ir::{Device, Schedule}; +use crate::parser; + #[derive(Debug, Copy, Clone)] pub enum Pass { ArraySLF, @@ -117,6 +119,11 @@ pub enum ScheduleExp { body: Vec<ScheduleStmt>, res: Box<ScheduleExp>, }, + SetOp { + op: parser::SetOp, + lhs: Box<ScheduleExp>, + rhs: Box<ScheduleExp>, + }, // This is used to "box" a selection by evaluating it at one point and then // allowing it to be used as a selector later on Selection { diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l index ca75276e..af154fce 100644 --- a/juno_scheduler/src/lang.l +++ b/juno_scheduler/src/lang.l @@ -39,6 +39,10 @@ false "false" \{ "{" \} "}" +\\ "\\" +\| "|" +& "&" + panic[\t \n\r]+after "panic_after" print[\t \n\r]+iter "print_iter" stop[\t \n\r]+after "stop_after" diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 584bf2a4..3b030e1d 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -3,6 +3,11 @@ %avoid_insert "ID" "INT" "STRING" %expect-unused Unmatched 'UNMATCHED' +%left '\\' +%left '|' +%left '&' +%left '.' '@' + %% Schedule -> OperationList @@ -59,6 +64,12 @@ Expr -> Expr { Expr::BlockExpr { span: $span, body: Box::new($2) } } | '<' Fields '>' { Expr::Record { span: $span, fields: rev($2) } } + | Expr '\\' Expr + { Expr::SetOp { span: $span, op: SetOp::Difference, lhs: Box::new($1), rhs: Box::new($3) } } + | Expr '|' Expr + { Expr::SetOp { span: $span, op: SetOp::Union, lhs: Box::new($1), rhs: Box::new($3) } } + | Expr '&' Expr + { Expr::SetOp { span: $span, op: SetOp::Intersection, lhs: Box::new($1), rhs: Box::new($3) } } ; Args -> Vec<Expr> @@ -151,6 +162,13 @@ pub enum FixpointLimit { PrintIter { span: Span }, } +#[derive(Copy, Clone, Debug)] +pub enum SetOp { + Difference, + Union, + Intersection, +} + pub enum Expr { Function { span: Span, name: Span, args: Vec<Expr>, selection: Selector }, Macro { span: Span, name: Span, args: Vec<Expr>, selection: Selector }, @@ -161,6 +179,7 @@ pub enum Expr { Field { span: Span, lhs: Box<Expr>, field: Span }, BlockExpr { span: Span, body: Box<OperationList> }, Record { span: Span, fields: Vec<(Span, Expr)> }, + SetOp { span: Span, op: SetOp, lhs: Box<Expr>, rhs: Box<Expr> }, } pub enum Selector { diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 84b25811..44777301 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1,9 +1,12 @@ use crate::ir::*; use crate::labels::*; +use crate::parser; use hercules_cg::*; use hercules_ir::*; use hercules_opt::*; +use bitvec::prelude::*; + use serde::Deserialize; use serde::Serialize; use tempfile::TempDir; @@ -17,8 +20,233 @@ use std::fmt; use std::fs::File; use std::io::Write; use std::iter::zip; +use std::ops::Not; use std::process::{Command, Stdio}; +#[derive(Debug, Clone)] +enum FunctionSelectionState { + Nothing(), + Everything(), + Selection(BitVec<u8, Lsb0>), +} + +#[derive(Debug, Clone)] +pub struct FunctionSelection { + num_nodes: usize, + selection: FunctionSelectionState, +} + +impl FunctionSelection { + pub fn new(num_nodes: usize) -> Self { + FunctionSelection { + num_nodes, + selection: FunctionSelectionState::Nothing(), + } + } + + pub fn is_everything(&self) -> bool { + match self.selection { + FunctionSelectionState::Everything() => true, + _ => false, + } + } + + pub fn add_everything(&mut self) { + self.selection = FunctionSelectionState::Everything(); + } + + pub fn add_node(&mut self, node: NodeID) { + match &mut self.selection { + FunctionSelectionState::Nothing() => { + let mut selection = bitvec![u8, Lsb0; 0; self.num_nodes]; + selection.set(node.idx(), true); + self.selection = FunctionSelectionState::Selection(selection); + } + FunctionSelectionState::Everything() => {} + FunctionSelectionState::Selection(selection) => { + selection.set(node.idx(), true); + } + } + } + + pub fn add(&mut self, other: Self) { + match &mut self.selection { + FunctionSelectionState::Nothing() => self.selection = other.selection, + FunctionSelectionState::Everything() => (), + FunctionSelectionState::Selection(selection) => + match other.selection { + FunctionSelectionState::Nothing() => (), + FunctionSelectionState::Everything() => self.selection = other.selection, + FunctionSelectionState::Selection(other) => { + for idx in 0..self.num_nodes { + if other[idx] { + selection.set(idx, true); + } + } + } + } + } + } + + pub fn union(mut self, other: Self) -> Self { + self.add(other); + self + } + + pub fn intersection(self, other: Self) -> Self { + let num_nodes = self.num_nodes; + + match self.selection { + FunctionSelectionState::Nothing() => self, + FunctionSelectionState::Everything() => other, + FunctionSelectionState::Selection(mut selection) => + match other.selection { + FunctionSelectionState::Nothing() => other, + FunctionSelectionState::Everything() => Self { num_nodes, selection: FunctionSelectionState::Selection(selection) }, + FunctionSelectionState::Selection(other) => { + for idx in 0..num_nodes { + if !other[idx] { + selection.set(idx, false); + } + } + Self { num_nodes, selection: FunctionSelectionState::Selection(selection) } + } + } + } + } + + pub fn difference(self, other: Self) -> Self { + let num_nodes = self.num_nodes; + + match self.selection { + FunctionSelectionState::Nothing() => self, + FunctionSelectionState::Everything() => + match other.selection { + FunctionSelectionState::Nothing() => self, + FunctionSelectionState::Everything() => { + Self { num_nodes, selection: FunctionSelectionState::Nothing() } + } + FunctionSelectionState::Selection(other) => { + Self { num_nodes, selection: FunctionSelectionState::Selection(other.not()) } + } + } + FunctionSelectionState::Selection(mut selection) => + match other.selection { + FunctionSelectionState::Nothing() => Self { num_nodes, selection: FunctionSelectionState::Selection(selection) }, + FunctionSelectionState::Everything() => { + Self { num_nodes, selection: FunctionSelectionState::Nothing() } + } + FunctionSelectionState::Selection(other) => { + for idx in 0..num_nodes { + if other[idx] { + selection.set(idx, false); + } + } + Self { num_nodes, selection: FunctionSelectionState::Selection(selection) } + } + } + } + } + + pub fn as_nodes(&self) -> Vec<NodeID> { + match &self.selection { + FunctionSelectionState::Nothing() => vec![], + FunctionSelectionState::Everything() => (0..self.num_nodes).map(NodeID::new).collect(), + FunctionSelectionState::Selection(selection) => + (0..self.num_nodes).map(NodeID::new).filter(|node| selection[node.idx()]).collect(), + } + } +} + +#[derive(Debug, Clone)] +pub struct ModuleSelection { + selection: Vec<FunctionSelection>, +} + +impl ModuleSelection { + pub fn new(pm: &PassManager) -> Self { + ModuleSelection { + selection: (0..pm.functions.len()) + .map(|idx| FunctionSelection::new(pm.functions[idx].nodes.len())) + .collect(), + } + } + + pub fn is_everything(&self) -> bool { + self.selection.iter().all(|func| func.is_everything()) + } + + pub fn add_everything(&mut self) { + self.selection.iter_mut().for_each(|func| func.add_everything()); + } + + pub fn add_function(&mut self, func: FunctionID) { + self.selection[func.idx()].add_everything(); + } + + pub fn add_label(&mut self, func: FunctionID, label: LabelID, pm: &PassManager) { + pm.functions[func.idx()].labels.iter().enumerate() + .for_each(|(node_idx, labels)| if labels.contains(&label) { + self.selection[func.idx()].add_node(NodeID::new(node_idx)); + } else { + }); + } + + pub fn add(&mut self, other: Self) { + self.selection.iter_mut().zip(other.selection.into_iter()) + .for_each(|(this, other)| this.add(other)); + } + + pub fn union(mut self, other: Self) -> Self { + self.add(other); + self + } + + pub fn intersection(self, other: Self) -> Self { + Self { + selection: self.selection.into_iter().zip(other.selection.into_iter()) + .map(|(this, other)| this.intersection(other)) + .collect() + } + } + + pub fn difference(self, other: Self) -> Self { + Self { + selection: self.selection.into_iter().zip(other.selection.into_iter()) + .map(|(this, other)| this.difference(other)) + .collect() + } + } + + pub fn as_funcs(&self) -> Result<Vec<FunctionID>, SchedulerError> { + let mut res = vec![]; + for (func_id, func_selection) in self.selection.iter().enumerate() { + match func_selection.selection { + FunctionSelectionState::Nothing() => (), + FunctionSelectionState::Everything() => res.push(FunctionID::new(func_id)), + _ => return Err(SchedulerError::SemanticError( + "Expected selection of functions, found fine-grain selection".to_string() + )), + } + } + Ok(res) + } + + pub fn as_nodes(&self) -> Vec<(FunctionID, NodeID)> { + let mut res = vec![]; + for (func_id, func_selection) in self.selection.iter().enumerate() { + for node_id in func_selection.as_nodes() { + res.push((FunctionID::new(func_id), node_id)); + } + } + res + } + + pub fn as_func_states(&self) -> Vec<&FunctionSelectionState> { + self.selection.iter().map(|selection| &selection.selection).collect() + } +} + #[derive(Debug, Clone)] pub enum Value { Label { labels: Vec<LabelInfo> }, @@ -27,17 +255,12 @@ pub enum Value { Record { fields: HashMap<String, Value> }, Everything {}, Selection { selection: Vec<Value> }, + SetOp { op: parser::SetOp, lhs: Box<Value>, rhs: Box<Value> }, Integer { val: usize }, Boolean { val: bool }, String { val: String }, } -#[derive(Debug, Copy, Clone)] -enum CodeLocation { - Label(LabelInfo), - Function(FunctionID), -} - impl Value { fn is_everything(&self) -> bool { match self { @@ -46,101 +269,66 @@ impl Value { } } - fn as_labels(&self) -> Result<Vec<LabelInfo>, SchedulerError> { - match self { - Value::Label { labels } => Ok(labels.clone()), - Value::Selection { selection } => { - let mut result = vec![]; - for val in selection { - result.extend(val.as_labels()?); - } - Ok(result) - } - Value::JunoFunction { .. } | Value::HerculesFunction { .. } => Err( - SchedulerError::SemanticError("Expected labels, found function".to_string()), - ), - Value::Record { .. } => Err(SchedulerError::SemanticError( - "Expected labels, found record".to_string(), - )), - Value::Everything {} => Err(SchedulerError::SemanticError( - "Expected labels, found everything".to_string(), - )), - Value::Integer { .. } => Err(SchedulerError::SemanticError( - "Expected labels, found integer".to_string(), - )), - Value::Boolean { .. } => Err(SchedulerError::SemanticError( - "Expected labels, found boolean".to_string(), - )), - Value::String { .. } => Err(SchedulerError::SemanticError( - "Expected labels, found string".to_string(), - )), - } + fn as_selection(&self, pm: &PassManager, funcs: &JunoFunctions) + -> Result<ModuleSelection, SchedulerError> + { + let mut selection = ModuleSelection::new(pm); + self.add_to_selection(pm, funcs, &mut selection)?; + Ok(selection) } - fn as_functions(&self, funcs: &JunoFunctions) -> Result<Vec<FunctionID>, SchedulerError> { + fn add_to_selection(&self, pm: &PassManager, funcs: &JunoFunctions, selection: &mut ModuleSelection) -> Result<(), SchedulerError> { match self { - Value::JunoFunction { func } => Ok(funcs.get_function(*func).clone()), - Value::HerculesFunction { func } => Ok(vec![*func]), - Value::Selection { selection } => { - let mut result = vec![]; - for val in selection { - result.extend(val.as_functions(funcs)?); + Value::Label { labels } => { + for LabelInfo { func, label } in labels { + selection.add_label(*func, *label, pm); } - Ok(result) } - Value::Label { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found label".to_string(), - )), - Value::Record { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found record".to_string(), - )), - Value::Everything {} => Err(SchedulerError::SemanticError( - "Expected functions, found everything".to_string(), - )), - Value::Integer { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found integer".to_string(), - )), - Value::Boolean { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found boolean".to_string(), - )), - Value::String { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found string".to_string(), - )), - } - } - - fn as_locations(&self, funcs: &JunoFunctions) -> Result<Vec<CodeLocation>, SchedulerError> { - match self { - Value::Label { labels } => Ok(labels.iter().map(|l| CodeLocation::Label(*l)).collect()), - Value::JunoFunction { func } => Ok(funcs - .get_function(*func) - .iter() - .map(|f| CodeLocation::Function(*f)) - .collect()), - Value::HerculesFunction { func } => Ok(vec![CodeLocation::Function(*func)]), - Value::Selection { selection } => { - let mut result = vec![]; - for val in selection { - result.extend(val.as_locations(funcs)?); + Value::JunoFunction { func } => { + for func in funcs.get_function(*func) { + selection.add_function(*func); } - Ok(result) } - Value::Record { .. } => Err(SchedulerError::SemanticError( - "Expected code locations, found record".to_string(), - )), - Value::Everything {} => { - panic!("Internal error, check is_everything() before using as_functions()") + Value::HerculesFunction { func } => selection.add_function(*func), + Value::Everything {} => selection.add_everything(), + Value::Selection { selection: values } => { + for value in values { + value.add_to_selection(pm, funcs, selection)?; + } + } + Value::SetOp { op, lhs, rhs } => { + let lhs = lhs.as_selection(pm, funcs)?; + let rhs = rhs.as_selection(pm, funcs)?; + let res = + match op { + parser::SetOp::Union => lhs.union(rhs), + parser::SetOp::Intersection => lhs.intersection(rhs), + parser::SetOp::Difference => lhs.difference(rhs), + }; + selection.add(res); + } + Value::Record { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found record".to_string() + )); + } + Value::Integer { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found integer".to_string() + )); + } + Value::Boolean { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found boolean".to_string() + )); + } + Value::String { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found string".to_string() + )); } - Value::Integer { .. } => Err(SchedulerError::SemanticError( - "Expected code locations, found integer".to_string(), - )), - Value::Boolean { .. } => Err(SchedulerError::SemanticError( - "Expected code locations, found boolean".to_string(), - )), - Value::String { .. } => Err(SchedulerError::SemanticError( - "Expected code locations, found string".to_string(), - )), } + Ok(()) } } @@ -974,7 +1162,7 @@ fn schedule_interpret( for label in selection { let (label, modified) = interp_expr(pm, label, stringtab, env, functions)?; changed |= modified; - add_schedule(pm, sched.clone(), label.as_labels()?); + add_schedule(pm, sched.clone(), label.as_selection(pm, functions)?); } Ok(changed) } @@ -991,7 +1179,7 @@ fn schedule_interpret( for func in selection { let (func, modified) = interp_expr(pm, func, stringtab, env, functions)?; changed |= modified; - add_device(pm, device.clone(), func.as_functions(functions)?); + add_device(pm, device.clone(), func.as_selection(pm, functions)?.as_funcs()?); } Ok(changed) } @@ -1017,6 +1205,11 @@ fn interp_expr( ScheduleExp::Integer { val } => Ok((Value::Integer { val: *val }, false)), ScheduleExp::Boolean { val } => Ok((Value::Boolean { val: *val }, false)), ScheduleExp::String { val } => Ok((Value::String { val: val.clone() }, false)), + ScheduleExp::SetOp { op, lhs, rhs } => { + let (lhs, lhs_mod) = interp_expr(pm, lhs, stringtab, env, functions)?; + let (rhs, rhs_mod) = interp_expr(pm, rhs, stringtab, env, functions)?; + Ok((Value::SetOp { op: *op, lhs: Box::new(lhs), rhs: Box::new(rhs) }, lhs_mod || rhs_mod)) + } ScheduleExp::Field { collect, field } => { let (lhs, changed) = interp_expr(pm, collect, stringtab, env, functions)?; match lhs { @@ -1025,7 +1218,8 @@ fn interp_expr( | Value::Everything { .. } | Value::Integer { .. } | Value::Boolean { .. } - | Value::String { .. } => Err(SchedulerError::UndefinedField(field.clone())), + | Value::String { .. } + | Value::SetOp { .. } => Err(SchedulerError::UndefinedField(field.clone())), Value::JunoFunction { func } => { match pm.labels.borrow().iter().position(|s| s == field) { None => Err(SchedulerError::UndefinedLabel(field.clone())), @@ -1074,26 +1268,17 @@ fn interp_expr( } let selection = match on { - Selector::Everything() => None, + Selector::Everything() => Value::Everything{}, Selector::Selection(selection) => { - let mut locs = vec![]; - let mut everything = false; + let mut vals = vec![]; for loc in selection { let (val, modified) = interp_expr(pm, loc, stringtab, env, functions)?; changed |= modified; - if val.is_everything() { - everything = true; - break; - } - locs.extend(val.as_locations(functions)?); - } - if everything { - None - } else { - Some(locs) + vals.push(val); } + Value::Selection { selection: vals } } - }; + }.as_selection(pm, functions)?; let (res, modified) = run_pass(pm, *pass, arg_vals, selection)?; changed |= modified; @@ -1277,6 +1462,11 @@ fn update_value( func: FunctionID::new(i), }) } + Value::SetOp { op, lhs, rhs } => { + let lhs = update_value(*lhs, func_idx, juno_func_idx)?; + let rhs = update_value(*rhs, func_idx, juno_func_idx)?; + Some(Value::SetOp { op, lhs: Box::new(lhs), rhs: Box::new(rhs) }) + } Value::Everything {} => Some(Value::Everything {}), Value::Integer { val } => Some(Value::Integer { val }), Value::Boolean { val } => Some(Value::Boolean { val }), @@ -1284,18 +1474,9 @@ fn update_value( } } -fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo>) { - for LabelInfo { func, label } in label_ids { - let nodes = pm.functions[func.idx()] - .labels - .iter() - .enumerate() - .filter(|(_, ls)| ls.contains(&label)) - .map(|(i, _)| i) - .collect::<Vec<_>>(); - for node in nodes { - pm.functions[func.idx()].schedules[node].push(sched.clone()); - } +fn add_schedule(pm: &mut PassManager, sched: Schedule, selection: ModuleSelection) { + for (func, node) in selection.as_nodes() { + pm.functions[func.idx()].schedules[node.idx()].push(sched.clone()); } } @@ -1305,31 +1486,7 @@ fn add_device(pm: &mut PassManager, device: Device, funcs: Vec<FunctionID>) { } } -#[derive(Debug, Clone)] -enum FunctionSelection { - Nothing(), - Everything(), - Labels(HashSet<LabelID>), -} - -impl FunctionSelection { - fn add_label(&mut self, label: LabelID) { - match self { - FunctionSelection::Nothing() => { - *self = FunctionSelection::Labels(HashSet::from([label])); - } - FunctionSelection::Everything() => {} - FunctionSelection::Labels(set) => { - set.insert(label); - } - } - } - - fn add_everything(&mut self) { - *self = FunctionSelection::Everything(); - } -} - +// Builds (completely mutable) editors for all functions fn build_editors<'a>(pm: &'a mut PassManager) -> Vec<FunctionEditor<'a>> { pm.make_def_uses(); let def_uses = pm.def_uses.take().unwrap(); @@ -1351,93 +1508,17 @@ fn build_editors<'a>(pm: &'a mut PassManager) -> Vec<FunctionEditor<'a>> { .collect() } -// With a selection, we process it to identify which labels in which functions are to be selected -fn construct_selection(pm: &PassManager, selection: Vec<CodeLocation>) -> Vec<FunctionSelection> { - let mut selected = vec![FunctionSelection::Nothing(); pm.functions.len()]; - for loc in selection { - match loc { - CodeLocation::Label(label) => selected[label.func.idx()].add_label(label.label), - CodeLocation::Function(func) => selected[func.idx()].add_everything(), - } - } - selected -} - -// Given a selection, constructs the set of functions which are selected (and each must be selected -// fully) -fn selection_of_functions( - pm: &PassManager, - selection: Option<Vec<CodeLocation>>, -) -> Option<Vec<FunctionID>> { - if let Some(selection) = selection { - let selection = construct_selection(pm, selection); - - let mut result = vec![]; - - for (idx, selected) in selection.into_iter().enumerate() { - match selected { - FunctionSelection::Nothing() => {} - FunctionSelection::Everything() => result.push(FunctionID::new(idx)), - FunctionSelection::Labels(_) => { - return None; - } - } - } - - Some(result) - } else { - Some( - pm.functions - .iter() - .enumerate() - .map(|(i, _)| FunctionID::new(i)) - .collect(), - ) - } -} - // Given a selection, constructs the set of the nodes selected for a single function, returning the // function's id fn selection_as_set( pm: &PassManager, - selection: Option<Vec<CodeLocation>>, + selection: ModuleSelection, ) -> Option<(BTreeSet<NodeID>, FunctionID)> { - if let Some(selection) = selection { - let selection = construct_selection(pm, selection); - let mut result = None; - - for (idx, (selected, func)) in selection.into_iter().zip(pm.functions.iter()).enumerate() { - match selected { - FunctionSelection::Nothing() => {} - FunctionSelection::Everything() => match result { - Some(_) => { - return None; - } - None => { - result = Some(( - (0..func.nodes.len()).map(|i| NodeID::new(i)).collect(), - FunctionID::new(idx), - )); - } - }, - FunctionSelection::Labels(labels) => match result { - Some(_) => { - return None; - } - None => { - result = Some(( - (0..func.nodes.len()) - .filter(|i| !func.labels[*i].is_disjoint(&labels)) - .map(|i| NodeID::new(i)) - .collect(), - FunctionID::new(idx), - )); - } - }, - } - } - - result + let nodes = selection.as_nodes(); + let mut funcs = nodes.iter().map(|(func, _)| *func).collect::<BTreeSet<_>>(); + if funcs.len() == 1 { + let func = funcs.pop_first().unwrap(); + Some((nodes.into_iter().map(|(_, node)| node).collect(), func)) } else { None } @@ -1445,47 +1526,40 @@ fn selection_as_set( fn build_selection<'a>( pm: &'a mut PassManager, - selection: Option<Vec<CodeLocation>>, + selection: ModuleSelection, create_editors_for_nothing_functions: bool, ) -> Vec<Option<FunctionEditor<'a>>> { // Build def uses, which are needed for the editors we'll construct pm.make_def_uses(); let def_uses = pm.def_uses.take().unwrap(); - if let Some(selection) = selection { - let selected = construct_selection(pm, selection); - - pm.functions - .iter_mut() - .zip(selected.iter()) - .zip(def_uses.iter()) - .enumerate() - .map(|(idx, ((func, selected), def_use))| match selected { - FunctionSelection::Nothing() => { - if create_editors_for_nothing_functions { - Some(FunctionEditor::new_immutable( - func, - FunctionID::new(idx), - &pm.constants, - &pm.dynamic_constants, - &pm.types, - &pm.labels, - def_use, - )) - } else { - None - } - } - FunctionSelection::Everything() => Some(FunctionEditor::new( - func, - FunctionID::new(idx), - &pm.constants, - &pm.dynamic_constants, - &pm.types, - &pm.labels, - def_use, + selection.as_func_states() + .into_iter() + .zip(pm.functions.iter_mut()) + .zip(def_uses.iter()) + .enumerate() + .map(|(idx, ((selection, func), def_use))| match selection { + FunctionSelectionState::Nothing() if create_editors_for_nothing_functions => + Some(FunctionEditor::new_immutable( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + )), + FunctionSelectionState::Nothing() => None, + FunctionSelectionState::Everything() => Some(FunctionEditor::new( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, )), - FunctionSelection::Labels(labels) => Some(FunctionEditor::new_labeled( + FunctionSelectionState::Selection(mask) => Some(FunctionEditor::new_mask( func, FunctionID::new(idx), &pm.constants, @@ -1493,23 +1567,17 @@ fn build_selection<'a>( &pm.types, &pm.labels, def_use, - labels, - )), - }) - .collect() - } else { - build_editors(pm) - .into_iter() - .map(|func| Some(func)) - .collect() - } + mask.clone(), + )), + }) + .collect() } fn run_pass( pm: &mut PassManager, pass: Pass, args: Vec<Value>, - selection: Option<Vec<CodeLocation>>, + selection: ModuleSelection, ) -> Result<(Value, bool), SchedulerError> { let mut result = Value::Record { fields: HashMap::new(), @@ -1571,7 +1639,7 @@ fn run_pass( pm.clear_analyses(); } Pass::AutoOutline => { - let Some(funcs) = selection_of_functions(pm, selection) else { + let Some(funcs) = selection.as_funcs().ok() else { return Err(SchedulerError::PassError { pass: "autoOutline".to_string(), error: "must be applied to whole functions".to_string(), @@ -1939,7 +2007,7 @@ fn run_pass( } Pass::GCM => { assert!(args.is_empty()); - if let Some(_) = selection { + if !selection.is_everything() { return Err(SchedulerError::PassError { pass: "gcm".to_string(), error: "must be applied to the entire module".to_string(), @@ -2089,7 +2157,7 @@ fn run_pass( }; let selection = - selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError { + selection.as_funcs().map_err(|_| SchedulerError::PassError { pass: "xdot".to_string(), error: "expected coarse-grained selection (can't partially xdot a function)" .to_string(), @@ -2264,8 +2332,7 @@ fn run_pass( }); } - if let Some(funcs) = selection_of_functions(pm, selection) - && funcs.len() == 1 + if let Some(funcs) = selection.as_funcs().ok() && funcs.len() == 1 { let func = funcs[0]; pm.functions[func.idx()].name = new_name; @@ -2734,8 +2801,7 @@ fn run_pass( None => true, }; - let selection = - selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError { + let selection = selection.as_funcs().map_err(|_| SchedulerError::PassError { pass: "xdot".to_string(), error: "expected coarse-grained selection (can't partially xdot a function)" .to_string(), -- GitLab From 3da44adaa016772c9be59d2addb18de87d4b687d Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Sat, 22 Feb 2025 17:58:02 -0600 Subject: [PATCH 2/2] Format --- juno_scheduler/src/compile.rs | 13 +- juno_scheduler/src/pm.rs | 303 +++++++++++++++++++++------------- 2 files changed, 203 insertions(+), 113 deletions(-) diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index 86377241..13990ef9 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -473,10 +473,19 @@ fn compile_expr( } Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result })) } - parser::Expr::SetOp { span: _, op, lhs, rhs } => { + parser::Expr::SetOp { + span: _, + op, + lhs, + rhs, + } => { let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; let rhs = compile_exp_as_expr(*rhs, lexer, macrostab, macros)?; - Ok(ExprResult::Expr(ir::ScheduleExp::SetOp { op, lhs: Box::new(lhs), rhs: Box::new(rhs), })) + Ok(ExprResult::Expr(ir::ScheduleExp::SetOp { + op, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + })) } } } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 44777301..d5e280b4 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -73,18 +73,17 @@ impl FunctionSelection { match &mut self.selection { FunctionSelectionState::Nothing() => self.selection = other.selection, FunctionSelectionState::Everything() => (), - FunctionSelectionState::Selection(selection) => - match other.selection { - FunctionSelectionState::Nothing() => (), - FunctionSelectionState::Everything() => self.selection = other.selection, - FunctionSelectionState::Selection(other) => { - for idx in 0..self.num_nodes { - if other[idx] { - selection.set(idx, true); - } + FunctionSelectionState::Selection(selection) => match other.selection { + FunctionSelectionState::Nothing() => (), + FunctionSelectionState::Everything() => self.selection = other.selection, + FunctionSelectionState::Selection(other) => { + for idx in 0..self.num_nodes { + if other[idx] { + selection.set(idx, true); } } } + }, } } @@ -99,19 +98,24 @@ impl FunctionSelection { match self.selection { FunctionSelectionState::Nothing() => self, FunctionSelectionState::Everything() => other, - FunctionSelectionState::Selection(mut selection) => - match other.selection { - FunctionSelectionState::Nothing() => other, - FunctionSelectionState::Everything() => Self { num_nodes, selection: FunctionSelectionState::Selection(selection) }, - FunctionSelectionState::Selection(other) => { - for idx in 0..num_nodes { - if !other[idx] { - selection.set(idx, false); - } + FunctionSelectionState::Selection(mut selection) => match other.selection { + FunctionSelectionState::Nothing() => other, + FunctionSelectionState::Everything() => Self { + num_nodes, + selection: FunctionSelectionState::Selection(selection), + }, + FunctionSelectionState::Selection(other) => { + for idx in 0..num_nodes { + if !other[idx] { + selection.set(idx, false); } - Self { num_nodes, selection: FunctionSelectionState::Selection(selection) } + } + Self { + num_nodes, + selection: FunctionSelectionState::Selection(selection), } } + }, } } @@ -120,31 +124,38 @@ impl FunctionSelection { match self.selection { FunctionSelectionState::Nothing() => self, - FunctionSelectionState::Everything() => - match other.selection { - FunctionSelectionState::Nothing() => self, - FunctionSelectionState::Everything() => { - Self { num_nodes, selection: FunctionSelectionState::Nothing() } - } - FunctionSelectionState::Selection(other) => { - Self { num_nodes, selection: FunctionSelectionState::Selection(other.not()) } - } - } - FunctionSelectionState::Selection(mut selection) => - match other.selection { - FunctionSelectionState::Nothing() => Self { num_nodes, selection: FunctionSelectionState::Selection(selection) }, - FunctionSelectionState::Everything() => { - Self { num_nodes, selection: FunctionSelectionState::Nothing() } - } - FunctionSelectionState::Selection(other) => { - for idx in 0..num_nodes { - if other[idx] { - selection.set(idx, false); - } + FunctionSelectionState::Everything() => match other.selection { + FunctionSelectionState::Nothing() => self, + FunctionSelectionState::Everything() => Self { + num_nodes, + selection: FunctionSelectionState::Nothing(), + }, + FunctionSelectionState::Selection(other) => Self { + num_nodes, + selection: FunctionSelectionState::Selection(other.not()), + }, + }, + FunctionSelectionState::Selection(mut selection) => match other.selection { + FunctionSelectionState::Nothing() => Self { + num_nodes, + selection: FunctionSelectionState::Selection(selection), + }, + FunctionSelectionState::Everything() => Self { + num_nodes, + selection: FunctionSelectionState::Nothing(), + }, + FunctionSelectionState::Selection(other) => { + for idx in 0..num_nodes { + if other[idx] { + selection.set(idx, false); } - Self { num_nodes, selection: FunctionSelectionState::Selection(selection) } + } + Self { + num_nodes, + selection: FunctionSelectionState::Selection(selection), } } + }, } } @@ -152,8 +163,10 @@ impl FunctionSelection { match &self.selection { FunctionSelectionState::Nothing() => vec![], FunctionSelectionState::Everything() => (0..self.num_nodes).map(NodeID::new).collect(), - FunctionSelectionState::Selection(selection) => - (0..self.num_nodes).map(NodeID::new).filter(|node| selection[node.idx()]).collect(), + FunctionSelectionState::Selection(selection) => (0..self.num_nodes) + .map(NodeID::new) + .filter(|node| selection[node.idx()]) + .collect(), } } } @@ -177,7 +190,9 @@ impl ModuleSelection { } pub fn add_everything(&mut self) { - self.selection.iter_mut().for_each(|func| func.add_everything()); + self.selection + .iter_mut() + .for_each(|func| func.add_everything()); } pub fn add_function(&mut self, func: FunctionID) { @@ -185,15 +200,22 @@ impl ModuleSelection { } pub fn add_label(&mut self, func: FunctionID, label: LabelID, pm: &PassManager) { - pm.functions[func.idx()].labels.iter().enumerate() - .for_each(|(node_idx, labels)| if labels.contains(&label) { - self.selection[func.idx()].add_node(NodeID::new(node_idx)); - } else { + pm.functions[func.idx()] + .labels + .iter() + .enumerate() + .for_each(|(node_idx, labels)| { + if labels.contains(&label) { + self.selection[func.idx()].add_node(NodeID::new(node_idx)); + } else { + } }); } pub fn add(&mut self, other: Self) { - self.selection.iter_mut().zip(other.selection.into_iter()) + self.selection + .iter_mut() + .zip(other.selection.into_iter()) .for_each(|(this, other)| this.add(other)); } @@ -204,17 +226,23 @@ impl ModuleSelection { pub fn intersection(self, other: Self) -> Self { Self { - selection: self.selection.into_iter().zip(other.selection.into_iter()) + selection: self + .selection + .into_iter() + .zip(other.selection.into_iter()) .map(|(this, other)| this.intersection(other)) - .collect() + .collect(), } } pub fn difference(self, other: Self) -> Self { Self { - selection: self.selection.into_iter().zip(other.selection.into_iter()) + selection: self + .selection + .into_iter() + .zip(other.selection.into_iter()) .map(|(this, other)| this.difference(other)) - .collect() + .collect(), } } @@ -224,9 +252,11 @@ impl ModuleSelection { match func_selection.selection { FunctionSelectionState::Nothing() => (), FunctionSelectionState::Everything() => res.push(FunctionID::new(func_id)), - _ => return Err(SchedulerError::SemanticError( - "Expected selection of functions, found fine-grain selection".to_string() - )), + _ => { + return Err(SchedulerError::SemanticError( + "Expected selection of functions, found fine-grain selection".to_string(), + )) + } } } Ok(res) @@ -243,22 +273,45 @@ impl ModuleSelection { } pub fn as_func_states(&self) -> Vec<&FunctionSelectionState> { - self.selection.iter().map(|selection| &selection.selection).collect() + self.selection + .iter() + .map(|selection| &selection.selection) + .collect() } } #[derive(Debug, Clone)] pub enum Value { - Label { labels: Vec<LabelInfo> }, - JunoFunction { func: JunoFunctionID }, - HerculesFunction { func: FunctionID }, - Record { fields: HashMap<String, Value> }, + Label { + labels: Vec<LabelInfo>, + }, + JunoFunction { + func: JunoFunctionID, + }, + HerculesFunction { + func: FunctionID, + }, + Record { + fields: HashMap<String, Value>, + }, Everything {}, - Selection { selection: Vec<Value> }, - SetOp { op: parser::SetOp, lhs: Box<Value>, rhs: Box<Value> }, - Integer { val: usize }, - Boolean { val: bool }, - String { val: String }, + Selection { + selection: Vec<Value>, + }, + SetOp { + op: parser::SetOp, + lhs: Box<Value>, + rhs: Box<Value>, + }, + Integer { + val: usize, + }, + Boolean { + val: bool, + }, + String { + val: String, + }, } impl Value { @@ -269,15 +322,22 @@ impl Value { } } - fn as_selection(&self, pm: &PassManager, funcs: &JunoFunctions) - -> Result<ModuleSelection, SchedulerError> - { + fn as_selection( + &self, + pm: &PassManager, + funcs: &JunoFunctions, + ) -> Result<ModuleSelection, SchedulerError> { let mut selection = ModuleSelection::new(pm); self.add_to_selection(pm, funcs, &mut selection)?; Ok(selection) } - fn add_to_selection(&self, pm: &PassManager, funcs: &JunoFunctions, selection: &mut ModuleSelection) -> Result<(), SchedulerError> { + fn add_to_selection( + &self, + pm: &PassManager, + funcs: &JunoFunctions, + selection: &mut ModuleSelection, + ) -> Result<(), SchedulerError> { match self { Value::Label { labels } => { for LabelInfo { func, label } in labels { @@ -299,32 +359,31 @@ impl Value { Value::SetOp { op, lhs, rhs } => { let lhs = lhs.as_selection(pm, funcs)?; let rhs = rhs.as_selection(pm, funcs)?; - let res = - match op { - parser::SetOp::Union => lhs.union(rhs), - parser::SetOp::Intersection => lhs.intersection(rhs), - parser::SetOp::Difference => lhs.difference(rhs), - }; + let res = match op { + parser::SetOp::Union => lhs.union(rhs), + parser::SetOp::Intersection => lhs.intersection(rhs), + parser::SetOp::Difference => lhs.difference(rhs), + }; selection.add(res); } Value::Record { .. } => { return Err(SchedulerError::SemanticError( - "Expected code selection, found record".to_string() + "Expected code selection, found record".to_string(), )); } Value::Integer { .. } => { return Err(SchedulerError::SemanticError( - "Expected code selection, found integer".to_string() + "Expected code selection, found integer".to_string(), )); } Value::Boolean { .. } => { return Err(SchedulerError::SemanticError( - "Expected code selection, found boolean".to_string() + "Expected code selection, found boolean".to_string(), )); } Value::String { .. } => { return Err(SchedulerError::SemanticError( - "Expected code selection, found string".to_string() + "Expected code selection, found string".to_string(), )); } } @@ -1179,7 +1238,11 @@ fn schedule_interpret( for func in selection { let (func, modified) = interp_expr(pm, func, stringtab, env, functions)?; changed |= modified; - add_device(pm, device.clone(), func.as_selection(pm, functions)?.as_funcs()?); + add_device( + pm, + device.clone(), + func.as_selection(pm, functions)?.as_funcs()?, + ); } Ok(changed) } @@ -1208,7 +1271,14 @@ fn interp_expr( ScheduleExp::SetOp { op, lhs, rhs } => { let (lhs, lhs_mod) = interp_expr(pm, lhs, stringtab, env, functions)?; let (rhs, rhs_mod) = interp_expr(pm, rhs, stringtab, env, functions)?; - Ok((Value::SetOp { op: *op, lhs: Box::new(lhs), rhs: Box::new(rhs) }, lhs_mod || rhs_mod)) + Ok(( + Value::SetOp { + op: *op, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + }, + lhs_mod || rhs_mod, + )) } ScheduleExp::Field { collect, field } => { let (lhs, changed) = interp_expr(pm, collect, stringtab, env, functions)?; @@ -1268,7 +1338,7 @@ fn interp_expr( } let selection = match on { - Selector::Everything() => Value::Everything{}, + Selector::Everything() => Value::Everything {}, Selector::Selection(selection) => { let mut vals = vec![]; for loc in selection { @@ -1278,7 +1348,8 @@ fn interp_expr( } Value::Selection { selection: vals } } - }.as_selection(pm, functions)?; + } + .as_selection(pm, functions)?; let (res, modified) = run_pass(pm, *pass, arg_vals, selection)?; changed |= modified; @@ -1465,7 +1536,11 @@ fn update_value( Value::SetOp { op, lhs, rhs } => { let lhs = update_value(*lhs, func_idx, juno_func_idx)?; let rhs = update_value(*rhs, func_idx, juno_func_idx)?; - Some(Value::SetOp { op, lhs: Box::new(lhs), rhs: Box::new(rhs) }) + Some(Value::SetOp { + op, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + }) } Value::Everything {} => Some(Value::Everything {}), Value::Integer { val } => Some(Value::Integer { val }), @@ -1533,33 +1608,15 @@ fn build_selection<'a>( pm.make_def_uses(); let def_uses = pm.def_uses.take().unwrap(); - selection.as_func_states() + selection + .as_func_states() .into_iter() .zip(pm.functions.iter_mut()) .zip(def_uses.iter()) .enumerate() .map(|(idx, ((selection, func), def_use))| match selection { - FunctionSelectionState::Nothing() if create_editors_for_nothing_functions => + FunctionSelectionState::Nothing() if create_editors_for_nothing_functions => { Some(FunctionEditor::new_immutable( - func, - FunctionID::new(idx), - &pm.constants, - &pm.dynamic_constants, - &pm.types, - &pm.labels, - def_use, - )), - FunctionSelectionState::Nothing() => None, - FunctionSelectionState::Everything() => Some(FunctionEditor::new( - func, - FunctionID::new(idx), - &pm.constants, - &pm.dynamic_constants, - &pm.types, - &pm.labels, - def_use, - )), - FunctionSelectionState::Selection(mask) => Some(FunctionEditor::new_mask( func, FunctionID::new(idx), &pm.constants, @@ -1567,7 +1624,27 @@ fn build_selection<'a>( &pm.types, &pm.labels, def_use, - mask.clone(), + )) + } + FunctionSelectionState::Nothing() => None, + FunctionSelectionState::Everything() => Some(FunctionEditor::new( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + )), + FunctionSelectionState::Selection(mask) => Some(FunctionEditor::new_mask( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + mask.clone(), )), }) .collect() @@ -2156,8 +2233,9 @@ fn run_pass( None => false, }; - let selection = - selection.as_funcs().map_err(|_| SchedulerError::PassError { + let selection = selection + .as_funcs() + .map_err(|_| SchedulerError::PassError { pass: "xdot".to_string(), error: "expected coarse-grained selection (can't partially xdot a function)" .to_string(), @@ -2332,7 +2410,8 @@ fn run_pass( }); } - if let Some(funcs) = selection.as_funcs().ok() && funcs.len() == 1 + if let Some(funcs) = selection.as_funcs().ok() + && funcs.len() == 1 { let func = funcs[0]; pm.functions[func.idx()].name = new_name; @@ -2801,7 +2880,9 @@ fn run_pass( None => true, }; - let selection = selection.as_funcs().map_err(|_| SchedulerError::PassError { + let selection = selection + .as_funcs() + .map_err(|_| SchedulerError::PassError { pass: "xdot".to_string(), error: "expected coarse-grained selection (can't partially xdot a function)" .to_string(), -- GitLab