Skip to content
Snippets Groups Projects
pm.rs 57.83 KiB
use crate::ir::*;
use crate::labels::*;
use hercules_cg::*;
use hercules_ir::*;
use hercules_opt::*;

use tempfile::TempDir;

use juno_utils::env::Env;
use juno_utils::stringtab::StringTable;

use std::cell::RefCell;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::fmt;
use std::fs::File;
use std::io::Write;
use std::iter::zip;
use std::process::{Command, Stdio};

#[derive(Debug, Clone)]
pub enum Value {
    Label { labels: Vec<LabelInfo> },
    JunoFunction { func: JunoFunctionID },
    HerculesFunction { func: FunctionID },
    Record { fields: HashMap<String, Value> },
    Everything {},
    Selection { selection: Vec<Value> },
    Integer { val: usize },
    Boolean { val: bool },
}

#[derive(Debug, Copy, Clone)]
enum CodeLocation {
    Label(LabelInfo),
    Function(FunctionID),
}

impl Value {
    fn is_everything(&self) -> bool {
        match self {
            Value::Everything {} => true,
            _ => false,
        }
    }

    fn as_labels(&self) -> Result<Vec<LabelInfo>, SchedulerError> {
        match self {
            Value::Label { labels } => Ok(labels.clone()),
            Value::Selection { selection } => {
                let mut result = vec![];
                for val in selection {
                    result.extend(val.as_labels()?);
                }
                Ok(result)
            }
            Value::JunoFunction { .. } | Value::HerculesFunction { .. } => Err(
                SchedulerError::SemanticError("Expected labels, found function".to_string()),
            ),
            Value::Record { .. } => Err(SchedulerError::SemanticError(
                "Expected labels, found record".to_string(),
            )),
            Value::Everything {} => Err(SchedulerError::SemanticError(
                "Expected labels, found everything".to_string(),
            )),
            Value::Integer { .. } => Err(SchedulerError::SemanticError(
                "Expected labels, found integer".to_string(),
            )),
            Value::Boolean { .. } => Err(SchedulerError::SemanticError(
                "Expected labels, found boolean".to_string(),
            )),
        }
    }

    fn as_functions(&self, funcs: &JunoFunctions) -> Result<Vec<FunctionID>, SchedulerError> {
        match self {
            Value::JunoFunction { func } => Ok(funcs.get_function(*func).clone()),
            Value::HerculesFunction { func } => Ok(vec![*func]),
            Value::Selection { selection } => {
                let mut result = vec![];
                for val in selection {
                    result.extend(val.as_functions(funcs)?);
                }
                Ok(result)
            }
            Value::Label { .. } => Err(SchedulerError::SemanticError(
                "Expected functions, found label".to_string(),
            )),
            Value::Record { .. } => Err(SchedulerError::SemanticError(
                "Expected functions, found record".to_string(),
            )),
            Value::Everything {} => Err(SchedulerError::SemanticError(
                "Expected functions, found everything".to_string(),
            )),
            Value::Integer { .. } => Err(SchedulerError::SemanticError(
                "Expected functions, found integer".to_string(),
            )),
            Value::Boolean { .. } => Err(SchedulerError::SemanticError(
                "Expected functions, found boolean".to_string(),
            )),
        }
    }

    fn as_locations(&self, funcs: &JunoFunctions) -> Result<Vec<CodeLocation>, SchedulerError> {
        match self {
            Value::Label { labels } => Ok(labels.iter().map(|l| CodeLocation::Label(*l)).collect()),
            Value::JunoFunction { func } => Ok(funcs
                .get_function(*func)
                .iter()
                .map(|f| CodeLocation::Function(*f))
                .collect()),
            Value::HerculesFunction { func } => Ok(vec![CodeLocation::Function(*func)]),
            Value::Selection { selection } => {
                let mut result = vec![];
                for val in selection {
                    result.extend(val.as_locations(funcs)?);
                }
                Ok(result)
            }
            Value::Record { .. } => Err(SchedulerError::SemanticError(
                "Expected code locations, found record".to_string(),
            )),
            Value::Everything {} => {
                panic!("Internal error, check is_everything() before using as_functions()")
            }
            Value::Integer { .. } => Err(SchedulerError::SemanticError(
                "Expected code locations, found integer".to_string(),
            )),
            Value::Boolean { .. } => Err(SchedulerError::SemanticError(
                "Expected code locations, found boolean".to_string(),
            )),
        }
    }
}

#[derive(Debug, Clone)]
pub enum SchedulerError {
    UndefinedVariable(String),
    UndefinedField(String),
    UndefinedLabel(String),
    SemanticError(String),
    PassError { pass: String, error: String },
    FixpointFailure(),
}

impl fmt::Display for SchedulerError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            SchedulerError::UndefinedVariable(nm) => write!(f, "Undefined variable '{}'", nm),
            SchedulerError::UndefinedField(nm) => write!(f, "No field '{}'", nm),
            SchedulerError::UndefinedLabel(nm) => write!(f, "No label '{}'", nm),
            SchedulerError::SemanticError(msg) => write!(f, "{}", msg),
            SchedulerError::PassError { pass, error } => {
                write!(f, "Error in pass {}: {}", pass, error)
            }
            SchedulerError::FixpointFailure() => {
                write!(f, "Fixpoint did not converge within limit")
            }
        }
    }
}

#[derive(Debug)]
struct PassManager {
    functions: Vec<Function>,
    types: RefCell<Vec<Type>>,
    constants: RefCell<Vec<Constant>>,
    dynamic_constants: RefCell<Vec<DynamicConstant>>,
    labels: RefCell<Vec<String>>,

