use crate::ir::*; use crate::labels::*; use crate::parser; use hercules_cg::*; use hercules_ir::*; use hercules_opt::*; use bitvec::prelude::*; use serde::Deserialize; use serde::Serialize; use tempfile::TempDir; use juno_utils::env::Env; use juno_utils::stringtab::StringTable; use std::cell::RefCell; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::env; use std::fmt; use std::fs::File; use std::io::Write; use std::iter::zip; use std::ops::Not; use std::process::{Command, Stdio}; #[derive(Debug, Clone)] enum FunctionSelectionState { Nothing(), Everything(), Selection(BitVec<u8, Lsb0>), } #[derive(Debug, Clone)] pub struct FunctionSelection { num_nodes: usize, selection: FunctionSelectionState, } impl FunctionSelection { pub fn new(num_nodes: usize) -> Self { FunctionSelection { num_nodes, selection: FunctionSelectionState::Nothing(), } } pub fn is_everything(&self) -> bool { match self.selection { FunctionSelectionState::Everything() => true, _ => false, } } pub fn add_everything(&mut self) { self.selection = FunctionSelectionState::Everything(); } pub fn add_node(&mut self, node: NodeID) { match &mut self.selection { FunctionSelectionState::Nothing() => { let mut selection = bitvec![u8, Lsb0; 0; self.num_nodes]; selection.set(node.idx(), true); self.selection = FunctionSelectionState::Selection(selection); } FunctionSelectionState::Everything() => {} FunctionSelectionState::Selection(selection) => { selection.set(node.idx(), true); } } } pub fn add(&mut self, other: Self) { match &mut self.selection { FunctionSelectionState::Nothing() => self.selection = other.selection, FunctionSelectionState::Everything() => (), FunctionSelectionState::Selection(selection) => match other.selection { FunctionSelectionState::Nothing() => (), FunctionSelectionState::Everything() => self.selection = other.selection, FunctionSelectionState::Selection(other) => { for idx in 0..self.num_nodes { if other[idx] { selection.set(idx, true); } } } }, } } pub fn union(mut self, other: Self) -> Self { self.add(other); self } pub fn intersection(self, other: Self) -> Self { let num_nodes = self.num_nodes; match self.selection { FunctionSelectionState::Nothing() => self, FunctionSelectionState::Everything() => other, FunctionSelectionState::Selection(mut selection) => match other.selection { FunctionSelectionState::Nothing() => other, FunctionSelectionState::Everything() => Self { num_nodes, selection: FunctionSelectionState::Selection(selection), }, FunctionSelectionState::Selection(other) => { for idx in 0..num_nodes { if !other[idx] { selection.set(idx, false); } } Self { num_nodes, selection: FunctionSelectionState::Selection(selection), } } }, } } pub fn difference(self, other: Self) -> Self { let num_nodes = self.num_nodes; match self.selection { FunctionSelectionState::Nothing() => self, FunctionSelectionState::Everything() => match other.selection { FunctionSelectionState::Nothing() => self, FunctionSelectionState::Everything() => Self { num_nodes, selection: FunctionSelectionState::Nothing(), }, FunctionSelectionState::Selection(other) => Self { num_nodes, selection: FunctionSelectionState::Selection(other.not()), }, }, FunctionSelectionState::Selection(mut selection) => match other.selection { FunctionSelectionState::Nothing() => Self { num_nodes, selection: FunctionSelectionState::Selection(selection), }, FunctionSelectionState::Everything() => Self { num_nodes, selection: FunctionSelectionState::Nothing(), }, FunctionSelectionState::Selection(other) => { for idx in 0..num_nodes { if other[idx] { selection.set(idx, false); } } Self { num_nodes, selection: FunctionSelectionState::Selection(selection), } } }, } } pub fn as_nodes(&self) -> Vec<NodeID> { match &self.selection { FunctionSelectionState::Nothing() => vec![], FunctionSelectionState::Everything() => (0..self.num_nodes).map(NodeID::new).collect(), FunctionSelectionState::Selection(selection) => (0..self.num_nodes) .map(NodeID::new) .filter(|node| selection[node.idx()]) .collect(), } } } #[derive(Debug, Clone)] pub struct ModuleSelection { selection: Vec<FunctionSelection>, } impl ModuleSelection { pub fn new(pm: &PassManager) -> Self { ModuleSelection { selection: (0..pm.functions.len()) .map(|idx| FunctionSelection::new(pm.functions[idx].nodes.len())) .collect(), } } pub fn is_everything(&self) -> bool { self.selection.iter().all(|func| func.is_everything()) } pub fn add_everything(&mut self) { self.selection .iter_mut() .for_each(|func| func.add_everything()); } pub fn add_function(&mut self, func: FunctionID) { self.selection[func.idx()].add_everything(); } pub fn add_label(&mut self, func: FunctionID, label: LabelID, pm: &PassManager) { pm.functions[func.idx()] .labels .iter() .enumerate() .for_each(|(node_idx, labels)| { if labels.contains(&label) { self.selection[func.idx()].add_node(NodeID::new(node_idx)); } else { } }); } pub fn add(&mut self, other: Self) { self.selection .iter_mut() .zip(other.selection.into_iter()) .for_each(|(this, other)| this.add(other)); } pub fn union(mut self, other: Self) -> Self { self.add(other); self } pub fn intersection(self, other: Self) -> Self { Self { selection: self .selection .into_iter() .zip(other.selection.into_iter()) .map(|(this, other)| this.intersection(other)) .collect(), } } pub fn difference(self, other: Self) -> Self { Self { selection: self .selection .into_iter() .zip(other.selection.into_iter()) .map(|(this, other)| this.difference(other)) .collect(), } } pub fn as_funcs(&self) -> Result<Vec<FunctionID>, SchedulerError> { let mut res = vec![]; for (func_id, func_selection) in self.selection.iter().enumerate() { match func_selection.selection { FunctionSelectionState::Nothing() => (), FunctionSelectionState::Everything() => res.push(FunctionID::new(func_id)), _ => { return Err(SchedulerError::SemanticError( "Expected selection of functions, found fine-grain selection".to_string(), )) } } } Ok(res) } pub fn as_nodes(&self) -> Vec<(FunctionID, NodeID)> { let mut res = vec![]; for (func_id, func_selection) in self.selection.iter().enumerate() { for node_id in func_selection.as_nodes() { res.push((FunctionID::new(func_id), node_id)); } } res } pub fn as_func_states(&self) -> Vec<&FunctionSelectionState> { self.selection .iter() .map(|selection| &selection.selection) .collect() } } #[derive(Debug, Clone)] pub enum Value { Label { labels: Vec<LabelInfo>, }, JunoFunction { func: JunoFunctionID, }, HerculesFunction { func: FunctionID, }, Record { fields: HashMap<String, Value>, }, Tuple { values: Vec<Value>, }, Everything {}, Selection { selection: Vec<Value>, }, SetOp { op: parser::SetOp, lhs: Box<Value>, rhs: Box<Value>, }, Integer { val: usize, }, Boolean { val: bool, }, String { val: String, }, } impl Value { fn is_everything(&self) -> bool { match self { Value::Everything {} => true, _ => false, } } fn as_selection( &self, pm: &PassManager, funcs: &JunoFunctions, ) -> Result<ModuleSelection, SchedulerError> { let mut selection = ModuleSelection::new(pm); self.add_to_selection(pm, funcs, &mut selection)?; Ok(selection) } fn add_to_selection( &self, pm: &PassManager, funcs: &JunoFunctions, selection: &mut ModuleSelection, ) -> Result<(), SchedulerError> { match self { Value::Label { labels } => { for LabelInfo { func, label } in labels { selection.add_label(*func, *label, pm); } } Value::JunoFunction { func } => { for func in funcs.get_function(*func) { selection.add_function(*func); } } Value::HerculesFunction { func } => selection.add_function(*func), Value::Everything {} => selection.add_everything(), Value::Selection { selection: values } => { for value in values { value.add_to_selection(pm, funcs, selection)?; } } Value::SetOp { op, lhs, rhs } => { let lhs = lhs.as_selection(pm, funcs)?; let rhs = rhs.as_selection(pm, funcs)?; let res = match op { parser::SetOp::Union => lhs.union(rhs), parser::SetOp::Intersection => lhs.intersection(rhs), parser::SetOp::Difference => lhs.difference(rhs), }; selection.add(res); } Value::Record { .. } => { return Err(SchedulerError::SemanticError( "Expected code selection, found record".to_string(), )); } Value::Tuple { .. } => { return Err(SchedulerError::SemanticError( "Expected code selection, found tuple".to_string(), )); } Value::Integer { .. } => { return Err(SchedulerError::SemanticError( "Expected code selection, found integer".to_string(), )); } Value::Boolean { .. } => { return Err(SchedulerError::SemanticError( "Expected code selection, found boolean".to_string(), )); } Value::String { .. } => { return Err(SchedulerError::SemanticError( "Expected code selection, found string".to_string(), )); } } Ok(()) } } #[derive(Debug, Clone, Serialize, Deserialize)] pub enum SchedulerError { UndefinedVariable(String), UndefinedField(String), UndefinedLabel(String), SemanticError(String), PassError { pass: String, error: String }, FixpointFailure(), } impl fmt::Display for SchedulerError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { SchedulerError::UndefinedVariable(nm) => write!(f, "Undefined variable '{}'", nm), SchedulerError::UndefinedField(nm) => write!(f, "No field '{}'", nm), SchedulerError::UndefinedLabel(nm) => write!(f, "No label '{}'", nm), SchedulerError::SemanticError(msg) => write!(f, "{}", msg), SchedulerError::PassError { pass, error } => { write!(f, "Error in pass {}: {}", pass, error) } SchedulerError::FixpointFailure() => { write!(f, "Fixpoint did not converge within limit") } } } } #[derive(Debug, Clone)] pub struct PassManager { functions: Vec<Function>, types: RefCell<Vec<Type>>, constants: RefCell<Vec<Constant>>, dynamic_constants: RefCell<Vec<DynamicConstant>>, labels: RefCell<Vec<String>>, // Cached analysis results. pub def_uses: Option<Vec<ImmutableDefUseMap>>, pub reverse_postorders: Option<Vec<Vec<NodeID>>>, pub typing: Option<ModuleTyping>, pub control_subgraphs: Option<Vec<Subgraph>>, pub doms: Option<Vec<DomTree>>, pub postdoms: Option<Vec<DomTree>>, pub fork_join_maps: Option<Vec<HashMap<NodeID, NodeID>>>, pub fork_join_nests: Option<Vec<HashMap<NodeID, Vec<NodeID>>>>, pub fork_control_maps: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub fork_trees: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub loops: Option<Vec<LoopTree>>, pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>, pub no_reset_constants: Option<Vec<BTreeSet<NodeID>>>, pub collection_objects: Option<CollectionObjects>, pub callgraph: Option<CallGraph>, pub devices: Option<Vec<Device>>, pub bbs: Option<Vec<BasicBlocks>>, pub node_colors: Option<NodeColors>, pub backing_allocations: Option<BackingAllocations>, } impl PassManager { pub fn new(module: Module) -> Self { let Module { functions, types, constants, dynamic_constants, labels, } = module; PassManager { functions, types: RefCell::new(types), constants: RefCell::new(constants), dynamic_constants: RefCell::new(dynamic_constants), labels: RefCell::new(labels), def_uses: None, reverse_postorders: None, typing: None, control_subgraphs: None, doms: None, postdoms: None, fork_join_maps: None, fork_join_nests: None, fork_control_maps: None, fork_trees: None, loops: None, reduce_cycles: None, nodes_in_fork_joins: None, reduce_einsums: None, no_reset_constants: None, collection_objects: None, callgraph: None, devices: None, bbs: None, node_colors: None, backing_allocations: None, } } pub fn make_def_uses(&mut self) { if self.def_uses.is_none() { self.def_uses = Some(self.functions.iter().map(def_use).collect()); } } pub fn make_reverse_postorders(&mut self) { if self.reverse_postorders.is_none() { self.make_def_uses(); self.reverse_postorders = Some( self.def_uses .as_ref() .unwrap() .iter() .map(reverse_postorder) .collect(), ); } } pub fn make_typing(&mut self) { if self.typing.is_none() { self.make_reverse_postorders(); self.typing = Some( typecheck( &self.functions, &mut self.types.borrow_mut(), &self.constants.borrow(), &mut self.dynamic_constants.borrow_mut(), self.reverse_postorders.as_ref().unwrap(), ) .unwrap(), ); } } pub fn make_control_subgraphs(&mut self) { if self.control_subgraphs.is_none() { self.make_def_uses(); self.control_subgraphs = Some( zip(&self.functions, self.def_uses.as_ref().unwrap()) .map(|(function, def_use)| control_subgraph(function, def_use)) .collect(), ); } } pub fn make_doms(&mut self) { if self.doms.is_none() { self.make_control_subgraphs(); self.doms = Some( self.control_subgraphs .as_ref() .unwrap() .iter() .map(|subgraph| dominator(subgraph, NodeID::new(0))) .collect(), ); } } pub fn make_postdoms(&mut self) { if self.postdoms.is_none() { self.make_control_subgraphs(); self.postdoms = Some( zip( self.control_subgraphs.as_ref().unwrap().iter(), self.functions.iter(), ) .map(|(subgraph, function)| dominator(subgraph, NodeID::new(function.nodes.len()))) .collect(), ); } } pub fn make_fork_join_maps(&mut self) { if self.fork_join_maps.is_none() { self.make_control_subgraphs(); self.fork_join_maps = Some( zip( self.functions.iter(), self.control_subgraphs.as_ref().unwrap().iter(), ) .map(|(function, subgraph)| fork_join_map(function, subgraph)) .collect(), ); } } pub fn make_fork_join_nests(&mut self) { if self.fork_join_nests.is_none() { self.make_doms(); self.make_fork_join_maps(); self.fork_join_nests = Some( zip( self.functions.iter(), zip( self.doms.as_ref().unwrap().iter(), self.fork_join_maps.as_ref().unwrap().iter(), ), ) .map(|(function, (dom, fork_join_map))| { compute_fork_join_nesting(function, dom, fork_join_map) }) .collect(), ); } } pub fn make_fork_control_maps(&mut self) { if self.fork_control_maps.is_none() { self.make_fork_join_nests(); self.fork_control_maps = Some( self.fork_join_nests .as_ref() .unwrap() .iter() .map(fork_control_map) .collect(), ); } } pub fn make_fork_trees(&mut self) { if self.fork_trees.is_none() { self.make_fork_join_nests(); self.fork_trees = Some( zip( self.functions.iter(), self.fork_join_nests.as_ref().unwrap().iter(), ) .map(|(function, fork_join_nesting)| fork_tree(function, fork_join_nesting)) .collect(), ); } } pub fn make_loops(&mut self) { if self.loops.is_none() { self.make_control_subgraphs(); self.make_doms(); self.make_fork_join_maps(); let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter(); let doms = self.doms.as_ref().unwrap().iter(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); self.loops = Some( zip(control_subgraphs, zip(doms, fork_join_maps)) .map(|(control_subgraph, (dom, fork_join_map))| { loops(control_subgraph, NodeID::new(0), dom, fork_join_map) }) .collect(), ); } } pub fn make_reduce_cycles(&mut self) { if self.reduce_cycles.is_none() { self.make_def_uses(); self.make_fork_join_maps(); self.make_fork_join_nests(); let def_uses = self.def_uses.as_ref().unwrap().iter(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); self.reduce_cycles = Some( self.functions .iter() .zip(fork_join_maps) .zip(fork_join_nests) .zip(def_uses) .map(|(((function, fork_join_map), fork_join_nests), def_use)| { reduce_cycles(function, def_use, fork_join_map, fork_join_nests) }) .collect(), ); } } pub fn make_nodes_in_fork_joins(&mut self) { if self.nodes_in_fork_joins.is_none() { self.make_def_uses(); self.make_fork_join_maps(); self.make_reduce_cycles(); self.nodes_in_fork_joins = Some( zip( self.functions.iter(), zip( self.def_uses.as_ref().unwrap().iter(), zip( self.fork_join_maps.as_ref().unwrap().iter(), self.reduce_cycles.as_ref().unwrap().iter(), ), ), ) .map(|(function, (def_use, (fork_join_map, reduce_cycles)))| { nodes_in_fork_joins(function, def_use, fork_join_map, reduce_cycles) }) .collect(), ); } } pub fn make_reduce_einsums(&mut self) { if self.reduce_einsums.is_none() { self.make_def_uses(); self.make_typing(); self.make_fork_join_maps(); self.make_fork_join_nests(); self.make_nodes_in_fork_joins(); let def_uses = self.def_uses.as_ref().unwrap().iter(); let typing = self.typing.as_ref().unwrap().iter(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); let nodes_in_fork_joins = self.nodes_in_fork_joins.as_ref().unwrap().iter(); self.reduce_einsums = Some( self.functions .iter() .zip(def_uses) .zip(typing) .zip(fork_join_maps) .zip(fork_join_nests) .zip(nodes_in_fork_joins) .map( |( ((((function, def_use), typing), fork_join_map), fork_join_nest), nodes_in_fork_joins, )| { einsum( function, &self.types.borrow(), &self.constants.borrow(), def_use, typing, fork_join_map, fork_join_nest, nodes_in_fork_joins, ) }, ) .collect(), ); } } pub fn make_no_reset_constants(&mut self) { if self.no_reset_constants.is_none() { self.make_reverse_postorders(); self.make_typing(); self.make_collection_objects(); self.make_reduce_einsums(); let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter(); let typing = self.typing.as_ref().unwrap().iter(); let collection_objects = self.collection_objects.as_ref().unwrap().iter(); let reduce_einsums = self.reduce_einsums.as_ref().unwrap().iter(); self.no_reset_constants = Some( self.functions .iter() .zip(reverse_postorders) .zip(typing) .zip(collection_objects) .zip(reduce_einsums) .map( |( (((function, reverse_postorder), typing), collection_object), reduce_einsum, )| { no_reset_constant_collections( function, &self.types.borrow(), reverse_postorder, typing, collection_object.1, reduce_einsum, ) }, ) .collect(), ); } } pub fn make_collection_objects(&mut self) { if self.collection_objects.is_none() { self.make_reverse_postorders(); self.make_typing(); self.make_callgraph(); let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); let typing = self.typing.as_ref().unwrap(); let callgraph = self.callgraph.as_ref().unwrap(); self.collection_objects = Some(collection_objects( &self.functions, &self.types.borrow(), reverse_postorders, typing, callgraph, )); } } pub fn make_callgraph(&mut self) { if self.callgraph.is_none() { self.callgraph = Some(callgraph(&self.functions)); } } pub fn make_devices(&mut self) { if self.devices.is_none() { self.make_callgraph(); let callgraph = self.callgraph.as_ref().unwrap(); self.devices = Some(device_placement(&self.functions, callgraph)); } } pub fn delete_gravestones(&mut self) { for func in self.functions.iter_mut() { func.delete_gravestones(); } } pub fn fix_deleted_functions(&mut self, id_mapping: &[Option<usize>]) { let mut idx = 0; self.functions.retain(|_| { idx += 1; id_mapping[idx - 1].is_some() }); } fn clear_analyses(&mut self) { self.def_uses = None; self.reverse_postorders = None; self.typing = None; self.control_subgraphs = None; self.doms = None; self.postdoms = None; self.fork_join_maps = None; self.fork_join_nests = None; self.fork_control_maps = None; self.fork_trees = None; self.loops = None; self.reduce_cycles = None; self.nodes_in_fork_joins = None; self.reduce_einsums = None; self.no_reset_constants = None; self.collection_objects = None; self.callgraph = None; self.devices = None; self.bbs = None; self.node_colors = None; self.backing_allocations = None; } fn with_mod<B, F>(&mut self, mut f: F) -> B where F: FnMut(&mut Module) -> B, { let mut module = Module { functions: std::mem::take(&mut self.functions), types: self.types.take(), constants: self.constants.take(), dynamic_constants: self.dynamic_constants.take(), labels: self.labels.take(), }; let res = f(&mut module); let Module { functions, types, constants, dynamic_constants, labels, } = module; self.functions = functions; self.types.replace(types); self.constants.replace(constants); self.dynamic_constants.replace(dynamic_constants); self.labels.replace(labels); res } pub fn get_module(&self) -> Module { let PassManager { functions, types, constants, dynamic_constants, labels, typing: _, control_subgraphs: _, bbs: _, collection_objects: _, callgraph: _, .. } = self; let module = Module { functions: functions.to_vec(), types: types.clone().into_inner(), constants: constants.clone().into_inner(), dynamic_constants: dynamic_constants.clone().into_inner(), labels: labels.clone().into_inner(), }; module } fn codegen(mut self, output_dir: String, module_name: String) -> Result<(), SchedulerError> { self.make_def_uses(); self.make_typing(); self.make_control_subgraphs(); self.make_fork_join_maps(); self.make_fork_join_nests(); self.make_fork_control_maps(); self.make_fork_trees(); self.make_nodes_in_fork_joins(); self.make_collection_objects(); self.make_callgraph(); self.make_devices(); let PassManager { functions, types, constants, dynamic_constants, labels, def_uses: Some(def_uses), typing: Some(typing), control_subgraphs: Some(control_subgraphs), fork_join_maps: Some(fork_join_maps), fork_join_nests: Some(fork_join_nests), fork_control_maps: Some(fork_control_maps), fork_trees: Some(fork_trees), nodes_in_fork_joins: Some(nodes_in_fork_joins), collection_objects: Some(collection_objects), callgraph: Some(callgraph), devices: Some(devices), bbs: Some(bbs), node_colors: Some(node_colors), backing_allocations: Some(backing_allocations), .. } = self else { return Err(SchedulerError::PassError { pass: "codegen".to_string(), error: "Missing basic blocks or backing allocations".to_string(), }); }; let module = Module { functions, types: types.into_inner(), constants: constants.into_inner(), dynamic_constants: dynamic_constants.into_inner(), labels: labels.into_inner(), }; let mut rust_rt = String::new(); let mut llvm_ir = String::new(); let mut cuda_ir = String::new(); for idx in 0..module.functions.len() { match devices[idx] { Device::LLVM => cpu_codegen( &module_name, &module.functions[idx], &module.types, &module.constants, &module.dynamic_constants, &typing[idx], &control_subgraphs[idx], &bbs[idx], &backing_allocations[&FunctionID::new(idx)], &mut llvm_ir, ) .map_err(|e| SchedulerError::PassError { pass: "cpu codegen".to_string(), error: format!("{}", e), })?, Device::CUDA => gpu_codegen( &module_name, &module.functions[idx], &module.types, &module.constants, &module.dynamic_constants, &typing[idx], &control_subgraphs[idx], &bbs[idx], &backing_allocations[&FunctionID::new(idx)], &collection_objects[&FunctionID::new(idx)], &def_uses[idx], &fork_join_maps[idx], &fork_control_maps[idx], &fork_trees[idx], &mut cuda_ir, ) .map_err(|e| SchedulerError::PassError { pass: "cuda codegen".to_string(), error: format!("{}", e), })?, Device::AsyncRust => rt_codegen( &module_name, FunctionID::new(idx), &module, &def_uses[idx], &typing[idx], &control_subgraphs[idx], &fork_join_maps[idx], &fork_join_nests[idx], &fork_trees[idx], &nodes_in_fork_joins[idx], &collection_objects, &callgraph, &devices, &bbs[idx], &node_colors[&FunctionID::new(idx)], &backing_allocations, &mut rust_rt, ) .map_err(|e| SchedulerError::PassError { pass: "rust codegen".to_string(), error: format!("{}", e), })?, } } println!("{}", llvm_ir); println!("{}", cuda_ir); let rust_rt = prettyplease::unparse( &syn::parse_file(&rust_rt) .expect(&format!("PANIC: Malformed RT Rust code: {}", rust_rt)), ); println!("{}", rust_rt); // Write the Rust runtime into a file. let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); println!("{}", output_rt); let mut file = File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file."); file.write_all(rust_rt.as_bytes()) .expect("PANIC: Unable to write output Rust runtime file contents."); let output_archive = format!("{}/lib{}.a", output_dir, module_name); println!("{}", output_archive); // Write the LLVM IR into a temporary file. let tmp_dir = TempDir::new().unwrap(); let mut llvm_path = tmp_dir.path().to_path_buf(); llvm_path.push(format!("{}.ll", module_name)); println!("{}", llvm_path.display()); let mut file = File::create(&llvm_path).expect("PANIC: Unable to open output LLVM IR file."); file.write_all(llvm_ir.as_bytes()) .expect("PANIC: Unable to write output LLVM IR file contents."); // Compile LLVM IR into an ELF object file. let llvm_object = format!("{}/{}_cpu.o", tmp_dir.path().to_str().unwrap(), module_name); let mut clang_process = Command::new("clang") .arg(&llvm_path) .arg("-c") .arg("-O3") .arg("-ffast-math") .arg("-march=native") .arg("-o") .arg(&llvm_object) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn() .expect("PANIC: Error running clang. Is it installed?"); if clang_process .wait() .map(|status| !status.success()) .unwrap_or(false) { let path = tmp_dir.into_path(); panic!("PANIC: Clang failed to compile the LLVM IR module. Persisting temporary directory ({}).", path.display()); } let mut ar_args = vec!["crus", &output_archive, &llvm_object]; let cuda_object = format!( "{}/{}_cuda.o", tmp_dir.path().to_str().unwrap(), module_name ); if cfg!(feature = "cuda") { // Write the CUDA IR into a temporary file. let mut cuda_path = tmp_dir.path().to_path_buf(); cuda_path.push(format!("{}.cu", module_name)); let mut file = File::create(&cuda_path).expect("PANIC: Unable to open output CUDA IR file."); file.write_all(cuda_ir.as_bytes()) .expect("PANIC: Unable to write output CUDA IR file contents."); let mut nvcc_process = Command::new("nvcc") .arg("-c") .arg("-Xptxas") .arg("-O3") .arg("-use_fast_math") .arg("-diag-suppress") .arg("177") .arg("-o") .arg(&cuda_object) .arg(&cuda_path) .spawn() .expect("PANIC: Error running NVCC. Is it installed?"); if nvcc_process .wait() .map(|status| !status.success()) .unwrap_or(false) { let path = tmp_dir.into_path(); panic!("PANIC: NVCC failed to compile the CUDA module. Persisting temporary directory ({}).", path.display()); } ar_args.push(&cuda_object); } let mut ar_process = Command::new("ar") .args(&ar_args) .spawn() .expect("Error running ar. Is it installed?"); if ar_process .wait() .map(|status| !status.success()) .unwrap_or(false) { let path = tmp_dir.into_path(); panic!( "PANIC: Ar failed to create a static library. Persisting temporary directory ({}).", path.display() ); } Ok(()) } } pub fn schedule_codegen( module: Module, schedule: ScheduleStmt, mut stringtab: StringTable, mut env: Env<usize, Value>, mut functions: JunoFunctions, output_dir: String, module_name: String, ) -> Result<(), SchedulerError> { let mut pm = PassManager::new(module); let _ = schedule_interpret(&mut pm, &schedule, &mut stringtab, &mut env, &mut functions)?; pm.codegen(output_dir, module_name) } pub fn schedule_module( module: Module, schedule: ScheduleStmt, mut stringtab: StringTable, mut env: Env<usize, Value>, mut functions: JunoFunctions, ) -> Result<Module, SchedulerError> { let mut pm = PassManager::new(module); let _ = schedule_interpret(&mut pm, &schedule, &mut stringtab, &mut env, &mut functions)?; Ok(pm.get_module()) } // Interpreter for statements and expressions returns a bool indicating whether // any optimization ran and changed the IR. This is used for implementing // the fixpoint fn schedule_interpret( pm: &mut PassManager, schedule: &ScheduleStmt, stringtab: &mut StringTable, env: &mut Env<usize, Value>, functions: &mut JunoFunctions, ) -> Result<bool, SchedulerError> { match schedule { ScheduleStmt::Fixpoint { body, limit } => { let mut i = 0; loop { // If no change was made, we've reached the fixpoint and are done if !schedule_interpret(pm, body, stringtab, env, functions)? { break; } // Otherwise, increase the iteration count and check the limit i += 1; match limit { FixpointLimit::NoLimit() => {} FixpointLimit::PrintIter() => { println!("Finished Iteration {}", i - 1) } FixpointLimit::StopAfter(n) => { if i >= *n { break; } } FixpointLimit::PanicAfter(n) => { if i >= *n { return Err(SchedulerError::FixpointFailure()); } } } } // If we ran just 1 iteration, then no changes were made and otherwise some changes // were made Ok(i > 1) } ScheduleStmt::IfThenElse { cond, thn, els } => { let (cond, modified) = interp_expr(pm, cond, stringtab, env, functions)?; let Value::Boolean { val: cond } = cond else { return Err(SchedulerError::SemanticError( "Condition must be a boolean value".to_string(), )); }; let changed = schedule_interpret( pm, if cond { &*thn } else { &*els }, stringtab, env, functions, )?; Ok(modified || changed) } ScheduleStmt::Block { body } => { let mut modified = false; env.open_scope(); for command in body { modified |= schedule_interpret(pm, command, stringtab, env, functions)?; } env.close_scope(); Ok(modified) } ScheduleStmt::Let { var, exp } => { let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?; let var_id = stringtab.lookup_string(var.clone()); env.insert(var_id, res); Ok(modified) } ScheduleStmt::Assign { var, exp } => { let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?; let var_id = stringtab.lookup_string(var.clone()); match env.lookup_mut(&var_id) { None => { return Err(SchedulerError::UndefinedVariable(var.clone())); } Some(val) => { *val = res; } } Ok(modified) } ScheduleStmt::AddSchedule { sched, on } => match on { Selector::Everything() => Err(SchedulerError::SemanticError( "Cannot apply schedule to everything".to_string(), )), Selector::Selection(selection) => { let mut changed = false; for label in selection { let (label, modified) = interp_expr(pm, label, stringtab, env, functions)?; changed |= modified; add_schedule(pm, sched.clone(), label.as_selection(pm, functions)?); } Ok(changed) } }, ScheduleStmt::AddDevice { device, on } => match on { Selector::Everything() => { for func in pm.functions.iter_mut() { func.device = Some(device.clone()); } Ok(false) } Selector::Selection(selection) => { let mut changed = false; for func in selection { let (func, modified) = interp_expr(pm, func, stringtab, env, functions)?; changed |= modified; add_device( pm, device.clone(), func.as_selection(pm, functions)?.as_funcs()?, ); } Ok(changed) } }, } } fn interp_expr( pm: &mut PassManager, expr: &ScheduleExp, stringtab: &mut StringTable, env: &mut Env<usize, Value>, functions: &mut JunoFunctions, ) -> Result<(Value, bool), SchedulerError> { match expr { ScheduleExp::Variable { var } => { let var_id = stringtab.lookup_string(var.clone()); match env.lookup(&var_id) { None => Err(SchedulerError::UndefinedVariable(var.clone())), Some(v) => Ok((v.clone(), false)), } } ScheduleExp::Integer { val } => Ok((Value::Integer { val: *val }, false)), ScheduleExp::Boolean { val } => Ok((Value::Boolean { val: *val }, false)), ScheduleExp::String { val } => Ok((Value::String { val: val.clone() }, false)), ScheduleExp::SetOp { op, lhs, rhs } => { let (lhs, lhs_mod) = interp_expr(pm, lhs, stringtab, env, functions)?; let (rhs, rhs_mod) = interp_expr(pm, rhs, stringtab, env, functions)?; Ok(( Value::SetOp { op: *op, lhs: Box::new(lhs), rhs: Box::new(rhs), }, lhs_mod || rhs_mod, )) } ScheduleExp::Field { collect, field } => { let (lhs, changed) = interp_expr(pm, collect, stringtab, env, functions)?; match lhs { Value::Label { .. } | Value::Selection { .. } | Value::Everything { .. } | Value::Integer { .. } | Value::Boolean { .. } | Value::String { .. } | Value::Tuple { .. } | Value::SetOp { .. } => Err(SchedulerError::UndefinedField(field.clone())), Value::JunoFunction { func } => { match pm.labels.borrow().iter().position(|s| s == field) { None => Err(SchedulerError::UndefinedLabel(field.clone())), Some(label_idx) => Ok(( Value::Label { labels: functions .get_function(func) .iter() .map(|f| LabelInfo { func: *f, label: LabelID::new(label_idx), }) .collect(), }, changed, )), } } Value::HerculesFunction { func } => { match pm.labels.borrow().iter().position(|s| s == field) { None => Err(SchedulerError::UndefinedLabel(field.clone())), Some(label_idx) => Ok(( Value::Label { labels: vec![LabelInfo { func: func, label: LabelID::new(label_idx), }], }, changed, )), } } Value::Record { fields } => match fields.get(field) { None => Err(SchedulerError::UndefinedField(field.clone())), Some(v) => Ok((v.clone(), changed)), }, } } ScheduleExp::RunPass { pass, args, on } => { let mut changed = false; let mut arg_vals = vec![]; for arg in args { let (val, modified) = interp_expr(pm, arg, stringtab, env, functions)?; arg_vals.push(val); changed |= modified; } let selection = match on { Selector::Everything() => Value::Everything {}, Selector::Selection(selection) => { let mut vals = vec![]; for loc in selection { let (val, modified) = interp_expr(pm, loc, stringtab, env, functions)?; changed |= modified; vals.push(val); } Value::Selection { selection: vals } } } .as_selection(pm, functions)?; let (res, modified) = run_pass(pm, *pass, arg_vals, selection)?; changed |= modified; Ok((res, changed)) } ScheduleExp::DeleteUncalled { on } => { let Selector::Everything() = on else { return Err(SchedulerError::PassError { pass: "DeleteUncalled".to_string(), error: "must be applied to the entire module".to_string(), }); }; pm.make_callgraph(); pm.make_def_uses(); let callgraph = pm.callgraph.take().unwrap(); let def_uses = pm.def_uses.take().unwrap(); let mut editors: Vec<_> = pm .functions .iter_mut() .enumerate() .zip(def_uses.iter()) .map(|((idx, func), def_use)| { FunctionEditor::new( func, FunctionID::new(idx), &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, def_use, ) }) .collect(); let new_idx = delete_uncalled(&mut editors, &callgraph); let changed = new_idx.iter().any(|i| i.is_none()); pm.fix_deleted_functions(&new_idx); pm.delete_gravestones(); pm.clear_analyses(); assert!(pm.functions.len() > 0, "PANIC: There are no entry functions in the Hercules module being compiled. Please mark at least one function as an entry!"); // Update all FunctionIDs contained in both the environment and // "functions" data structure to point to the new values. If there // is no new value (the function refered to no longer exists) then // we drop the value from the environment/functions list // Updating Juno functions may result in all instances of a function being deleted // which can cause renumbering of the Juno functions as well, so we do that first let mut new_juno_idx = vec![]; let mut new_juno_funcs = vec![]; for funcs in std::mem::take(&mut functions.func_ids).into_iter() { let new_funcs = funcs .into_iter() .filter_map(|f| new_idx[f.idx()].map(|i| FunctionID::new(i))) .collect::<Vec<_>>(); if !new_funcs.is_empty() { new_juno_idx.push(Some(new_juno_funcs.len())); new_juno_funcs.push(new_funcs); } else { new_juno_idx.push(None); } } functions.func_ids = new_juno_funcs; // Now, we update both the FunctionIDs and JunoFunctionIDs in the environment env.filter_map(|val| update_value(val, &new_idx, &new_juno_idx)); Ok(( Value::Record { fields: HashMap::new(), }, changed, )) } ScheduleExp::Feature { feature } => { let (feature, modified) = interp_expr(pm, &*feature, stringtab, env, functions)?; let Value::String { val } = feature else { return Err(SchedulerError::SemanticError( "Feature expects a single string argument (instead of a selection)".to_string(), )); }; // To test for features, the scheduler needs to be invoked from a build script so that // Cargo provides the enabled features via environment variables let key = "CARGO_FEATURE_".to_string() + &val.to_uppercase().replace("-", "_"); Ok(( Value::Boolean { val: env::var(key).is_ok(), }, modified, )) } ScheduleExp::Record { fields } => { let mut result = HashMap::new(); let mut changed = false; for (field, val) in fields { let (val, modified) = interp_expr(pm, val, stringtab, env, functions)?; result.insert(field.clone(), val); changed |= modified; } Ok((Value::Record { fields: result }, changed)) } ScheduleExp::Block { body, res } => { let mut changed = false; env.open_scope(); for command in body { changed |= schedule_interpret(pm, command, stringtab, env, functions)?; } let (res, modified) = interp_expr(pm, res, stringtab, env, functions)?; env.close_scope(); Ok((res, changed || modified)) } ScheduleExp::Selection { selection } => match selection { Selector::Everything() => Ok((Value::Everything {}, false)), Selector::Selection(selection) => { let mut values = vec![]; let mut changed = false; for e in selection { let (val, modified) = interp_expr(pm, e, stringtab, env, functions)?; values.push(val); changed |= modified; } Ok((Value::Selection { selection: values }, changed)) } }, ScheduleExp::Tuple { exprs } => { let mut vals = vec![]; let mut changed = false; for exp in exprs { let (val, change) = interp_expr(pm, exp, stringtab, env, functions)?; vals.push(val); changed = changed || change; } Ok((Value::Tuple { values: vals }, changed)) } ScheduleExp::TupleField { lhs, field } => { let (val, changed) = interp_expr(pm, lhs, stringtab, env, functions)?; match val { Value::Tuple { values } if *field < values.len() => { Ok((vec_take(values, *field), changed)) } _ => Err(SchedulerError::SemanticError(format!( "No field at index {}", field ))), } } } } fn update_value( val: Value, func_idx: &[Option<usize>], juno_func_idx: &[Option<usize>], ) -> Option<Value> { match val { // For a label (which may refer to labels in multiple functions) we update our labels to // point to the new functions (and eliminate any which were to now gone functions). If // there are no labels left, remove this value since it refers to nothing Value::Label { labels } => { let new_labels = labels .into_iter() .filter_map(|LabelInfo { func, label }| { func_idx[func.idx()].clone().map(|i| LabelInfo { func: FunctionID::new(i), label, }) }) .collect::<Vec<_>>(); if new_labels.is_empty() { None } else { Some(Value::Label { labels: new_labels }) } } // Similar approach for selections, update each value and if nothing remains just drop the // whole value Value::Selection { selection } => { let new_selection = selection .into_iter() .filter_map(|v| update_value(v, func_idx, juno_func_idx)) .collect::<Vec<_>>(); if new_selection.is_empty() { None } else { Some(Value::Selection { selection: new_selection, }) } } // And similarly for records (this one might seem a little odd, but it means that if an // optimization returned data for multiple functions we'll delete the fields that refered // to those functions but keep around fields that still hold useful values) Value::Record { fields } => { let new_fields = fields .into_iter() .filter_map(|(f, v)| update_value(v, func_idx, juno_func_idx).map(|v| (f, v))) .collect::<HashMap<_, _>>(); if new_fields.is_empty() { None } else { Some(Value::Record { fields: new_fields }) } } // For tuples, if we deleted values like we do for records this would mess up the indices // which would behave very strangely. Instead if any field cannot be updated then we // eliminate the entire value Value::Tuple { values } => values .into_iter() .map(|v| update_value(v, func_idx, juno_func_idx)) .collect::<Option<Vec<_>>>() .map(|values| Value::Tuple { values }), Value::JunoFunction { func } => { juno_func_idx[func.idx] .clone() .map(|i| Value::JunoFunction { func: JunoFunctionID::new(i), }) } Value::HerculesFunction { func } => { func_idx[func.idx()] .clone() .map(|i| Value::HerculesFunction { func: FunctionID::new(i), }) } Value::SetOp { op, lhs, rhs } => { let lhs = update_value(*lhs, func_idx, juno_func_idx)?; let rhs = update_value(*rhs, func_idx, juno_func_idx)?; Some(Value::SetOp { op, lhs: Box::new(lhs), rhs: Box::new(rhs), }) } Value::Everything {} => Some(Value::Everything {}), Value::Integer { val } => Some(Value::Integer { val }), Value::Boolean { val } => Some(Value::Boolean { val }), Value::String { val } => Some(Value::String { val }), } } fn add_schedule(pm: &mut PassManager, sched: Schedule, selection: ModuleSelection) { for (func, node) in selection.as_nodes() { pm.functions[func.idx()].schedules[node.idx()].push(sched.clone()); } } fn add_device(pm: &mut PassManager, device: Device, funcs: Vec<FunctionID>) { for func in funcs { pm.functions[func.idx()].device = Some(device.clone()); } } // Builds (completely mutable) editors for all functions fn build_editors<'a>(pm: &'a mut PassManager) -> Vec<FunctionEditor<'a>> { pm.make_def_uses(); let def_uses = pm.def_uses.take().unwrap(); pm.functions .iter_mut() .zip(def_uses.iter()) .enumerate() .map(|(idx, (func, def_use))| { FunctionEditor::new( func, FunctionID::new(idx), &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, def_use, ) }) .collect() } // Given a selection, constructs the set of the nodes selected for a single function, returning the // function's id fn selection_as_set( pm: &PassManager, selection: ModuleSelection, ) -> Option<(BTreeSet<NodeID>, FunctionID)> { let nodes = selection.as_nodes(); let mut funcs = nodes.iter().map(|(func, _)| *func).collect::<BTreeSet<_>>(); if funcs.len() == 1 { let func = funcs.pop_first().unwrap(); Some((nodes.into_iter().map(|(_, node)| node).collect(), func)) } else { None } } fn build_selection<'a>( pm: &'a mut PassManager, selection: ModuleSelection, create_editors_for_nothing_functions: bool, ) -> Vec<Option<FunctionEditor<'a>>> { // Build def uses, which are needed for the editors we'll construct pm.make_def_uses(); let def_uses = pm.def_uses.take().unwrap(); selection .as_func_states() .into_iter() .zip(pm.functions.iter_mut()) .zip(def_uses.iter()) .enumerate() .map(|(idx, ((selection, func), def_use))| match selection { FunctionSelectionState::Nothing() if create_editors_for_nothing_functions => { Some(FunctionEditor::new_immutable( func, FunctionID::new(idx), &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, def_use, )) } FunctionSelectionState::Nothing() => None, FunctionSelectionState::Everything() => Some(FunctionEditor::new( func, FunctionID::new(idx), &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, def_use, )), FunctionSelectionState::Selection(mask) => Some(FunctionEditor::new_mask( func, FunctionID::new(idx), &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, def_use, mask.clone(), )), }) .collect() } fn run_pass( pm: &mut PassManager, pass: Pass, args: Vec<Value>, selection: ModuleSelection, ) -> Result<(Value, bool), SchedulerError> { let mut result = Value::Record { fields: HashMap::new(), }; let mut changed = false; match pass { Pass::ArraySLF => { assert!(args.is_empty()); pm.make_fork_join_maps(); pm.make_reduce_einsums(); pm.make_nodes_in_fork_joins(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let reduce_einsums = pm.reduce_einsums.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); for (((func, fork_join_map), reduce_einsum), nodes_in_fork_joins) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(reduce_einsums.iter()) .zip(nodes_in_fork_joins.iter()) { let Some(mut func) = func else { continue; }; array_slf(&mut func, fork_join_map, reduce_einsum, nodes_in_fork_joins); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ArrayToProduct => { assert!(args.len() <= 1); let max_size = match args.get(0) { Some(Value::Integer { val }) => Some(*val), Some(_) => { return Err(SchedulerError::PassError { pass: "array-to-product".to_string(), error: "expected integer argument".to_string(), }); } None => None, }; pm.make_typing(); let typing = pm.typing.take().unwrap(); for (func, types) in build_selection(pm, selection, false) .into_iter() .zip(typing.iter()) { let Some(mut func) = func else { continue; }; array_to_product(&mut func, types, max_size); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::AutoOutline => { let Some(funcs) = selection.as_funcs().ok() else { return Err(SchedulerError::PassError { pass: "autoOutline".to_string(), error: "must be applied to whole functions".to_string(), }); }; pm.make_def_uses(); let def_uses = pm.def_uses.take().unwrap(); for func in funcs.iter() { let mut editor = FunctionEditor::new( &mut pm.functions[func.idx()], *func, &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, &def_uses[func.idx()], ); collapse_returns(&mut editor); ensure_between_control_flow(&mut editor); changed |= editor.modified(); } pm.clear_analyses(); pm.make_def_uses(); pm.make_typing(); pm.make_control_subgraphs(); pm.make_doms(); let def_uses = pm.def_uses.take().unwrap(); let typing = pm.typing.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); let doms = pm.doms.take().unwrap(); let old_num_funcs = pm.functions.len(); let mut new_funcs = vec![]; // Track the names of the old functions and the new function IDs for returning let mut new_func_ids = HashMap::new(); for func in funcs { let mut editor = FunctionEditor::new( &mut pm.functions[func.idx()], func, &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, &def_uses[func.idx()], ); let new_func_id = FunctionID::new(old_num_funcs + new_funcs.len()); let new_func = dumb_outline( &mut editor, &typing[func.idx()], &control_subgraphs[func.idx()], &doms[func.idx()], new_func_id, ); changed |= editor.modified(); if let Some(new_func) = new_func { new_func_ids.insert( editor.func().name.clone(), Value::HerculesFunction { func: new_func_id }, ); new_funcs.push(new_func); } pm.functions[func.idx()].delete_gravestones(); } pm.functions.extend(new_funcs); pm.clear_analyses(); result = Value::Record { fields: new_func_ids, }; } Pass::CCP => { assert!(args.is_empty()); pm.make_reverse_postorders(); let reverse_postorders = pm.reverse_postorders.take().unwrap(); for (func, reverse_postorder) in build_selection(pm, selection, false) .into_iter() .zip(reverse_postorders.iter()) { let Some(mut func) = func else { continue; }; ccp(&mut func, reverse_postorder); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::CleanMonoidReduces => { assert!(args.is_empty()); pm.make_typing(); let typing = pm.typing.take().unwrap(); for (func, typing) in build_selection(pm, selection, false) .into_iter() .zip(typing.iter()) { let Some(mut func) = func else { continue; }; clean_monoid_reduces(&mut func, typing); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ConstInline => { let inline_collections = match args.get(0) { Some(Value::Boolean { val }) => *val, Some(_) => { return Err(SchedulerError::PassError { pass: "constInline".to_string(), error: "expected boolean argument".to_string(), }); } None => true, }; pm.make_callgraph(); let callgraph = pm.callgraph.take().unwrap(); let mut editors: Vec<_> = build_selection(pm, selection, true) .into_iter() .map(|editor| editor.unwrap()) .collect(); const_inline(&mut editors, &callgraph, inline_collections); for func in editors { changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::CRC => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; crc(&mut func); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::DCE => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; dce(&mut func); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::FloatCollections => { assert!(args.is_empty()); pm.make_typing(); pm.make_callgraph(); pm.make_devices(); let typing = pm.typing.take().unwrap(); let callgraph = pm.callgraph.take().unwrap(); let devices = pm.devices.take().unwrap(); // Modify the selection to include callers of selected functions. let mut editors = build_selection(pm, selection, false) .into_iter() .filter_map(|editor| editor.map(|editor| (editor.func_id(), editor))) .collect(); float_collections(&mut editors, &typing, &callgraph, &devices); for func in editors { changed |= func.1.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ForkGuardElim => { assert!(args.is_empty()); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); for (func, fork_join_map) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; }; fork_guard_elim(&mut func, fork_join_map); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::Serialize => { // FIXME: How to get module name here? let output_file = "out.hbin"; let module = pm.clone().get_module().clone(); let module_contents: Vec<u8> = postcard::to_allocvec(&module).unwrap(); let mut file = File::create(&output_file).expect("PANIC: Unable to open output module file."); file.write_all(&module_contents) .expect("PANIC: Unable to write output module file contents."); } Pass::ForkSplit => { assert!(args.is_empty()); let mut created_fork_joins = vec![vec![vec![]]; pm.functions.len()]; loop { let mut inner_changed = false; pm.make_fork_join_maps(); pm.make_reduce_cycles(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection.clone(), false) .into_iter() .zip(fork_join_maps.iter()) .zip(reduce_cycles.iter()) { let Some(mut func) = func else { continue; }; if let Some((forks, joins)) = split_any_fork(&mut func, fork_join_map, reduce_cycles) { let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; if forks.len() > created_fork_joins.len() { created_fork_joins.resize(forks.len(), vec![]); } for (idx, (fork, join)) in zip(forks, joins).enumerate() { created_fork_joins[idx].push((fork, join)); } } changed |= func.modified(); inner_changed |= func.modified(); } pm.clear_analyses(); if !inner_changed { break; } } pm.make_nodes_in_fork_joins(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); let mut new_fork_joins = HashMap::new(); for (mut func, created_fork_joins) in build_editors(pm).into_iter().zip(created_fork_joins) { // For every function, create a label for every level of fork- // joins resulting from the split. let name = func.func().name.clone(); let func_id = func.func_id(); let labels = create_labels_for_node_sets( &mut func, created_fork_joins.into_iter().map(|level_fork_joins| { level_fork_joins .into_iter() .map(|(fork, _)| { nodes_in_fork_joins[func_id.idx()][&fork] .iter() .map(|id| *id) }) .flatten() }), ); // Assemble those labels into a record for this function. The // format of the records is <function>.<fjN>, where N is the // level of the split fork-joins being referred to. let mut func_record = HashMap::new(); for (idx, label) in labels { func_record.insert( format!("fj{}", idx), Value::Label { labels: vec![LabelInfo { func: func_id, label: label, }], }, ); } // Try to avoid creating unnecessary record entries. if !func_record.is_empty() { new_fork_joins.entry(name).insert_entry(Value::Record { fields: func_record, }); } } pm.delete_gravestones(); pm.clear_analyses(); result = Value::Record { fields: new_fork_joins, }; } Pass::ForkInterchange => { assert_eq!(args.len(), 2); let Some(Value::Integer { val: first_dim }) = args.get(0) else { return Err(SchedulerError::PassError { pass: "forkInterchange".to_string(), error: "expected integer argument".to_string(), }); }; let Some(Value::Integer { val: second_dim }) = args.get(1) else { return Err(SchedulerError::PassError { pass: "forkInterchange".to_string(), error: "expected integer argument".to_string(), }); }; pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); for (func, fork_join_map) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; }; fork_interchange_all_forks(&mut func, fork_join_map, *first_dim, *second_dim); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ForkUnroll => { assert_eq!(args.len(), 0); pm.make_fork_join_maps(); pm.make_nodes_in_fork_joins(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); for ((func, fork_join_map), nodes_in_fork_joins) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(nodes_in_fork_joins.iter()) { let Some(mut func) = func else { continue; }; fork_unroll_all_forks(&mut func, fork_join_map, nodes_in_fork_joins); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::Forkify => { assert!(args.is_empty()); loop { let mut inner_changed = false; pm.make_fork_join_maps(); pm.make_control_subgraphs(); pm.make_loops(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); for (((func, fork_join_map), loop_nest), control_subgraph) in build_selection(pm, selection.clone(), false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) .zip(control_subgraphs.iter()) { let Some(mut func) = func else { continue; }; let c = forkify(&mut func, control_subgraph, fork_join_map, loop_nest); changed |= c; inner_changed |= c; } pm.delete_gravestones(); pm.clear_analyses(); if !inner_changed { break; } } } Pass::GCM => { assert!(args.is_empty()); if !selection.is_everything() { return Err(SchedulerError::PassError { pass: "gcm".to_string(), error: "must be applied to the entire module".to_string(), }); } // Iterate functions in reverse topological order, since inter- // device copies introduced in a callee may affect demands in a // caller, and the object allocation of a callee affects the object // allocation of its callers. pm.make_callgraph(); let callgraph = pm.callgraph.take().unwrap(); let topo = callgraph.topo(); loop { pm.make_def_uses(); pm.make_reverse_postorders(); pm.make_typing(); pm.make_control_subgraphs(); pm.make_doms(); pm.make_fork_join_maps(); pm.make_fork_join_nests(); pm.make_loops(); pm.make_reduce_cycles(); pm.make_collection_objects(); pm.make_devices(); let def_uses = pm.def_uses.take().unwrap(); let reverse_postorders = pm.reverse_postorders.take().unwrap(); let typing = pm.typing.take().unwrap(); let doms = pm.doms.take().unwrap(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let fork_join_nests = pm.fork_join_nests.take().unwrap(); let loops = pm.loops.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); let collection_objects = pm.collection_objects.take().unwrap(); let devices = pm.devices.take().unwrap(); let mut bbs = vec![(vec![], vec![]); topo.len()]; let mut node_colors = BTreeMap::new(); let mut backing_allocations = BTreeMap::new(); let mut editors = build_editors(pm); let mut any_failed = false; for id in topo.iter() { let editor = &mut editors[id.idx()]; if let Some((bb, function_node_colors, backing_allocation)) = gcm( editor, &def_uses[id.idx()], &reverse_postorders[id.idx()], &typing[id.idx()], &control_subgraphs[id.idx()], &doms[id.idx()], &fork_join_maps[id.idx()], &fork_join_nests[id.idx()], &loops[id.idx()], &reduce_cycles[id.idx()], &collection_objects, &devices, &node_colors, &backing_allocations, ) { bbs[id.idx()] = bb; node_colors.insert(*id, function_node_colors); backing_allocations.insert(*id, backing_allocation); } else { any_failed = true; } changed |= editor.modified(); if any_failed { break; } } pm.delete_gravestones(); pm.clear_analyses(); if !any_failed { pm.bbs = Some(bbs); pm.node_colors = Some(node_colors); pm.backing_allocations = Some(backing_allocations); break; } } } Pass::GVN => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; gvn(&mut func, false); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::InferSchedules => { assert!(args.is_empty()); pm.make_fork_join_maps(); pm.make_reduce_cycles(); pm.make_no_reset_constants(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); let no_reset_constants = pm.no_reset_constants.take().unwrap(); for (((func, fork_join_map), reduce_cycles), no_reset_constants) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(reduce_cycles.iter()) .zip(no_reset_constants.iter()) { let Some(mut func) = func else { continue; }; infer_parallel_reduce(&mut func, fork_join_map, reduce_cycles); infer_parallel_fork(&mut func, fork_join_map); infer_monoid_reduce(&mut func, reduce_cycles); infer_no_reset_constants(&mut func, no_reset_constants); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::Inline => { assert!(args.is_empty()); pm.make_callgraph(); let callgraph = pm.callgraph.take().unwrap(); let mut editors: Vec<_> = build_selection(pm, selection, true) .into_iter() .map(|editor| editor.unwrap()) .collect(); inline(&mut editors, &callgraph); for func in editors { changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::InterproceduralSROA => { let sroa_with_arrays = match args.get(0) { Some(Value::Boolean { val }) => *val, Some(_) => { return Err(SchedulerError::PassError { pass: "sroa".to_string(), error: "expected boolean argument".to_string(), }); } None => false, }; let selection = selection .as_funcs() .map_err(|_| SchedulerError::PassError { pass: "xdot".to_string(), error: "expected coarse-grained selection (can't partially xdot a function)" .to_string(), })?; let mut bool_selection = vec![false; pm.functions.len()]; selection .into_iter() .for_each(|func| bool_selection[func.idx()] = true); pm.make_typing(); let typing = pm.typing.take().unwrap(); let mut editors = build_editors(pm); interprocedural_sroa(&mut editors, &typing, &bool_selection, sroa_with_arrays); for func in editors { changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::LiftDCMath => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; lift_dc_math(&mut func); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::Outline => { let Some((nodes, func)) = selection_as_set(pm, selection) else { return Err(SchedulerError::PassError { pass: "outline".to_string(), error: "must be applied to nodes in a single function".to_string(), }); }; pm.make_def_uses(); let def_uses = pm.def_uses.take().unwrap(); let mut editor = FunctionEditor::new( &mut pm.functions[func.idx()], func, &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, &def_uses[func.idx()], ); collapse_returns(&mut editor); ensure_between_control_flow(&mut editor); pm.clear_analyses(); pm.make_def_uses(); pm.make_typing(); pm.make_control_subgraphs(); pm.make_doms(); let def_uses = pm.def_uses.take().unwrap(); let typing = pm.typing.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); let doms = pm.doms.take().unwrap(); let new_func_id = FunctionID::new(pm.functions.len()); let mut editor = FunctionEditor::new( &mut pm.functions[func.idx()], func, &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, &def_uses[func.idx()], ); let new_func = outline( &mut editor, &typing[func.idx()], &control_subgraphs[func.idx()], &doms[func.idx()], nodes, new_func_id, false, ); let Some(new_func) = new_func else { return Err(SchedulerError::PassError { pass: "outlining".to_string(), error: "failed to outline".to_string(), }); }; pm.functions.push(new_func); changed = true; pm.functions[func.idx()].delete_gravestones(); pm.clear_analyses(); result = Value::HerculesFunction { func: new_func_id }; } Pass::PhiElim => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; phi_elim(&mut func); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::Predication => { assert!(args.is_empty()); pm.make_typing(); let typing = pm.typing.take().unwrap(); for (func, types) in build_selection(pm, selection, false) .into_iter() .zip(typing.iter()) { let Some(mut func) = func else { continue; }; predication(&mut func, types); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ReduceSLF => { assert!(args.is_empty()); pm.make_fork_join_maps(); pm.make_reduce_cycles(); pm.make_nodes_in_fork_joins(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); for (((func, fork_join_map), reduce_cycles), nodes_in_fork_joins) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(reduce_cycles.iter()) .zip(nodes_in_fork_joins.iter()) { let Some(mut func) = func else { continue; }; reduce_slf(&mut func, fork_join_map, reduce_cycles, nodes_in_fork_joins); changed |= func.modified(); } } Pass::Rename => { assert!(args.len() == 1); let new_name = match args[0] { Value::String { ref val } => val.clone(), _ => { return Err(SchedulerError::PassError { pass: "rename".to_string(), error: "expected string argument".to_string(), }); } }; if pm.functions.iter().any(|f| f.name == new_name) { return Err(SchedulerError::PassError { pass: "rename".to_string(), error: format!("function with name {} already exists", new_name), }); } if let Some(funcs) = selection.as_funcs().ok() && funcs.len() == 1 { let func = funcs[0]; pm.functions[func.idx()].name = new_name; } else { return Err(SchedulerError::PassError { pass: "rename".to_string(), error: "must be applied to the entirety of a single function".to_string(), }); }; } Pass::ReuseProducts => { assert!(args.is_empty()); pm.make_reverse_postorders(); pm.make_typing(); let reverse_postorders = pm.reverse_postorders.take().unwrap(); let typing = pm.typing.take().unwrap(); for ((func, reverse_postorder), types) in build_selection(pm, selection, false) .into_iter() .zip(reverse_postorders.iter()) .zip(typing.iter()) { let Some(mut func) = func else { continue; }; reuse_products(&mut func, reverse_postorder, types); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::RewriteMathExpressions => { assert!(args.is_empty()); pm.make_typing(); pm.make_fork_join_maps(); pm.make_nodes_in_fork_joins(); pm.make_reduce_einsums(); let typing = pm.typing.take().unwrap(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); let reduce_einsums = pm.reduce_einsums.take().unwrap(); for ((((func, typing), fork_join_map), nodes_in_fork_joins), reduce_einsums) in build_selection(pm, selection, false) .into_iter() .zip(typing.iter()) .zip(fork_join_maps.iter()) .zip(nodes_in_fork_joins.iter()) .zip(reduce_einsums.iter()) { let Some(mut func) = func else { continue; }; rewrite_math_expressions( &mut func, Device::CUDA, typing, fork_join_map, nodes_in_fork_joins, reduce_einsums, ); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::SLF => { assert!(args.is_empty()); pm.make_reverse_postorders(); pm.make_typing(); let reverse_postorders = pm.reverse_postorders.take().unwrap(); let typing = pm.typing.take().unwrap(); for ((func, reverse_postorder), types) in build_selection(pm, selection, false) .into_iter() .zip(reverse_postorders.iter()) .zip(typing.iter()) { let Some(mut func) = func else { continue; }; slf(&mut func, reverse_postorder, types); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::SimplifyCFG => { assert!(args.is_empty()); pm.make_fork_join_maps(); pm.make_reduce_cycles(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(reduce_cycles.iter()) { let Some(mut func) = func else { continue; }; simplify_cfg(&mut func, fork_join_map, reduce_cycles); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::SROA => { let sroa_with_arrays = match args.get(0) { Some(Value::Boolean { val }) => *val, Some(_) => { return Err(SchedulerError::PassError { pass: "sroa".to_string(), error: "expected boolean argument".to_string(), }); } None => false, }; pm.make_reverse_postorders(); pm.make_typing(); let reverse_postorders = pm.reverse_postorders.take().unwrap(); let typing = pm.typing.take().unwrap(); for ((func, reverse_postorder), types) in build_selection(pm, selection, false) .into_iter() .zip(reverse_postorders.iter()) .zip(typing.iter()) { let Some(mut func) = func else { continue; }; sroa(&mut func, reverse_postorder, types, sroa_with_arrays); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::Unforkify => { assert!(args.is_empty()); pm.make_fork_join_maps(); pm.make_loops(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); for ((func, fork_join_map), loop_tree) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) { let Some(mut func) = func else { continue; }; unforkify_all(&mut func, fork_join_map, loop_tree); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::UnforkifyOne => { assert!(args.is_empty()); pm.make_fork_join_maps(); pm.make_loops(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); for ((func, fork_join_map), loop_tree) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) { let Some(mut func) = func else { continue; }; unforkify_one(&mut func, fork_join_map, loop_tree); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ForkChunk => { assert_eq!(args.len(), 4); let Some(Value::Integer { val: tile_size }) = args.get(0) else { return Err(SchedulerError::PassError { pass: "forkChunk".to_string(), error: "expected integer argument".to_string(), }); }; let Some(Value::Integer { val: dim_idx }) = args.get(1) else { return Err(SchedulerError::PassError { pass: "forkChunk".to_string(), error: "expected integer argument".to_string(), }); }; let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { return Err(SchedulerError::PassError { pass: "forkChunk".to_string(), error: "expected boolean argument".to_string(), }); }; let Some(Value::Boolean { val: tile_order }) = args.get(3) else { return Err(SchedulerError::PassError { pass: "forkChunk".to_string(), error: "expected boolean argument".to_string(), }); }; assert!(!*guarded_flag); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); for (func, fork_join_map) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; }; chunk_all_forks_unguarded( &mut func, fork_join_map, *dim_idx, *tile_size, *tile_order, ); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ForkExtend => { assert_eq!(args.len(), 1); let Some(Value::Integer { val: multiple }) = args.get(0) else { return Err(SchedulerError::PassError { pass: "forkExtend".to_string(), error: "expected integer argument".to_string(), }); }; pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); for (func, fork_join_map) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; }; extend_all_forks(&mut func, fork_join_map, *multiple); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ForkFissionBufferize => { assert!(args.len() == 1 || args.len() == 2); let Some(Value::Label { labels: fork_labels, }) = args.get(0) else { return Err(SchedulerError::PassError { pass: "forkFissionBufferize".to_string(), error: "expected label argument".to_string(), }); }; let mut created_fork_joins = vec![vec![]; pm.functions.len()]; pm.make_fork_join_maps(); pm.make_typing(); pm.make_loops(); pm.make_reduce_cycles(); pm.make_nodes_in_fork_joins(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let typing = pm.typing.take().unwrap(); let loops = pm.loops.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); // assert only one function is in the selection. let num_functions = build_selection(pm, selection.clone(), false) .iter() .filter(|func| func.is_some()) .count(); assert!(num_functions <= 1); assert_eq!(fork_labels.len(), 1); let fork_label = fork_labels[0].label; for ( ((((func, fork_join_map), loop_tree), typing), reduce_cycles), nodes_in_fork_joins, ) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) .zip(typing.iter()) .zip(reduce_cycles.iter()) .zip(nodes_in_fork_joins.iter()) { let Some(mut func) = func else { continue; }; let data_label = if let Some(Value::Label { labels: fork_data_labels, }) = args.get(1) { assert_eq!(fork_data_labels.len(), 1); Some(fork_data_labels[0].label) } else { None }; if let Some((fork1, fork2)) = ff_bufferize_any_fork( &mut func, loop_tree, fork_join_map, reduce_cycles, nodes_in_fork_joins, typing, fork_label, data_label, ) { let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; created_fork_joins.push(fork1); created_fork_joins.push(fork2); } changed |= func.modified(); } pm.clear_analyses(); pm.make_nodes_in_fork_joins(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); let mut new_fork_joins = HashMap::new(); let _fork_label_name = &pm.labels.borrow()[fork_label.idx()].clone(); for (mut func, created_fork_joins) in build_editors(pm).into_iter().zip(created_fork_joins) { // For every function, create a label for every level of fork- // joins resulting from the split. let name = func.func().name.clone(); let func_id = func.func_id(); let labels = create_labels_for_node_sets( &mut func, created_fork_joins.into_iter().map(|fork| { nodes_in_fork_joins[func_id.idx()][&fork] .iter() .map(|id| *id) }), ); // Assemble those labels into a record for this function. The // format of the records is <function>.<f>, where N is the // level of the split fork-joins being referred to. let mut func_record = HashMap::new(); for (idx, label) in labels { let fmt = if idx % 2 == 0 { "fj_top" } else { "fj_bottom" }; func_record.insert( fmt.to_string(), Value::Label { labels: vec![LabelInfo { func: func_id, label: label, }], }, ); } // Try to avoid creating unnecessary record entries. if !func_record.is_empty() { new_fork_joins.entry(name).insert_entry(Value::Record { fields: func_record, }); } } pm.delete_gravestones(); pm.clear_analyses(); result = Value::Record { fields: new_fork_joins, }; } Pass::ForkFusion => { assert!(args.is_empty()); pm.make_fork_join_maps(); pm.make_nodes_in_fork_joins(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); for ((func, fork_join_map), nodes_in_fork_joins) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(nodes_in_fork_joins.iter()) { let Some(mut func) = func else { continue; }; fork_fusion_all_forks(&mut func, fork_join_map, nodes_in_fork_joins); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ForkDimMerge => { assert!(args.is_empty()); pm.make_fork_join_maps(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); for (func, fork_join_map) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) { let Some(mut func) = func else { continue; }; merge_all_fork_dims(&mut func, fork_join_map); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ForkCoalesce => { assert!(args.is_empty()); pm.make_fork_join_maps(); pm.make_control_subgraphs(); pm.make_loops(); pm.make_reduce_cycles(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); for (((func, fork_join_map), loop_nest), control_subgraph) in build_selection(pm, selection, false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) .zip(control_subgraphs.iter()) { let Some(mut func) = func else { continue; }; fork_coalesce(&mut func, loop_nest, fork_join_map); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::ForkReshape => { let mut shape = vec![]; let mut loops = BTreeSet::new(); let mut fork_count = 0; for arg in args { let Value::Tuple { values } = arg else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "expected each argument to be a list of integers".to_string(), }); }; let mut indices = vec![]; for val in values { let Value::Integer { val: idx } = val else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "expected each argument to be a list of integers".to_string(), }); }; indices.push(idx); loops.insert(idx); fork_count += 1; } shape.push(indices); } if loops != (0..fork_count).collect() { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "expected forks to be numbered sequentially from 0 and used exactly once" .to_string(), }); } let Some((nodes, func_id)) = selection_as_set(pm, selection) else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "must be applied to nodes in a single function".to_string(), }); }; let func = func_id.idx(); pm.make_def_uses(); pm.make_fork_join_maps(); pm.make_loops(); pm.make_reduce_cycles(); let def_uses = pm.def_uses.take().unwrap(); let mut fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); let reduce_cycles = pm.reduce_cycles.take().unwrap(); let def_use = &def_uses[func]; let fork_join_map = &mut fork_join_maps[func]; let loops = &loops[func]; let reduce_cycles = &reduce_cycles[func]; let mut editor = FunctionEditor::new( &mut pm.functions[func], func_id, &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, def_use, ); // There should be exactly one fork nest in the selection and it should contain // exactly fork_count forks (counting each dimension of each fork) // We determine the loops (ordered top-down) that are contained in the selection // (in particular the header is in the selection) and its a fork-join (the header // is a fork) let mut loops = loops .bottom_up_loops() .into_iter() .rev() .filter(|(header, _)| nodes.contains(header) && editor.node(header).is_fork()); let Some((top_fork_head, top_fork_body)) = loops.next() else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: format!( "expected {} forks found 0 in {}", fork_count, editor.func().name ), }); }; // All the remaining forks need to be contained in the top fork body let mut forks = vec![top_fork_head]; let mut num_dims = editor.node(top_fork_head).try_fork().unwrap().1.len(); for (head, _) in loops { if !top_fork_body[head.idx()] { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "selection includes multiple non-nested forks".to_string(), }); } else { forks.push(head); num_dims += editor.node(head).try_fork().unwrap().1.len(); } } if num_dims != fork_count { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: format!( "expected {} forks, found {} in {}", fork_count, num_dims, pm.functions[func].name ), }); } // Now, we coalesce all of these forks into one so that we can interchange them let mut forks = forks.into_iter(); let top_fork = forks.next().unwrap(); let mut cur_fork = top_fork; for next_fork in forks { let Some((new_fork, new_join)) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "failed to coalesce forks".to_string(), }); }; cur_fork = new_fork; fork_join_map.insert(new_fork, new_join); } let join = *fork_join_map.get(&cur_fork).unwrap(); // Now we have just one fork and we can perform the interchanges we need // To do this, we track two maps: from original index to current index and from // current index to original index let mut orig_to_cur = (0..fork_count).collect::<Vec<_>>(); let mut cur_to_orig = (0..fork_count).collect::<Vec<_>>(); // Now, starting from the first (outermost) index we move the desired fork bound // into place for (idx, original_idx) in shape.iter().flat_map(|idx| idx.iter()).enumerate() { let cur_idx = orig_to_cur[*original_idx]; let swapping = cur_to_orig[idx]; // If the desired factor is already in the correct place, do nothing if cur_idx == idx { continue; } assert!(idx < cur_idx); let Some(fork_res) = fork_interchange(&mut editor, cur_fork, join, idx, cur_idx) else { return Err(SchedulerError::PassError { pass: "fork-reshape".to_string(), error: "failed to interchange forks".to_string(), }); }; cur_fork = fork_res; // Update our maps orig_to_cur[*original_idx] = idx; orig_to_cur[swapping] = cur_idx; cur_to_orig[idx] = *original_idx; cur_to_orig[cur_idx] = swapping; } // Finally we split the fork into the desired pieces. We do this by first splitting // the fork into individual forks and then coalesce the chunks together // Not sure how split_fork could fail, so if it does panic is fine let (forks, joins) = split_fork(&mut editor, cur_fork, join, &reduce_cycles).unwrap(); for (fork, join) in forks.iter().zip(joins.iter()) { fork_join_map.insert(*fork, *join); } // Finally coalesce the chunks together let mut fork_idx = 0; let mut final_forks = vec![]; for chunk in shape.iter() { let chunk_len = chunk.len(); let mut cur_fork = forks[fork_idx]; for i in 1..chunk_len { let next_fork = forks[fork_idx + i]; // Again, not sure at this point how coalesce could fail, so panic if it // does let (new_fork, new_join) = fork_coalesce_helper(&mut editor, cur_fork, next_fork, &fork_join_map) .unwrap(); cur_fork = new_fork; fork_join_map.insert(new_fork, new_join); } fork_idx += chunk_len; final_forks.push(cur_fork); } // Label each fork and return the labels // We've trashed our analyses at this point, so rerun them so that we can determine the // nodes in each of the result fork-joins pm.clear_analyses(); pm.make_def_uses(); pm.make_nodes_in_fork_joins(); let def_uses = pm.def_uses.take().unwrap(); let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); let def_use = &def_uses[func]; let nodes_in_fork_joins = &nodes_in_fork_joins[func]; let mut editor = FunctionEditor::new( &mut pm.functions[func], func_id, &pm.constants, &pm.dynamic_constants, &pm.types, &pm.labels, def_use, ); let labels = create_labels_for_node_sets( &mut editor, final_forks .into_iter() .map(|fork| nodes_in_fork_joins[&fork].iter().copied()), ) .into_iter() .map(|(_, label)| Value::Label { labels: vec![LabelInfo { func: func_id, label, }], }) .collect(); result = Value::Tuple { values: labels }; changed = true; pm.delete_gravestones(); pm.clear_analyses(); } Pass::WritePredication => { assert!(args.is_empty()); for func in build_selection(pm, selection, false) { let Some(mut func) = func else { continue; }; write_predication(&mut func); changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); } Pass::Verify => { assert!(args.is_empty()); let (def_uses, reverse_postorders, typing, subgraphs, doms, postdoms, fork_join_maps) = pm.with_mod(|module| verify(module)) .map_err(|msg| SchedulerError::PassError { pass: "verify".to_string(), error: format!("failed: {}", msg), })?; // Verification produces a bunch of analysis results that // may be useful for later passes. pm.def_uses = Some(def_uses); pm.reverse_postorders = Some(reverse_postorders); pm.typing = Some(typing); pm.control_subgraphs = Some(subgraphs); pm.doms = Some(doms); pm.postdoms = Some(postdoms); pm.fork_join_maps = Some(fork_join_maps); } Pass::Xdot => { let force_analyses = match args.get(0) { Some(Value::Boolean { val }) => *val, Some(_) => { return Err(SchedulerError::PassError { pass: "xdot".to_string(), error: "expected boolean argument".to_string(), }); } None => true, }; let selection = selection .as_funcs() .map_err(|_| SchedulerError::PassError { pass: "xdot".to_string(), error: "expected coarse-grained selection (can't partially xdot a function)" .to_string(), })?; let mut bool_selection = vec![false; pm.functions.len()]; selection .into_iter() .for_each(|func| bool_selection[func.idx()] = true); pm.make_reverse_postorders(); if force_analyses { pm.make_typing(); pm.make_doms(); pm.make_fork_join_maps(); pm.make_devices(); } let reverse_postorders = pm.reverse_postorders.take().unwrap(); let typing = pm.typing.take(); let doms = pm.doms.take(); let fork_join_maps = pm.fork_join_maps.take(); let devices = pm.devices.take(); let bbs = pm.bbs.take(); pm.with_mod(|module| { xdot_module( module, &bool_selection, &reverse_postorders, typing.as_ref(), doms.as_ref(), fork_join_maps.as_ref(), devices.as_ref(), bbs.as_ref(), ) }); // Put BasicBlocks back, since it's needed for Codegen. pm.bbs = bbs; } Pass::Print => { println!("{:?}", args.get(0)); } Pass::LoopBoundCanon => { assert_eq!(args.len(), 0); loop { let mut inner_changed = false; pm.make_fork_join_maps(); pm.make_loops(); pm.make_control_subgraphs(); let fork_join_maps = pm.fork_join_maps.take().unwrap(); let loops = pm.loops.take().unwrap(); let control_subgraphs = pm.control_subgraphs.take().unwrap(); for (((func, fork_join_map), loops), control_subgraph) in build_selection(pm, selection.clone(), false) .into_iter() .zip(fork_join_maps.iter()) .zip(loops.iter()) .zip(control_subgraphs.iter()) { let Some(mut func) = func else { continue; }; loop_bound_canon_toplevel(&mut func, fork_join_map, control_subgraph, loops); changed |= func.modified(); inner_changed |= func.modified(); } pm.delete_gravestones(); pm.clear_analyses(); if !inner_changed { break; } } } } println!("Ran Pass: {:?}", pass); Ok((result, changed)) } fn create_labels_for_node_sets<I, J>( editor: &mut FunctionEditor, node_sets: I, ) -> Vec<(usize, LabelID)> where I: Iterator<Item = J>, J: Iterator<Item = NodeID>, { let mut labels = vec![]; editor.edit(|mut edit| { for (set_idx, node_set) in node_sets.enumerate() { let mut node_set = node_set.peekable(); if node_set.peek().is_some() { let label = edit.fresh_label(); for node in node_set { edit = edit.add_label(node, label).unwrap(); } labels.push((set_idx, label)); } } Ok(edit) }); labels } fn vec_take<T>(mut v: Vec<T>, index: usize) -> T { v.swap_remove(index) }