use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::{Error, Write};
use std::iter::zip;

use hercules_ir::*;

use crate::*;

/*
 * Entry Hercules functions are lowered to async Rust code to achieve easy task
 * level parallelism. This Rust is generated textually, and is included via a
 * procedural macro in the user's Rust code.
 *
 * Generating Rust that properly handles memory across devices is tricky. In
 * particular, the following elements are challenges:
 *
 * 1. RT functions may return objects that were not first passed in as
 *    parameters, via object constants.
 * 2. We want to allocate as much memory upfront as possible. Our goal is for
 *    each call to a RT function from host Rust code, there is at most one
 *    memory allocation per device.
 * 3. We want to statically determine when cross-device communication is
 *    necessary - this should be a separate concern from allocation.
 * 4. At the boundary between host Rust / RT functions, we want to encode
 *    lifetime rules so that the Hercules API is memory safe.
 * 5. We want to support efficient composition of Hercules RT functions in both
 *    synchronous and asynchronous contexts.
 *
 * Challenges #1 and #2 require that RT functions themselves do not allocate
 * memory. Instead, for every entry point function, a "runner" type will be
 * generated. The host Rust code must instantiate a runner object to invoke an
 * entry point function. This runner object contains a "backing" memory that is
 * the single allocation of memories for this function. The runner object can be
 * used to call the same entry point multiple times, and to the extent possible
 * the backing memory will be re-used. The size of the backing memory depends on
 * the dynamic constants passed in to the entry point, so it's lazily allocated
 * on calls to the entry point to the needed size. To address challenge #4, any
 * returned objects will be lifetime-bound to the runner object instance,
 * ensuring that the reference cannot be used after the runner object has de-
 * allocated the backing memory. This also ensures that the runner can't be run
 * again while a returned object from a previous iteration is still live, since
 * the entry point method requires an exclusive reference to the runner.
 *
 * Addressing challenge #3 requires we determine what objects are live on what
 * devices at what times. This can be done fairly easily by coloring nodes by
 * what device they produce their result on and inserting inter-device transfers
 * along edges connecting nodes of different colors. Nodes can only have a
 * single color - this is enforced by the GCM pass.
 *
 * Addressing challenge #5 requires runner objects for entry points accept and
 * return objects that are not in their own backing memory and potentially on
 * any device. For this reason, parameter and return nodes are not necessarily
 * CPU colored. Instead, runners take and return Hercules reference objects that
 * refer to memory on some device which have unknown origin. Hercules reference
 * objects have a lifetime parameter, and when a runner may return a Hercules
 * reference that refers to its backing memory, the lifetime of the Hercules
 * reference is the same as the lifetime of the mutable reference of the runner
 * used in the entry point signature. In other words, the RT backend infers the
 * proper lifetime bounds on parameter and returned Hercules reference objects
 * in relation to the runner's self reference using the collection objects
 * analysis. There are the following kinds of Hercules reference objects:
 *
 * - HerculesCPURef
 * - HerculesCPURefMut
 * - HerculesCUDARef
 * - HerculesCUDARefMut
 *
 * Essentially, there are types for each device, one for immutable references
 * and one for exclusive references. Mutable references can decay into immutable
 * references, and immutable references can be cloned. The CPU reference types
 * can be created from normal Rust references. The CUDA reference types can't be
 * created from normal Rust references - for that purpose, an additional type is
 * given, CUDABox, which essentially allows the user to manually allocate and
 * set some CUDA memory - the user can then take a CUDA reference to that box.
 */
pub fn rt_codegen<W: Write>(
    module_name: &str,
    func_id: FunctionID,
    module: &Module,
    def_use: &ImmutableDefUseMap,
    typing: &Vec<TypeID>,
    control_subgraph: &Subgraph,
    fork_join_map: &HashMap<NodeID, NodeID>,
    fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
    fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
    nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    collection_objects: &CollectionObjects,
    callgraph: &CallGraph,
    devices: &Vec<Device>,
    bbs: &BasicBlocks,
    node_colors: &FunctionNodeColors,
    backing_allocations: &BackingAllocations,
    w: &mut W,
) -> Result<(), Error> {
    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
        .into_iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();
    let ctx = RTContext {
        module_name,
        func_id,
        module,
        def_use,
        typing,
        control_subgraph,
        fork_join_map,
        join_fork_map: &join_fork_map,
        fork_join_nest,
        fork_tree,
        nodes_in_fork_joins,
        collection_objects,
        callgraph,
        devices,
        bbs,
        node_colors,
        backing_allocations,
    };
    ctx.codegen_function(w)
}

struct RTContext<'a> {
    module_name: &'a str,
    func_id: FunctionID,
    module: &'a Module,
    def_use: &'a ImmutableDefUseMap,
    typing: &'a Vec<TypeID>,
    control_subgraph: &'a Subgraph,
    fork_join_map: &'a HashMap<NodeID, NodeID>,
    join_fork_map: &'a HashMap<NodeID, NodeID>,
    fork_join_nest: &'a HashMap<NodeID, Vec<NodeID>>,
    fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>,
    nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
    collection_objects: &'a CollectionObjects,
    callgraph: &'a CallGraph,
    devices: &'a Vec<Device>,
    bbs: &'a BasicBlocks,
    node_colors: &'a FunctionNodeColors,
    backing_allocations: &'a BackingAllocations,
}

#[derive(Debug, Clone, Default)]
struct RustBlock {
    prologue: String,
    data: String,
    phi_tmp_assignments: String,
    phi_assignments: String,
    epilogue: String,
    join_epilogue: String,
}

impl<'a> RTContext<'a> {
    fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
        // If this is an entry function, generate a corresponding runner object
        // type definition.
        let func = &self.get_func();
        if func.entry {
            self.codegen_runner_object(w)?;
        }