    // Cached analysis results.
    pub def_uses: Option<Vec<ImmutableDefUseMap>>,
    pub reverse_postorders: Option<Vec<Vec<NodeID>>>,
    pub typing: Option<ModuleTyping>,
    pub control_subgraphs: Option<Vec<Subgraph>>,
    pub doms: Option<Vec<DomTree>>,
    pub postdoms: Option<Vec<DomTree>>,
    pub fork_join_maps: Option<Vec<HashMap<NodeID, NodeID>>>,
    pub fork_join_nests: Option<Vec<HashMap<NodeID, Vec<NodeID>>>>,
    pub loops: Option<Vec<LoopTree>>,
    pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
    pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>,
    pub collection_objects: Option<CollectionObjects>,
    pub callgraph: Option<CallGraph>,
    pub devices: Option<Vec<Device>>,
    pub object_device_demands: Option<ObjectDeviceDemands>,
    pub bbs: Option<Vec<BasicBlocks>>,
    pub node_colors: Option<NodeColors>,
    pub backing_allocations: Option<BackingAllocations>,
}

impl PassManager {
    fn new(module: Module) -> Self {
        let Module {
            functions,
            types,
            constants,
            dynamic_constants,
            labels,
        } = module;
        PassManager {
            functions,
            types: RefCell::new(types),
            constants: RefCell::new(constants),
            dynamic_constants: RefCell::new(dynamic_constants),
            labels: RefCell::new(labels),
            def_uses: None,
            reverse_postorders: None,
            typing: None,
            control_subgraphs: None,
            doms: None,
            postdoms: None,
            fork_join_maps: None,
            fork_join_nests: None,
            loops: None,
            reduce_cycles: None,
            data_nodes_in_fork_joins: None,
            collection_objects: None,
            callgraph: None,
            devices: None,
            object_device_demands: None,
            bbs: None,
            node_colors: None,
            backing_allocations: None,
        }
    }

    pub fn make_def_uses(&mut self) {
        if self.def_uses.is_none() {
            self.def_uses = Some(self.functions.iter().map(def_use).collect());
        }
    }

    pub fn make_reverse_postorders(&mut self) {
        if self.reverse_postorders.is_none() {
            self.make_def_uses();
            self.reverse_postorders = Some(
                self.def_uses
                    .as_ref()
                    .unwrap()
                    .iter()
                    .map(reverse_postorder)
                    .collect(),
            );
        }
    }

    pub fn make_typing(&mut self) {
        if self.typing.is_none() {
            self.make_reverse_postorders();
            self.typing = Some(
                typecheck(
                    &self.functions,
                    &mut self.types.borrow_mut(),
                    &self.constants.borrow(),
                    &mut self.dynamic_constants.borrow_mut(),
                    self.reverse_postorders.as_ref().unwrap(),
                )
                .unwrap(),
            );
        }
    }

    pub fn make_control_subgraphs(&mut self) {
        if self.control_subgraphs.is_none() {
            self.make_def_uses();
            self.control_subgraphs = Some(
                zip(&self.functions, self.def_uses.as_ref().unwrap())
                    .map(|(function, def_use)| control_subgraph(function, def_use))
                    .collect(),
            );
        }
    }

    pub fn make_doms(&mut self) {
        if self.doms.is_none() {
            self.make_control_subgraphs();
            self.doms = Some(
                self.control_subgraphs
                    .as_ref()
                    .unwrap()
                    .iter()
                    .map(|subgraph| dominator(subgraph, NodeID::new(0)))
                    .collect(),
            );
        }
    }

    pub fn make_postdoms(&mut self) {
        if self.postdoms.is_none() {
            self.make_control_subgraphs();
            self.postdoms = Some(
                zip(
                    self.control_subgraphs.as_ref().unwrap().iter(),
                    self.functions.iter(),
                )
                .map(|(subgraph, function)| dominator(subgraph, NodeID::new(function.nodes.len())))
                .collect(),
            );
        }
    }

    pub fn make_fork_join_maps(&mut self) {
        if self.fork_join_maps.is_none() {
            self.make_control_subgraphs();
            self.fork_join_maps = Some(
                zip(
                    self.functions.iter(),
                    self.control_subgraphs.as_ref().unwrap().iter(),
                )
                .map(|(function, subgraph)| fork_join_map(function, subgraph))
                .collect(),
            );
        }
    }

    pub fn make_fork_join_nests(&mut self) {
        if self.fork_join_nests.is_none() {
            self.make_doms();
            self.make_fork_join_maps();
            self.fork_join_nests = Some(
                zip(
                    self.functions.iter(),
                    zip(
                        self.doms.as_ref().unwrap().iter(),
                        self.fork_join_maps.as_ref().unwrap().iter(),
                    ),
                )
                .map(|(function, (dom, fork_join_map))| {
                    compute_fork_join_nesting(function, dom, fork_join_map)
                })
                .collect(),
            );
        }
    }

    pub fn make_loops(&mut self) {
        if self.loops.is_none() {
            self.make_control_subgraphs();
            self.make_doms();
            self.make_fork_join_maps();
            let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter();
            let doms = self.doms.as_ref().unwrap().iter();
            let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter();
            self.loops = Some(
                zip(control_subgraphs, zip(doms, fork_join_maps))
                    .map(|(control_subgraph, (dom, fork_join_map))| {
                        loops(control_subgraph, NodeID::new(0), dom, fork_join_map)
                    })
                    .collect(),
            );
        }
    }

    pub fn make_reduce_cycles(&mut self) {
        if self.reduce_cycles.is_none() {
            self.make_def_uses();
            let def_uses = self.def_uses.as_ref().unwrap().iter();
            self.reduce_cycles = Some(
                zip(self.functions.iter(), def_uses)
                    .map(|(function, def_use)| reduce_cycles(function, def_use))
                    .collect(),
            );
        }
    }

    pub fn make_data_nodes_in_fork_joins(&mut self) {
        if self.data_nodes_in_fork_joins.is_none() {
            self.make_def_uses();
            self.make_fork_join_maps();
            self.data_nodes_in_fork_joins = Some(
                zip(
                    self.functions.iter(),
                    zip(
                        self.def_uses.as_ref().unwrap().iter(),
                        self.fork_join_maps.as_ref().unwrap().iter(),
                    ),
                )
                .map(|(function, (def_use, fork_join_map))| {
                    data_nodes_in_fork_joins(function, def_use, fork_join_map)
                })
                .collect(),
            );
        }
    }

