pm.rs 57.83 KiB
use crate::ir::*;
use crate::labels::*;
use hercules_cg::*;
use hercules_ir::*;
use hercules_opt::*;
use tempfile::TempDir;
use juno_utils::env::Env;
use juno_utils::stringtab::StringTable;
use std::cell::RefCell;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::fmt;
use std::fs::File;
use std::io::Write;
use std::iter::zip;
use std::process::{Command, Stdio};
#[derive(Debug, Clone)]
pub enum Value {
Label { labels: Vec<LabelInfo> },
JunoFunction { func: JunoFunctionID },
HerculesFunction { func: FunctionID },
Record { fields: HashMap<String, Value> },
Everything {},
Selection { selection: Vec<Value> },
Integer { val: usize },
Boolean { val: bool },
}
#[derive(Debug, Copy, Clone)]
enum CodeLocation {
Label(LabelInfo),
Function(FunctionID),
}
impl Value {
fn is_everything(&self) -> bool {
match self {
Value::Everything {} => true,
_ => false,
}
}
fn as_labels(&self) -> Result<Vec<LabelInfo>, SchedulerError> {
match self {
Value::Label { labels } => Ok(labels.clone()),
Value::Selection { selection } => {
let mut result = vec![];
for val in selection {
result.extend(val.as_labels()?);
}
Ok(result)
}
Value::JunoFunction { .. } | Value::HerculesFunction { .. } => Err(
SchedulerError::SemanticError("Expected labels, found function".to_string()),
),
Value::Record { .. } => Err(SchedulerError::SemanticError(
"Expected labels, found record".to_string(),
)),
Value::Everything {} => Err(SchedulerError::SemanticError(
"Expected labels, found everything".to_string(),
)),
Value::Integer { .. } => Err(SchedulerError::SemanticError(
"Expected labels, found integer".to_string(),
)),
Value::Boolean { .. } => Err(SchedulerError::SemanticError(
"Expected labels, found boolean".to_string(),
)),
}
}
fn as_functions(&self, funcs: &JunoFunctions) -> Result<Vec<FunctionID>, SchedulerError> {
match self {
Value::JunoFunction { func } => Ok(funcs.get_function(*func).clone()),
Value::HerculesFunction { func } => Ok(vec![*func]),
Value::Selection { selection } => {
let mut result = vec![];
for val in selection {
result.extend(val.as_functions(funcs)?);
}
Ok(result)
}
Value::Label { .. } => Err(SchedulerError::SemanticError(
"Expected functions, found label".to_string(),
)),
Value::Record { .. } => Err(SchedulerError::SemanticError(
"Expected functions, found record".to_string(),
)),
Value::Everything {} => Err(SchedulerError::SemanticError(
"Expected functions, found everything".to_string(),
)),
Value::Integer { .. } => Err(SchedulerError::SemanticError(
"Expected functions, found integer".to_string(),
)),
Value::Boolean { .. } => Err(SchedulerError::SemanticError(
"Expected functions, found boolean".to_string(),
)),
}
}
fn as_locations(&self, funcs: &JunoFunctions) -> Result<Vec<CodeLocation>, SchedulerError> {
match self {
Value::Label { labels } => Ok(labels.iter().map(|l| CodeLocation::Label(*l)).collect()),
Value::JunoFunction { func } => Ok(funcs
.get_function(*func)
.iter()
.map(|f| CodeLocation::Function(*f))
.collect()),
Value::HerculesFunction { func } => Ok(vec![CodeLocation::Function(*func)]),
Value::Selection { selection } => {
let mut result = vec![];
for val in selection {
result.extend(val.as_locations(funcs)?);
}
Ok(result)
}
Value::Record { .. } => Err(SchedulerError::SemanticError(
"Expected code locations, found record".to_string(),
)),
Value::Everything {} => {
panic!("Internal error, check is_everything() before using as_functions()")
}
Value::Integer { .. } => Err(SchedulerError::SemanticError(
"Expected code locations, found integer".to_string(),
)),
Value::Boolean { .. } => Err(SchedulerError::SemanticError(
"Expected code locations, found boolean".to_string(),
)),
}
}
}
#[derive(Debug, Clone)]
pub enum SchedulerError {
UndefinedVariable(String),
UndefinedField(String),
UndefinedLabel(String),
SemanticError(String),
PassError { pass: String, error: String },
FixpointFailure(),
}
impl fmt::Display for SchedulerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SchedulerError::UndefinedVariable(nm) => write!(f, "Undefined variable '{}'", nm),
SchedulerError::UndefinedField(nm) => write!(f, "No field '{}'", nm),
SchedulerError::UndefinedLabel(nm) => write!(f, "No label '{}'", nm),
SchedulerError::SemanticError(msg) => write!(f, "{}", msg),
SchedulerError::PassError { pass, error } => {
write!(f, "Error in pass {}: {}", pass, error)
}
SchedulerError::FixpointFailure() => {
write!(f, "Fixpoint did not converge within limit")
}
}
}
}
#[derive(Debug)]
struct PassManager {
functions: Vec<Function>,
types: RefCell<Vec<Type>>,
constants: RefCell<Vec<Constant>>,
dynamic_constants: RefCell<Vec<DynamicConstant>>,
labels: RefCell<Vec<String>>,
// Cached analysis results.
pub def_uses: Option<Vec<ImmutableDefUseMap>>,
pub reverse_postorders: Option<Vec<Vec<NodeID>>>,
pub typing: Option<ModuleTyping>,
pub control_subgraphs: Option<Vec<Subgraph>>,
pub doms: Option<Vec<DomTree>>,
pub postdoms: Option<Vec<DomTree>>,
pub fork_join_maps: Option<Vec<HashMap<NodeID, NodeID>>>,
pub fork_join_nests: Option<Vec<HashMap<NodeID, Vec<NodeID>>>>,
pub loops: Option<Vec<LoopTree>>,
pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
pub collection_objects: Option<CollectionObjects>,
pub callgraph: Option<CallGraph>,
pub devices: Option<Vec<Device>>,
pub object_device_demands: Option<ObjectDeviceDemands>,
pub bbs: Option<Vec<BasicBlocks>>,
pub node_colors: Option<NodeColors>,
pub backing_allocations: Option<BackingAllocations>,
}
impl PassManager {
fn new(module: Module) -> Self {
let Module {
functions,
types,
constants,
dynamic_constants,
labels,
} = module;
PassManager {
functions,
types: RefCell::new(types),
constants: RefCell::new(constants),
dynamic_constants: RefCell::new(dynamic_constants),
labels: RefCell::new(labels),
def_uses: None,
reverse_postorders: None,
typing: None,
control_subgraphs: None,
doms: None,
postdoms: None,
fork_join_maps: None,
fork_join_nests: None,
loops: None,
reduce_cycles: None,
data_nodes_in_fork_joins: None,
collection_objects: None,
callgraph: None,
devices: None,
object_device_demands: None,
bbs: None,
node_colors: None,
backing_allocations: None,
}
}
pub fn make_def_uses(&mut self) {
if self.def_uses.is_none() {
self.def_uses = Some(self.functions.iter().map(def_use).collect());
}
}
pub fn make_reverse_postorders(&mut self) {
if self.reverse_postorders.is_none() {
self.make_def_uses();
self.reverse_postorders = Some(
self.def_uses
.as_ref()
.unwrap()
.iter()
.map(reverse_postorder)
.collect(),
);
}
}
pub fn make_typing(&mut self) {
if self.typing.is_none() {
self.make_reverse_postorders();
self.typing = Some(
typecheck(
&self.functions,
&mut self.types.borrow_mut(),
&self.constants.borrow(),
&mut self.dynamic_constants.borrow_mut(),
self.reverse_postorders.as_ref().unwrap(),
)
.unwrap(),
);
}
}
pub fn make_control_subgraphs(&mut self) {
if self.control_subgraphs.is_none() {
self.make_def_uses();
self.control_subgraphs = Some(
zip(&self.functions, self.def_uses.as_ref().unwrap())
.map(|(function, def_use)| control_subgraph(function, def_use))
.collect(),
);
}
}
pub fn make_doms(&mut self) {
if self.doms.is_none() {
self.make_control_subgraphs();
self.doms = Some(
self.control_subgraphs
.as_ref()
.unwrap()
.iter()
.map(|subgraph| dominator(subgraph, NodeID::new(0)))
.collect(),
);
}
}
pub fn make_postdoms(&mut self) {
if self.postdoms.is_none() {
self.make_control_subgraphs();
self.postdoms = Some(
zip(
self.control_subgraphs.as_ref().unwrap().iter(),
self.functions.iter(),
)
.map(|(subgraph, function)| dominator(subgraph, NodeID::new(function.nodes.len())))
.collect(),
);
}
}
pub fn make_fork_join_maps(&mut self) {
if self.fork_join_maps.is_none() {
self.make_control_subgraphs();
self.fork_join_maps = Some(
zip(
self.functions.iter(),
self.control_subgraphs.as_ref().unwrap().iter(),
)
.map(|(function, subgraph)| fork_join_map(function, subgraph))
.collect(),
);
}
}
pub fn make_fork_join_nests(&mut self) {
if self.fork_join_nests.is_none() {
self.make_doms();
self.make_fork_join_maps();
self.fork_join_nests = Some(
zip(
self.functions.iter(),
zip(
self.doms.as_ref().unwrap().iter(),
self.fork_join_maps.as_ref().unwrap().iter(),
),
)
.map(|(function, (dom, fork_join_map))| {
compute_fork_join_nesting(function, dom, fork_join_map)
})
.collect(),
);
}
}
pub fn make_loops(&mut self) {
if self.loops.is_none() {
self.make_control_subgraphs();
self.make_doms();
self.make_fork_join_maps();
let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter();
let doms = self.doms.as_ref().unwrap().iter();
let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
self.loops = Some(
zip(control_subgraphs, zip(doms, fork_join_maps))
.map(|(control_subgraph, (dom, fork_join_map))| {
loops(control_subgraph, NodeID::new(0), dom, fork_join_map)
})
.collect(),
);
}
}
pub fn make_reduce_cycles(&mut self) {
if self.reduce_cycles.is_none() {
self.make_def_uses();
let def_uses = self.def_uses.as_ref().unwrap().iter();
self.reduce_cycles = Some(
zip(self.functions.iter(), def_uses)
.map(|(function, def_use)| reduce_cycles(function, def_use))
.collect(),
);
}
}
pub fn make_data_nodes_in_fork_joins(&mut self) {
if self.data_nodes_in_fork_joins.is_none() {
self.make_def_uses();
self.make_fork_join_maps();
self.data_nodes_in_fork_joins = Some(
zip(
self.functions.iter(),
zip(
self.def_uses.as_ref().unwrap().iter(),
self.fork_join_maps.as_ref().unwrap().iter(),
),
)
.map(|(function, (def_use, fork_join_map))| {
data_nodes_in_fork_joins(function, def_use, fork_join_map)
})
.collect(),
);
}
}
pub fn make_collection_objects(&mut self) {
if self.collection_objects.is_none() {
self.make_reverse_postorders();
self.make_typing();
self.make_callgraph();
let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
let typing = self.typing.as_ref().unwrap();
let callgraph = self.callgraph.as_ref().unwrap();
self.collection_objects = Some(collection_objects(
&self.functions,
&self.types.borrow(),
reverse_postorders,
typing,
callgraph,
));
}
}
pub fn make_callgraph(&mut self) {
if self.callgraph.is_none() {
self.callgraph = Some(callgraph(&self.functions));
}
}
pub fn make_devices(&mut self) {
if self.devices.is_none() {
self.make_callgraph();
let callgraph = self.callgraph.as_ref().unwrap();
self.devices = Some(device_placement(&self.functions, callgraph));
}
}
pub fn make_object_device_demands(&mut self) {
if self.object_device_demands.is_none() {
self.make_typing();
self.make_callgraph();
self.make_collection_objects();
self.make_devices();
let typing = self.typing.as_ref().unwrap();
let callgraph = self.callgraph.as_ref().unwrap();
let collection_objects = self.collection_objects.as_ref().unwrap();
let devices = self.devices.as_ref().unwrap();
self.object_device_demands = Some(object_device_demands(
&self.functions,
&self.types.borrow(),
typing,
callgraph,
collection_objects,
devices,
));
}
}
pub fn delete_gravestones(&mut self) {
for func in self.functions.iter_mut() {
func.delete_gravestones();
}
}
fn clear_analyses(&mut self) {
self.def_uses = None;
self.reverse_postorders = None;
self.typing = None;
self.control_subgraphs = None;
self.doms = None;
self.postdoms = None;
self.fork_join_maps = None;
self.fork_join_nests = None;
self.loops = None;
self.reduce_cycles = None;
self.data_nodes_in_fork_joins = None;
self.collection_objects = None;
self.callgraph = None;
self.devices = None;
self.object_device_demands = None;
self.bbs = None;
self.node_colors = None;
self.backing_allocations = None;
}
fn with_mod<B, F>(&mut self, mut f: F) -> B
where
F: FnMut(&mut Module) -> B,
{
let mut module = Module {
functions: std::mem::take(&mut self.functions),
types: self.types.take(),
constants: self.constants.take(),
dynamic_constants: self.dynamic_constants.take(),
labels: self.labels.take(),
};
let res = f(&mut module);
let Module {
functions,
types,
constants,
dynamic_constants,
labels,
} = module;
self.functions = functions;
self.types.replace(types);
self.constants.replace(constants);
self.dynamic_constants.replace(dynamic_constants);
self.labels.replace(labels);
res
}
fn codegen(mut self, output_dir: String, module_name: String) -> Result<(), SchedulerError> {
self.make_typing();
self.make_control_subgraphs();
self.make_collection_objects();
self.make_callgraph();
self.make_devices();
let PassManager {
functions,
types,
constants,
dynamic_constants,
labels,
typing: Some(typing),
control_subgraphs: Some(control_subgraphs),
collection_objects: Some(collection_objects),
callgraph: Some(callgraph),
devices: Some(devices),
bbs: Some(bbs),
node_colors: Some(node_colors),
backing_allocations: Some(backing_allocations),
..
} = self
else {
return Err(SchedulerError::PassError {
pass: "codegen".to_string(),
error: "Missing basic blocks or backing allocations".to_string(),
});
};
let module = Module {
functions,
types: types.into_inner(),
constants: constants.into_inner(),
dynamic_constants: dynamic_constants.into_inner(),
labels: labels.into_inner(),
};
let mut rust_rt = String::new();
let mut llvm_ir = String::new();
for idx in 0..module.functions.len() {
match devices[idx] {
Device::LLVM => cpu_codegen(
&module.functions[idx],
&module.types,
&module.constants,
&module.dynamic_constants,
&typing[idx],
&control_subgraphs[idx],
&bbs[idx],
&mut llvm_ir,
)
.map_err(|e| SchedulerError::PassError {
pass: "cpu codegen".to_string(),
error: format!("{}", e),
})?,
Device::AsyncRust => rt_codegen(
FunctionID::new(idx),
&module,
&typing[idx],
&control_subgraphs[idx],
&collection_objects,
&callgraph,
&devices,
&bbs[idx],
&node_colors[idx],
&backing_allocations[&FunctionID::new(idx)],
&mut rust_rt,
)
.map_err(|e| SchedulerError::PassError {
pass: "rust codegen".to_string(),
error: format!("{}", e),
})?,
_ => todo!(),
}
}
println!("{}", llvm_ir);
println!("{}", rust_rt);
// Write the LLVM IR into a temporary file.
let tmp_dir = TempDir::new().unwrap();
let mut tmp_path = tmp_dir.path().to_path_buf();
tmp_path.push(format!("{}.ll", module_name));
println!("{}", tmp_path.display());
let mut file = File::create(&tmp_path).expect("PANIC: Unable to open output LLVM IR file.");
file.write_all(llvm_ir.as_bytes())
.expect("PANIC: Unable to write output LLVM IR file contents.");
// Compile LLVM IR into an ELF object file.
let output_archive = format!("{}/lib{}.a", output_dir, module_name);
let mut clang_process = Command::new("clang")
.arg(&tmp_path)
.arg("--emit-static-lib")
.arg("-O3")
.arg("-march=native")
.arg("-o")
.arg(&output_archive)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.expect("Error running clang. Is it installed?");
assert!(clang_process.wait().unwrap().success());
// Write the Rust runtime into a file.
let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name);
println!("{}", output_rt);
let mut file =
File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file.");
file.write_all(rust_rt.as_bytes())
.expect("PANIC: Unable to write output Rust runtime file contents.");
Ok(())
}
}
pub fn schedule_codegen(
mut module: Module,
schedule: ScheduleStmt,
mut stringtab: StringTable,
mut env: Env<usize, Value>,
functions: JunoFunctions,
output_dir: String,
module_name: String,
) -> Result<(), SchedulerError> {
let mut pm = PassManager::new(module);
let _ = schedule_interpret(&mut pm, &schedule, &mut stringtab, &mut env, &functions)?;
pm.codegen(output_dir, module_name)
}
// Interpreter for statements and expressions returns a bool indicating whether
// any optimization ran and changed the IR. This is used for implementing
// the fixpoint
fn schedule_interpret(
pm: &mut PassManager,
schedule: &ScheduleStmt,
stringtab: &mut StringTable,
env: &mut Env<usize, Value>,
functions: &JunoFunctions,
) -> Result<bool, SchedulerError> {
match schedule {
ScheduleStmt::Fixpoint { body, limit } => {
let mut i = 0;
loop {
// If no change was made, we've reached the fixpoint and are done
if !schedule_interpret(pm, body, stringtab, env, functions)? {
break;
}
// Otherwise, increase the iteration count and check the limit
i += 1;
match limit {
FixpointLimit::NoLimit() => {}
FixpointLimit::PrintIter() => {
println!("Finished Iteration {}", i - 1)
}
FixpointLimit::StopAfter(n) => {
if i >= *n {
break;
}
}
FixpointLimit::PanicAfter(n) => {
if i >= *n {
return Err(SchedulerError::FixpointFailure());
}
}
}
}
// If we ran just 1 iteration, then no changes were made and otherwise some changes
// were made
Ok(i > 1)
}
ScheduleStmt::Block { body } => {
let mut modified = false;
env.open_scope();
for command in body {
modified |= schedule_interpret(pm, command, stringtab, env, functions)?;
}
env.close_scope();
Ok(modified)
}
ScheduleStmt::Let { var, exp } => {
let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?;
let var_id = stringtab.lookup_string(var.clone());
env.insert(var_id, res);
Ok(modified)
}
ScheduleStmt::Assign { var, exp } => {
let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?;
let var_id = stringtab.lookup_string(var.clone());
match env.lookup_mut(&var_id) {
None => {
return Err(SchedulerError::UndefinedVariable(var.clone()));
}
Some(val) => {
*val = res;
}
}
Ok(modified)
}
ScheduleStmt::AddSchedule { sched, on } => match on {
Selector::Everything() => Err(SchedulerError::SemanticError(
"Cannot apply schedule to everything".to_string(),
)),
Selector::Selection(selection) => {
let mut changed = false;
for label in selection {
let (label, modified) = interp_expr(pm, label, stringtab, env, functions)?;
changed |= modified;
add_schedule(pm, sched.clone(), label.as_labels()?);
}
Ok(changed)
}
},
ScheduleStmt::AddDevice { device, on } => match on {
Selector::Everything() => {
for func in pm.functions.iter_mut() {
func.device = Some(device.clone());
}
Ok(false)
}
Selector::Selection(selection) => {
let mut changed = false;
for func in selection {
let (func, modified) = interp_expr(pm, func, stringtab, env, functions)?;
changed |= modified;
add_device(pm, device.clone(), func.as_functions(functions)?);
}
Ok(changed)
}
},
}
}
fn interp_expr(
pm: &mut PassManager,
expr: &ScheduleExp,
stringtab: &mut StringTable,
env: &mut Env<usize, Value>,
functions: &JunoFunctions,
) -> Result<(Value, bool), SchedulerError> {
match expr {
ScheduleExp::Variable { var } => {
let var_id = stringtab.lookup_string(var.clone());
match env.lookup(&var_id) {
None => Err(SchedulerError::UndefinedVariable(var.clone())),
Some(v) => Ok((v.clone(), false)),
}
}
ScheduleExp::Integer { val } => Ok((Value::Integer { val: *val }, false)),
ScheduleExp::Boolean { val } => Ok((Value::Boolean { val: *val }, false)),
ScheduleExp::Field { collect, field } => {
let (lhs, changed) = interp_expr(pm, collect, stringtab, env, functions)?;
match lhs {
Value::Label { .. }
| Value::Selection { .. }
| Value::Everything { .. }
| Value::Integer { .. }
| Value::Boolean { .. } => Err(SchedulerError::UndefinedField(field.clone())),
Value::JunoFunction { func } => {
match pm.labels.borrow().iter().position(|s| s == field) {
None => Err(SchedulerError::UndefinedLabel(field.clone())),
Some(label_idx) => Ok((
Value::Label {
labels: functions
.get_function(func)
.iter()
.map(|f| LabelInfo {
func: *f,
label: LabelID::new(label_idx),
})
.collect(),
},
changed,
)),
}
}
Value::HerculesFunction { func } => {
match pm.labels.borrow().iter().position(|s| s == field) {
None => Err(SchedulerError::UndefinedLabel(field.clone())),
Some(label_idx) => Ok((
Value::Label {
labels: vec![LabelInfo {
func: func,
label: LabelID::new(label_idx),
}],
},
changed,
)),
}
}
Value::Record { fields } => match fields.get(field) {
None => Err(SchedulerError::UndefinedField(field.clone())),
Some(v) => Ok((v.clone(), changed)),
},
}
}
ScheduleExp::RunPass { pass, args, on } => {
let mut changed = false;
let mut arg_vals = vec![];
for arg in args {
let (val, modified) = interp_expr(pm, arg, stringtab, env, functions)?;
arg_vals.push(val);
changed |= modified;
}
let selection = match on {
Selector::Everything() => None,
Selector::Selection(selection) => {
let mut locs = vec![];
let mut everything = false;
for loc in selection {
let (val, modified) = interp_expr(pm, loc, stringtab, env, functions)?;
changed |= modified;
if val.is_everything() {
everything = true;
break;
}
locs.extend(val.as_locations(functions)?);
}
if everything {
None
} else {
Some(locs)
}
}
};
let (res, modified) = run_pass(pm, *pass, arg_vals, selection)?;
changed |= modified;
Ok((res, changed))
}
ScheduleExp::Record { fields } => {
let mut result = HashMap::new();
let mut changed = false;
for (field, val) in fields {
let (val, modified) = interp_expr(pm, val, stringtab, env, functions)?;
result.insert(field.clone(), val);
changed |= modified;
}
Ok((Value::Record { fields: result }, changed))
}
ScheduleExp::Block { body, res } => {
let mut changed = false;
env.open_scope();
for command in body {
changed |= schedule_interpret(pm, command, stringtab, env, functions)?;
}
let (res, modified) = interp_expr(pm, res, stringtab, env, functions)?;
env.close_scope();
Ok((res, changed || modified))
}
ScheduleExp::Selection { selection } => match selection {
Selector::Everything() => Ok((Value::Everything {}, false)),
Selector::Selection(selection) => {
let mut values = vec![];
let mut changed = false;
for e in selection {
let (val, modified) = interp_expr(pm, e, stringtab, env, functions)?;
values.push(val);
changed |= modified;
}
Ok((Value::Selection { selection: values }, changed))
}
},
}
}
fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo>) {
for LabelInfo { func, label } in label_ids {
let nodes = pm.functions[func.idx()]
.labels
.iter()
.enumerate()
.filter(|(i, ls)| ls.contains(&label))
.map(|(i, ls)| i)
.collect::<Vec<_>>();
for node in nodes {
pm.functions[func.idx()].schedules[node].push(sched.clone());
}
}
}
fn add_device(pm: &mut PassManager, device: Device, funcs: Vec<FunctionID>) {
for func in funcs {
pm.functions[func.idx()].device = Some(device.clone());
}
}
#[derive(Debug, Clone)]
enum FunctionSelection {
Nothing(),
Everything(),
Labels(HashSet<LabelID>),
}
impl FunctionSelection {
fn add_label(&mut self, label: LabelID) {
match self {
FunctionSelection::Nothing() => {
*self = FunctionSelection::Labels(HashSet::from([label]));
}
FunctionSelection::Everything() => {}
FunctionSelection::Labels(set) => {
set.insert(label);
}
}
}
fn add_everything(&mut self) {
*self = FunctionSelection::Everything();
}
}
fn build_editors<'a>(pm: &'a mut PassManager) -> Vec<FunctionEditor<'a>> {
pm.make_def_uses();
let def_uses = pm.def_uses.take().unwrap();
pm.functions
.iter_mut()
.zip(def_uses.iter())
.enumerate()
.map(|(idx, (func, def_use))| {
FunctionEditor::new(
func,
FunctionID::new(idx),
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
def_use,
)
})
.collect()
}
// With a selection, we process it to identify which labels in which functions are to be selected
fn construct_selection(pm: &PassManager, selection: Vec<CodeLocation>) -> Vec<FunctionSelection> {
let mut selected = vec![FunctionSelection::Nothing(); pm.functions.len()];
for loc in selection {
match loc {
CodeLocation::Label(label) => selected[label.func.idx()].add_label(label.label),
CodeLocation::Function(func) => selected[func.idx()].add_everything(),
}
}
selected
}
// Given a selection, constructs the set of functions which are selected (and each must be selected
// fully)
fn selection_of_functions(
pm: &PassManager,
selection: Option<Vec<CodeLocation>>,
) -> Option<Vec<FunctionID>> {
if let Some(selection) = selection {
let selection = construct_selection(pm, selection);
let mut result = vec![];
for (idx, selected) in selection.into_iter().enumerate() {
match selected {
FunctionSelection::Nothing() => {}
FunctionSelection::Everything() => result.push(FunctionID::new(idx)),
FunctionSelection::Labels(_) => {
return None;
}
}
}
Some(result)
} else {
Some(
pm.functions
.iter()
.enumerate()
.map(|(i, _)| FunctionID::new(i))
.collect(),
)
}
}
// Given a selection, constructs the set of the nodes selected for a single function, returning the
// function's id
fn selection_as_set(
pm: &PassManager,
selection: Option<Vec<CodeLocation>>,
) -> Option<(BTreeSet<NodeID>, FunctionID)> {
if let Some(selection) = selection {
let selection = construct_selection(pm, selection);
let mut result = None;
for (idx, (selected, func)) in selection.into_iter().zip(pm.functions.iter()).enumerate() {
match selected {
FunctionSelection::Nothing() => {}
FunctionSelection::Everything() => match result {
Some(_) => {
return None;
}
None => {
result = Some((
(0..func.nodes.len()).map(|i| NodeID::new(i)).collect(),
FunctionID::new(idx),
));
}
},
FunctionSelection::Labels(labels) => match result {
Some(_) => {
return None;
}
None => {
result = Some((
(0..func.nodes.len())
.filter(|i| !func.labels[*i].is_disjoint(&labels))
.map(|i| NodeID::new(i))
.collect(),
FunctionID::new(idx),
));
}
},
}
}
result
} else {
None
}
}
fn build_selection<'a>(
pm: &'a mut PassManager,
selection: Option<Vec<CodeLocation>>,
) -> Vec<Option<FunctionEditor<'a>>> {
// Build def uses, which are needed for the editors we'll construct
pm.make_def_uses();
let def_uses = pm.def_uses.take().unwrap();
if let Some(selection) = selection {
let selected = construct_selection(pm, selection);
pm.functions
.iter_mut()
.zip(selected.iter())
.zip(def_uses.iter())
.enumerate()
.map(|(idx, ((func, selected), def_use))| match selected {
FunctionSelection::Nothing() => None,
FunctionSelection::Everything() => Some(FunctionEditor::new(
func,
FunctionID::new(idx),
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
def_use,
)),
FunctionSelection::Labels(labels) => Some(FunctionEditor::new_labeled(
func,
FunctionID::new(idx),
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
def_use,
labels,
)),
})
.collect()
} else {
build_editors(pm)
.into_iter()
.map(|func| Some(func))
.collect()
}
}
fn run_pass(
pm: &mut PassManager,
pass: Pass,
args: Vec<Value>,
selection: Option<Vec<CodeLocation>>,
) -> Result<(Value, bool), SchedulerError> {
let mut result = Value::Record {
fields: HashMap::new(),
};
let mut changed = false;
match pass {
Pass::AutoOutline => {
let Some(funcs) = selection_of_functions(pm, selection) else {
return Err(SchedulerError::PassError {
pass: "autoOutline".to_string(),
error: "must be applied to whole functions".to_string(),
});
};
pm.make_def_uses();
let def_uses = pm.def_uses.take().unwrap();
for func in funcs.iter() {
let mut editor = FunctionEditor::new(
&mut pm.functions[func.idx()],
*func,
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
&def_uses[func.idx()],
);
collapse_returns(&mut editor);
ensure_between_control_flow(&mut editor);
changed |= editor.modified();
}
pm.clear_analyses();
pm.make_def_uses();
pm.make_typing();
pm.make_control_subgraphs();
pm.make_doms();
let def_uses = pm.def_uses.take().unwrap();
let typing = pm.typing.take().unwrap();
let control_subgraphs = pm.control_subgraphs.take().unwrap();
let doms = pm.doms.take().unwrap();
let old_num_funcs = pm.functions.len();
let mut new_funcs = vec![];
// Track the names of the old functions and the new function IDs for returning
let mut new_func_ids = HashMap::new();
for func in funcs {
let mut editor = FunctionEditor::new(
&mut pm.functions[func.idx()],
func,
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
&def_uses[func.idx()],
);
let new_func_id = FunctionID::new(old_num_funcs + new_funcs.len());
let new_func = dumb_outline(
&mut editor,
&typing[func.idx()],
&control_subgraphs[func.idx()],
&doms[func.idx()],
new_func_id,
);
changed |= editor.modified();
if let Some(new_func) = new_func {
new_func_ids.insert(
editor.func().name.clone(),
Value::HerculesFunction { func: new_func_id },
);
new_funcs.push(new_func);
}
pm.functions[func.idx()].delete_gravestones();
}
pm.functions.extend(new_funcs);
pm.clear_analyses();
result = Value::Record {
fields: new_func_ids,
};
}
Pass::CCP => {
assert!(args.is_empty());
pm.make_reverse_postorders();
let reverse_postorders = pm.reverse_postorders.take().unwrap();
for (func, reverse_postorder) in build_selection(pm, selection)
.into_iter()
.zip(reverse_postorders.iter())
{
let Some(mut func) = func else {
continue;
};
ccp(&mut func, reverse_postorder);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::CRC => {
assert!(args.is_empty());
for func in build_selection(pm, selection) {
let Some(mut func) = func else {
continue;
};
crc(&mut func);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::DCE => {
assert!(args.is_empty());
for func in build_selection(pm, selection) {
let Some(mut func) = func else {
continue;
};
dce(&mut func);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::DeleteUncalled => {
todo!("Delete Uncalled changes FunctionIDs, a bunch of bookkeeping is needed for the pass manager to address this")
}
Pass::FloatCollections => {
assert!(args.is_empty());
if let Some(_) = selection {
return Err(SchedulerError::PassError {
pass: "floatCollections".to_string(),
error: "must be applied to the entire module".to_string(),
});
}
pm.make_typing();
pm.make_callgraph();
pm.make_devices();
let typing = pm.typing.take().unwrap();
let callgraph = pm.callgraph.take().unwrap();
let devices = pm.devices.take().unwrap();
let mut editors = build_editors(pm);
float_collections(&mut editors, &typing, &callgraph, &devices);
for func in editors {
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::ForkGuardElim => {
todo!("Fork Guard Elim doesn't use editor")
}
Pass::ForkSplit => {
assert!(args.is_empty());
pm.make_fork_join_maps();
pm.make_reduce_cycles();
let fork_join_maps = pm.fork_join_maps.take().unwrap();
let reduce_cycles = pm.reduce_cycles.take().unwrap();
for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection)
.into_iter()
.zip(fork_join_maps.iter())
.zip(reduce_cycles.iter())
{
let Some(mut func) = func else {
continue;
};
fork_split(&mut func, fork_join_map, reduce_cycles);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::Forkify => {
todo!("Forkify doesn't use editor")
}
Pass::GCM => {
assert!(args.is_empty());
if let Some(_) = selection {
return Err(SchedulerError::PassError {
pass: "gcm".to_string(),
error: "must be applied to the entire module".to_string(),
});
}
// Iterate functions in reverse topological order, since inter-
// device copies introduced in a callee may affect demands in a
// caller, and the object allocation of a callee affects the object
// allocation of its callers.
pm.make_callgraph();
let callgraph = pm.callgraph.take().unwrap();
let topo = callgraph.topo();
loop {
pm.make_def_uses();
pm.make_reverse_postorders();
pm.make_typing();
pm.make_control_subgraphs();
pm.make_doms();
pm.make_fork_join_maps();
pm.make_loops();
pm.make_collection_objects();
pm.make_devices();
pm.make_object_device_demands();
let def_uses = pm.def_uses.take().unwrap();
let reverse_postorders = pm.reverse_postorders.take().unwrap();
let typing = pm.typing.take().unwrap();
let doms = pm.doms.take().unwrap();
let fork_join_maps = pm.fork_join_maps.take().unwrap();
let loops = pm.loops.take().unwrap();
let control_subgraphs = pm.control_subgraphs.take().unwrap();
let collection_objects = pm.collection_objects.take().unwrap();
let devices = pm.devices.take().unwrap();
let object_device_demands = pm.object_device_demands.take().unwrap();
let mut bbs = vec![(vec![], vec![]); topo.len()];
let mut node_colors = vec![BTreeMap::new(); topo.len()];
let mut backing_allocations = BTreeMap::new();
let mut editors = build_editors(pm);
let mut any_failed = false;
for id in topo.iter() {
let editor = &mut editors[id.idx()];
if let Some((bb, function_node_colors, backing_allocation)) = gcm(
editor,
&def_uses[id.idx()],
&reverse_postorders[id.idx()],
&typing[id.idx()],
&control_subgraphs[id.idx()],
&doms[id.idx()],
&fork_join_maps[id.idx()],
&loops[id.idx()],
&collection_objects,
&devices,
&object_device_demands[id.idx()],
&backing_allocations,
) {
bbs[id.idx()] = bb;
node_colors[id.idx()] = function_node_colors;
backing_allocations.insert(*id, backing_allocation);
} else {
any_failed = true;
}
changed |= editor.modified();
if any_failed {
break;
}
}
pm.delete_gravestones();
pm.clear_analyses();
if !any_failed {
pm.bbs = Some(bbs);
pm.node_colors = Some(node_colors);
pm.backing_allocations = Some(backing_allocations);
break;
}
}
}
Pass::GVN => {
assert!(args.is_empty());
for func in build_selection(pm, selection) {
let Some(mut func) = func else {
continue;
};
gvn(&mut func, false);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::InferSchedules => {
assert!(args.is_empty());
pm.make_fork_join_maps();
pm.make_reduce_cycles();
let fork_join_maps = pm.fork_join_maps.take().unwrap();
let reduce_cycles = pm.reduce_cycles.take().unwrap();
for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection)
.into_iter()
.zip(fork_join_maps.iter())
.zip(reduce_cycles.iter())
{
let Some(mut func) = func else {
continue;
};
infer_parallel_reduce(&mut func, fork_join_map, reduce_cycles);
infer_parallel_fork(&mut func, fork_join_map);
infer_vectorizable(&mut func, fork_join_map);
infer_tight_associative(&mut func, reduce_cycles);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::Inline => {
assert!(args.is_empty());
if let Some(_) = selection {
return Err(SchedulerError::PassError {
pass: "inline".to_string(),
error: "must be applied to the entire module (currently)".to_string(),
});
}
pm.make_callgraph();
let callgraph = pm.callgraph.take().unwrap();
let mut editors = build_editors(pm);
inline(&mut editors, &callgraph);
for func in editors {
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::InterproceduralSROA => {
assert!(args.is_empty());
if let Some(_) = selection {
return Err(SchedulerError::PassError {
pass: "interproceduralSROA".to_string(),
error: "must be applied to the entire module".to_string(),
});
}
let mut editors = build_editors(pm);
interprocedural_sroa(&mut editors);
for func in editors {
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::LiftDCMath => {
assert!(args.is_empty());
for func in build_selection(pm, selection) {
let Some(mut func) = func else {
continue;
};
lift_dc_math(&mut func);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::Outline => {
let Some((nodes, func)) = selection_as_set(pm, selection) else {
return Err(SchedulerError::PassError {
pass: "outline".to_string(),
error: "must be applied to nodes in a single function".to_string(),
});
};
pm.make_def_uses();
let def_uses = pm.def_uses.take().unwrap();
let mut editor = FunctionEditor::new(
&mut pm.functions[func.idx()],
func,
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
&def_uses[func.idx()],
);
collapse_returns(&mut editor);
ensure_between_control_flow(&mut editor);
pm.clear_analyses();
pm.make_def_uses();
pm.make_typing();
pm.make_control_subgraphs();
pm.make_doms();
let def_uses = pm.def_uses.take().unwrap();
let typing = pm.typing.take().unwrap();
let control_subgraphs = pm.control_subgraphs.take().unwrap();
let doms = pm.doms.take().unwrap();
let new_func_id = FunctionID::new(pm.functions.len());
let mut editor = FunctionEditor::new(
&mut pm.functions[func.idx()],
func,
&pm.constants,
&pm.dynamic_constants,
&pm.types,
&pm.labels,
&def_uses[func.idx()],
);
let new_func = outline(
&mut editor,
&typing[func.idx()],
&control_subgraphs[func.idx()],
&doms[func.idx()],
&nodes,
new_func_id,
);
let Some(new_func) = new_func else {
return Err(SchedulerError::PassError {
pass: "outlining".to_string(),
error: "failed to outline".to_string(),
});
};
pm.functions.push(new_func);
changed = true;
pm.functions[func.idx()].delete_gravestones();
pm.clear_analyses();
result = Value::HerculesFunction { func: new_func_id };
}
Pass::PhiElim => {
assert!(args.is_empty());
for func in build_selection(pm, selection) {
let Some(mut func) = func else {
continue;
};
phi_elim(&mut func);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::Predication => {
assert!(args.is_empty());
pm.make_typing();
let typing = pm.typing.take().unwrap();
for (func, types) in build_selection(pm, selection)
.into_iter()
.zip(typing.iter())
{
let Some(mut func) = func else {
continue;
};
predication(&mut func, types);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::SLF => {
assert!(args.is_empty());
pm.make_reverse_postorders();
pm.make_typing();
let reverse_postorders = pm.reverse_postorders.take().unwrap();
let typing = pm.typing.take().unwrap();
for ((func, reverse_postorder), types) in build_selection(pm, selection)
.into_iter()
.zip(reverse_postorders.iter())
.zip(typing.iter())
{
let Some(mut func) = func else {
continue;
};
slf(&mut func, reverse_postorder, types);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::SROA => {
assert!(args.is_empty());
pm.make_reverse_postorders();
pm.make_typing();
let reverse_postorders = pm.reverse_postorders.take().unwrap();
let typing = pm.typing.take().unwrap();
for ((func, reverse_postorder), types) in build_selection(pm, selection)
.into_iter()
.zip(reverse_postorders.iter())
.zip(typing.iter())
{
let Some(mut func) = func else {
continue;
};
sroa(&mut func, reverse_postorder, types);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::Unforkify => {
assert!(args.is_empty());
pm.make_fork_join_maps();
let fork_join_maps = pm.fork_join_maps.take().unwrap();
for (func, fork_join_map) in build_selection(pm, selection)
.into_iter()
.zip(fork_join_maps.iter())
{
let Some(mut func) = func else {
continue;
};
unforkify(&mut func, fork_join_map);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::WritePredication => {
assert!(args.is_empty());
for func in build_selection(pm, selection) {
let Some(mut func) = func else {
continue;
};
write_predication(&mut func);
changed |= func.modified();
}
pm.delete_gravestones();
pm.clear_analyses();
}
Pass::Verify => {
assert!(args.is_empty());
let (def_uses, reverse_postorders, typing, subgraphs, doms, postdoms, fork_join_maps) =
pm.with_mod(|module| verify(module))
.map_err(|msg| SchedulerError::PassError {
pass: "verify".to_string(),
error: format!("failed: {}", msg),
})?;
// Verification produces a bunch of analysis results that
// may be useful for later passes.
pm.def_uses = Some(def_uses);
pm.reverse_postorders = Some(reverse_postorders);
pm.typing = Some(typing);
pm.control_subgraphs = Some(subgraphs);
pm.doms = Some(doms);
pm.postdoms = Some(postdoms);
pm.fork_join_maps = Some(fork_join_maps);
}
Pass::Xdot => {
let force_analyses = match args.get(0) {
Some(Value::Boolean { val }) => *val,
Some(_) => {
return Err(SchedulerError::PassError {
pass: "xdot".to_string(),
error: "expected boolean argument".to_string(),
});
}
None => true,
};
pm.make_reverse_postorders();
if force_analyses {
pm.make_doms();
pm.make_fork_join_maps();
pm.make_devices();
}
let reverse_postorders = pm.reverse_postorders.take().unwrap();
let doms = pm.doms.take();
let fork_join_maps = pm.fork_join_maps.take();
let devices = pm.devices.take();
let bbs = pm.bbs.take();
pm.with_mod(|module| {
xdot_module(
module,
&reverse_postorders,
doms.as_ref(),
fork_join_maps.as_ref(),
devices.as_ref(),
bbs.as_ref(),
)
});
// Put BasicBlocks back, since it's needed for Codegen.
pm.bbs = bbs;
}
}
println!("Ran Pass: {:?}", pass);
Ok((result, changed))
}