        // Dump the function signature.
        write!(
            w,
            "#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]async unsafe fn {}_{}(",
            self.module_name,
            func.name
        )?;
        let mut first_param = true;
        // The first set of parameters are pointers to backing memories.
        for (device, _) in self.backing_allocations[&self.func_id].iter() {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
            }
            write!(
                w,
                "backing_{}: ::hercules_rt::__RawPtrSendSync",
                device.name()
            )?;
        }
        // The second set of parameters are dynamic constants.
        for idx in 0..func.num_dynamic_constants {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
            }
            write!(w, "dc_p{}: u64", idx)?;
        }
        // The third set of parameters are normal parameters.
        for idx in 0..func.param_types.len() {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
            }
            write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?;
        }
        write!(w, ") -> ")?;
        self.write_rust_return_type(w, &func.return_types)?;
        write!(w, " {{")?;

        // Dump signatures for called device functions.
        // For single-return functions we directly expose the device function
        // while for multi-return functions we generate a wrapper which handles
        // allocation of the return struct and extracting values from it. This
        // ensures that device function signatures match what they would be in
        // AsyncRust
        for callee_id in self.callgraph.get_callees(self.func_id) {
            if self.devices[callee_id.idx()] == Device::AsyncRust {
                continue;
            }
            let callee = &self.module.functions[callee_id.idx()];
            let is_single_return = callee.return_types.len() == 1;
            if is_single_return {
                write!(w, "extern \"C\" {{")?;
            }
            self.write_device_signature_async(w, *callee_id, !is_single_return)?;
            if is_single_return {
                write!(w, ";}}")?;
            } else {
                // Generate the wrapper function for multi-return device functions
                write!(w, " {{ ")?;
                // Define the return struct
                write!(
                    w,
                    "#[repr(C)] struct ReturnStruct {{ {} }} ",
                    callee
                        .return_types
                        .iter()
                        .enumerate()
                        .map(|(idx, t)| format!("f{}: {}", idx, self.get_type(*t)))
                        .collect::<Vec<_>>()
                        .join(", "),
                )?;
                // Declare the extern function's signature
                write!(w, "extern \"C\" {{ ")?;
                self.write_device_signature(w, *callee_id)?;
                write!(w, "; }}")?;
                // Create the return struct
                write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?;
                // Call the device function
                write!(w, "{}_{}(", self.module_name, callee.name)?;
                if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()])
                {
                    write!(w, "backing, ")?;
                }
                for idx in 0..callee.num_dynamic_constants {
                    write!(w, "dc{}, ", idx)?;
                }
                for idx in 0..callee.param_types.len() {
                    write!(w, "p{}, ", idx)?;
                }
                write!(w, "ret_struct.as_mut_ptr());")?;
                // Extract the result into a Rust product
                write!(w, "let ret_struct = ret_struct.assume_init();")?;
                write!(
                    w,
                    "({})",
                    (0..callee.return_types.len())
                        .map(|idx| format!("ret_struct.f{}", idx))
                        .collect::<Vec<_>>()
                        .join(", "),
                )?;
                write!(w, "}}")?;
            }
        }

        // Set up the root environment for the function. An environment is set
        // up for every created task in async closures, and there needs to be a
        // root environment corresponding to the root control node (start node).
        self.codegen_open_environment(NodeID::new(0), w)?;

        let mut blocks: BTreeMap<_, _> = (0..func.nodes.len())
            .filter(|idx| func.nodes[*idx].is_control())
            .map(|idx| (NodeID::new(idx), RustBlock::default()))
            .collect();

        // Emit data flow into basic blocks.
        for block in self.bbs.1.iter() {
            for id in block {
                self.codegen_data_node(*id, &mut blocks)?;
            }
        }

        // Emit control flow into basic blocks.
        let rev_po = self.control_subgraph.rev_po(NodeID::new(0));
        for id in rev_po.iter() {
            self.codegen_control_node(*id, &mut blocks)?;
        }

        // Dump the emitted basic blocks. Do this in reverse postorder since
        // fork and join nodes open and close environments, respectively.
        for id in rev_po.iter() {
            let block = &blocks[id];
            if func.nodes[id.idx()].is_join() {
                write!(
                    w,
                    "{}{}{}{}{}{}",
                    block.prologue,
                    block.data,
                    block.epilogue,
                    block.phi_tmp_assignments,
                    block.phi_assignments,
                    block.join_epilogue
                )?;
            } else {
                write!(
                    w,
                    "{}{}{}{}{}",
                    block.prologue,
                    block.data,
                    block.phi_tmp_assignments,
                    block.phi_assignments,
                    block.epilogue
                )?;
            }
        }

        // Close the root environment.
        self.codegen_close_environment(w)?;
        write!(w, "}}")?;
        Ok(())
    }

    /*
     * While control nodes in Hercules IR are predecessor-centric (each take a
     * control input that defines the predecessor relationship), the Rust loop
     * we generate is successor centric. This difference requires explicit
     * translation.
     */
    fn codegen_control_node(
        &self,
        id: NodeID,
        blocks: &mut BTreeMap<NodeID, RustBlock>,
    ) -> Result<(), Error> {
        let func = &self.get_func();
        match func.nodes[id.idx()] {
            // Start, region, and projection control nodes all have exactly one
            // successor and are otherwise simple.
            Node::Start
            | Node::Region { preds: _ }
            | Node::ControlProjection {
                control: _,
                selection: _,
            } => {
                let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                write!(prologue, "{} => {{", id.idx())?;
                let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
                let succ = self.control_subgraph.succs(id).next().unwrap();
                write!(epilogue, "control_token = {};}}", succ.idx())?;
            }
            // If nodes have two successors - examine the projections to
            // determine which branch is which, and branch between them.
            Node::If { control: _, cond } => {
                let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                write!(prologue, "{} => {{", id.idx())?;
                let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
                let mut succs = self.control_subgraph.succs(id);
                let succ1 = succs.next().unwrap();
                let succ2 = succs.next().unwrap();
                let succ1_is_true = func.nodes[succ1.idx()].try_control_projection(1).is_some();
                write!(
                    epilogue,
                    "control_token = if {} {{{}}} else {{{}}};}}",
                    self.get_value(cond, id, false),
                    if succ1_is_true { succ1 } else { succ2 }.idx(),
                    if succ1_is_true { succ2 } else { succ1 }.idx(),
                )?;
            }
            Node::Return {
                control: _,
                ref data,
            } => {
                let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                write!(prologue, "{} => {{", id.idx())?;
                let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
                if data.len() == 1 {
                    write!(epilogue, "return {};}}", self.get_value(data[0], id, false))?;
                } else {
                    write!(
                        epilogue,
                        "return ({});}}",
                        data.iter()
                            .map(|v| self.get_value(*v, id, false))
                            .collect::<Vec<_>>()
                            .join(", "),
                    )?;
                }
            }
            // Fork nodes open a new environment for defining an async closure.
            Node::Fork {
                control: _,
                ref factors,
            } => {
                assert!(func.schedules[id.idx()].contains(&Schedule::ParallelFork));

                // Set the outer environment control token to the join.
                let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                let join = self.fork_join_map[&id];
                write!(
                    prologue,
                    "{} => {{control_token = {};",
                    id.idx(),
                    join.idx()
                )?;

                // Emit loops for the thread IDs.
                for (idx, factor) in factors.into_iter().enumerate() {
                    write!(prologue, "for tid_{}_{} in 0..", id.idx(), idx)?;
                    self.codegen_dynamic_constant(*factor, prologue)?;
                    write!(prologue, " {{")?;
                }

                // Emit clones of arcs used inside the fork-join.
                for other_id in (0..func.nodes.len()).map(NodeID::new) {
                    if self.def_use.get_users(other_id).into_iter().any(|user_id| {
                        self.nodes_in_fork_joins[&id].contains(&self.bbs.0[user_id.idx()])
                    }) && let Some(arc) = self.clone_arc(other_id)
                    {
                        write!(prologue, "{}", arc)?;
                    }
                }

                // Spawn an async closure and push its future to a Vec.
                write!(
                    prologue,
                    "fork_{}.push(::async_std::task::spawn(async move {{",
                    id.idx()
                )?;

                // Open a new environment.
                self.codegen_open_environment(id, prologue)?;

                // Open the branch inside the async closure for the fork.
                let succ = self.control_subgraph.succs(id).next().unwrap();
                write!(
                    prologue,
                    "{} => {{control_token = {};",
                    id.idx(),
                    succ.idx()
                )?;

                // Close the branch for the fork inside the async closure.
                let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
                write!(epilogue, "}}")?;
            }
            // Join nodes close the environment opened by its corresponding
            // fork.
            Node::Join { control: _ } => {
                // Emit the branch for the join inside the async closure.
                let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                write!(prologue, "{} => {{", id.idx())?;

                // Close the branch inside the async closure.
                let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
                write!(
                    epilogue,
                    "::std::sync::atomic::fence(::std::sync::atomic::Ordering::Release);return;}}"
                )?;

                // Close the fork's environment.
                self.codegen_close_environment(epilogue)?;

                // Close the async closure and push statement from the
                // fork.
                write!(epilogue, "}}));")?;

                // Close the loops emitted by the fork node.
                let fork = self.join_fork_map[&id];
                for _ in 0..func.nodes[fork.idx()].try_fork().unwrap().1.len() {
                    write!(epilogue, "}}")?;
                }

                // Close the branch for the fork outside the async closure.
                write!(epilogue, "}}")?;

                // Open the branch in the surrounding context for the join.
                let succ = self.control_subgraph.succs(id).next().unwrap();
                write!(epilogue, "{} => {{", id.idx())?;

                // Await the empty futures for the fork-joins, waiting for them
                // to complete.
                write!(
                    epilogue,
                    "for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire);",
                    fork.idx(),
                )?;

                // Emit the assignments to the reduce variables in the
                // surrounding context. It's very unfortunate that we have to do
                // it while lowering the join node (rather than the reduce nodes
                // themselves), but this is the only place we can put these
                // assignments in the correct control location.
                for user in self.def_use.get_users(id) {
                    if let Some((_, init, _)) = func.nodes[user.idx()].try_reduce() {
                        write!(
                            epilogue,
                            "{} = {};",
                            self.get_value(*user, id, true),
                            self.get_value(init, id, false)
                        )?;
                    }
                }

                let join_epilogue = &mut blocks.get_mut(&id).unwrap().join_epilogue;
                // Branch to the successor control node in the surrounding
                // context, and close the branch for the join.
                write!(join_epilogue, "control_token = {};}}", succ.idx())?;
            }
            _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
        }
        Ok(())
    }

    /*
     * Lower data nodes in Hercules IR into Rust statements.
     */
    fn codegen_data_node(
        &self,
        id: NodeID,
        blocks: &mut BTreeMap<NodeID, RustBlock>,
    ) -> Result<(), Error> {
        let func = &self.get_func();
        let bb = self.bbs.0[id.idx()];
        match func.nodes[id.idx()] {
            Node::Parameter { index } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                write!(block, "{} = p{};", self.get_value(id, bb, true), index)?
            }
            Node::Constant { id: cons_id } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                write!(block, "{} = ", self.get_value(id, bb, true))?;
                let mut size_and_device = None;
                match self.module.constants[cons_id.idx()] {
                    Constant::Boolean(val) => write!(block, "{}", val)?,
                    Constant::Integer8(val) => write!(block, "{}i8", val)?,
                    Constant::Integer16(val) => write!(block, "{}i16", val)?,
                    Constant::Integer32(val) => write!(block, "{}i32", val)?,
                    Constant::Integer64(val) => write!(block, "{}i64", val)?,
                    Constant::UnsignedInteger8(val) => write!(block, "{}u8", val)?,
                    Constant::UnsignedInteger16(val) => write!(block, "{}u16", val)?,
                    Constant::UnsignedInteger32(val) => write!(block, "{}u32", val)?,
                    Constant::UnsignedInteger64(val) => write!(block, "{}u64", val)?,
                    Constant::Float32(val) => {
                        if val == f32::INFINITY {
                            write!(block, "f32::INFINITY")?
                        } else if val == f32::NEG_INFINITY {
                            write!(block, "f32::NEG_INFINITY")?
                        } else {
                            write!(block, "{}f32", val)?
                        }
                    }
                    Constant::Float64(val) => {
                        if val == f64::INFINITY {
                            write!(block, "f64::INFINITY")?
                        } else if val == f64::NEG_INFINITY {
                            write!(block, "f64::NEG_INFINITY")?
                        } else {
                            write!(block, "{}f64", val)?
                        }
                    }
                    Constant::Product(ty, _)
                    | Constant::Summation(ty, _, _)
                    | Constant::Array(ty) => {
                        let (device, (offset, _)) = self.backing_allocations[&self.func_id]
                            .iter()
                            .filter_map(|(device, (_, offsets))| {
                                offsets.get(&id).map(|id| (*device, *id))
                            })
                            .next()
                            .unwrap();
                        write!(block, "backing_{}.byte_add(", device.name())?;
                        self.codegen_dynamic_constant(offset, block)?;
                        write!(block, " as usize)")?;
                        size_and_device = Some((self.codegen_type_size(ty), device));
                    }
                }
                write!(block, ";")?;
                if !func.schedules[id.idx()].contains(&Schedule::NoResetConstant) {
                    if let Some((size, device)) = size_and_device {
                        write!(
                            block,
                            "::hercules_rt::__{}_zero_mem({}.0, {} as usize);",
                            device.name(),
                            self.get_value(id, bb, false),
                            size
                        )?;
                    }
                }
            }
            Node::DynamicConstant { id: dc_id } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                write!(block, "{} = ", self.get_value(id, bb, true))?;
                self.codegen_dynamic_constant(dc_id, block)?;
                write!(block, ";")?;
            }
            Node::ThreadID { control, dimension } => {
                assert_eq!(control, bb);
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                write!(
                    block,
                    "{} = tid_{}_{};",
                    self.get_value(id, bb, true),
                    bb.idx(),
                    dimension
                )?;
            }
            Node::Phi { control, ref data } => {
                assert_eq!(control, bb);
                // Phis aren't executable in their own basic block - predecessor
                // blocks assign the to-be phi values themselves. Assign
                // temporary values first before assigning the phi itself, since
                // there may be simultaneous inter-dependent phis.
                for (data, pred) in zip(data.into_iter(), self.control_subgraph.preds(bb)) {
                    let block = &mut blocks.get_mut(&pred).unwrap().phi_tmp_assignments;
                    write!(
                        block,
                        "let {}_tmp = {};",
                        self.get_value(id, pred, true),
                        self.get_value(*data, pred, false),
                    )?;
                    let block = &mut blocks.get_mut(&pred).unwrap().phi_assignments;
                    write!(
                        block,
                        "{} = {}_tmp;",
                        self.get_value(id, pred, true),
                        self.get_value(id, pred, false),
                    )?;
                }
            }
            Node::Reduce {
                control: _,
                init: _,
                reduct: _,
            } => {
                assert!(func.schedules[id.idx()].contains(&Schedule::ParallelReduce));
            }
            Node::Call {
                control,
                function: callee_id,
                ref dynamic_constants,
                ref args,
            } => {
                assert_eq!(control, bb);
                // The device backends and the wrappers we generated earlier ensure that device
                // functions have the same interface as AsyncRust functions.
                let block = &mut blocks.get_mut(&bb).unwrap();
                let block = &mut block.data;
                let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
                if is_async {
                    for arg in args {
                        if let Some(arc) = self.clone_arc(*arg) {
                            write!(block, "{}", arc)?;
                        }
                    }
                }
                let device = self.devices[callee_id.idx()];
                let prefix = match (device, is_async) {
                    (Device::AsyncRust, false) | (_, false) => {
                        format!("{} = ", self.get_value(id, bb, true))
                    }
                    (_, true) => {
                        write!(block, "{}", self.clone_arc(id).unwrap())?;
                        format!(
                            "*async_call_{}.lock().await = ::hercules_rt::__FutureSlotWrapper::new(::async_std::task::spawn(async move {{ ",
                            id.idx(),
                        )
                    }
                };
                let postfix = match (device, is_async) {
                    (Device::AsyncRust, false) => ".await",
                    (_, false) => "",
                    (Device::AsyncRust, true) => ".await}))",
                    (_, true) => "}))",
                };
                write!(
                    block,
                    "{}{}_{}(",
                    prefix,
                    self.module_name,
                    self.module.functions[callee_id.idx()].name
                )?;
                for (device, (offset, size)) in self.backing_allocations[&self.func_id]
                    .iter()
                    .filter_map(|(device, (_, offsets))| offsets.get(&id).map(|id| (*device, *id)))
                {
                    write!(block, "backing_{}.byte_add(((", device.name())?;
                    self.codegen_dynamic_constant(offset, block)?;
                    let forks = &self.fork_join_nest[&bb];
                    if !forks.is_empty() {
                        write!(block, ") + ")?;
                        let mut linear_thread = "0".to_string();
                        for fork in forks {
                            let factors = func.nodes[fork.idx()].try_fork().unwrap().1;
                            for (factor_idx, factor) in factors.into_iter().enumerate() {
                                linear_thread = format!("({} *", linear_thread);
                                self.codegen_dynamic_constant(*factor, &mut linear_thread)?;
                                write!(linear_thread, " + tid_{}_{})", fork.idx(), factor_idx)?;
                            }
                        }
                        write!(block, "{} * (", linear_thread)?;
                        self.codegen_dynamic_constant(size, block)?;
                    }
                    write!(block, ")) as usize), ")?
                }
                for dc in dynamic_constants {
                    self.codegen_dynamic_constant(*dc, block)?;
                    write!(block, ", ")?;
                }
                for arg in args {
                    write!(block, "{}, ", self.get_value(*arg, bb, false))?;
                }
                write!(block, "){};", postfix)?;
            }
            Node::DataProjection { data, selection } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                let Node::Call {
                    function: callee_id,
                    ..
                } = func.nodes[data.idx()]
                else {
                    panic!()
                };
                if self.module.functions[callee_id.idx()].return_types.len() == 1 {
                    assert!(selection == 0);
                    write!(
                        block,
                        "{} = {};",
                        self.get_value(id, bb, true),
                        self.get_value(data, bb, false),
                    )?;
                } else {
                    write!(
                        block,
                        "{} = {}.{};",
                        self.get_value(id, bb, true),
                        self.get_value(data, bb, false),
                        selection,
                    )?;
                }
            }
            Node::IntrinsicCall {
                intrinsic,
                ref args,
            } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                write!(
                    block,
                    "{} = {}::{}(",
                    self.get_value(id, bb, true),
                    self.get_type(self.typing[id.idx()]),
                    intrinsic.lower_case_name(),
                )?;
                for arg in args {
                    write!(block, "{}, ", self.get_value(*arg, bb, false))?;
                }
                write!(block, ");")?;
            }
            Node::LibraryCall {
                library_function,
                ref args,
                ty,
                device,
            } => match library_function {
                LibraryFunction::GEMM => {
                    assert_eq!(args.len(), 3);
                    assert_eq!(self.typing[args[0].idx()], ty);
                    let c_ty = &self.module.types[self.typing[args[0].idx()].idx()];
                    let a_ty = &self.module.types[self.typing[args[1].idx()].idx()];
                    let b_ty = &self.module.types[self.typing[args[2].idx()].idx()];
                    let (
                        Type::Array(c_elem, c_dims),
                        Type::Array(a_elem, a_dims),
                        Type::Array(b_elem, b_dims),
                    ) = (c_ty, a_ty, b_ty)
                    else {
                        panic!();
                    };
                    assert_eq!(a_elem, b_elem);
                    assert_eq!(a_elem, c_elem);
                    assert_eq!(c_dims.len(), 2);
                    assert_eq!(a_dims.len(), 2);
                    assert_eq!(b_dims.len(), 2);
                    assert_eq!(a_dims[1], b_dims[0]);
                    assert_eq!(a_dims[0], c_dims[0]);
                    assert_eq!(b_dims[1], c_dims[1]);

                    let block = &mut blocks.get_mut(&bb).unwrap().data;
                    let prim_ty = self.library_prim_ty(*a_elem);
                    write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?;
                    self.codegen_dynamic_constant(a_dims[0], block)?;
                    write!(block, ", ")?;
                    self.codegen_dynamic_constant(a_dims[1], block)?;
                    write!(block, ", ")?;
                    self.codegen_dynamic_constant(b_dims[1], block)?;
                    write!(
                        block,
                        ", {}.0, {}.0, {}.0, {});",
                        self.get_value(args[0], bb, false),
                        self.get_value(args[1], bb, false),
                        self.get_value(args[2], bb, false),
                        prim_ty
                    )?;
                    write!(
                        block,
                        "{} = {};",
                        self.get_value(id, bb, true),
                        self.get_value(args[0], bb, false)
                    )?;
                }
            },
            Node::Unary { op, input } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                match op {
                    UnaryOperator::Not => write!(
                        block,
                        "{} = !{};",
                        self.get_value(id, bb, true),
                        self.get_value(input, bb, false)
                    )?,
                    UnaryOperator::Neg => write!(
                        block,
                        "{} = -{};",
                        self.get_value(id, bb, true),
                        self.get_value(input, bb, false)
                    )?,
                    UnaryOperator::Cast(ty) => write!(
                        block,
                        "{} = {} as {};",
                        self.get_value(id, bb, true),
                        self.get_value(input, bb, false),
                        self.get_type(ty)
                    )?,
                };
            }
            Node::Binary { op, left, right } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                let op = match op {
                    BinaryOperator::Add => "+",
                    BinaryOperator::Sub => "-",
                    BinaryOperator::Mul => "*",
                    BinaryOperator::Div => "/",
                    BinaryOperator::Rem => "%",
                    BinaryOperator::LT => "<",
                    BinaryOperator::LTE => "<=",
                    BinaryOperator::GT => ">",
                    BinaryOperator::GTE => ">=",
                    BinaryOperator::EQ => "==",
                    BinaryOperator::NE => "!=",
                    BinaryOperator::Or => "|",
                    BinaryOperator::And => "&",
                    BinaryOperator::Xor => "^",
                    BinaryOperator::LSh => "<<",
                    BinaryOperator::RSh => ">>",
                };

                write!(
                    block,
                    "{} = {} {} {};",
                    self.get_value(id, bb, true),
                    self.get_value(left, bb, false),
                    op,
                    self.get_value(right, bb, false)
                )?;
            }
            Node::Ternary {
                op,
                first,
                second,
                third,
            } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                match op {
                    TernaryOperator::Select => write!(
                        block,
                        "{} = if {} {{{}}} else {{{}}};",
                        self.get_value(id, bb, true),
                        self.get_value(first, bb, false),
                        self.get_value(second, bb, false),
                        self.get_value(third, bb, false),
                    )?,
                };
            }
            Node::Read {
                collect,
                ref indices,
            } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                let collect_ty = self.typing[collect.idx()];
                let self_ty = self.typing[id.idx()];
                let offset = self.codegen_index_math(collect_ty, indices, bb)?;
                if self.module.types[self_ty.idx()].is_primitive() {
                    write!(
                        block,
                        "{} = ({}.byte_add({} as usize).0 as *mut {}).read();",
                        self.get_value(id, bb, true),
                        self.get_value(collect, bb, false),
                        offset,
                        self.get_type(self_ty)
                    )?;
                } else {
                    write!(
                        block,
                        "{} = {}.byte_add({} as usize);",
                        self.get_value(id, bb, true),
                        self.get_value(collect, bb, false),
                        offset,
                    )?;
                }
            }
            Node::Write {
                collect,
                data,
                ref indices,
            } => {
                let block = &mut blocks.get_mut(&bb).unwrap().data;
                let collect_ty = self.typing[collect.idx()];
                let data_ty = self.typing[data.idx()];
                let data_size = self.codegen_type_size(data_ty);
                let offset = self.codegen_index_math(collect_ty, indices, bb)?;
                if self.module.types[data_ty.idx()].is_primitive() {
                    write!(
                        block,
                        "({}.byte_add({} as usize).0 as *mut {}).write({});",
                        self.get_value(collect, bb, false),
                        offset,
                        self.get_type(data_ty),
                        self.get_value(data, bb, false),
                    )?;
                } else {
                    // If the data item being written is not a primitive type,
                    // then perform a memcpy from the data collection to the
                    // destination collection. Look at the colors of the
                    // `collect` and `data` inputs, since this may be an inter-
                    // device copy.
                    let src_device = self.node_colors.0[&data];
                    let dst_device = self.node_colors.0[&collect];
                    write!(
                        block,
                        "::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {} as usize);",
                        src_device.name(),
                        dst_device.name(),
                        self.get_value(collect, bb, false),
                        offset,
                        self.get_value(data, bb, false),
                        data_size,
                    )?;
                }
                write!(
                    block,
                    "{} = {};",
                    self.get_value(id, bb, true),
                    self.get_value(collect, bb, false)
                )?;
            }
            _ => panic!(
                "PANIC: Can't lower {:?} in {}.",
                func.nodes[id.idx()],
                func.name
            ),
        }
        Ok(())
    }

    /*
     * Lower dynamic constant in Hercules IR into a Rust expression.
     */
    fn codegen_dynamic_constant<W: Write>(
        &self,
        id: DynamicConstantID,
        w: &mut W,
    ) -> Result<(), Error> {
        match &self.module.dynamic_constants[id.idx()] {
            DynamicConstant::Constant(val) => write!(w, "{}", val)?,
            DynamicConstant::Parameter(idx) => write!(w, "dc_p{}", idx)?,
            DynamicConstant::Add(xs) => {
                write!(w, "(")?;
                let mut xs = xs.iter();
                self.codegen_dynamic_constant(*xs.next().unwrap(), w)?;
                for x in xs {
                    write!(w, "+")?;
                    self.codegen_dynamic_constant(*x, w)?;
                }
                write!(w, ")")?;
            }
            DynamicConstant::Sub(left, right) => {
                write!(w, "(")?;
                self.codegen_dynamic_constant(*left, w)?;
                write!(w, "-")?;
                self.codegen_dynamic_constant(*right, w)?;
                write!(w, ")")?;
            }
            DynamicConstant::Mul(xs) => {
                write!(w, "(")?;
                let mut xs = xs.iter();
                self.codegen_dynamic_constant(*xs.next().unwrap(), w)?;
                for x in xs {
                    write!(w, "*")?;
                    self.codegen_dynamic_constant(*x, w)?;
                }
                write!(w, ")")?;
            }
            DynamicConstant::Div(left, right) => {
                write!(w, "(")?;
                self.codegen_dynamic_constant(*left, w)?;
                write!(w, "/")?;
                self.codegen_dynamic_constant(*right, w)?;
                write!(w, ")")?;
            }
            DynamicConstant::Rem(left, right) => {
                write!(w, "(")?;
                self.codegen_dynamic_constant(*left, w)?;
                write!(w, "%")?;
                self.codegen_dynamic_constant(*right, w)?;
                write!(w, ")")?;
            }
            DynamicConstant::Min(xs) => {
                let mut xs = xs.iter().peekable();

                // Track the number of parentheses we open that need to be closed later
                let mut opens = 0;
                while let Some(x) = xs.next() {
                    if xs.peek().is_none() {
                        // For the last element, we just print it
                        self.codegen_dynamic_constant(*x, w)?;
                    } else {
                        // Otherwise, we create a new call to min and print the element as the
                        // first argument
                        write!(w, "::core::cmp::min(")?;
                        self.codegen_dynamic_constant(*x, w)?;
                        write!(w, ",")?;
                        opens += 1;
                    }
                }
                for _ in 0..opens {
                    write!(w, ")")?;
                }
            }
            DynamicConstant::Max(xs) => {
                let mut xs = xs.iter().peekable();

                let mut opens = 0;
                while let Some(x) = xs.next() {
                    if xs.peek().is_none() {
                        self.codegen_dynamic_constant(*x, w)?;
                    } else {
                        write!(w, "::core::cmp::max(")?;
                        self.codegen_dynamic_constant(*x, w)?;
                        write!(w, ",")?;
                        opens += 1;
                    }
                }
                for _ in 0..opens {
                    write!(w, ")")?;
                }
            }
        }
        Ok(())
    }

    /*
     * Emit logic to index into an collection.
     */
    fn codegen_index_math(
        &self,
        mut collect_ty: TypeID,
        indices: &[Index],
        bb: NodeID,
    ) -> Result<String, Error> {
        let mut acc_offset = "0".to_string();
        for index in indices {
            match index {
                Index::Field(idx) => {
                    let Type::Product(ref fields) = self.module.types[collect_ty.idx()] else {
                        panic!()
                    };

                    // Get the offset of the field at index `idx` by calculating
                    // the product's size up to field `idx`, then offseting the
                    // base pointer by that amount.
                    for field in &fields[..*idx] {
                        let field_align = get_type_alignment(&self.module.types, *field);
                        let field = self.codegen_type_size(*field);
                        acc_offset = format!(
                            "((({} + {}) & !{}) + {})",
                            acc_offset,
                            field_align - 1,
                            field_align - 1,
                            field
                        );
                    }
                    let last_align = get_type_alignment(&self.module.types, fields[*idx]);
                    acc_offset = format!(
                        "(({} + {}) & !{})",
                        acc_offset,
                        last_align - 1,
                        last_align - 1
                    );
                    collect_ty = fields[*idx];
                }
                Index::Variant(idx) => {
                    // The tag of a summation is at the end of the summation, so
                    // the variant pointer is just the base pointer. Do nothing.
                    let Type::Summation(ref variants) = self.module.types[collect_ty.idx()] else {
                        panic!()
                    };
                    collect_ty = variants[*idx];
                }
                Index::Position(ref pos) => {
                    let Type::Array(elem, ref dims) = self.module.types[collect_ty.idx()] else {
                        panic!()
                    };

                    // The offset of the position into an array is:
                    //
                    //     ((0 * s1 + p1) * s2 + p2) * s3 + p3 ...
                    let elem_size = self.codegen_type_size(elem);
                    let elem_align = get_type_alignment(&self.module.types, elem);
                    let aligned_elem_size = format!(
                        "(({} + {}) & !{})",
                        elem_size,
                        elem_align - 1,
                        elem_align - 1
                    );
                    for (p, s) in zip(pos, dims) {
                        let p = self.get_value(*p, bb, false);
                        acc_offset = format!("{} * ", acc_offset);
                        self.codegen_dynamic_constant(*s, &mut acc_offset)?;
                        acc_offset = format!("({} + {})", acc_offset, p);
                    }

                    // Convert offset in # elements -> # bytes.
                    acc_offset = format!("({} * {})", acc_offset, aligned_elem_size);
                    collect_ty = elem;
                }
            }
        }
        Ok(acc_offset)
    }

    /*
     * Lower the size of a type into a Rust expression.
     */
    fn codegen_type_size(&self, ty: TypeID) -> String {
        match self.module.types[ty.idx()] {
            Type::Control | Type::MultiReturn(_) => panic!(),
            Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => {
                "1".to_string()
            }
            Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => "2".to_string(),
            Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => "4".to_string(),
            Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => "8".to_string(),
            Type::Product(ref fields) => {
                let fields_align = fields
                    .into_iter()
                    .map(|id| get_type_alignment(&self.module.types, *id));
                let fields: Vec<String> = fields
                    .into_iter()
                    .map(|id| self.codegen_type_size(*id))
                    .collect();

                // Emit LLVM IR to round up to the alignment of the next field,
                // and then add the size of that field. At the end, round up to
                // the alignment of the whole struct.
                let mut acc_size = "0".to_string();
                for (field_align, field) in zip(fields_align, fields) {
                    acc_size = format!(
                        "(({} + {}) & !{})",
                        acc_size,
                        field_align - 1,
                        field_align - 1
                    );
                    acc_size = format!("({} + {})", acc_size, field);
                }
                let total_align = get_type_alignment(&self.module.types, ty);
                format!(
                    "(({} + {}) & !{})",
                    acc_size,
                    total_align - 1,
                    total_align - 1
                )
            }
            Type::Summation(ref variants) => {
                let variants = variants.into_iter().map(|id| self.codegen_type_size(*id));

                // The size of a summation is the size of the largest field,
                // plus 1 byte and alignment for the discriminant.
                let mut acc_size = "0".to_string();
                for variant in variants {
                    acc_size = format!("::core::cmp::max({}, {})", acc_size, variant);
                }

                // No alignment is necessary before the 1 byte discriminant.
                let total_align = get_type_alignment(&self.module.types, ty);
                format!(
                    "(({} + 1 + {}) & !{})",
                    acc_size,
                    total_align - 1,
                    total_align - 1
                )
            }
            Type::Array(elem, ref bounds) => {
                // The size of an array is the size of the element multipled by
                // the dynamic constant bounds.
                let mut acc_size = self.codegen_type_size(elem);
                let elem_align = get_type_alignment(&self.module.types, elem);
                acc_size = format!(
                    "(({} + {}) & !{})",
                    acc_size,
                    elem_align - 1,
                    elem_align - 1
                );
                for dc in bounds {
                    acc_size = format!("{} * ", acc_size);
                    self.codegen_dynamic_constant(*dc, &mut acc_size).unwrap();
                }
                format!("({})", acc_size)
            }
        }
    }

    fn codegen_open_environment<W: Write>(&self, root: NodeID, w: &mut W) -> Result<(), Error> {
        let func = &self.get_func();

        // Declare intermediary variables for every value in this fork-join (or
        // whole function for start) that isn't in any child fork-joins.
        for idx in 0..func.nodes.len() {
            let id = NodeID::new(idx);
            let control = func.nodes[idx].is_control();
            let in_root = self.nodes_in_fork_joins[&root].contains(&id);
            let in_child = self.fork_tree[&root]
                .iter()
                .any(|child| self.nodes_in_fork_joins[&child].contains(&id));
            let is_reduce_on_child = func.nodes[idx]
                .try_reduce()
                .map(|(control, _, _)| {
                    self.fork_tree[&root].contains(&self.join_fork_map[&control])
                })
                .unwrap_or(false);
            if (control || !in_root || in_child) && !is_reduce_on_child {
                continue;
            }

            // If the node is a call with an AsyncCall schedule, it should be
            // lowered to a Arc<Mutex<>> over the future.
            let is_async_call =
                func.nodes[idx].is_call() && func.schedules[idx].contains(&Schedule::AsyncCall);
            if is_async_call {
                write!(
                    w,
                    "let mut async_call_{} = ::std::sync::Arc::new(::async_std::sync::Mutex::new(::hercules_rt::__FutureSlotWrapper::empty()));",
                    idx,
                )?;
            } else {
                write!(
                    w,
                    "let mut {}_{}: {} = {};",
                    if is_reduce_on_child { "reduce" } else { "node" },
                    idx,
                    self.get_type(self.typing[idx]),
                    self.get_default_value(self.typing[idx]),
                )?;
            }
        }

        // Declare Vecs for storing futures of fork-joins.
        for fork in self.fork_tree[&root].iter() {
            write!(w, "let mut fork_{} = vec![];", fork.idx())?;
        }

        // The core executor is a Rust loop. We literally run a "control token"
        // as described in the original sea of nodes paper through the basic
        // blocks to drive execution.
        write!(
            w,
            "let mut control_token: i8 = {};loop {{match control_token {{",
            root.idx(),
        )?;

        Ok(())
    }

    fn codegen_close_environment<W: Write>(&self, w: &mut W) -> Result<(), Error> {
        // Close the match and loop.
        write!(w, "_ => panic!()}}}}")
    }

    /*
     * Generate a runner object for this function. The runner object stores
     * backing memory for a Hercules function and wraps calls to the Hercules
     * function.
     */
    fn codegen_runner_object<W: Write>(&self, w: &mut W) -> Result<(), Error> {
        // Figure out the devices for the parameters and the return value if
        // they are collections and whether they should be immutable or mutable
        // references.
        let func = self.get_func();
        let param_devices = &self.node_colors.1;
        let return_devices = &self.node_colors.2;
        let mut param_muts = vec![false; func.param_types.len()];
        let mut return_muts = vec![true; func.return_types.len()];
        let objects = &self.collection_objects[&self.func_id];
        for idx in 0..func.param_types.len() {
            if let Some(object) = objects.param_to_object(idx)
                && objects.is_mutated(object)
            {
                param_muts[idx] = true;
            }
        }
        let num_returns = func.return_types.len();
        for idx in 0..num_returns {
            for object in objects.returned_objects(idx) {
                if let Some(param_idx) = objects.origin(*object).try_parameter()
                    && !param_muts[param_idx]
                {
                    return_muts[idx] = false;
                }
            }
        }

        // Emit the type definition. A runner object owns its backing memory.
        write!(
            w,
            "#[allow(non_camel_case_types)]struct HerculesRunner_{} {{",
            func.name
        )?;
        for (device, _) in self.backing_allocations[&self.func_id].iter() {
            write!(w, "backing_ptr_{}: *mut u8,", device.name(),)?;
            write!(w, "backing_size_{}: usize,", device.name(),)?;
        }
        write!(w, "}}")?;
        write!(
            w,
            "impl HerculesRunner_{} {{fn new() -> Self {{Self {{",
            func.name
        )?;
        for (device, _) in self.backing_allocations[&self.func_id].iter() {
            write!(
                w,
                "backing_ptr_{}: ::core::ptr::null_mut(),backing_size_{}: 0,",
                device.name(),
                device.name()
            )?;
        }
        write!(w, "}}}}")?;

        // Each returned reference, input reference, and the runner will have
        // its own lifetime. We use lifetime bounds to ensure that the runner
        // and parameters are borrowed for the lifetimes needed by the outputs
        let returned_origins: Vec<HashSet<_>> = (0..num_returns)
            .map(|idx| {
                objects
                    .returned_objects(idx)
                    .iter()
                    .map(|obj| objects.origin(*obj))
                    .collect()
            })
            .collect();

        write!(w, "async fn run<'runner:")?;
        for (ret_idx, origins) in returned_origins.iter().enumerate() {
            if origins.iter().any(|origin| !origin.is_parameter()) {
                write!(w, " 'r{} +", ret_idx)?;
            }
        }
        for idx in 0..num_returns {
            write!(w, ", 'r{}", idx)?;
        }
        for idx in 0..func.param_types.len() {
            write!(w, ", 'p{}:", idx)?;
            for (ret_idx, origins) in returned_origins.iter().enumerate() {
                if origins.iter().any(|origin| {
                    origin
                        .try_parameter()
                        .map(|oidx| idx == oidx)
                        .unwrap_or(false)
                }) {
                    write!(w, " 'r{} +", ret_idx)?;
                }
            }
        }
        write!(w, ">(&'runner mut self")?;
        for idx in 0..func.num_dynamic_constants {
            write!(w, ", dc_p{}: u64", idx)?;
        }
        for idx in 0..func.param_types.len() {
            if self.module.types[func.param_types[idx].idx()].is_primitive() {
                write!(w, ", p{}: {}", idx, self.get_type(func.param_types[idx]))?;
            } else {
                let device = match param_devices[idx] {
                    Some(Device::LLVM) | None => "CPU",
                    Some(Device::CUDA) => "CUDA",
                    _ => panic!(),
                };
                let mutability = if param_muts[idx] { "Mut" } else { "" };
                write!(
                    w,
                    ", p{}: ::hercules_rt::Hercules{}Ref{}<'p{}>",
                    idx, device, mutability, idx,
                )?;
            }
        }
        write!(
            w,
            ") -> {}{}{} {{",
            if num_returns != 1 { "(" } else { "" },
            func.return_types
                .iter()
                .enumerate()
                .map(
                    |(ret_idx, typ)| if self.module.types[typ.idx()].is_primitive() {
                        self.get_type(*typ)
                    } else {
                        let device = match return_devices[ret_idx] {
                            Some(Device::LLVM) | None => "CPU",
                            Some(Device::CUDA) => "CUDA",
                            _ => panic!(),
                        };
                        let mutability = if return_muts[ret_idx] { "Mut" } else { "" };
                        format!(
                            "::hercules_rt::Hercules{}Ref{}<'r{}>",
                            device, mutability, ret_idx
                        )
                    }
                )
                .collect::<Vec<_>>()
                .join(", "),
            if num_returns != 1 { ")" } else { "" },
        )?;

        // Start with possibly re-allocating the backing memory if it's not
        // large enough.
        write!(w, "unsafe {{")?;
        for (device, (total, _)) in self.backing_allocations[&self.func_id].iter() {
            write!(w, "let size = ")?;
            self.codegen_dynamic_constant(*total, w)?;
            write!(
                w,
                " as usize;if self.backing_size_{} < size {{",
                device.name()
            )?;
            write!(
                w,
                "::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});",
                device.name(),
                device.name(),
                device.name()
            )?;
            write!(w, "self.backing_size_{} = size;", device.name())?;
            write!(
                w,
                "self.backing_ptr_{} = ::hercules_rt::__{}_alloc(self.backing_size_{});",
                device.name(),
                device.name(),
                device.name()
            )?;
            write!(w, "}}")?;
        }
        for idx in 0..func.param_types.len() {
            if !self.module.types[func.param_types[idx].idx()].is_primitive() {
                write!(
                    w,
                    "let p{} = ::hercules_rt::__RawPtrSendSync(p{}.__ptr());",
                    idx, idx
                )?;
            }
        }

        // Call the wrapped function.
        write!(w, "let ret = {}_{}(", self.module_name, func.name)?;
        for (device, _) in self.backing_allocations[&self.func_id].iter() {
            write!(
                w,
                "::hercules_rt::__RawPtrSendSync(self.backing_ptr_{}), ",
                device.name()
            )?;
        }
        for idx in 0..func.num_dynamic_constants {
            write!(w, "dc_p{}, ", idx)?;
        }
        for idx in 0..func.param_types.len() {
            write!(w, "p{}, ", idx)?;
        }
        write!(w, ").await;")?;
        // Return the result, appropriately wrapping pointers
        if num_returns == 1 {
            if self.module.types[func.return_types[0].idx()].is_primitive() {
                write!(w, "ret")?;
            } else {
                let device = match return_devices[0] {
                    Some(Device::LLVM) | None => "CPU",
                    Some(Device::CUDA) => "CUDA",
                    _ => panic!(),
                };
                let mutability = if return_muts[0] { "Mut" } else { "" };
                write!(
                    w,
                    "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)",
                    device,
                    mutability,
                    self.codegen_type_size(func.return_types[0])
                )?;
            }
        } else {
            write!(w, "(")?;
            for (idx, typ) in func.return_types.iter().enumerate() {
                if self.module.types[typ.idx()].is_primitive() {
                    write!(w, "ret.{},", idx)?;
                } else {
                    let device = match return_devices[idx] {
                        Some(Device::LLVM) | None => "CPU",
                        Some(Device::CUDA) => "CUDA",
                        _ => panic!(),
                    };
                    let mutability = if return_muts[idx] { "Mut" } else { "" };
                    write!(
                        w,
                        "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.{}.0, {} as usize),",
                        device,
                        mutability,
                        idx,
                        self.codegen_type_size(func.return_types[idx]),
                    )?;
                }
            }
            write!(w, ")")?;
        }
        write!(w, "}}}}")?;

        // De-allocate the backing memory on drop.
        write!(
            w,
            "}}impl Drop for HerculesRunner_{} {{#[allow(unused_unsafe)]fn drop(&mut self) {{unsafe {{",
            func.name
        )?;
        for (device, _) in self.backing_allocations[&self.func_id].iter() {
            write!(
                w,
                "::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});",
                device.name(),
                device.name(),
                device.name()
            )?;
        }
        write!(w, "}}}}}}")?;
        Ok(())
    }

    fn get_func(&self) -> &Function {
        &self.module.functions[self.func_id.idx()]
    }

    fn get_value(&self, id: NodeID, bb: NodeID, lhs: bool) -> String {
        let func = self.get_func();
        if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce()
            && (control == bb
                || !self.nodes_in_fork_joins[&self.join_fork_map[&control]].contains(&bb))
        {
            format!("reduce_{}", id.idx())
        } else if func.nodes[id.idx()].is_call()
            && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
        {
            assert!(!lhs);
            format!("async_call_{}.lock().await.inspect().await", id.idx(),)
        } else {
            format!("node_{}", id.idx())
        }
    }

    fn clone_arc(&self, id: NodeID) -> Option<String> {
        let func = self.get_func();
        if func.nodes[id.idx()].is_call() && func.schedules[id.idx()].contains(&Schedule::AsyncCall)
        {
            Some(format!(
                "let async_call_{} = async_call_{}.clone();",
                id.idx(),
                id.idx()
            ))
        } else {
            None
        }
    }

    fn get_type(&self, id: TypeID) -> String {
        convert_type(&self.module.types[id.idx()], &self.module.types)
    }

    fn get_default_value(&self, idx: TypeID) -> String {
        let typ = &self.module.types[idx.idx()];
        if typ.is_bool() {
            "false".to_string()
        } else if typ.is_integer() {
            "0".to_string()
        } else if typ.is_float() {
            "0.0".to_string()
        } else if let Some(ts) = typ.try_multi_return() {
            format!(
                "({})",
                ts.iter()
                    .map(|t| self.get_default_value(*t))
                    .collect::<Vec<_>>()
                    .join(", ")
            )
        } else {
            "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())".to_string()
        }
    }

    fn write_rust_return_type<W: Write>(&self, w: &mut W, tys: &[TypeID]) -> Result<(), Error> {
        if tys.len() == 1 {
            write!(w, "{}", self.get_type(tys[0]))
        } else {
            write!(
                w,
                "({})",
                tys.iter()
                    .map(|t| self.get_type(*t))
                    .collect::<Vec<_>>()
                    .join(", "),
            )
        }
    }

    // Writes the signature of a device function as if it were an async function, in particular
    // this means that if the function is multi-return it will return a product in the produced
    // Rust code
    // Writes from the "fn" keyword up to the end of the return type
    fn write_device_signature_async<W: Write>(
        &self,
        w: &mut W,
        func_id: FunctionID,
        is_unsafe: bool,
    ) -> Result<(), Error> {
        let func = &self.module.functions[func_id.idx()];
        write!(
            w,
            "{}fn {}_{}(",
            if is_unsafe { "unsafe " } else { "" },
            self.module_name,
            func.name
        )?;
        let mut first_param = true;
        if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) {
            first_param = false;
            write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?;
        }
        for idx in 0..func.num_dynamic_constants {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
            }
            write!(w, "dc{}: u64", idx)?;
        }
        for (idx, ty) in func.param_types.iter().enumerate() {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
            }
            write!(w, "p{}: {}", idx, self.get_type(*ty))?;
        }
        write!(w, ") -> ")?;
        self.write_rust_return_type(w, &func.return_types)
    }

    // Writes the true signature of a device function
    // Compared to the _async version this converts multi-return into a return struct
    fn write_device_signature<W: Write>(
        &self,
        w: &mut W,
        func_id: FunctionID,
    ) -> Result<(), Error> {
        let func = &self.module.functions[func_id.idx()];
        write!(w, "fn {}_{}(", self.module_name, func.name)?;
        let mut first_param = true;
        if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) {
            first_param = false;
            write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?;
        }
        for idx in 0..func.num_dynamic_constants {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
            }
            write!(w, "dc{}: u64", idx)?;
        }
        for (idx, ty) in func.param_types.iter().enumerate() {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
            }
            write!(w, "p{}: {}", idx, self.get_type(*ty))?;
        }
        if func.return_types.len() == 1 {
            write!(w, ") -> {}", self.get_type(func.return_types[0]))
        } else {
            write!(w, ", ret_ptr: *mut ReturnStruct)")
        }
    }

    fn library_prim_ty(&self, id: TypeID) -> &'static str {
        match self.module.types[id.idx()] {
            Type::Boolean => "::hercules_rt::PrimTy::Bool",
            Type::Integer8 => "::hercules_rt::PrimTy::I8",
            Type::Integer16 => "::hercules_rt::PrimTy::I16",
            Type::Integer32 => "::hercules_rt::PrimTy::I32",
            Type::Integer64 => "::hercules_rt::PrimTy::I64",
            Type::UnsignedInteger8 => "::hercules_rt::PrimTy::U8",
            Type::UnsignedInteger16 => "::hercules_rt::PrimTy::U16",
            Type::UnsignedInteger32 => "::hercules_rt::PrimTy::U32",
            Type::UnsignedInteger64 => "::hercules_rt::PrimTy::U64",
            Type::Float8 => "::hercules_rt::PrimTy::F8",
            Type::BFloat16 => "::hercules_rt::PrimTy::BF16",
            Type::Float32 => "::hercules_rt::PrimTy::F32",
            Type::Float64 => "::hercules_rt::PrimTy::F64",
            _ => panic!(),
        }
    }
}

fn convert_type(ty: &Type, types: &[Type]) -> String {
    match ty {
        Type::Boolean => "bool".to_string(),
        Type::Integer8 => "i8".to_string(),
        Type::Integer16 => "i16".to_string(),
        Type::Integer32 => "i32".to_string(),
        Type::Integer64 => "i64".to_string(),
        Type::UnsignedInteger8 => "u8".to_string(),
        Type::UnsignedInteger16 => "u16".to_string(),
        Type::UnsignedInteger32 => "u32".to_string(),
        Type::UnsignedInteger64 => "u64".to_string(),
        Type::Float32 => "f32".to_string(),
        Type::Float64 => "f64".to_string(),
        Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
            "::hercules_rt::__RawPtrSendSync".to_string()
        }
        Type::MultiReturn(ts) => {
            format!(
                "({})",
                ts.iter()
                    .map(|t| convert_type(&types[t.idx()], types))
                    .collect::<Vec<_>>()
                    .join(", ")
            )
        }
        _ => panic!(),
    }
}