    pub fn make_collection_objects(&mut self) {
        if self.collection_objects.is_none() {
            self.make_reverse_postorders();
            self.make_typing();
            self.make_callgraph();
            let reverse_postorders = self.reverse_postorders.as_ref().unwrap();
            let typing = self.typing.as_ref().unwrap();
            let callgraph = self.callgraph.as_ref().unwrap();
            self.collection_objects = Some(collection_objects(
                &self.functions,
                &self.types.borrow(),
                reverse_postorders,
                typing,
                callgraph,
            ));
        }
    }

    pub fn make_callgraph(&mut self) {
        if self.callgraph.is_none() {
            self.callgraph = Some(callgraph(&self.functions));
        }
    }

    pub fn make_devices(&mut self) {
        if self.devices.is_none() {
            self.make_callgraph();
            let callgraph = self.callgraph.as_ref().unwrap();
            self.devices = Some(device_placement(&self.functions, callgraph));
        }
    }

    pub fn make_object_device_demands(&mut self) {
        if self.object_device_demands.is_none() {
            self.make_typing();
            self.make_callgraph();
            self.make_collection_objects();
            self.make_devices();
            let typing = self.typing.as_ref().unwrap();
            let callgraph = self.callgraph.as_ref().unwrap();
            let collection_objects = self.collection_objects.as_ref().unwrap();
            let devices = self.devices.as_ref().unwrap();
            self.object_device_demands = Some(object_device_demands(
                &self.functions,
                &self.types.borrow(),
                typing,
                callgraph,
                collection_objects,
                devices,
            ));
        }
    }

    pub fn delete_gravestones(&mut self) {
        for func in self.functions.iter_mut() {
            func.delete_gravestones();
        }
    }

    fn clear_analyses(&mut self) {
        self.def_uses = None;
        self.reverse_postorders = None;
        self.typing = None;
        self.control_subgraphs = None;
        self.doms = None;
        self.postdoms = None;
        self.fork_join_maps = None;
        self.fork_join_nests = None;
        self.loops = None;
        self.reduce_cycles = None;
        self.data_nodes_in_fork_joins = None;
        self.collection_objects = None;
        self.callgraph = None;
        self.devices = None;
        self.object_device_demands = None;
        self.bbs = None;
        self.node_colors = None;
        self.backing_allocations = None;
    }

    fn with_mod<B, F>(&mut self, mut f: F) -> B
    where
        F: FnMut(&mut Module) -> B,
    {
        let mut module = Module {
            functions: std::mem::take(&mut self.functions),
            types: self.types.take(),
            constants: self.constants.take(),
            dynamic_constants: self.dynamic_constants.take(),
            labels: self.labels.take(),
        };

        let res = f(&mut module);

        let Module {
            functions,
            types,
            constants,
            dynamic_constants,
            labels,
        } = module;
        self.functions = functions;
        self.types.replace(types);
        self.constants.replace(constants);
        self.dynamic_constants.replace(dynamic_constants);
        self.labels.replace(labels);

        res
    }

    fn codegen(mut self, output_dir: String, module_name: String) -> Result<(), SchedulerError> {
        self.make_typing();
        self.make_control_subgraphs();
        self.make_collection_objects();
        self.make_callgraph();
        self.make_devices();

        let PassManager {
            functions,
            types,
            constants,
            dynamic_constants,
            labels,
            typing: Some(typing),
            control_subgraphs: Some(control_subgraphs),
            collection_objects: Some(collection_objects),
            callgraph: Some(callgraph),
            devices: Some(devices),
            bbs: Some(bbs),
            node_colors: Some(node_colors),
            backing_allocations: Some(backing_allocations),
            ..
        } = self
        else {
            return Err(SchedulerError::PassError {
                pass: "codegen".to_string(),
                error: "Missing basic blocks or backing allocations".to_string(),
            });
        };

        let module = Module {
            functions,
            types: types.into_inner(),
            constants: constants.into_inner(),
            dynamic_constants: dynamic_constants.into_inner(),
            labels: labels.into_inner(),
        };

        let mut rust_rt = String::new();
        let mut llvm_ir = String::new();
        for idx in 0..module.functions.len() {
            match devices[idx] {
                Device::LLVM => cpu_codegen(
                    &module.functions[idx],
                    &module.types,
                    &module.constants,
                    &module.dynamic_constants,
                    &typing[idx],
                    &control_subgraphs[idx],
                    &bbs[idx],
                    &mut llvm_ir,
                )
                .map_err(|e| SchedulerError::PassError {
                    pass: "cpu codegen".to_string(),
                    error: format!("{}", e),
                })?,
                Device::AsyncRust => rt_codegen(
                    FunctionID::new(idx),
                    &module,
                    &typing[idx],
                    &control_subgraphs[idx],
                    &collection_objects,
                    &callgraph,
                    &devices,
                    &bbs[idx],
                    &node_colors[idx],
                    &backing_allocations[&FunctionID::new(idx)],
                    &mut rust_rt,
                )
                .map_err(|e| SchedulerError::PassError {
                    pass: "rust codegen".to_string(),
                    error: format!("{}", e),
                })?,
                _ => todo!(),
            }
        }
        println!("{}", llvm_ir);
        println!("{}", rust_rt);

        // Write the LLVM IR into a temporary file.
        let tmp_dir = TempDir::new().unwrap();
        let mut tmp_path = tmp_dir.path().to_path_buf();
        tmp_path.push(format!("{}.ll", module_name));
        println!("{}", tmp_path.display());
        let mut file = File::create(&tmp_path).expect("PANIC: Unable to open output LLVM IR file.");
        file.write_all(llvm_ir.as_bytes())
            .expect("PANIC: Unable to write output LLVM IR file contents.");

        // Compile LLVM IR into an ELF object file.
        let output_archive = format!("{}/lib{}.a", output_dir, module_name);
        let mut clang_process = Command::new("clang")
            .arg(&tmp_path)
            .arg("--emit-static-lib")
            .arg("-O3")
            .arg("-march=native")
            .arg("-o")
            .arg(&output_archive)
            .stdin(Stdio::piped())
            .stdout(Stdio::piped())
            .spawn()
            .expect("Error running clang. Is it installed?");
        assert!(clang_process.wait().unwrap().success());

        // Write the Rust runtime into a file.
        let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name);
        println!("{}", output_rt);
        let mut file =
            File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file.");
        file.write_all(rust_rt.as_bytes())
            .expect("PANIC: Unable to write output Rust runtime file contents.");

