diff --git a/Cargo.lock b/Cargo.lock index fdcbaf8426dd64fca782b34240256cf149657303..c872be3ad7993859e29e1482866ce49865e3cc29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1459,6 +1459,7 @@ dependencies = [ name = "juno_scheduler" version = "0.0.1" dependencies = [ + "bitvec", "cfgrammar", "hercules_cg", "hercules_ir", diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 9cf5af72e2945b35337a7e1da871d9a4c5e06fb5..17cea32500ab86779c01e37e2b03db842f9f3712 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -108,9 +108,8 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { } } - // Constructs an editor but only makes the nodes with at least one of the - // set of labels as mutable. - pub fn new_labeled( + // Constructs an editor with a specified mask determining which nodes are mutable + pub fn new_mask( function: &'a mut Function, function_id: FunctionID, constants: &'a RefCell<Vec<Constant>>, @@ -118,7 +117,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { types: &'a RefCell<Vec<Type>>, labels: &'a RefCell<Vec<String>>, def_use: &ImmutableDefUseMap, - with_labels: &HashSet<LabelID>, + mask: BitVec<u8, Lsb0>, ) -> Self { let mut_def_use = (0..function.nodes.len()) .map(|idx| { @@ -130,14 +129,6 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { }) .collect(); - let mut mutable_nodes = bitvec![u8, Lsb0; 0; function.nodes.len()]; - // Add all nodes which have some label which is in the with_labels set - for (idx, labels) in function.labels.iter().enumerate() { - if !labels.is_disjoint(with_labels) { - mutable_nodes.set(idx, true); - } - } - FunctionEditor { function, function_id, @@ -146,7 +137,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { types, labels, mut_def_use, - mutable_nodes, + mutable_nodes: mask, modified: false, } } diff --git a/juno_scheduler/Cargo.toml b/juno_scheduler/Cargo.toml index 03a18c8350ebc6a86ece1cc3e65681faaddae4e5..174de05b62b57d08c84026ed1bd221546d7b76a5 100644 --- a/juno_scheduler/Cargo.toml +++ b/juno_scheduler/Cargo.toml @@ -24,4 +24,5 @@ hercules_ir = { path = "../hercules_ir" } hercules_opt = { path = "../hercules_opt" } juno_utils = { path = "../juno_utils" } postcard = { version = "*", features = ["alloc"] } -serde = { version = "*", features = ["derive"] } \ No newline at end of file +serde = { version = "*", features = ["derive"] } +bitvec = "*" diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs index e9132fd20650d1c06a9b13f8a8f0815f16f83337..13990ef9df632e2a01914d2690152a3d1462e739 100644 --- a/juno_scheduler/src/compile.rs +++ b/juno_scheduler/src/compile.rs @@ -473,6 +473,20 @@ fn compile_expr( } Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result })) } + parser::Expr::SetOp { + span: _, + op, + lhs, + rhs, + } => { + let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; + let rhs = compile_exp_as_expr(*rhs, lexer, macrostab, macros)?; + Ok(ExprResult::Expr(ir::ScheduleExp::SetOp { + op, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + })) + } } } diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs index a888cf74dc223a8e52466daa1bacdec533809de4..bbecc6ff190af7077f646c525291f5da7655f0fa 100644 --- a/juno_scheduler/src/ir.rs +++ b/juno_scheduler/src/ir.rs @@ -1,5 +1,7 @@ use hercules_ir::ir::{Device, Schedule}; +use crate::parser; + #[derive(Debug, Copy, Clone)] pub enum Pass { ArraySLF, @@ -117,6 +119,11 @@ pub enum ScheduleExp { body: Vec<ScheduleStmt>, res: Box<ScheduleExp>, }, + SetOp { + op: parser::SetOp, + lhs: Box<ScheduleExp>, + rhs: Box<ScheduleExp>, + }, // This is used to "box" a selection by evaluating it at one point and then // allowing it to be used as a selector later on Selection { diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l index ca75276e326e79eeb60aa1f3ea8b1169ddc96e0a..af154fce3d489b7b078607e52888d76d08f8e5fe 100644 --- a/juno_scheduler/src/lang.l +++ b/juno_scheduler/src/lang.l @@ -39,6 +39,10 @@ false "false" \{ "{" \} "}" +\\ "\\" +\| "|" +& "&" + panic[\t \n\r]+after "panic_after" print[\t \n\r]+iter "print_iter" stop[\t \n\r]+after "stop_after" diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index 584bf2a4ef1a476669f10de115a5dda38213a695..3b030e1d42bdb970cdfa67d21c4198dc89edea9e 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -3,6 +3,11 @@ %avoid_insert "ID" "INT" "STRING" %expect-unused Unmatched 'UNMATCHED' +%left '\\' +%left '|' +%left '&' +%left '.' '@' + %% Schedule -> OperationList @@ -59,6 +64,12 @@ Expr -> Expr { Expr::BlockExpr { span: $span, body: Box::new($2) } } | '<' Fields '>' { Expr::Record { span: $span, fields: rev($2) } } + | Expr '\\' Expr + { Expr::SetOp { span: $span, op: SetOp::Difference, lhs: Box::new($1), rhs: Box::new($3) } } + | Expr '|' Expr + { Expr::SetOp { span: $span, op: SetOp::Union, lhs: Box::new($1), rhs: Box::new($3) } } + | Expr '&' Expr + { Expr::SetOp { span: $span, op: SetOp::Intersection, lhs: Box::new($1), rhs: Box::new($3) } } ; Args -> Vec<Expr> @@ -151,6 +162,13 @@ pub enum FixpointLimit { PrintIter { span: Span }, } +#[derive(Copy, Clone, Debug)] +pub enum SetOp { + Difference, + Union, + Intersection, +} + pub enum Expr { Function { span: Span, name: Span, args: Vec<Expr>, selection: Selector }, Macro { span: Span, name: Span, args: Vec<Expr>, selection: Selector }, @@ -161,6 +179,7 @@ pub enum Expr { Field { span: Span, lhs: Box<Expr>, field: Span }, BlockExpr { span: Span, body: Box<OperationList> }, Record { span: Span, fields: Vec<(Span, Expr)> }, + SetOp { span: Span, op: SetOp, lhs: Box<Expr>, rhs: Box<Expr> }, } pub enum Selector { diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 84b25811f7f387fe5822ea514a732842a0e77b02..d5e280b442de3a6571bdfa960cef652a4dac117d 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -1,9 +1,12 @@ use crate::ir::*; use crate::labels::*; +use crate::parser; use hercules_cg::*; use hercules_ir::*; use hercules_opt::*; +use bitvec::prelude::*; + use serde::Deserialize; use serde::Serialize; use tempfile::TempDir; @@ -17,130 +20,374 @@ use std::fmt; use std::fs::File; use std::io::Write; use std::iter::zip; +use std::ops::Not; use std::process::{Command, Stdio}; #[derive(Debug, Clone)] -pub enum Value { - Label { labels: Vec<LabelInfo> }, - JunoFunction { func: JunoFunctionID }, - HerculesFunction { func: FunctionID }, - Record { fields: HashMap<String, Value> }, - Everything {}, - Selection { selection: Vec<Value> }, - Integer { val: usize }, - Boolean { val: bool }, - String { val: String }, +enum FunctionSelectionState { + Nothing(), + Everything(), + Selection(BitVec<u8, Lsb0>), } -#[derive(Debug, Copy, Clone)] -enum CodeLocation { - Label(LabelInfo), - Function(FunctionID), +#[derive(Debug, Clone)] +pub struct FunctionSelection { + num_nodes: usize, + selection: FunctionSelectionState, } -impl Value { - fn is_everything(&self) -> bool { - match self { - Value::Everything {} => true, +impl FunctionSelection { + pub fn new(num_nodes: usize) -> Self { + FunctionSelection { + num_nodes, + selection: FunctionSelectionState::Nothing(), + } + } + + pub fn is_everything(&self) -> bool { + match self.selection { + FunctionSelectionState::Everything() => true, _ => false, } } - fn as_labels(&self) -> Result<Vec<LabelInfo>, SchedulerError> { - match self { - Value::Label { labels } => Ok(labels.clone()), - Value::Selection { selection } => { - let mut result = vec![]; - for val in selection { - result.extend(val.as_labels()?); - } - Ok(result) + pub fn add_everything(&mut self) { + self.selection = FunctionSelectionState::Everything(); + } + + pub fn add_node(&mut self, node: NodeID) { + match &mut self.selection { + FunctionSelectionState::Nothing() => { + let mut selection = bitvec![u8, Lsb0; 0; self.num_nodes]; + selection.set(node.idx(), true); + self.selection = FunctionSelectionState::Selection(selection); + } + FunctionSelectionState::Everything() => {} + FunctionSelectionState::Selection(selection) => { + selection.set(node.idx(), true); } - Value::JunoFunction { .. } | Value::HerculesFunction { .. } => Err( - SchedulerError::SemanticError("Expected labels, found function".to_string()), - ), - Value::Record { .. } => Err(SchedulerError::SemanticError( - "Expected labels, found record".to_string(), - )), - Value::Everything {} => Err(SchedulerError::SemanticError( - "Expected labels, found everything".to_string(), - )), - Value::Integer { .. } => Err(SchedulerError::SemanticError( - "Expected labels, found integer".to_string(), - )), - Value::Boolean { .. } => Err(SchedulerError::SemanticError( - "Expected labels, found boolean".to_string(), - )), - Value::String { .. } => Err(SchedulerError::SemanticError( - "Expected labels, found string".to_string(), - )), } } - fn as_functions(&self, funcs: &JunoFunctions) -> Result<Vec<FunctionID>, SchedulerError> { - match self { - Value::JunoFunction { func } => Ok(funcs.get_function(*func).clone()), - Value::HerculesFunction { func } => Ok(vec![*func]), - Value::Selection { selection } => { - let mut result = vec![]; - for val in selection { - result.extend(val.as_functions(funcs)?); + pub fn add(&mut self, other: Self) { + match &mut self.selection { + FunctionSelectionState::Nothing() => self.selection = other.selection, + FunctionSelectionState::Everything() => (), + FunctionSelectionState::Selection(selection) => match other.selection { + FunctionSelectionState::Nothing() => (), + FunctionSelectionState::Everything() => self.selection = other.selection, + FunctionSelectionState::Selection(other) => { + for idx in 0..self.num_nodes { + if other[idx] { + selection.set(idx, true); + } + } + } + }, + } + } + + pub fn union(mut self, other: Self) -> Self { + self.add(other); + self + } + + pub fn intersection(self, other: Self) -> Self { + let num_nodes = self.num_nodes; + + match self.selection { + FunctionSelectionState::Nothing() => self, + FunctionSelectionState::Everything() => other, + FunctionSelectionState::Selection(mut selection) => match other.selection { + FunctionSelectionState::Nothing() => other, + FunctionSelectionState::Everything() => Self { + num_nodes, + selection: FunctionSelectionState::Selection(selection), + }, + FunctionSelectionState::Selection(other) => { + for idx in 0..num_nodes { + if !other[idx] { + selection.set(idx, false); + } + } + Self { + num_nodes, + selection: FunctionSelectionState::Selection(selection), + } + } + }, + } + } + + pub fn difference(self, other: Self) -> Self { + let num_nodes = self.num_nodes; + + match self.selection { + FunctionSelectionState::Nothing() => self, + FunctionSelectionState::Everything() => match other.selection { + FunctionSelectionState::Nothing() => self, + FunctionSelectionState::Everything() => Self { + num_nodes, + selection: FunctionSelectionState::Nothing(), + }, + FunctionSelectionState::Selection(other) => Self { + num_nodes, + selection: FunctionSelectionState::Selection(other.not()), + }, + }, + FunctionSelectionState::Selection(mut selection) => match other.selection { + FunctionSelectionState::Nothing() => Self { + num_nodes, + selection: FunctionSelectionState::Selection(selection), + }, + FunctionSelectionState::Everything() => Self { + num_nodes, + selection: FunctionSelectionState::Nothing(), + }, + FunctionSelectionState::Selection(other) => { + for idx in 0..num_nodes { + if other[idx] { + selection.set(idx, false); + } + } + Self { + num_nodes, + selection: FunctionSelectionState::Selection(selection), + } + } + }, + } + } + + pub fn as_nodes(&self) -> Vec<NodeID> { + match &self.selection { + FunctionSelectionState::Nothing() => vec![], + FunctionSelectionState::Everything() => (0..self.num_nodes).map(NodeID::new).collect(), + FunctionSelectionState::Selection(selection) => (0..self.num_nodes) + .map(NodeID::new) + .filter(|node| selection[node.idx()]) + .collect(), + } + } +} + +#[derive(Debug, Clone)] +pub struct ModuleSelection { + selection: Vec<FunctionSelection>, +} + +impl ModuleSelection { + pub fn new(pm: &PassManager) -> Self { + ModuleSelection { + selection: (0..pm.functions.len()) + .map(|idx| FunctionSelection::new(pm.functions[idx].nodes.len())) + .collect(), + } + } + + pub fn is_everything(&self) -> bool { + self.selection.iter().all(|func| func.is_everything()) + } + + pub fn add_everything(&mut self) { + self.selection + .iter_mut() + .for_each(|func| func.add_everything()); + } + + pub fn add_function(&mut self, func: FunctionID) { + self.selection[func.idx()].add_everything(); + } + + pub fn add_label(&mut self, func: FunctionID, label: LabelID, pm: &PassManager) { + pm.functions[func.idx()] + .labels + .iter() + .enumerate() + .for_each(|(node_idx, labels)| { + if labels.contains(&label) { + self.selection[func.idx()].add_node(NodeID::new(node_idx)); + } else { + } + }); + } + + pub fn add(&mut self, other: Self) { + self.selection + .iter_mut() + .zip(other.selection.into_iter()) + .for_each(|(this, other)| this.add(other)); + } + + pub fn union(mut self, other: Self) -> Self { + self.add(other); + self + } + + pub fn intersection(self, other: Self) -> Self { + Self { + selection: self + .selection + .into_iter() + .zip(other.selection.into_iter()) + .map(|(this, other)| this.intersection(other)) + .collect(), + } + } + + pub fn difference(self, other: Self) -> Self { + Self { + selection: self + .selection + .into_iter() + .zip(other.selection.into_iter()) + .map(|(this, other)| this.difference(other)) + .collect(), + } + } + + pub fn as_funcs(&self) -> Result<Vec<FunctionID>, SchedulerError> { + let mut res = vec![]; + for (func_id, func_selection) in self.selection.iter().enumerate() { + match func_selection.selection { + FunctionSelectionState::Nothing() => (), + FunctionSelectionState::Everything() => res.push(FunctionID::new(func_id)), + _ => { + return Err(SchedulerError::SemanticError( + "Expected selection of functions, found fine-grain selection".to_string(), + )) } - Ok(result) } - Value::Label { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found label".to_string(), - )), - Value::Record { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found record".to_string(), - )), - Value::Everything {} => Err(SchedulerError::SemanticError( - "Expected functions, found everything".to_string(), - )), - Value::Integer { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found integer".to_string(), - )), - Value::Boolean { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found boolean".to_string(), - )), - Value::String { .. } => Err(SchedulerError::SemanticError( - "Expected functions, found string".to_string(), - )), } + Ok(res) + } + + pub fn as_nodes(&self) -> Vec<(FunctionID, NodeID)> { + let mut res = vec![]; + for (func_id, func_selection) in self.selection.iter().enumerate() { + for node_id in func_selection.as_nodes() { + res.push((FunctionID::new(func_id), node_id)); + } + } + res + } + + pub fn as_func_states(&self) -> Vec<&FunctionSelectionState> { + self.selection + .iter() + .map(|selection| &selection.selection) + .collect() + } +} + +#[derive(Debug, Clone)] +pub enum Value { + Label { + labels: Vec<LabelInfo>, + }, + JunoFunction { + func: JunoFunctionID, + }, + HerculesFunction { + func: FunctionID, + }, + Record { + fields: HashMap<String, Value>, + }, + Everything {}, + Selection { + selection: Vec<Value>, + }, + SetOp { + op: parser::SetOp, + lhs: Box<Value>, + rhs: Box<Value>, + }, + Integer { + val: usize, + }, + Boolean { + val: bool, + }, + String { + val: String, + }, +} + +impl Value { + fn is_everything(&self) -> bool { + match self { + Value::Everything {} => true, + _ => false, + } + } + + fn as_selection( + &self, + pm: &PassManager, + funcs: &JunoFunctions, + ) -> Result<ModuleSelection, SchedulerError> { + let mut selection = ModuleSelection::new(pm); + self.add_to_selection(pm, funcs, &mut selection)?; + Ok(selection) } - fn as_locations(&self, funcs: &JunoFunctions) -> Result<Vec<CodeLocation>, SchedulerError> { + fn add_to_selection( + &self, + pm: &PassManager, + funcs: &JunoFunctions, + selection: &mut ModuleSelection, + ) -> Result<(), SchedulerError> { match self { - Value::Label { labels } => Ok(labels.iter().map(|l| CodeLocation::Label(*l)).collect()), - Value::JunoFunction { func } => Ok(funcs - .get_function(*func) - .iter() - .map(|f| CodeLocation::Function(*f)) - .collect()), - Value::HerculesFunction { func } => Ok(vec![CodeLocation::Function(*func)]), - Value::Selection { selection } => { - let mut result = vec![]; - for val in selection { - result.extend(val.as_locations(funcs)?); + Value::Label { labels } => { + for LabelInfo { func, label } in labels { + selection.add_label(*func, *label, pm); } - Ok(result) } - Value::Record { .. } => Err(SchedulerError::SemanticError( - "Expected code locations, found record".to_string(), - )), - Value::Everything {} => { - panic!("Internal error, check is_everything() before using as_functions()") + Value::JunoFunction { func } => { + for func in funcs.get_function(*func) { + selection.add_function(*func); + } + } + Value::HerculesFunction { func } => selection.add_function(*func), + Value::Everything {} => selection.add_everything(), + Value::Selection { selection: values } => { + for value in values { + value.add_to_selection(pm, funcs, selection)?; + } + } + Value::SetOp { op, lhs, rhs } => { + let lhs = lhs.as_selection(pm, funcs)?; + let rhs = rhs.as_selection(pm, funcs)?; + let res = match op { + parser::SetOp::Union => lhs.union(rhs), + parser::SetOp::Intersection => lhs.intersection(rhs), + parser::SetOp::Difference => lhs.difference(rhs), + }; + selection.add(res); + } + Value::Record { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found record".to_string(), + )); + } + Value::Integer { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found integer".to_string(), + )); + } + Value::Boolean { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found boolean".to_string(), + )); + } + Value::String { .. } => { + return Err(SchedulerError::SemanticError( + "Expected code selection, found string".to_string(), + )); } - Value::Integer { .. } => Err(SchedulerError::SemanticError( - "Expected code locations, found integer".to_string(), - )), - Value::Boolean { .. } => Err(SchedulerError::SemanticError( - "Expected code locations, found boolean".to_string(), - )), - Value::String { .. } => Err(SchedulerError::SemanticError( - "Expected code locations, found string".to_string(), - )), } + Ok(()) } } @@ -974,7 +1221,7 @@ fn schedule_interpret( for label in selection { let (label, modified) = interp_expr(pm, label, stringtab, env, functions)?; changed |= modified; - add_schedule(pm, sched.clone(), label.as_labels()?); + add_schedule(pm, sched.clone(), label.as_selection(pm, functions)?); } Ok(changed) } @@ -991,7 +1238,11 @@ fn schedule_interpret( for func in selection { let (func, modified) = interp_expr(pm, func, stringtab, env, functions)?; changed |= modified; - add_device(pm, device.clone(), func.as_functions(functions)?); + add_device( + pm, + device.clone(), + func.as_selection(pm, functions)?.as_funcs()?, + ); } Ok(changed) } @@ -1017,6 +1268,18 @@ fn interp_expr( ScheduleExp::Integer { val } => Ok((Value::Integer { val: *val }, false)), ScheduleExp::Boolean { val } => Ok((Value::Boolean { val: *val }, false)), ScheduleExp::String { val } => Ok((Value::String { val: val.clone() }, false)), + ScheduleExp::SetOp { op, lhs, rhs } => { + let (lhs, lhs_mod) = interp_expr(pm, lhs, stringtab, env, functions)?; + let (rhs, rhs_mod) = interp_expr(pm, rhs, stringtab, env, functions)?; + Ok(( + Value::SetOp { + op: *op, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + }, + lhs_mod || rhs_mod, + )) + } ScheduleExp::Field { collect, field } => { let (lhs, changed) = interp_expr(pm, collect, stringtab, env, functions)?; match lhs { @@ -1025,7 +1288,8 @@ fn interp_expr( | Value::Everything { .. } | Value::Integer { .. } | Value::Boolean { .. } - | Value::String { .. } => Err(SchedulerError::UndefinedField(field.clone())), + | Value::String { .. } + | Value::SetOp { .. } => Err(SchedulerError::UndefinedField(field.clone())), Value::JunoFunction { func } => { match pm.labels.borrow().iter().position(|s| s == field) { None => Err(SchedulerError::UndefinedLabel(field.clone())), @@ -1074,26 +1338,18 @@ fn interp_expr( } let selection = match on { - Selector::Everything() => None, + Selector::Everything() => Value::Everything {}, Selector::Selection(selection) => { - let mut locs = vec![]; - let mut everything = false; + let mut vals = vec![]; for loc in selection { let (val, modified) = interp_expr(pm, loc, stringtab, env, functions)?; changed |= modified; - if val.is_everything() { - everything = true; - break; - } - locs.extend(val.as_locations(functions)?); - } - if everything { - None - } else { - Some(locs) + vals.push(val); } + Value::Selection { selection: vals } } - }; + } + .as_selection(pm, functions)?; let (res, modified) = run_pass(pm, *pass, arg_vals, selection)?; changed |= modified; @@ -1277,6 +1533,15 @@ fn update_value( func: FunctionID::new(i), }) } + Value::SetOp { op, lhs, rhs } => { + let lhs = update_value(*lhs, func_idx, juno_func_idx)?; + let rhs = update_value(*rhs, func_idx, juno_func_idx)?; + Some(Value::SetOp { + op, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + }) + } Value::Everything {} => Some(Value::Everything {}), Value::Integer { val } => Some(Value::Integer { val }), Value::Boolean { val } => Some(Value::Boolean { val }), @@ -1284,18 +1549,9 @@ fn update_value( } } -fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo>) { - for LabelInfo { func, label } in label_ids { - let nodes = pm.functions[func.idx()] - .labels - .iter() - .enumerate() - .filter(|(_, ls)| ls.contains(&label)) - .map(|(i, _)| i) - .collect::<Vec<_>>(); - for node in nodes { - pm.functions[func.idx()].schedules[node].push(sched.clone()); - } +fn add_schedule(pm: &mut PassManager, sched: Schedule, selection: ModuleSelection) { + for (func, node) in selection.as_nodes() { + pm.functions[func.idx()].schedules[node.idx()].push(sched.clone()); } } @@ -1305,31 +1561,7 @@ fn add_device(pm: &mut PassManager, device: Device, funcs: Vec<FunctionID>) { } } -#[derive(Debug, Clone)] -enum FunctionSelection { - Nothing(), - Everything(), - Labels(HashSet<LabelID>), -} - -impl FunctionSelection { - fn add_label(&mut self, label: LabelID) { - match self { - FunctionSelection::Nothing() => { - *self = FunctionSelection::Labels(HashSet::from([label])); - } - FunctionSelection::Everything() => {} - FunctionSelection::Labels(set) => { - set.insert(label); - } - } - } - - fn add_everything(&mut self) { - *self = FunctionSelection::Everything(); - } -} - +// Builds (completely mutable) editors for all functions fn build_editors<'a>(pm: &'a mut PassManager) -> Vec<FunctionEditor<'a>> { pm.make_def_uses(); let def_uses = pm.def_uses.take().unwrap(); @@ -1351,93 +1583,17 @@ fn build_editors<'a>(pm: &'a mut PassManager) -> Vec<FunctionEditor<'a>> { .collect() } -// With a selection, we process it to identify which labels in which functions are to be selected -fn construct_selection(pm: &PassManager, selection: Vec<CodeLocation>) -> Vec<FunctionSelection> { - let mut selected = vec![FunctionSelection::Nothing(); pm.functions.len()]; - for loc in selection { - match loc { - CodeLocation::Label(label) => selected[label.func.idx()].add_label(label.label), - CodeLocation::Function(func) => selected[func.idx()].add_everything(), - } - } - selected -} - -// Given a selection, constructs the set of functions which are selected (and each must be selected -// fully) -fn selection_of_functions( - pm: &PassManager, - selection: Option<Vec<CodeLocation>>, -) -> Option<Vec<FunctionID>> { - if let Some(selection) = selection { - let selection = construct_selection(pm, selection); - - let mut result = vec![]; - - for (idx, selected) in selection.into_iter().enumerate() { - match selected { - FunctionSelection::Nothing() => {} - FunctionSelection::Everything() => result.push(FunctionID::new(idx)), - FunctionSelection::Labels(_) => { - return None; - } - } - } - - Some(result) - } else { - Some( - pm.functions - .iter() - .enumerate() - .map(|(i, _)| FunctionID::new(i)) - .collect(), - ) - } -} - // Given a selection, constructs the set of the nodes selected for a single function, returning the // function's id fn selection_as_set( pm: &PassManager, - selection: Option<Vec<CodeLocation>>, + selection: ModuleSelection, ) -> Option<(BTreeSet<NodeID>, FunctionID)> { - if let Some(selection) = selection { - let selection = construct_selection(pm, selection); - let mut result = None; - - for (idx, (selected, func)) in selection.into_iter().zip(pm.functions.iter()).enumerate() { - match selected { - FunctionSelection::Nothing() => {} - FunctionSelection::Everything() => match result { - Some(_) => { - return None; - } - None => { - result = Some(( - (0..func.nodes.len()).map(|i| NodeID::new(i)).collect(), - FunctionID::new(idx), - )); - } - }, - FunctionSelection::Labels(labels) => match result { - Some(_) => { - return None; - } - None => { - result = Some(( - (0..func.nodes.len()) - .filter(|i| !func.labels[*i].is_disjoint(&labels)) - .map(|i| NodeID::new(i)) - .collect(), - FunctionID::new(idx), - )); - } - }, - } - } - - result + let nodes = selection.as_nodes(); + let mut funcs = nodes.iter().map(|(func, _)| *func).collect::<BTreeSet<_>>(); + if funcs.len() == 1 { + let func = funcs.pop_first().unwrap(); + Some((nodes.into_iter().map(|(_, node)| node).collect(), func)) } else { None } @@ -1445,47 +1601,22 @@ fn selection_as_set( fn build_selection<'a>( pm: &'a mut PassManager, - selection: Option<Vec<CodeLocation>>, + selection: ModuleSelection, create_editors_for_nothing_functions: bool, ) -> Vec<Option<FunctionEditor<'a>>> { // Build def uses, which are needed for the editors we'll construct pm.make_def_uses(); let def_uses = pm.def_uses.take().unwrap(); - if let Some(selection) = selection { - let selected = construct_selection(pm, selection); - - pm.functions - .iter_mut() - .zip(selected.iter()) - .zip(def_uses.iter()) - .enumerate() - .map(|(idx, ((func, selected), def_use))| match selected { - FunctionSelection::Nothing() => { - if create_editors_for_nothing_functions { - Some(FunctionEditor::new_immutable( - func, - FunctionID::new(idx), - &pm.constants, - &pm.dynamic_constants, - &pm.types, - &pm.labels, - def_use, - )) - } else { - None - } - } - FunctionSelection::Everything() => Some(FunctionEditor::new( - func, - FunctionID::new(idx), - &pm.constants, - &pm.dynamic_constants, - &pm.types, - &pm.labels, - def_use, - )), - FunctionSelection::Labels(labels) => Some(FunctionEditor::new_labeled( + selection + .as_func_states() + .into_iter() + .zip(pm.functions.iter_mut()) + .zip(def_uses.iter()) + .enumerate() + .map(|(idx, ((selection, func), def_use))| match selection { + FunctionSelectionState::Nothing() if create_editors_for_nothing_functions => { + Some(FunctionEditor::new_immutable( func, FunctionID::new(idx), &pm.constants, @@ -1493,23 +1624,37 @@ fn build_selection<'a>( &pm.types, &pm.labels, def_use, - labels, - )), - }) - .collect() - } else { - build_editors(pm) - .into_iter() - .map(|func| Some(func)) - .collect() - } + )) + } + FunctionSelectionState::Nothing() => None, + FunctionSelectionState::Everything() => Some(FunctionEditor::new( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + )), + FunctionSelectionState::Selection(mask) => Some(FunctionEditor::new_mask( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + mask.clone(), + )), + }) + .collect() } fn run_pass( pm: &mut PassManager, pass: Pass, args: Vec<Value>, - selection: Option<Vec<CodeLocation>>, + selection: ModuleSelection, ) -> Result<(Value, bool), SchedulerError> { let mut result = Value::Record { fields: HashMap::new(), @@ -1571,7 +1716,7 @@ fn run_pass( pm.clear_analyses(); } Pass::AutoOutline => { - let Some(funcs) = selection_of_functions(pm, selection) else { + let Some(funcs) = selection.as_funcs().ok() else { return Err(SchedulerError::PassError { pass: "autoOutline".to_string(), error: "must be applied to whole functions".to_string(), @@ -1939,7 +2084,7 @@ fn run_pass( } Pass::GCM => { assert!(args.is_empty()); - if let Some(_) = selection { + if !selection.is_everything() { return Err(SchedulerError::PassError { pass: "gcm".to_string(), error: "must be applied to the entire module".to_string(), @@ -2088,8 +2233,9 @@ fn run_pass( None => false, }; - let selection = - selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError { + let selection = selection + .as_funcs() + .map_err(|_| SchedulerError::PassError { pass: "xdot".to_string(), error: "expected coarse-grained selection (can't partially xdot a function)" .to_string(), @@ -2264,7 +2410,7 @@ fn run_pass( }); } - if let Some(funcs) = selection_of_functions(pm, selection) + if let Some(funcs) = selection.as_funcs().ok() && funcs.len() == 1 { let func = funcs[0]; @@ -2734,8 +2880,9 @@ fn run_pass( None => true, }; - let selection = - selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError { + let selection = selection + .as_funcs() + .map_err(|_| SchedulerError::PassError { pass: "xdot".to_string(), error: "expected coarse-grained selection (can't partially xdot a function)" .to_string(),