Skip to content
Snippets Groups Projects
rt.rs 70.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • use std::collections::{BTreeMap, HashMap, HashSet};
    
    rarbore2's avatar
    rarbore2 committed
    use std::fmt::{Error, Write};
    
    use std::iter::zip;
    
    rarbore2's avatar
    rarbore2 committed
    
    
    rarbore2's avatar
    rarbore2 committed
    use hercules_ir::*;
    
    rarbore2's avatar
    rarbore2 committed
    
    use crate::*;
    
    /*
     * Entry Hercules functions are lowered to async Rust code to achieve easy task
     * level parallelism. This Rust is generated textually, and is included via a
     * procedural macro in the user's Rust code.
    
     *
     * Generating Rust that properly handles memory across devices is tricky. In
     * particular, the following elements are challenges:
     *
     * 1. RT functions may return objects that were not first passed in as
     *    parameters, via object constants.
     * 2. We want to allocate as much memory upfront as possible. Our goal is for
     *    each call to a RT function from host Rust code, there is at most one
     *    memory allocation per device.
     * 3. We want to statically determine when cross-device communication is
     *    necessary - this should be a separate concern from allocation.
     * 4. At the boundary between host Rust / RT functions, we want to encode
     *    lifetime rules so that the Hercules API is memory safe.
     * 5. We want to support efficient composition of Hercules RT functions in both
     *    synchronous and asynchronous contexts.
     *
     * Challenges #1 and #2 require that RT functions themselves do not allocate
     * memory. Instead, for every entry point function, a "runner" type will be
     * generated. The host Rust code must instantiate a runner object to invoke an
     * entry point function. This runner object contains a "backing" memory that is
     * the single allocation of memories for this function. The runner object can be
     * used to call the same entry point multiple times, and to the extent possible
     * the backing memory will be re-used. The size of the backing memory depends on
     * the dynamic constants passed in to the entry point, so it's lazily allocated
     * on calls to the entry point to the needed size. To address challenge #4, any
     * returned objects will be lifetime-bound to the runner object instance,
     * ensuring that the reference cannot be used after the runner object has de-
     * allocated the backing memory. This also ensures that the runner can't be run
     * again while a returned object from a previous iteration is still live, since
     * the entry point method requires an exclusive reference to the runner.
     *
     * Addressing challenge #3 requires we determine what objects are live on what
     * devices at what times. This can be done fairly easily by coloring nodes by
     * what device they produce their result on and inserting inter-device transfers
     * along edges connecting nodes of different colors. Nodes can only have a
     * single color - this is enforced by the GCM pass.
     *
     * Addressing challenge #5 requires runner objects for entry points accept and
     * return objects that are not in their own backing memory and potentially on
     * any device. For this reason, parameter and return nodes are not necessarily
     * CPU colored. Instead, runners take and return Hercules reference objects that
     * refer to memory on some device which have unknown origin. Hercules reference
     * objects have a lifetime parameter, and when a runner may return a Hercules
     * reference that refers to its backing memory, the lifetime of the Hercules
     * reference is the same as the lifetime of the mutable reference of the runner
     * used in the entry point signature. In other words, the RT backend infers the
     * proper lifetime bounds on parameter and returned Hercules reference objects
     * in relation to the runner's self reference using the collection objects
     * analysis. There are the following kinds of Hercules reference objects:
     *
     * - HerculesCPURef
     * - HerculesCPURefMut
     * - HerculesCUDARef
     * - HerculesCUDARefMut
     *
     * Essentially, there are types for each device, one for immutable references
     * and one for exclusive references. Mutable references can decay into immutable
     * references, and immutable references can be cloned. The CPU reference types
     * can be created from normal Rust references. The CUDA reference types can't be
     * created from normal Rust references - for that purpose, an additional type is
     * given, CUDABox, which essentially allows the user to manually allocate and
     * set some CUDA memory - the user can then take a CUDA reference to that box.
    
    rarbore2's avatar
    rarbore2 committed
     */
    pub fn rt_codegen<W: Write>(
    
        module_name: &str,
    
    rarbore2's avatar
    rarbore2 committed
        func_id: FunctionID,
        module: &Module,
    
        def_use: &ImmutableDefUseMap,
    
    rarbore2's avatar
    rarbore2 committed
        typing: &Vec<TypeID>,
        control_subgraph: &Subgraph,
    
        fork_join_map: &HashMap<NodeID, NodeID>,
    
    rarbore2's avatar
    rarbore2 committed
        fork_join_nest: &HashMap<NodeID, Vec<NodeID>>,
    
        fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
        nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>,
    
        collection_objects: &CollectionObjects,
    
    rarbore2's avatar
    rarbore2 committed
        callgraph: &CallGraph,
    
        devices: &Vec<Device>,
    
        bbs: &BasicBlocks,
        node_colors: &FunctionNodeColors,
    
        backing_allocations: &BackingAllocations,
    
    rarbore2's avatar
    rarbore2 committed
        w: &mut W,
    ) -> Result<(), Error> {
    
        let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
            .into_iter()
            .map(|(fork, join)| (*join, *fork))
            .collect();
    
    rarbore2's avatar
    rarbore2 committed
        let ctx = RTContext {
    
            module_name,
    
    rarbore2's avatar
    rarbore2 committed
            func_id,
            module,
    
            def_use,
    
    rarbore2's avatar
    rarbore2 committed
            typing,
            control_subgraph,
    
            fork_join_map,
            join_fork_map: &join_fork_map,
    
    rarbore2's avatar
    rarbore2 committed
            fork_join_nest,
    
            fork_tree,
            nodes_in_fork_joins,
    
            collection_objects,
    
    rarbore2's avatar
    rarbore2 committed
            callgraph,
    
            devices,
    
            bbs,
            node_colors,
    
    rarbore2's avatar
    rarbore2 committed
        };
        ctx.codegen_function(w)
    }
    
    struct RTContext<'a> {
    
        module_name: &'a str,
    
    rarbore2's avatar
    rarbore2 committed
        func_id: FunctionID,
        module: &'a Module,
    
        def_use: &'a ImmutableDefUseMap,
    
    rarbore2's avatar
    rarbore2 committed
        typing: &'a Vec<TypeID>,
        control_subgraph: &'a Subgraph,
    
        fork_join_map: &'a HashMap<NodeID, NodeID>,
        join_fork_map: &'a HashMap<NodeID, NodeID>,
    
    rarbore2's avatar
    rarbore2 committed
        fork_join_nest: &'a HashMap<NodeID, Vec<NodeID>>,
    
        fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>,
        nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>,
    
        collection_objects: &'a CollectionObjects,
    
    rarbore2's avatar
    rarbore2 committed
        callgraph: &'a CallGraph,
    
        devices: &'a Vec<Device>,
    
        bbs: &'a BasicBlocks,
        node_colors: &'a FunctionNodeColors,
    
        backing_allocations: &'a BackingAllocations,
    
    rarbore2's avatar
    rarbore2 committed
    }
    
    
    #[derive(Debug, Clone, Default)]
    struct RustBlock {
        prologue: String,
        data: String,
    
    rarbore2's avatar
    rarbore2 committed
        phi_tmp_assignments: String,
        phi_assignments: String,
    
        epilogue: String,
    
    rarbore2's avatar
    rarbore2 committed
        join_epilogue: String,
    
    rarbore2's avatar
    rarbore2 committed
    impl<'a> RTContext<'a> {
        fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
    
            // If this is an entry function, generate a corresponding runner object
            // type definition.
    
    rarbore2's avatar
    rarbore2 committed
            let func = &self.get_func();
    
            if func.entry {
                self.codegen_runner_object(w)?;
            }
    
    rarbore2's avatar
    rarbore2 committed
    
            // Dump the function signature.
            write!(
                w,
    
                "#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]async unsafe fn {}_{}(",
                self.module_name,
    
    rarbore2's avatar
    rarbore2 committed
                func.name
            )?;
            let mut first_param = true;
    
            // The first set of parameters are pointers to backing memories.
    
            for (device, _) in self.backing_allocations[&self.func_id].iter() {
    
                if first_param {
                    first_param = false;
                } else {
                    write!(w, ", ")?;
                }
    
                write!(
                    w,
                    "backing_{}: ::hercules_rt::__RawPtrSendSync",
                    device.name()
                )?;
    
            }
            // The second set of parameters are dynamic constants.
    
    rarbore2's avatar
    rarbore2 committed
            for idx in 0..func.num_dynamic_constants {
                if first_param {
                    first_param = false;
                } else {
                    write!(w, ", ")?;
                }
                write!(w, "dc_p{}: u64", idx)?;
            }
    
            // The third set of parameters are normal parameters.
    
    rarbore2's avatar
    rarbore2 committed
            for idx in 0..func.param_types.len() {
                if first_param {
                    first_param = false;
                } else {
                    write!(w, ", ")?;
                }
    
    rarbore2's avatar
    rarbore2 committed
                write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?;
    
    rarbore2's avatar
    rarbore2 committed
            }
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            write!(w, ") -> ")?;
            self.write_rust_return_type(w, &func.return_types)?;
            write!(w, " {{")?;
    
    rarbore2's avatar
    rarbore2 committed
    
    
    rarbore2's avatar
    rarbore2 committed
            // Dump signatures for called device functions.
    
    Aaron Councilman's avatar
    Aaron Councilman committed
            // For single-return functions we directly expose the device function
            // while for multi-return functions we generate a wrapper which handles
            // allocation of the return struct and extracting values from it. This
            // ensures that device function signatures match what they would be in
            // AsyncRust
    
            for callee_id in self.callgraph.get_callees(self.func_id) {
                if self.devices[callee_id.idx()] == Device::AsyncRust {
    
                    continue;
                }
    
                let callee = &self.module.functions[callee_id.idx()];
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                let is_single_return = callee.return_types.len() == 1;
                if is_single_return {
                    write!(w, "extern \"C\" {{")?;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                self.write_device_signature_async(w, *callee_id, !is_single_return)?;
                if is_single_return {
                    write!(w, ";}}")?;
                } else {
                    // Generate the wrapper function for multi-return device functions
                    write!(w, " {{ ")?;
                    // Define the return struct
                    write!(
                        w,
                        "#[repr(C)] struct ReturnStruct {{ {} }} ",
                        callee
                            .return_types
                            .iter()
                            .enumerate()
                            .map(|(idx, t)| format!("f{}: {}", idx, self.get_type(*t)))
                            .collect::<Vec<_>>()
                            .join(", "),
                    )?;
                    // Declare the extern function's signature
                    write!(w, "extern \"C\" {{ ")?;
                    self.write_device_signature(w, *callee_id)?;
                    write!(w, "; }}")?;
                    // Create the return struct
                    write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?;
                    // Call the device function
    
                    write!(w, "{}_{}(", self.module_name, callee.name)?;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                    if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()])
                    {
                        write!(w, "backing, ")?;
    
    rarbore2's avatar
    rarbore2 committed
                    }
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                    for idx in 0..callee.num_dynamic_constants {
                        write!(w, "dc{}, ", idx)?;
                    }
                    for idx in 0..callee.param_types.len() {
                        write!(w, "p{}, ", idx)?;
    
    rarbore2's avatar
    rarbore2 committed
                    }
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                    write!(w, "ret_struct.as_mut_ptr());")?;
                    // Extract the result into a Rust product
                    write!(w, "let ret_struct = ret_struct.assume_init();")?;
                    write!(
                        w,
                        "({})",
                        (0..callee.return_types.len())
                            .map(|idx| format!("ret_struct.f{}", idx))
                            .collect::<Vec<_>>()
                            .join(", "),
                    )?;
                    write!(w, "}}")?;
    
            // Set up the root environment for the function. An environment is set
            // up for every created task in async closures, and there needs to be a
            // root environment corresponding to the root control node (start node).
            self.codegen_open_environment(NodeID::new(0), w)?;
    
    rarbore2's avatar
    rarbore2 committed
    
            let mut blocks: BTreeMap<_, _> = (0..func.nodes.len())
                .filter(|idx| func.nodes[*idx].is_control())
    
                .map(|idx| (NodeID::new(idx), RustBlock::default()))
    
    rarbore2's avatar
    rarbore2 committed
                .collect();
    
            // Emit data flow into basic blocks.
    
            for block in self.bbs.1.iter() {
                for id in block {
    
    rarbore2's avatar
    rarbore2 committed
                    self.codegen_data_node(*id, &mut blocks)?;
                }
            }
    
            // Emit control flow into basic blocks.
    
            let rev_po = self.control_subgraph.rev_po(NodeID::new(0));
            for id in rev_po.iter() {
                self.codegen_control_node(*id, &mut blocks)?;
    
    rarbore2's avatar
    rarbore2 committed
            }
    
    
            // Dump the emitted basic blocks. Do this in reverse postorder since
            // fork and join nodes open and close environments, respectively.
            for id in rev_po.iter() {
                let block = &blocks[id];
    
    rarbore2's avatar
    rarbore2 committed
                if func.nodes[id.idx()].is_join() {
                    write!(
                        w,
                        "{}{}{}{}{}{}",
                        block.prologue,
                        block.data,
                        block.epilogue,
                        block.phi_tmp_assignments,
                        block.phi_assignments,
                        block.join_epilogue
                    )?;
                } else {
                    write!(
                        w,
                        "{}{}{}{}{}",
                        block.prologue,
                        block.data,
                        block.phi_tmp_assignments,
                        block.phi_assignments,
                        block.epilogue
                    )?;
                }
    
    rarbore2's avatar
    rarbore2 committed
            }
    
    
            // Close the root environment.
            self.codegen_close_environment(w)?;
            write!(w, "}}")?;
    
    rarbore2's avatar
    rarbore2 committed
            Ok(())
        }
    
        /*
         * While control nodes in Hercules IR are predecessor-centric (each take a
         * control input that defines the predecessor relationship), the Rust loop
         * we generate is successor centric. This difference requires explicit
         * translation.
         */
        fn codegen_control_node(
            &self,
            id: NodeID,
    
            blocks: &mut BTreeMap<NodeID, RustBlock>,
    
    rarbore2's avatar
    rarbore2 committed
        ) -> Result<(), Error> {
            let func = &self.get_func();
            match func.nodes[id.idx()] {
                // Start, region, and projection control nodes all have exactly one
                // successor and are otherwise simple.
                Node::Start
                | Node::Region { preds: _ }
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                | Node::ControlProjection {
    
    rarbore2's avatar
    rarbore2 committed
                    control: _,
                    selection: _,
                } => {
    
                    let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                    write!(prologue, "{} => {{", id.idx())?;
                    let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
    
    rarbore2's avatar
    rarbore2 committed
                    let succ = self.control_subgraph.succs(id).next().unwrap();
    
                    write!(epilogue, "control_token = {};}}", succ.idx())?;
    
    rarbore2's avatar
    rarbore2 committed
                }
                // If nodes have two successors - examine the projections to
                // determine which branch is which, and branch between them.
                Node::If { control: _, cond } => {
    
                    let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                    write!(prologue, "{} => {{", id.idx())?;
                    let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
    
    rarbore2's avatar
    rarbore2 committed
                    let mut succs = self.control_subgraph.succs(id);
                    let succ1 = succs.next().unwrap();
                    let succ2 = succs.next().unwrap();
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                    let succ1_is_true = func.nodes[succ1.idx()].try_control_projection(1).is_some();
    
    rarbore2's avatar
    rarbore2 committed
                    write!(
    
                        epilogue,
                        "control_token = if {} {{{}}} else {{{}}};}}",
    
                        self.get_value(cond, id, false),
    
    rarbore2's avatar
    rarbore2 committed
                        if succ1_is_true { succ1 } else { succ2 }.idx(),
                        if succ1_is_true { succ2 } else { succ1 }.idx(),
    
    rarbore2's avatar
    rarbore2 committed
                }
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                Node::Return {
                    control: _,
                    ref data,
                } => {
    
                    let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                    write!(prologue, "{} => {{", id.idx())?;
                    let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                    if data.len() == 1 {
                        write!(epilogue, "return {};}}", self.get_value(data[0], id, false))?;
                    } else {
                        write!(
                            epilogue,
                            "return ({});}}",
                            data.iter()
                                .map(|v| self.get_value(*v, id, false))
                                .collect::<Vec<_>>()
                                .join(", "),
                        )?;
                    }
    
                }
                // Fork nodes open a new environment for defining an async closure.
                Node::Fork {
                    control: _,
                    ref factors,
                } => {
                    assert!(func.schedules[id.idx()].contains(&Schedule::ParallelFork));
    
                    // Set the outer environment control token to the join.
                    let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                    let join = self.fork_join_map[&id];
                    write!(
                        prologue,
                        "{} => {{control_token = {};",
                        id.idx(),
                        join.idx()
                    )?;
    
                    // Emit loops for the thread IDs.
                    for (idx, factor) in factors.into_iter().enumerate() {
                        write!(prologue, "for tid_{}_{} in 0..", id.idx(), idx)?;
                        self.codegen_dynamic_constant(*factor, prologue)?;
                        write!(prologue, " {{")?;
                    }
    
    
                    // Emit clones of arcs used inside the fork-join.
                    for other_id in (0..func.nodes.len()).map(NodeID::new) {
                        if self.def_use.get_users(other_id).into_iter().any(|user_id| {
                            self.nodes_in_fork_joins[&id].contains(&self.bbs.0[user_id.idx()])
                        }) && let Some(arc) = self.clone_arc(other_id)
                        {
                            write!(prologue, "{}", arc)?;
                        }
                    }
    
    
                    // Spawn an async closure and push its future to a Vec.
                    write!(
                        prologue,
                        "fork_{}.push(::async_std::task::spawn(async move {{",
                        id.idx()
                    )?;
    
                    // Open a new environment.
                    self.codegen_open_environment(id, prologue)?;
    
                    // Open the branch inside the async closure for the fork.
                    let succ = self.control_subgraph.succs(id).next().unwrap();
                    write!(
                        prologue,
                        "{} => {{control_token = {};",
                        id.idx(),
                        succ.idx()
                    )?;
    
                    // Close the branch for the fork inside the async closure.
                    let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
                    write!(epilogue, "}}")?;
                }
                // Join nodes close the environment opened by its corresponding
                // fork.
                Node::Join { control: _ } => {
                    // Emit the branch for the join inside the async closure.
                    let prologue = &mut blocks.get_mut(&id).unwrap().prologue;
                    write!(prologue, "{} => {{", id.idx())?;
    
                    // Close the branch inside the async closure.
                    let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue;
    
    rarbore2's avatar
    rarbore2 committed
                    write!(
                        epilogue,
                        "::std::sync::atomic::fence(::std::sync::atomic::Ordering::Release);return;}}"
                    )?;
    
    
                    // Close the fork's environment.
                    self.codegen_close_environment(epilogue)?;
    
                    // Close the async closure and push statement from the
                    // fork.
                    write!(epilogue, "}}));")?;
    
                    // Close the loops emitted by the fork node.
                    let fork = self.join_fork_map[&id];
                    for _ in 0..func.nodes[fork.idx()].try_fork().unwrap().1.len() {
                        write!(epilogue, "}}")?;
                    }
    
                    // Close the branch for the fork outside the async closure.
                    write!(epilogue, "}}")?;
    
                    // Open the branch in the surrounding context for the join.
                    let succ = self.control_subgraph.succs(id).next().unwrap();
                    write!(epilogue, "{} => {{", id.idx())?;
    
    
                    // Await the empty futures for the fork-joins, waiting for them
                    // to complete.
                    write!(
                        epilogue,
                        "for fut in fork_{}.drain(..) {{ fut.await; }}; ::std::sync::atomic::fence(::std::sync::atomic::Ordering::Acquire);",
                        fork.idx(),
                    )?;
    
    
                    // Emit the assignments to the reduce variables in the
                    // surrounding context. It's very unfortunate that we have to do
                    // it while lowering the join node (rather than the reduce nodes
                    // themselves), but this is the only place we can put these
                    // assignments in the correct control location.
                    for user in self.def_use.get_users(id) {
                        if let Some((_, init, _)) = func.nodes[user.idx()].try_reduce() {
                            write!(
                                epilogue,
                                "{} = {};",
    
                                self.get_value(*user, id, true),
                                self.get_value(init, id, false)
    
    rarbore2's avatar
    rarbore2 committed
                    let join_epilogue = &mut blocks.get_mut(&id).unwrap().join_epilogue;
    
                    // Branch to the successor control node in the surrounding
                    // context, and close the branch for the join.
    
    rarbore2's avatar
    rarbore2 committed
                    write!(join_epilogue, "control_token = {};}}", succ.idx())?;
    
    rarbore2's avatar
    rarbore2 committed
                }
                _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
            }
            Ok(())
        }
    
        /*
         * Lower data nodes in Hercules IR into Rust statements.
         */
        fn codegen_data_node(
            &self,
            id: NodeID,
    
            blocks: &mut BTreeMap<NodeID, RustBlock>,
    
    rarbore2's avatar
    rarbore2 committed
        ) -> Result<(), Error> {
            let func = &self.get_func();
    
            let bb = self.bbs.0[id.idx()];
    
    rarbore2's avatar
    rarbore2 committed
            match func.nodes[id.idx()] {
                Node::Parameter { index } => {
    
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
    
                    write!(block, "{} = p{};", self.get_value(id, bb, true), index)?
    
    rarbore2's avatar
    rarbore2 committed
                }
                Node::Constant { id: cons_id } => {
    
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
    
                    write!(block, "{} = ", self.get_value(id, bb, true))?;
    
                    let mut size_and_device = None;
    
    rarbore2's avatar
    rarbore2 committed
                    match self.module.constants[cons_id.idx()] {
    
    rarbore2's avatar
    rarbore2 committed
                        Constant::Boolean(val) => write!(block, "{}", val)?,
    
    rarbore2's avatar
    rarbore2 committed
                        Constant::Integer8(val) => write!(block, "{}i8", val)?,
                        Constant::Integer16(val) => write!(block, "{}i16", val)?,
                        Constant::Integer32(val) => write!(block, "{}i32", val)?,
                        Constant::Integer64(val) => write!(block, "{}i64", val)?,
                        Constant::UnsignedInteger8(val) => write!(block, "{}u8", val)?,
                        Constant::UnsignedInteger16(val) => write!(block, "{}u16", val)?,
                        Constant::UnsignedInteger32(val) => write!(block, "{}u32", val)?,
                        Constant::UnsignedInteger64(val) => write!(block, "{}u64", val)?,
    
    rarbore2's avatar
    rarbore2 committed
                        Constant::Float32(val) => {
                            if val == f32::INFINITY {
                                write!(block, "f32::INFINITY")?
                            } else if val == f32::NEG_INFINITY {
                                write!(block, "f32::NEG_INFINITY")?
                            } else {
                                write!(block, "{}f32", val)?
                            }
                        }
                        Constant::Float64(val) => {
                            if val == f64::INFINITY {
                                write!(block, "f64::INFINITY")?
                            } else if val == f64::NEG_INFINITY {
                                write!(block, "f64::NEG_INFINITY")?
                            } else {
                                write!(block, "{}f64", val)?
                            }
                        }
    
                        Constant::Product(ty, _)
                        | Constant::Summation(ty, _, _)
                        | Constant::Array(ty) => {
    
    rarbore2's avatar
    rarbore2 committed
                            let (device, (offset, _)) = self.backing_allocations[&self.func_id]
    
                                .filter_map(|(device, (_, offsets))| {
                                    offsets.get(&id).map(|id| (*device, *id))
                                })
                                .next()
                                .unwrap();
                            write!(block, "backing_{}.byte_add(", device.name())?;
                            self.codegen_dynamic_constant(offset, block)?;
                            write!(block, " as usize)")?;
    
                            size_and_device = Some((self.codegen_type_size(ty), device));
    
    rarbore2's avatar
    rarbore2 committed
                        }
                    }
    
                    write!(block, ";")?;
    
                    if !func.schedules[id.idx()].contains(&Schedule::NoResetConstant) {
                        if let Some((size, device)) = size_and_device {
                            write!(
                                block,
    
                                "::hercules_rt::__{}_zero_mem({}.0, {} as usize);",
    
                                self.get_value(id, bb, false),
    
    rarbore2's avatar
    rarbore2 committed
                }
    
                Node::DynamicConstant { id: dc_id } => {
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
                    write!(block, "{} = ", self.get_value(id, bb, true))?;
                    self.codegen_dynamic_constant(dc_id, block)?;
                    write!(block, ";")?;
                }
    
                Node::ThreadID { control, dimension } => {
    
    rarbore2's avatar
    rarbore2 committed
                    assert_eq!(control, bb);
    
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
                    write!(
                        block,
                        "{} = tid_{}_{};",
    
                        self.get_value(id, bb, true),
    
    rarbore2's avatar
    rarbore2 committed
                        bb.idx(),
    
    rarbore2's avatar
    rarbore2 committed
                Node::Phi { control, ref data } => {
                    assert_eq!(control, bb);
                    // Phis aren't executable in their own basic block - predecessor
                    // blocks assign the to-be phi values themselves. Assign
                    // temporary values first before assigning the phi itself, since
                    // there may be simultaneous inter-dependent phis.
                    for (data, pred) in zip(data.into_iter(), self.control_subgraph.preds(bb)) {
                        let block = &mut blocks.get_mut(&pred).unwrap().phi_tmp_assignments;
                        write!(
                            block,
                            "let {}_tmp = {};",
                            self.get_value(id, pred, true),
                            self.get_value(*data, pred, false),
                        )?;
                        let block = &mut blocks.get_mut(&pred).unwrap().phi_assignments;
                        write!(
                            block,
                            "{} = {}_tmp;",
                            self.get_value(id, pred, true),
                            self.get_value(id, pred, false),
                        )?;
                    }
                }
    
                Node::Reduce {
                    control: _,
                    init: _,
                    reduct: _,
                } => {
                    assert!(func.schedules[id.idx()].contains(&Schedule::ParallelReduce));
                }
    
    rarbore2's avatar
    rarbore2 committed
                Node::Call {
    
    rarbore2's avatar
    rarbore2 committed
                    control,
    
    rarbore2's avatar
    rarbore2 committed
                    function: callee_id,
                    ref dynamic_constants,
                    ref args,
                } => {
    
    rarbore2's avatar
    rarbore2 committed
                    assert_eq!(control, bb);
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                    // The device backends and the wrappers we generated earlier ensure that device
                    // functions have the same interface as AsyncRust functions.
    
                    let block = &mut blocks.get_mut(&bb).unwrap();
                    let block = &mut block.data;
    
                    let is_async = func.schedules[id.idx()].contains(&Schedule::AsyncCall);
    
    rarbore2's avatar
    rarbore2 committed
                    if is_async {
                        for arg in args {
    
                            if let Some(arc) = self.clone_arc(*arg) {
    
    rarbore2's avatar
    rarbore2 committed
                                write!(block, "{}", arc)?;
                            }
                        }
                    }
    
                    let device = self.devices[callee_id.idx()];
                    let prefix = match (device, is_async) {
    
    rarbore2's avatar
    rarbore2 committed
                        (Device::AsyncRust, false) | (_, false) => {
                            format!("{} = ", self.get_value(id, bb, true))
                        }
    
                        (_, true) => {
                            write!(block, "{}", self.clone_arc(id).unwrap())?;
                            format!(
                                "*async_call_{}.lock().await = ::hercules_rt::__FutureSlotWrapper::new(::async_std::task::spawn(async move {{ ",
                                id.idx(),
                            )
                        }
    
                    };
                    let postfix = match (device, is_async) {
                        (Device::AsyncRust, false) => ".await",
                        (_, false) => "",
    
                        (Device::AsyncRust, true) => ".await}))",
                        (_, true) => "}))",
    
                    write!(
                        block,
    
                        "{}{}_{}(",
    
                        prefix,
    
                        self.module_name,
    
                        self.module.functions[callee_id.idx()].name
                    )?;
    
    rarbore2's avatar
    rarbore2 committed
                    for (device, (offset, size)) in self.backing_allocations[&self.func_id]
    
                        .filter_map(|(device, (_, offsets))| offsets.get(&id).map(|id| (*device, *id)))
                    {
    
    rarbore2's avatar
    rarbore2 committed
                        write!(block, "backing_{}.byte_add(((", device.name())?;
    
                        self.codegen_dynamic_constant(offset, block)?;
    
    rarbore2's avatar
    rarbore2 committed
                        let forks = &self.fork_join_nest[&bb];
                        if !forks.is_empty() {
                            write!(block, ") + ")?;
                            let mut linear_thread = "0".to_string();
                            for fork in forks {
                                let factors = func.nodes[fork.idx()].try_fork().unwrap().1;
                                for (factor_idx, factor) in factors.into_iter().enumerate() {
                                    linear_thread = format!("({} *", linear_thread);
                                    self.codegen_dynamic_constant(*factor, &mut linear_thread)?;
                                    write!(linear_thread, " + tid_{}_{})", fork.idx(), factor_idx)?;
                                }
                            }
                            write!(block, "{} * (", linear_thread)?;
                            self.codegen_dynamic_constant(size, block)?;
                        }
                        write!(block, ")) as usize), ")?
    
                    }
                    for dc in dynamic_constants {
                        self.codegen_dynamic_constant(*dc, block)?;
                        write!(block, ", ")?;
                    }
                    for arg in args {
    
                        write!(block, "{}, ", self.get_value(*arg, bb, false))?;
    
    rarbore2's avatar
    rarbore2 committed
                    }
    
                    write!(block, "){};", postfix)?;
    
    rarbore2's avatar
    rarbore2 committed
                }
    
    Aaron Councilman's avatar
    Aaron Councilman committed
                Node::DataProjection { data, selection } => {
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
                    let Node::Call {
                        function: callee_id,
                        ..
                    } = func.nodes[data.idx()]
                    else {
                        panic!()
                    };
                    if self.module.functions[callee_id.idx()].return_types.len() == 1 {
                        assert!(selection == 0);
                        write!(
                            block,
                            "{} = {};",
                            self.get_value(id, bb, true),
                            self.get_value(data, bb, false),
                        )?;
                    } else {
                        write!(
                            block,
                            "{} = {}.{};",
                            self.get_value(id, bb, true),
                            self.get_value(data, bb, false),
                            selection,
                        )?;
                    }
                }
    
    rarbore2's avatar
    rarbore2 committed
                Node::IntrinsicCall {
                    intrinsic,
                    ref args,
                } => {
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
                    write!(
                        block,
                        "{} = {}::{}(",
                        self.get_value(id, bb, true),
                        self.get_type(self.typing[id.idx()]),
                        intrinsic.lower_case_name(),
                    )?;
                    for arg in args {
                        write!(block, "{}, ", self.get_value(*arg, bb, false))?;
                    }
                    write!(block, ");")?;
                }
    
                Node::LibraryCall {
                    library_function,
                    ref args,
                    ty,
                    device,
                } => match library_function {
                    LibraryFunction::GEMM => {
                        assert_eq!(args.len(), 3);
                        assert_eq!(self.typing[args[0].idx()], ty);
                        let c_ty = &self.module.types[self.typing[args[0].idx()].idx()];
                        let a_ty = &self.module.types[self.typing[args[1].idx()].idx()];
                        let b_ty = &self.module.types[self.typing[args[2].idx()].idx()];
                        let (
                            Type::Array(c_elem, c_dims),
                            Type::Array(a_elem, a_dims),
                            Type::Array(b_elem, b_dims),
                        ) = (c_ty, a_ty, b_ty)
                        else {
                            panic!();
                        };
                        assert_eq!(a_elem, b_elem);
                        assert_eq!(a_elem, c_elem);
                        assert_eq!(c_dims.len(), 2);
                        assert_eq!(a_dims.len(), 2);
                        assert_eq!(b_dims.len(), 2);
                        assert_eq!(a_dims[1], b_dims[0]);
                        assert_eq!(a_dims[0], c_dims[0]);
                        assert_eq!(b_dims[1], c_dims[1]);
    
                        let block = &mut blocks.get_mut(&bb).unwrap().data;
                        let prim_ty = self.library_prim_ty(*a_elem);
                        write!(block, "::hercules_rt::__library_{}_gemm(", device.name())?;
                        self.codegen_dynamic_constant(a_dims[0], block)?;
                        write!(block, ", ")?;
                        self.codegen_dynamic_constant(a_dims[1], block)?;
                        write!(block, ", ")?;
                        self.codegen_dynamic_constant(b_dims[1], block)?;
                        write!(
                            block,
                            ", {}.0, {}.0, {}.0, {});",
                            self.get_value(args[0], bb, false),
                            self.get_value(args[1], bb, false),
                            self.get_value(args[2], bb, false),
                            prim_ty
                        )?;
                        write!(
                            block,
                            "{} = {};",
                            self.get_value(id, bb, true),
                            self.get_value(args[0], bb, false)
                        )?;
                    }
                },
    
                Node::Unary { op, input } => {
    
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
    
                    match op {
                        UnaryOperator::Not => write!(
                            block,
    
                            "{} = !{};",
    
                            self.get_value(id, bb, true),
                            self.get_value(input, bb, false)
    
                        )?,
                        UnaryOperator::Neg => write!(
                            block,
    
                            "{} = -{};",
    
                            self.get_value(id, bb, true),
                            self.get_value(input, bb, false)
    
                        )?,
                        UnaryOperator::Cast(ty) => write!(
                            block,
    
                            "{} = {} as {};",
    
                            self.get_value(id, bb, true),
                            self.get_value(input, bb, false),
    
                            self.get_type(ty)
                        )?,
                    };
                }
                Node::Binary { op, left, right } => {
    
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
    
                    let op = match op {
                        BinaryOperator::Add => "+",
                        BinaryOperator::Sub => "-",
                        BinaryOperator::Mul => "*",
                        BinaryOperator::Div => "/",
                        BinaryOperator::Rem => "%",
                        BinaryOperator::LT => "<",
                        BinaryOperator::LTE => "<=",
                        BinaryOperator::GT => ">",
                        BinaryOperator::GTE => ">=",
                        BinaryOperator::EQ => "==",
                        BinaryOperator::NE => "!=",
                        BinaryOperator::Or => "|",
                        BinaryOperator::And => "&",
                        BinaryOperator::Xor => "^",
                        BinaryOperator::LSh => "<<",
                        BinaryOperator::RSh => ">>",
                    };
    
                    write!(
                        block,
    
                        "{} = {} {} {};",
    
                        self.get_value(id, bb, true),
                        self.get_value(left, bb, false),
    
                        self.get_value(right, bb, false)
    
                    )?;
                }
                Node::Ternary {
                    op,
                    first,
                    second,
                    third,
                } => {
    
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
    
                    match op {
                        TernaryOperator::Select => write!(
                            block,
    
                            "{} = if {} {{{}}} else {{{}}};",
    
                            self.get_value(id, bb, true),
                            self.get_value(first, bb, false),
                            self.get_value(second, bb, false),
                            self.get_value(third, bb, false),
    
                Node::Read {
                    collect,
                    ref indices,
                } => {
    
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
    
                    let collect_ty = self.typing[collect.idx()];
    
    rarbore2's avatar
    rarbore2 committed
                    let self_ty = self.typing[id.idx()];
    
                    let offset = self.codegen_index_math(collect_ty, indices, bb)?;
    
    rarbore2's avatar
    rarbore2 committed
                    if self.module.types[self_ty.idx()].is_primitive() {
                        write!(
                            block,
                            "{} = ({}.byte_add({} as usize).0 as *mut {}).read();",
                            self.get_value(id, bb, true),
                            self.get_value(collect, bb, false),
                            offset,
                            self.get_type(self_ty)
                        )?;
                    } else {
                        write!(
                            block,
                            "{} = {}.byte_add({} as usize);",
                            self.get_value(id, bb, true),
                            self.get_value(collect, bb, false),
                            offset,
                        )?;
                    }
    
                }
                Node::Write {
                    collect,
                    data,
                    ref indices,
                } => {
    
                    let block = &mut blocks.get_mut(&bb).unwrap().data;
    
                    let collect_ty = self.typing[collect.idx()];
    
    rarbore2's avatar
    rarbore2 committed
                    let data_ty = self.typing[data.idx()];
    
    rarbore2's avatar
    rarbore2 committed
                    let data_size = self.codegen_type_size(data_ty);
                    let offset = self.codegen_index_math(collect_ty, indices, bb)?;
    
    rarbore2's avatar
    rarbore2 committed
                    if self.module.types[data_ty.idx()].is_primitive() {
    
    rarbore2's avatar
    rarbore2 committed
                        write!(
                            block,
                            "({}.byte_add({} as usize).0 as *mut {}).write({});",
                            self.get_value(collect, bb, false),
                            offset,
                            self.get_type(data_ty),
                            self.get_value(data, bb, false),
                        )?;
    
    rarbore2's avatar
    rarbore2 committed
                    } else {
                        // If the data item being written is not a primitive type,
                        // then perform a memcpy from the data collection to the
                        // destination collection. Look at the colors of the
                        // `collect` and `data` inputs, since this may be an inter-
                        // device copy.
                        let src_device = self.node_colors.0[&data];
                        let dst_device = self.node_colors.0[&collect];
                        write!(
                            block,
    
    rarbore2's avatar
    rarbore2 committed
                            "::hercules_rt::__copy_{}_to_{}({}.byte_add({} as usize).0, {}.0, {} as usize);",
    
    rarbore2's avatar
    rarbore2 committed
                            src_device.name(),
                            dst_device.name(),
    
                            self.get_value(collect, bb, false),
    
    rarbore2's avatar
    rarbore2 committed
                            offset,
    
                            self.get_value(data, bb, false),
    
    rarbore2's avatar
    rarbore2 committed
                            data_size,
                        )?;
                    }
                    write!(
                        block,
                        "{} = {};",
    
                        self.get_value(id, bb, true),
                        self.get_value(collect, bb, false)
    
    rarbore2's avatar
    rarbore2 committed
                    )?;
    
                _ => panic!(
                    "PANIC: Can't lower {:?} in {}.",
                    func.nodes[id.idx()],
                    func.name
                ),
    
    rarbore2's avatar
    rarbore2 committed
            }
            Ok(())
        }
    
        /*
         * Lower dynamic constant in Hercules IR into a Rust expression.
         */
        fn codegen_dynamic_constant<W: Write>(
            &self,
            id: DynamicConstantID,
            w: &mut W,
        ) -> Result<(), Error> {
    
            match &self.module.dynamic_constants[id.idx()] {
    
    rarbore2's avatar
    rarbore2 committed
                DynamicConstant::Constant(val) => write!(w, "{}", val)?,
                DynamicConstant::Parameter(idx) => write!(w, "dc_p{}", idx)?,
    
                DynamicConstant::Add(xs) => {
    
    rarbore2's avatar
    rarbore2 committed
                    write!(w, "(")?;
    
                    let mut xs = xs.iter();
                    self.codegen_dynamic_constant(*xs.next().unwrap(), w)?;
                    for x in xs {
                        write!(w, "+")?;
                        self.codegen_dynamic_constant(*x, w)?;
                    }
    
    rarbore2's avatar
    rarbore2 committed
                    write!(w, ")")?;
                }
                DynamicConstant::Sub(left, right) => {
                    write!(w, "(")?;
    
                    self.codegen_dynamic_constant(*left, w)?;
    
    rarbore2's avatar
    rarbore2 committed
                    write!(w, "-")?;
    
                    self.codegen_dynamic_constant(*right, w)?;
    
    rarbore2's avatar
    rarbore2 committed
                    write!(w, ")")?;
                }
    
                DynamicConstant::Mul(xs) => {
    
    rarbore2's avatar
    rarbore2 committed
                    write!(w, "(")?;
    
                    let mut xs = xs.iter();
                    self.codegen_dynamic_constant(*xs.next().unwrap(), w)?;
                    for x in xs {
                        write!(w, "*")?;
                        self.codegen_dynamic_constant(*x, w)?;