        Ok(())
    }
}

pub fn schedule_codegen(
    mut module: Module,
    schedule: ScheduleStmt,
    mut stringtab: StringTable,
    mut env: Env<usize, Value>,
    functions: JunoFunctions,
    output_dir: String,
    module_name: String,
) -> Result<(), SchedulerError> {
    let mut pm = PassManager::new(module);
    let _ = schedule_interpret(&mut pm, &schedule, &mut stringtab, &mut env, &functions)?;
    pm.codegen(output_dir, module_name)
}

// Interpreter for statements and expressions returns a bool indicating whether
// any optimization ran and changed the IR. This is used for implementing
// the fixpoint
fn schedule_interpret(
    pm: &mut PassManager,
    schedule: &ScheduleStmt,
    stringtab: &mut StringTable,
    env: &mut Env<usize, Value>,
    functions: &JunoFunctions,
) -> Result<bool, SchedulerError> {
    match schedule {
        ScheduleStmt::Fixpoint { body, limit } => {
            let mut i = 0;
            loop {
                // If no change was made, we've reached the fixpoint and are done
                if !schedule_interpret(pm, body, stringtab, env, functions)? {
                    break;
                }
                // Otherwise, increase the iteration count and check the limit
                i += 1;
                match limit {
                    FixpointLimit::NoLimit() => {}
                    FixpointLimit::PrintIter() => {
                        println!("Finished Iteration {}", i - 1)
                    }
                    FixpointLimit::StopAfter(n) => {
                        if i >= *n {
                            break;
                        }
                    }
                    FixpointLimit::PanicAfter(n) => {
                        if i >= *n {
                            return Err(SchedulerError::FixpointFailure());
                        }
                    }
                }
            }
            // If we ran just 1 iteration, then no changes were made and otherwise some changes
            // were made
            Ok(i > 1)
        }
        ScheduleStmt::Block { body } => {
            let mut modified = false;
            env.open_scope();
            for command in body {
                modified |= schedule_interpret(pm, command, stringtab, env, functions)?;
            }
            env.close_scope();
            Ok(modified)
        }
        ScheduleStmt::Let { var, exp } => {
            let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?;
            let var_id = stringtab.lookup_string(var.clone());
            env.insert(var_id, res);
            Ok(modified)
        }
        ScheduleStmt::Assign { var, exp } => {
            let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?;
            let var_id = stringtab.lookup_string(var.clone());
            match env.lookup_mut(&var_id) {
                None => {
                    return Err(SchedulerError::UndefinedVariable(var.clone()));
                }
                Some(val) => {
                    *val = res;
                }
            }
            Ok(modified)
        }
        ScheduleStmt::AddSchedule { sched, on } => match on {
            Selector::Everything() => Err(SchedulerError::SemanticError(
                "Cannot apply schedule to everything".to_string(),
            )),
            Selector::Selection(selection) => {
                let mut changed = false;
                for label in selection {
                    let (label, modified) = interp_expr(pm, label, stringtab, env, functions)?;
                    changed |= modified;
                    add_schedule(pm, sched.clone(), label.as_labels()?);
                }
                Ok(changed)
            }
        },
        ScheduleStmt::AddDevice { device, on } => match on {
            Selector::Everything() => {
                for func in pm.functions.iter_mut() {
                    func.device = Some(device.clone());
                }
                Ok(false)
            }
            Selector::Selection(selection) => {
                let mut changed = false;
                for func in selection {
                    let (func, modified) = interp_expr(pm, func, stringtab, env, functions)?;
                    changed |= modified;
                    add_device(pm, device.clone(), func.as_functions(functions)?);
                }
                Ok(changed)
            }
        },
    }
}

fn interp_expr(
    pm: &mut PassManager,
    expr: &ScheduleExp,
    stringtab: &mut StringTable,
    env: &mut Env<usize, Value>,
    functions: &JunoFunctions,
) -> Result<(Value, bool), SchedulerError> {
    match expr {
        ScheduleExp::Variable { var } => {
            let var_id = stringtab.lookup_string(var.clone());
            match env.lookup(&var_id) {
                None => Err(SchedulerError::UndefinedVariable(var.clone())),
                Some(v) => Ok((v.clone(), false)),
            }
        }
        ScheduleExp::Integer { val } => Ok((Value::Integer { val: *val }, false)),
        ScheduleExp::Boolean { val } => Ok((Value::Boolean { val: *val }, false)),
        ScheduleExp::Field { collect, field } => {
            let (lhs, changed) = interp_expr(pm, collect, stringtab, env, functions)?;
            match lhs {
                Value::Label { .. }
                | Value::Selection { .. }
                | Value::Everything { .. }
                | Value::Integer { .. }
                | Value::Boolean { .. } => Err(SchedulerError::UndefinedField(field.clone())),
                Value::JunoFunction { func } => {
                    match pm.labels.borrow().iter().position(|s| s == field) {
                        None => Err(SchedulerError::UndefinedLabel(field.clone())),
                        Some(label_idx) => Ok((
                            Value::Label {
                                labels: functions
                                    .get_function(func)
                                    .iter()
                                    .map(|f| LabelInfo {
                                        func: *f,
                                        label: LabelID::new(label_idx),
                                    })
                                    .collect(),
                            },
                            changed,
                        )),
                    }
                }
                Value::HerculesFunction { func } => {
                    match pm.labels.borrow().iter().position(|s| s == field) {
                        None => Err(SchedulerError::UndefinedLabel(field.clone())),
                        Some(label_idx) => Ok((
                            Value::Label {
                                labels: vec![LabelInfo {
                                    func: func,
                                    label: LabelID::new(label_idx),
                                }],
                            },
                            changed,
                        )),
                    }
                }
                Value::Record { fields } => match fields.get(field) {
                    None => Err(SchedulerError::UndefinedField(field.clone())),
                    Some(v) => Ok((v.clone(), changed)),
                },
            }
        }
        ScheduleExp::RunPass { pass, args, on } => {
            let mut changed = false;
            let mut arg_vals = vec![];
            for arg in args {
                let (val, modified) = interp_expr(pm, arg, stringtab, env, functions)?;
                arg_vals.push(val);
                changed |= modified;
            }

            let selection = match on {
                Selector::Everything() => None,
                Selector::Selection(selection) => {
                    let mut locs = vec![];
                    let mut everything = false;
                    for loc in selection {
                        let (val, modified) = interp_expr(pm, loc, stringtab, env, functions)?;
                        changed |= modified;
                        if val.is_everything() {
                            everything = true;
                            break;
                        }
                        locs.extend(val.as_locations(functions)?);
                    }
                    if everything {
                        None
                    } else {
                        Some(locs)
                    }
                }
            };

            let (res, modified) = run_pass(pm, *pass, arg_vals, selection)?;
            changed |= modified;
            Ok((res, changed))
        }
        ScheduleExp::Record { fields } => {
            let mut result = HashMap::new();
            let mut changed = false;
            for (field, val) in fields {
                let (val, modified) = interp_expr(pm, val, stringtab, env, functions)?;
                result.insert(field.clone(), val);
                changed |= modified;
            }
            Ok((Value::Record { fields: result }, changed))
        }
        ScheduleExp::Block { body, res } => {
            let mut changed = false;

            env.open_scope();
            for command in body {
                changed |= schedule_interpret(pm, command, stringtab, env, functions)?;
            }
            let (res, modified) = interp_expr(pm, res, stringtab, env, functions)?;
            env.close_scope();

            Ok((res, changed || modified))
        }
        ScheduleExp::Selection { selection } => match selection {
            Selector::Everything() => Ok((Value::Everything {}, false)),
            Selector::Selection(selection) => {
                let mut values = vec![];
                let mut changed = false;
                for e in selection {
                    let (val, modified) = interp_expr(pm, e, stringtab, env, functions)?;
                    values.push(val);
                    changed |= modified;
                }
                Ok((Value::Selection { selection: values }, changed))
            }
        },
    }
}

fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo>) {
    for LabelInfo { func, label } in label_ids {
        let nodes = pm.functions[func.idx()]
            .labels
            .iter()
            .enumerate()
            .filter(|(i, ls)| ls.contains(&label))
            .map(|(i, ls)| i)
            .collect::<Vec<_>>();
        for node in nodes {
            pm.functions[func.idx()].schedules[node].push(sched.clone());
        }
    }
}

fn add_device(pm: &mut PassManager, device: Device, funcs: Vec<FunctionID>) {
    for func in funcs {
        pm.functions[func.idx()].device = Some(device.clone());
    }
}

#[derive(Debug, Clone)]
enum FunctionSelection {
    Nothing(),
    Everything(),
    Labels(HashSet<LabelID>),
}

impl FunctionSelection {
    fn add_label(&mut self, label: LabelID) {
        match self {
            FunctionSelection::Nothing() => {
                *self = FunctionSelection::Labels(HashSet::from([label]));
            }
            FunctionSelection::Everything() => {}
            FunctionSelection::Labels(set) => {
                set.insert(label);
            }
        }
    }

    fn add_everything(&mut self) {
        *self = FunctionSelection::Everything();
    }
}

fn build_editors<'a>(pm: &'a mut PassManager) -> Vec<FunctionEditor<'a>> {
    pm.make_def_uses();
    let def_uses = pm.def_uses.take().unwrap();
    pm.functions
        .iter_mut()
        .zip(def_uses.iter())
        .enumerate()
        .map(|(idx, (func, def_use))| {
            FunctionEditor::new(
                func,
                FunctionID::new(idx),
                &pm.constants,
                &pm.dynamic_constants,
                &pm.types,
                &pm.labels,
                def_use,
            )
        })
        .collect()
}

// With a selection, we process it to identify which labels in which functions are to be selected
fn construct_selection(pm: &PassManager, selection: Vec<CodeLocation>) -> Vec<FunctionSelection> {
    let mut selected = vec![FunctionSelection::Nothing(); pm.functions.len()];
    for loc in selection {
        match loc {
            CodeLocation::Label(label) => selected[label.func.idx()].add_label(label.label),
            CodeLocation::Function(func) => selected[func.idx()].add_everything(),
        }
    }
    selected
}

// Given a selection, constructs the set of functions which are selected (and each must be selected
// fully)
fn selection_of_functions(
    pm: &PassManager,
    selection: Option<Vec<CodeLocation>>,
) -> Option<Vec<FunctionID>> {
    if let Some(selection) = selection {
        let selection = construct_selection(pm, selection);

        let mut result = vec![];

        for (idx, selected) in selection.into_iter().enumerate() {
            match selected {
                FunctionSelection::Nothing() => {}
                FunctionSelection::Everything() => result.push(FunctionID::new(idx)),
                FunctionSelection::Labels(_) => {
                    return None;
                }
            }
        }

        Some(result)
    } else {
        Some(
            pm.functions
                .iter()
                .enumerate()
                .map(|(i, _)| FunctionID::new(i))
                .collect(),
        )
    }
}

// Given a selection, constructs the set of the nodes selected for a single function, returning the
// function's id
fn selection_as_set(
    pm: &PassManager,
    selection: Option<Vec<CodeLocation>>,
) -> Option<(BTreeSet<NodeID>, FunctionID)> {
    if let Some(selection) = selection {
        let selection = construct_selection(pm, selection);
        let mut result = None;

        for (idx, (selected, func)) in selection.into_iter().zip(pm.functions.iter()).enumerate() {
            match selected {
                FunctionSelection::Nothing() => {}
                FunctionSelection::Everything() => match result {
                    Some(_) => {
                        return None;
                    }
                    None => {
                        result = Some((
                            (0..func.nodes.len()).map(|i| NodeID::new(i)).collect(),
                            FunctionID::new(idx),
                        ));
                    }
                },
                FunctionSelection::Labels(labels) => match result {
                    Some(_) => {
                        return None;
                    }
                    None => {
                        result = Some((
                            (0..func.nodes.len())
                                .filter(|i| !func.labels[*i].is_disjoint(&labels))
                                .map(|i| NodeID::new(i))
                                .collect(),
                            FunctionID::new(idx),
                        ));
                    }
                },
            }
        }

        result
    } else {
        None
    }
}

fn build_selection<'a>(
    pm: &'a mut PassManager,
    selection: Option<Vec<CodeLocation>>,
) -> Vec<Option<FunctionEditor<'a>>> {
    // Build def uses, which are needed for the editors we'll construct
    pm.make_def_uses();
    let def_uses = pm.def_uses.take().unwrap();

    if let Some(selection) = selection {
        let selected = construct_selection(pm, selection);

        pm.functions
            .iter_mut()
            .zip(selected.iter())
            .zip(def_uses.iter())
            .enumerate()
            .map(|(idx, ((func, selected), def_use))| match selected {
                FunctionSelection::Nothing() => None,
                FunctionSelection::Everything() => Some(FunctionEditor::new(
                    func,
                    FunctionID::new(idx),
                    &pm.constants,
                    &pm.dynamic_constants,
                    &pm.types,
                    &pm.labels,
                    def_use,
                )),
                FunctionSelection::Labels(labels) => Some(FunctionEditor::new_labeled(
                    func,
                    FunctionID::new(idx),
                    &pm.constants,
                    &pm.dynamic_constants,
                    &pm.types,
                    &pm.labels,
                    def_use,
                    labels,
                )),
            })
            .collect()
    } else {
        build_editors(pm)
            .into_iter()
            .map(|func| Some(func))
            .collect()
    }
}

fn run_pass(
    pm: &mut PassManager,
    pass: Pass,
    args: Vec<Value>,
    selection: Option<Vec<CodeLocation>>,
) -> Result<(Value, bool), SchedulerError> {
    let mut result = Value::Record {
        fields: HashMap::new(),
    };
    let mut changed = false;

    match pass {
        Pass::AutoOutline => {
            let Some(funcs) = selection_of_functions(pm, selection) else {
                return Err(SchedulerError::PassError {
                    pass: "autoOutline".to_string(),
                    error: "must be applied to whole functions".to_string(),
                });
            };

            pm.make_def_uses();
            let def_uses = pm.def_uses.take().unwrap();

            for func in funcs.iter() {
                let mut editor = FunctionEditor::new(
                    &mut pm.functions[func.idx()],
                    *func,
                    &pm.constants,
                    &pm.dynamic_constants,
                    &pm.types,
                    &pm.labels,
                    &def_uses[func.idx()],
                );
                collapse_returns(&mut editor);
                ensure_between_control_flow(&mut editor);
                changed |= editor.modified();
            }
            pm.clear_analyses();

            pm.make_def_uses();
            pm.make_typing();
            pm.make_control_subgraphs();
            pm.make_doms();

            let def_uses = pm.def_uses.take().unwrap();
            let typing = pm.typing.take().unwrap();
            let control_subgraphs = pm.control_subgraphs.take().unwrap();
            let doms = pm.doms.take().unwrap();
            let old_num_funcs = pm.functions.len();

            let mut new_funcs = vec![];
            // Track the names of the old functions and the new function IDs for returning
            let mut new_func_ids = HashMap::new();

            for func in funcs {
                let mut editor = FunctionEditor::new(
                    &mut pm.functions[func.idx()],
                    func,
                    &pm.constants,
                    &pm.dynamic_constants,
                    &pm.types,
                    &pm.labels,
                    &def_uses[func.idx()],
                );

                let new_func_id = FunctionID::new(old_num_funcs + new_funcs.len());

                let new_func = dumb_outline(
                    &mut editor,
                    &typing[func.idx()],
                    &control_subgraphs[func.idx()],
                    &doms[func.idx()],
                    new_func_id,
                );
                changed |= editor.modified();

                if let Some(new_func) = new_func {
                    new_func_ids.insert(
                        editor.func().name.clone(),
                        Value::HerculesFunction { func: new_func_id },
                    );
                    new_funcs.push(new_func);
                }

                pm.functions[func.idx()].delete_gravestones();
            }

            pm.functions.extend(new_funcs);
            pm.clear_analyses();

            result = Value::Record {
                fields: new_func_ids,
            };
        }
        Pass::CCP => {
            assert!(args.is_empty());
            pm.make_reverse_postorders();
            let reverse_postorders = pm.reverse_postorders.take().unwrap();
            for (func, reverse_postorder) in build_selection(pm, selection)
                .into_iter()
                .zip(reverse_postorders.iter())
            {
                let Some(mut func) = func else {
                    continue;
                };
                ccp(&mut func, reverse_postorder);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::CRC => {
            assert!(args.is_empty());
            for func in build_selection(pm, selection) {
                let Some(mut func) = func else {
                    continue;
                };
                crc(&mut func);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::DCE => {
            assert!(args.is_empty());
            for func in build_selection(pm, selection) {
                let Some(mut func) = func else {
                    continue;
                };
                dce(&mut func);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::DeleteUncalled => {
            todo!("Delete Uncalled changes FunctionIDs, a bunch of bookkeeping is needed for the pass manager to address this")
        }
        Pass::FloatCollections => {
            assert!(args.is_empty());
            if let Some(_) = selection {
                return Err(SchedulerError::PassError {
                    pass: "floatCollections".to_string(),
                    error: "must be applied to the entire module".to_string(),
                });
            }

            pm.make_typing();
            pm.make_callgraph();
            pm.make_devices();
            let typing = pm.typing.take().unwrap();
            let callgraph = pm.callgraph.take().unwrap();
            let devices = pm.devices.take().unwrap();

            let mut editors = build_editors(pm);
            float_collections(&mut editors, &typing, &callgraph, &devices);

            for func in editors {
                changed |= func.modified();
            }

            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::ForkGuardElim => {
            todo!("Fork Guard Elim doesn't use editor")
        }
        Pass::ForkSplit => {
            assert!(args.is_empty());
            pm.make_fork_join_maps();
            pm.make_reduce_cycles();
            let fork_join_maps = pm.fork_join_maps.take().unwrap();
            let reduce_cycles = pm.reduce_cycles.take().unwrap();
            for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection)
                .into_iter()
                .zip(fork_join_maps.iter())
                .zip(reduce_cycles.iter())
            {
                let Some(mut func) = func else {
                    continue;
                };
                fork_split(&mut func, fork_join_map, reduce_cycles);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::Forkify => {
            todo!("Forkify doesn't use editor")
        }
        Pass::GCM => {
            assert!(args.is_empty());
            if let Some(_) = selection {
                return Err(SchedulerError::PassError {
                    pass: "gcm".to_string(),
                    error: "must be applied to the entire module".to_string(),
                });
            }

            // Iterate functions in reverse topological order, since inter-
            // device copies introduced in a callee may affect demands in a
            // caller, and the object allocation of a callee affects the object
            // allocation of its callers.
            pm.make_callgraph();
            let callgraph = pm.callgraph.take().unwrap();
            let topo = callgraph.topo();
            loop {
                pm.make_def_uses();
                pm.make_reverse_postorders();
                pm.make_typing();
                pm.make_control_subgraphs();
                pm.make_doms();
                pm.make_fork_join_maps();
                pm.make_loops();
                pm.make_collection_objects();
                pm.make_devices();
                pm.make_object_device_demands();

                let def_uses = pm.def_uses.take().unwrap();
                let reverse_postorders = pm.reverse_postorders.take().unwrap();
                let typing = pm.typing.take().unwrap();
                let doms = pm.doms.take().unwrap();
                let fork_join_maps = pm.fork_join_maps.take().unwrap();
                let loops = pm.loops.take().unwrap();
                let control_subgraphs = pm.control_subgraphs.take().unwrap();
                let collection_objects = pm.collection_objects.take().unwrap();
                let devices = pm.devices.take().unwrap();
                let object_device_demands = pm.object_device_demands.take().unwrap();

                let mut bbs = vec![(vec![], vec![]); topo.len()];
                let mut node_colors = vec![BTreeMap::new(); topo.len()];
                let mut backing_allocations = BTreeMap::new();
                let mut editors = build_editors(pm);
                let mut any_failed = false;
                for id in topo.iter() {
                    let editor = &mut editors[id.idx()];
                    if let Some((bb, function_node_colors, backing_allocation)) = gcm(
                        editor,
                        &def_uses[id.idx()],
                        &reverse_postorders[id.idx()],
                        &typing[id.idx()],
                        &control_subgraphs[id.idx()],
                        &doms[id.idx()],
                        &fork_join_maps[id.idx()],
                        &loops[id.idx()],
                        &collection_objects,
                        &devices,
                        &object_device_demands[id.idx()],
                        &backing_allocations,
                    ) {
                        bbs[id.idx()] = bb;
                        node_colors[id.idx()] = function_node_colors;
                        backing_allocations.insert(*id, backing_allocation);
                    } else {
                        any_failed = true;
                    }
                    changed |= editor.modified();
                    if any_failed {
                        break;
                    }
                }
                pm.delete_gravestones();
                pm.clear_analyses();
                if !any_failed {
                    pm.bbs = Some(bbs);
                    pm.node_colors = Some(node_colors);
                    pm.backing_allocations = Some(backing_allocations);
                    break;
                }
            }
        }
        Pass::GVN => {
            assert!(args.is_empty());
            for func in build_selection(pm, selection) {
                let Some(mut func) = func else {
                    continue;
                };
                gvn(&mut func, false);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::InferSchedules => {
            assert!(args.is_empty());
            pm.make_fork_join_maps();
            pm.make_reduce_cycles();
            let fork_join_maps = pm.fork_join_maps.take().unwrap();
            let reduce_cycles = pm.reduce_cycles.take().unwrap();
            for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection)
                .into_iter()
                .zip(fork_join_maps.iter())
                .zip(reduce_cycles.iter())
            {
                let Some(mut func) = func else {
                    continue;
                };
                infer_parallel_reduce(&mut func, fork_join_map, reduce_cycles);
                infer_parallel_fork(&mut func, fork_join_map);
                infer_vectorizable(&mut func, fork_join_map);
                infer_tight_associative(&mut func, reduce_cycles);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::Inline => {
            assert!(args.is_empty());
            if let Some(_) = selection {
                return Err(SchedulerError::PassError {
                    pass: "inline".to_string(),
                    error: "must be applied to the entire module (currently)".to_string(),
                });
            }

            pm.make_callgraph();
            let callgraph = pm.callgraph.take().unwrap();

            let mut editors = build_editors(pm);
            inline(&mut editors, &callgraph);

            for func in editors {
                changed |= func.modified();
            }

            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::InterproceduralSROA => {
            assert!(args.is_empty());
            if let Some(_) = selection {
                return Err(SchedulerError::PassError {
                    pass: "interproceduralSROA".to_string(),
                    error: "must be applied to the entire module".to_string(),
                });
            }

            let mut editors = build_editors(pm);
            interprocedural_sroa(&mut editors);

            for func in editors {
                changed |= func.modified();
            }

            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::LiftDCMath => {
            assert!(args.is_empty());
            for func in build_selection(pm, selection) {
                let Some(mut func) = func else {
                    continue;
                };
                lift_dc_math(&mut func);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::Outline => {
            let Some((nodes, func)) = selection_as_set(pm, selection) else {
                return Err(SchedulerError::PassError {
                    pass: "outline".to_string(),
                    error: "must be applied to nodes in a single function".to_string(),
                });
            };

            pm.make_def_uses();
            let def_uses = pm.def_uses.take().unwrap();

            let mut editor = FunctionEditor::new(
                &mut pm.functions[func.idx()],
                func,
                &pm.constants,
                &pm.dynamic_constants,
                &pm.types,
                &pm.labels,
                &def_uses[func.idx()],
            );

            collapse_returns(&mut editor);
            ensure_between_control_flow(&mut editor);
            pm.clear_analyses();

            pm.make_def_uses();
            pm.make_typing();
            pm.make_control_subgraphs();
            pm.make_doms();

            let def_uses = pm.def_uses.take().unwrap();
            let typing = pm.typing.take().unwrap();
            let control_subgraphs = pm.control_subgraphs.take().unwrap();
            let doms = pm.doms.take().unwrap();
            let new_func_id = FunctionID::new(pm.functions.len());

            let mut editor = FunctionEditor::new(
                &mut pm.functions[func.idx()],
                func,
                &pm.constants,
                &pm.dynamic_constants,
                &pm.types,
                &pm.labels,
                &def_uses[func.idx()],
            );

            let new_func = outline(
                &mut editor,
                &typing[func.idx()],
                &control_subgraphs[func.idx()],
                &doms[func.idx()],
                &nodes,
                new_func_id,
            );
            let Some(new_func) = new_func else {
                return Err(SchedulerError::PassError {
                    pass: "outlining".to_string(),
                    error: "failed to outline".to_string(),
                });
            };

            pm.functions.push(new_func);
            changed = true;
            pm.functions[func.idx()].delete_gravestones();
            pm.clear_analyses();

            result = Value::HerculesFunction { func: new_func_id };
        }
        Pass::PhiElim => {
            assert!(args.is_empty());
            for func in build_selection(pm, selection) {
                let Some(mut func) = func else {
                    continue;
                };
                phi_elim(&mut func);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::Predication => {
            assert!(args.is_empty());
            pm.make_typing();
            let typing = pm.typing.take().unwrap();

            for (func, types) in build_selection(pm, selection)
                .into_iter()
                .zip(typing.iter())
            {
                let Some(mut func) = func else {
                    continue;
                };
                predication(&mut func, types);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::SLF => {
            assert!(args.is_empty());
            pm.make_reverse_postorders();
            pm.make_typing();
            let reverse_postorders = pm.reverse_postorders.take().unwrap();
            let typing = pm.typing.take().unwrap();

            for ((func, reverse_postorder), types) in build_selection(pm, selection)
                .into_iter()
                .zip(reverse_postorders.iter())
                .zip(typing.iter())
            {
                let Some(mut func) = func else {
                    continue;
                };
                slf(&mut func, reverse_postorder, types);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::SROA => {
            assert!(args.is_empty());
            pm.make_reverse_postorders();
            pm.make_typing();
            let reverse_postorders = pm.reverse_postorders.take().unwrap();
            let typing = pm.typing.take().unwrap();

            for ((func, reverse_postorder), types) in build_selection(pm, selection)
                .into_iter()
                .zip(reverse_postorders.iter())
                .zip(typing.iter())
            {
                let Some(mut func) = func else {
                    continue;
                };
                sroa(&mut func, reverse_postorder, types);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::Unforkify => {
            assert!(args.is_empty());
            pm.make_fork_join_maps();
            let fork_join_maps = pm.fork_join_maps.take().unwrap();

            for (func, fork_join_map) in build_selection(pm, selection)
                .into_iter()
                .zip(fork_join_maps.iter())
            {
                let Some(mut func) = func else {
                    continue;
                };
                unforkify(&mut func, fork_join_map);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::WritePredication => {
            assert!(args.is_empty());
            for func in build_selection(pm, selection) {
                let Some(mut func) = func else {
                    continue;
                };
                write_predication(&mut func);
                changed |= func.modified();
            }
            pm.delete_gravestones();
            pm.clear_analyses();
        }
        Pass::Verify => {
            assert!(args.is_empty());
            let (def_uses, reverse_postorders, typing, subgraphs, doms, postdoms, fork_join_maps) =
                pm.with_mod(|module| verify(module))
                    .map_err(|msg| SchedulerError::PassError {
                        pass: "verify".to_string(),
                        error: format!("failed: {}", msg),
                    })?;

            // Verification produces a bunch of analysis results that
            // may be useful for later passes.
            pm.def_uses = Some(def_uses);
            pm.reverse_postorders = Some(reverse_postorders);
            pm.typing = Some(typing);
            pm.control_subgraphs = Some(subgraphs);
            pm.doms = Some(doms);
            pm.postdoms = Some(postdoms);
            pm.fork_join_maps = Some(fork_join_maps);
        }
        Pass::Xdot => {
            let force_analyses = match args.get(0) {
                Some(Value::Boolean { val }) => *val,
                Some(_) => {
                    return Err(SchedulerError::PassError {
                        pass: "xdot".to_string(),
                        error: "expected boolean argument".to_string(),
                    });
                }
                None => true,
            };

            pm.make_reverse_postorders();
            if force_analyses {
                pm.make_doms();
                pm.make_fork_join_maps();
                pm.make_devices();
            }

            let reverse_postorders = pm.reverse_postorders.take().unwrap();
            let doms = pm.doms.take();
            let fork_join_maps = pm.fork_join_maps.take();
            let devices = pm.devices.take();
            let bbs = pm.bbs.take();
            pm.with_mod(|module| {
                xdot_module(
                    module,
                    &reverse_postorders,
                    doms.as_ref(),
                    fork_join_maps.as_ref(),
                    devices.as_ref(),
                    bbs.as_ref(),
                )
            });

            // Put BasicBlocks back, since it's needed for Codegen.
            pm.bbs = bbs;
        }
    }
    println!("Ran Pass: {:?}", pass);

    Ok((result, changed))
}