Skip to content
Snippets Groups Projects
gpu.rs 97.08 KiB
extern crate bitvec;
extern crate hercules_ir;

use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::{Error, Write};

use self::hercules_ir::*;

use crate::*;

/*
 * The top level function to compile a Hercules IR function into CUDA
 * kernel for execution on the GPU. We generate CUDA C textually, with a lot
 * of similarities with the CPU LLVM generation plus custom GPU parallelization.
 */
pub fn gpu_codegen<W: Write>(
    module_name: &str,
    function: &Function,
    types: &Vec<Type>,
    constants: &Vec<Constant>,
    dynamic_constants: &Vec<DynamicConstant>,
    typing: &Vec<TypeID>,
    control_subgraph: &Subgraph,
    bbs: &BasicBlocks,
    backing_allocation: &FunctionBackingAllocation,
    collection_objects: &FunctionCollectionObjects,
    def_use_map: &ImmutableDefUseMap,
    fork_join_map: &HashMap<NodeID, NodeID>,
    fork_control_map: &HashMap<NodeID, HashSet<NodeID>>,
    fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
    w: &mut W,
) -> Result<(), Error> {
    /*
     * We assert the following:
     * - Fork node must have >= 1 Reduce nodes
     * - (Later in code) If the returned data type is a collection, it must have
     *   originated from (potentially multiple) parameter(s).
     *
     * We don't assert but assume the following:
     * - max_num_blocks in KernelParams is within constraint of 1D grid size. This
     *   can be relaxed if we want to support larger grids.
     * - Product types are packed with padding inserted for each element to
     *   be aligned for its type and for full product to be aligned to its
     *   largest element
     * - Summation types must be aligned to their largest element
     *
     * Notes on GPU parallelization strategy and tips for IR transformations:
     * - The top level block fork and any lower thread forks require a known Fork
     *   size. Thus for an otherwise parallelizable Fork with unknown size,
     *   consider splitting it into two Forks with one of known size. For block
     *   level, the known fork has to be the (only) top-most fork.
     * - The thread-level strategy is determined by starting at the most nested
     *   Forks and working outwards in a greedy manner, with caps by GPU spec.
     *   Thus, to ensure some outer Fork is parallelized, ensure the inner
     *   parallelizable Forks aren't too large or consider removing schedule
     *   annotations.
     * - Tight-Associative reductions can only be efficiently implemented if
     *   different Hercules ThreadIDs correspond to consecutive CUDA threads. But
     *   this prevents nested parallelization since each parallel group must always
     *   be a contiguous tile of threads. We use a heuristic of choosing the larger
     *   factor when this results in a conflict between a Fork and it's subtree,
     *   but this choice may not be optimal.
     * - A given Fork (not talking about its children) can only be parallelized
     *   if all its Reduces are Parallel-Reduce or Tight-Associative. So if the
     *   Fork contains expensive parallelizable operations, ensure all reductions
     *   are parallelizable or if not try pulling those out into a different Fork.
     * - We do nothing to mitigate intra-warp divergence. To mitigate this, the
     *   IR, for example, should ensure the innermost parallelizable Forks either
     *   have factor >= warp size (32) or remove Fork/Reduce node schedule
     *   annotations.
     *
     * Main TODOs:
     * - Fix dynamic shared memory allocation to reuse old shmem. The main case
     *   for improvement is when we have serialized forks with unused intermediate
     *   values from previous iterations.
     * - Add mapping from Region node to Fork node if there's a reduce whose control
     *   is a Region not Join.
     * - Matmul/Conv detection
     * - Add float8, float16, bfloat16 dtypes if they come
     */

    let reduce_nodes: Vec<NodeID> = (0..function.nodes.len())
        .filter(|idx| function.nodes[*idx].is_reduce())
        .map(NodeID::new)
        .collect();

    let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
        .iter()
        .map(|(fork, join)| (*join, *fork))
        .collect();
    // Fork Reduce map should have all reduces contained in some key
    let mut fork_reduce_map: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
    // Reduct Reduce map should have all non-parallel and non-associative reduces
    // contained in some key. Unlike Fork, Reduct is not involved in any assertions.
    // It's placed here for convenience but can be moved.
    let mut reduct_reduce_map: HashMap<NodeID, Vec<NodeID>> = HashMap::new();
    for reduce_node in &reduce_nodes {
        if let Node::Reduce {
            control,
            init: _,
            reduct,
        } = &function.nodes[reduce_node.idx()]
        {
            match function.nodes[control.idx()] {
                Node::Join { .. } => {
                    let fork_node = join_fork_map[control];
                    fork_reduce_map
                        .entry(fork_node)
                        .or_default()
                        .push(*reduce_node);
                }
                Node::Region { preds: _ } => {
                    // TODO: map region node to fork node
                }
                _ => {
                    panic!("Reduce's control must be a join or region node");
                }
            }
            reduct_reduce_map
                .entry(*reduct)
                .or_default()
                .push(*reduce_node);
        }
    }
    for idx in 0..function.nodes.len() {
        if function.nodes[idx].is_fork() {
            assert!(
                fork_reduce_map
                    .get(&NodeID::new(idx))
                    .is_some_and(|reduces| !reduces.is_empty()),
                "Fork node {} has no reduce nodes",
                idx
            );
        }
    }

    // Map from control to pairs of data to update phi
    // For each phi, we go to its region and get region's controls
    let mut control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>> = HashMap::new();
    for (idx, node) in function.nodes.iter().enumerate() {
        if let Node::Phi { control, data } = node {
            let Node::Region { preds } = &function.nodes[control.idx()] else {
                panic!("Phi's control must be a region node");
            };
            for (i, &pred) in preds.iter().enumerate() {
                control_data_phi_map
                    .entry(pred)
                    .or_default()
                    .push((data[i], NodeID::new(idx)));
            }
        }
    }

    // Tracks for each return value whether it is always the same parameter
    // collection
    let return_parameters = (0..function.return_types.len())
        .map(|idx| {
            if collection_objects.returned_objects(idx).len() == 1 {
                collection_objects
                    .origin(*collection_objects.returned_objects(idx).first().unwrap())
                    .try_parameter()
            } else {
                None
            }
        })
        .collect::<Vec<_>>();

    let kernel_params = &GPUKernelParams {
        max_num_threads: 1024,
        threads_per_warp: 32,
    };

    let ctx = GPUContext {
        module_name,
        function,
        types,
        constants,
        dynamic_constants,
        typing,
        control_subgraph,
        bbs,
        backing_allocation,
        collection_objects,
        def_use_map,
        fork_join_map,
        fork_control_map,
        fork_tree,
        join_fork_map,
        fork_reduce_map,
        reduct_reduce_map,
        control_data_phi_map,
        return_parameters,
        kernel_params,
    };
    ctx.codegen_function(w)
}

struct GPUKernelParams {
    max_num_threads: usize,
    threads_per_warp: usize,
}

struct GPUContext<'a> {
    module_name: &'a str,
    function: &'a Function,
    types: &'a Vec<Type>,
    constants: &'a Vec<Constant>,
    dynamic_constants: &'a Vec<DynamicConstant>,
    typing: &'a Vec<TypeID>,
    control_subgraph: &'a Subgraph,
    bbs: &'a BasicBlocks,
    backing_allocation: &'a FunctionBackingAllocation,
    collection_objects: &'a FunctionCollectionObjects,
    def_use_map: &'a ImmutableDefUseMap,
    fork_join_map: &'a HashMap<NodeID, NodeID>,
    fork_control_map: &'a HashMap<NodeID, HashSet<NodeID>>,
    fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>,
    join_fork_map: HashMap<NodeID, NodeID>,
    fork_reduce_map: HashMap<NodeID, Vec<NodeID>>,
    reduct_reduce_map: HashMap<NodeID, Vec<NodeID>>,
    control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
    return_parameters: Vec<Option<usize>>,
    kernel_params: &'a GPUKernelParams,
}

/*
 * For all control nodes besides forks, Init, Body, and Term compose the main basic
 * block, with Init and Term populated by control flow (Init used only by Fork and
 * Join) and Body populated by data flow.
 * For serialized Fork nodes which may be jumped back to by corresponding Join node,
 * init and post_init separate one-time code (currently just cooperative group
 * creation) from repeated code.
 */
#[derive(Default, Debug)]
struct CudaGoto {
    init: String,
    post_init: String,
    body: String,
    term: String,
}

/*
 * KernelState is used for data and control node organization and generation.
 * We define a block fork as one with each ThreadID being a block, and a thread
 * fork as one with each ThreadID being a subset of threads within a block.
 * OutBlock is outside a potential block fork at the full grid level, InBlock
 * is inside a block fork but outside any thread forks, and InThread is inside
 * a thread fork.
 */
#[derive(Clone, Copy, PartialEq, Debug)]
enum KernelState {
    OutBlock,
    InBlock,
    InThread,
}

/*
 * CGType is used to track cooperative group types. UsePerId is the group of (CUDA)
 * threads for a current ThreadID, Use is the union of such threads for all ThreadIDs
 * in the current innermost Fork, and Available is Use plus additional threads not
 * used in the current Fork.
 */
#[derive(Clone, Copy, PartialEq, Debug)]
enum CGType {
    UsePerId,
    Use,
    Available,
}

impl GPUContext<'_> {
    fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
        // Emit all code up to the "goto" to Start's block
        let mut top = String::new();
        self.codegen_kernel_preamble(&mut top)?;
        self.codegen_return_struct(&mut top)?;
        self.codegen_kernel_begin(&mut top)?;
        let mut dynamic_shared_offset = "0".to_string();
        self.codegen_dynamic_constants(&mut top)?;
        self.codegen_declare_data(&mut top)?;
        self.codegen_helpers(&mut top)?;
        write!(w, "{}", top)?;

        // Setup for CUDA's "goto" for control flow between basic blocks.
        let mut gotos: BTreeMap<_, _> = (0..self.function.nodes.len())
            .filter(|idx| self.function.nodes[*idx].is_control())
            .map(|idx| {
                let node_id = NodeID::new(idx);
                let goto = CudaGoto::default();
                (node_id, goto)
            })
            .collect();
        let mut thread_block_tiles = String::new();

        // If there are no forks, fast forward to single-block, single-thread codegen
        let (num_blocks, num_threads) = if self.fork_join_map.is_empty() {
            self.codegen_data_control_no_forks(
                &mut dynamic_shared_offset,
                &mut thread_block_tiles,
                &mut gotos,
            )?;
            ("1".to_string(), "1".to_string())
        } else {
            // Create structures and determine block and thread parallelization strategy
            let (root_forks, num_blocks, is_block_parallel) =
                self.get_root_forks_and_num_blocks(self.fork_tree);
            let (thread_root_root_fork, thread_root_forks) =
                self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel);
            let (fork_thread_quota_map, num_threads) =
                self.get_thread_quotas(self.fork_tree, thread_root_root_fork);

            // Core function for the CUDA code of all data and control nodes.
            self.codegen_data_control(
                if is_block_parallel {
                    Some(thread_root_root_fork)
                } else {
                    None
                },
                &thread_root_forks,
                &fork_thread_quota_map,
                &mut dynamic_shared_offset,
                is_block_parallel,
                num_threads,
                &mut thread_block_tiles,
                &mut gotos,
            )?;
            (num_blocks, num_threads.to_string())
        };

        // Emit all GPU kernel code from previous steps
        self.codegen_goto_start(&mut thread_block_tiles)?;
        write!(w, "{}", thread_block_tiles)?;
        let mut kernel_body = String::new();
        let rev_po = self.control_subgraph.rev_po(NodeID::new(0));
        write!(w, "\n")?;
        self.codegen_goto(false, &mut gotos, NodeID::new(0), &mut kernel_body)?;
        self.codegen_gotos(false, &mut gotos, &rev_po, NodeID::new(0), &mut kernel_body)?;
        write!(w, "{}", kernel_body)?;
        write!(w, "}}\n")?;

        // Emit host launch code
        let mut host_launch = String::new();
        self.codegen_launch_code(
            num_blocks,
            num_threads,
            &dynamic_shared_offset,
            &mut host_launch,
        )?;
        write!(w, "{}", host_launch)?;

        Ok(())
    }

    fn codegen_kernel_preamble<W: Write>(&self, w: &mut W) -> Result<(), Error> {
        write!(
            w,
            "
#include <assert.h>
#include <stdio.h>
#include <stddef.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math_constants.h>
#include <mma.h>

#if (CUDA_VERSION >= 12000)
#else
#define _CG_ABI_EXPERIMENTAL
#endif

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

#if (CUDA_VERSION >= 12000)
namespace cg = cooperative_groups;
namespace cge = cooperative_groups;
#else
namespace cg = cooperative_groups;
namespace cge = cooperative_groups::experimental;
#endif

#include <cuda_bf16.h>
namespace cg = cooperative_groups;

#define uabs(a) (a)
#define umin(a, b) ((a) < (b) ? (a) : (b))
#define umax(a, b) ((a) > (b) ? (a) : (b))
#define powi(a, b) ({{ int res = 1; for(int i = 0; i < b; ++i) res *= a; res; }})
#define roundi(a) (a)
#define isqrt(a) ((int)sqrtf((float)(a)))

",
        )
    }

    fn codegen_return_struct<W: Write>(&self, w: &mut W) -> Result<(), Error> {
        write!(
            w,
            "struct return_{} {{ {} }};\n",
            self.function.name,
            self.function
                .return_types
                .iter()
                .enumerate()
                .map(|(idx, typ)| format!("{} f{};", self.get_type(*typ, false), idx))
                .collect::<Vec<_>>()
                .join(" "),
        )
    }

    /*
     * Emit kernel signature, arguments, and dynamic shared memory declaration
     */
    fn codegen_kernel_begin<W: Write>(&self, w: &mut W) -> Result<(), Error> {
        write!(
            w,
            "__global__ void __launch_bounds__({}) {}_{}_gpu(",
            self.kernel_params.max_num_threads, self.module_name, self.function.name
        )?;
        let mut first_param = true;
        // The first parameter is a pointer to GPU backing memory, if it's
        // needed.
        if self.backing_allocation.contains_key(&Device::CUDA) {
            first_param = false;
            write!(w, "char* backing")?;
        }
        // The second set of parameters are dynamic constants.
        for idx in 0..self.function.num_dynamic_constants {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
            }
            write!(w, "unsigned long long dc_p{}", idx)?;
        }
        // The third set of parameters are normal arguments.
        for (idx, ty) in self.function.param_types.iter().enumerate() {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
            }
            let param_type = if self.types[ty.idx()].is_primitive() {
                self.get_type(*ty, false)
            } else {
                format!("{} __restrict__", self.get_type(*ty, false))
            };
            write!(w, "{} p{}", param_type, idx)?;
        }
        let ret_fields = self
            .return_parameters
            .iter()
            .enumerate()
            .filter_map(|(idx, param)| {
                if param.is_some() {
                    None
                } else {
                    Some((idx, self.function.return_types[idx]))
                }
            })
            .collect::<Vec<(usize, TypeID)>>();
        if !ret_fields.is_empty() {
            if !first_param {
                write!(w, ", ")?;
            }
            write!(w, "return_{}* __restrict__ ret", self.function.name)?;
        }

        // Type is char since it's simplest to use single bytes for indexing
        // and it's required for heterogeneous Product and Summation types.
        // Casting is later used for conversion to different types like int.
        write!(w, ") {{\n")?;
        write!(w, "\textern __shared__ char dynamic_shared[];\n")?;
        // This will only get used by thread rank 0 in each block, since it
        // does all shared memory "allocation". The actual state is preserved
        // in Rust string and this offset is assigned to for ease of readability.
        write!(w, "\tuint64_t dynamic_shared_offset;\n")?;

        Ok(())
    }

    /*
     * Emit calculation of all dynamic constants
     */
    fn codegen_dynamic_constants(&self, w: &mut String) -> Result<(), Error> {
        for dc in dynamic_constants_bottom_up(&self.dynamic_constants) {
            let dc_val = format!("unsigned long long dc{}", dc.idx());
            match &self.dynamic_constants[dc.idx()] {
                DynamicConstant::Constant(val) => write!(w, "\t{} = {}ull;\n", dc_val, val)?,
                DynamicConstant::Parameter(idx) => {
                    if *idx < self.function.num_dynamic_constants as usize {
                        write!(w, "\t{} = dc_p{};\n", dc_val, idx)?
                    } else {
                        write!(w, "\t{} = 0;\n", dc_val)?
                    }
                }
                DynamicConstant::Add(args) => {
                    let rhs = args
                        .iter()
                        .map(|arg| format!("dc{}", arg.idx()))
                        .collect::<Vec<_>>()
                        .join(" + ");
                    write!(w, "\t{} = {};\n", dc_val, rhs)?
                }
                DynamicConstant::Mul(args) => {
                    let rhs = args
                        .iter()
                        .map(|arg| format!("dc{}", arg.idx()))
                        .collect::<Vec<_>>()
                        .join(" * ");
                    write!(w, "\t{} = {};\n", dc_val, rhs)?
                }
                DynamicConstant::Min(args) => {
                    let rhs_but_last: String = args
                        .iter()
                        .take(args.len() - 1)
                        .map(|arg| format!("min(dc{}, ", arg.idx()))
                        .collect();
                    let rhs_last = format!("dc{}", args.last().unwrap().idx());
                    let rhs_end: String = std::iter::repeat(")").take(args.len() - 1).collect();
                    write!(
                        w,
                        "\t{} = {}{}{};\n",
                        dc_val, rhs_but_last, rhs_last, rhs_end
                    )?
                }
                DynamicConstant::Max(args) => {
                    let rhs_but_last: String = args
                        .iter()
                        .take(args.len() - 1)
                        .map(|arg| format!("max(dc{}, ", arg.idx()))
                        .collect();
                    let rhs_last = format!("dc{}", args.last().unwrap().idx());
                    let rhs_end: String = std::iter::repeat(")").take(args.len() - 1).collect();
                    write!(
                        w,
                        "\t{} = {}{}{};\n",
                        dc_val, rhs_but_last, rhs_last, rhs_end
                    )?
                }
                DynamicConstant::Sub(left, right) => {
                    write!(w, "\t{} = dc{} - dc{};\n", dc_val, left.idx(), right.idx())?
                }
                DynamicConstant::Div(left, right) => {
                    write!(w, "\t{} = dc{} / dc{};\n", dc_val, left.idx(), right.idx())?
                }
                DynamicConstant::Rem(left, right) => {
                    write!(w, "\t{} = dc{} % dc{};\n", dc_val, left.idx(), right.idx())?
                }
            }
        }
        Ok(())
    }

    /*
     * To abide by c++ reassignment restrictions, we declare all data values
     * upfront.
     */
    fn codegen_declare_data(&self, w: &mut String) -> Result<(), Error> {
        for id in (0..self.function.nodes.len()).map(NodeID::new) {
            if !self.function.nodes[id.idx()].is_control()
                && !self.function.nodes[id.idx()].is_dynamic_constant()
                && !self.function.nodes[id.idx()].is_parameter()
            {
                write!(w, "\t{};\n", self.get_value(id, true, false))?;
            }
            if self.function.nodes[id.idx()].is_phi() {
                write!(w, "\t{}_tmp;\n", self.get_value(id, true, false))?;
            }
        }
        Ok(())
    }

    /*
     * Emit helper registers that are used throughout the kernel. grid and block
     * are from CUDA's cooperative groups API and are used specifically for reads
     * and writes.
     */
    fn codegen_helpers(&self, w: &mut String) -> Result<(), Error> {
        write!(
            w,
            "\t__shared__ cge::block_tile_memory<1024> block_sync_shared;\n"
        )?;
        write!(w, "\tcg::grid_group grid = cg::this_grid();\n")?;
        write!(
            w,
            "\tcg::thread_block block = cge::this_thread_block(block_sync_shared);\n"
        )?;
        Ok(())
    }

    fn codegen_goto_start(&self, w: &mut String) -> Result<(), Error> {
        let block_start = self.get_block_name(NodeID::new(0), false);
        write!(w, "\tgoto {};\n", block_start)?;
        Ok(())
    }

    fn codegen_gotos(
        &self,
        goto_debug: bool,
        gotos: &mut BTreeMap<NodeID, CudaGoto>,
        rev_po: &Vec<NodeID>,
        root: NodeID,
        w: &mut String,
    ) -> Result<(), Error> {
        // Print the blocks in a kind of silly way to avoid errors aroun
        // initialization of fork variables and gotos.
        let mut not_forks = vec![];
        let mut forks = vec![];
        let not_fork_controls = &self.fork_control_map[&root];
        for bb in rev_po
            .into_iter()
            .filter(|id| not_fork_controls.contains(id) && **id != root)
        {
            not_forks.push(*bb);
        }
        if let Some(fork_controls) = &self.fork_tree.get(&root) {
            for bb in rev_po
                .into_iter()
                .filter(|id| fork_controls.contains(id) && **id != root)
            {
                forks.push(*bb);
            }
        }
        for id in not_forks {
            self.codegen_goto(goto_debug, gotos, id, w)?;
        }
        for root in forks {
            self.codegen_goto(goto_debug, gotos, root, w)?;
            self.codegen_gotos(goto_debug, gotos, rev_po, root, w)?;
        }
        Ok(())
    }

    fn codegen_goto(
        &self,
        goto_debug: bool,
        gotos: &mut BTreeMap<NodeID, CudaGoto>,
        bb: NodeID,
        w: &mut String,
    ) -> Result<(), Error> {
        let goto = &gotos[&bb];
        let goto_block = self.get_block_name(bb, false);
        write!(w, "{}:\n", goto_block)?;
        if goto_debug {
            write!(w, "\tprintf(\"goto {}\\n\");\n", goto_block)?;
        }
        write!(w, "{}", goto.init)?;
        if !goto.post_init.is_empty() {
            let goto_block = self.get_block_name(bb, true);
            write!(w, "{}:\n", goto_block)?;
            write!(w, "{}", goto.post_init)?;
        }
        write!(w, "{}", goto.body)?;
        write!(w, "{}\n", goto.term)
    }

    fn codegen_launch_code(
        &self,
        num_blocks: String,
        num_threads: String,
        dynamic_shared_offset: &str,
        w: &mut String,
    ) -> Result<(), Error> {
        let mut pass_args = String::new();

        let is_multi_return = self.function.return_types.len() != 1;
        write!(w, "extern \"C\" ")?;
        if is_multi_return {
            write!(w, "void")?;
        } else {
            write!(w, "{}", self.get_type(self.function.return_types[0], false))?;
        }
        write!(w, " {}_{}(", self.module_name, self.function.name)?;

        let mut first_param = true;
        // The first parameter is a pointer to GPU backing memory, if it's
        // needed.
        if self.backing_allocation.contains_key(&Device::CUDA) {
            first_param = false;
            write!(w, "char* backing")?;
            write!(pass_args, "backing")?;
        }
        // The second set of parameters are dynamic constants.
        for idx in 0..self.function.num_dynamic_constants {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
                write!(pass_args, ", ")?;
            }
            write!(w, "unsigned long long dc_p{}", idx)?;
            write!(pass_args, "dc_p{}", idx)?;
        }
        // The third set of parameters are normal arguments.
        for (idx, ty) in self.function.param_types.iter().enumerate() {
            if first_param {
                first_param = false;
            } else {
                write!(w, ", ")?;
                write!(pass_args, ", ")?;
            }
            let param_type = self.get_type(*ty, false);
            write!(w, "{} p{}", param_type, idx)?;
            write!(pass_args, "p{}", idx)?;
        }
        // If the function is multi-return, the last argument is the return pointer
        // This is a CPU pointer, we will allocate a separate pointer used for the kernel's return
        // arguments (if any)
        if is_multi_return {
            if !first_param {
                write!(w, ", ")?;
            }
            write!(w, "return_{}* ret_ptr", self.function.name)?;
        }
        write!(w, ") {{\n")?;
        // For case of dynamic block count
        self.codegen_dynamic_constants(w)?;

        let (kernel_returns, param_returns) = self.return_parameters.iter().enumerate().fold(
            (vec![], vec![]),
            |(mut kernel_returns, mut param_returns), (idx, param)| {
                if let Some(param_idx) = param {
                    param_returns.push((idx, param_idx));
                } else {
                    kernel_returns.push((idx, self.function.return_types[idx]));
                }
                (kernel_returns, param_returns)
            },
        );

        if !kernel_returns.is_empty() {
            // Allocate kernel return struct
            write!(w, "\treturn_{}* ret_cuda;\n", self.function.name)?;
            write!(
                w,
                "\tcudaMalloc((void**)&ret_cuda, sizeof(return_{}));\n",
                self.function.name
            )?;
            // Add the return pointer to the kernel arguments
            if !first_param {
                write!(pass_args, ", ")?;
            }
            write!(pass_args, "ret_cuda")?;
        }

        write!(w, "\tcudaError_t err;\n")?;
        write!(
            w,
            "\t{}_{}_gpu<<<{}, {}, {}>>>({});\n",
            self.module_name,
            self.function.name,
            num_blocks,
            num_threads,
            dynamic_shared_offset,
            pass_args
        )?;
        write!(w, "\terr = cudaGetLastError();\n")?;
        write!(
            w,
            "\tif (cudaSuccess != err) {{ printf(\"Error1: %s\\n\", cudaGetErrorString(err)); }}\n"
        )?;

        if !is_multi_return {
            if kernel_returns.is_empty() {
                // A single return of a parameter, we can just return it directly
                write!(w, "\treturn p{};\n", param_returns[0].1)?;
            } else {
                // A single return of a value computed on the device, we create a stack allocation
                // and retrieve the value from the device and then return it
                write!(w, "\t return_{} ret_host;\n", self.function.name)?;
                write!(w,
                    "\tcudaMemcpy(&ret_host, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n",
                    self.function.name,
                )?;
                write!(w, "\treturn ret_host.f0;\n")?;
            }
        } else {
            // Multi return is handle via an output pointer provided to this function
            // If there are kernel returns then we copy those back from the device and then fill in
            // the parameter returns
            if !kernel_returns.is_empty() {
                // Copy from the device directly into the output struct
                write!(
                    w,
                    "\tcudaMemcpy(ret_ptr, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n",
                    self.function.name,
                )?;
            }
            for (field_idx, param_idx) in param_returns {
                write!(w, "\tret_ptr->f{} = p{};\n", field_idx, param_idx)?;
            }
            write!(w, "\treturn;\n")?;
        }

        write!(w, "}}\n")?;
        Ok(())
    }

    /*
     * If tree has a single root fork of known size s <= max_num_blocks
     * with parallel-fork schedule, then set num_blocks to s, else set num_blocks
     * to 1. Also return the root fork(s) for parallelization strategy within
     * threadblocks for threads and their eventual generation.
     */
    fn get_root_forks_and_num_blocks(
        &self,
        fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
    ) -> (HashSet<NodeID>, String, bool) {
        let root_forks: HashSet<NodeID> = fork_tree.get(&NodeID::new(0)).unwrap().clone();
        if root_forks.len() != 1 {
            return (root_forks, "1".to_string(), false);
        }

        let root_fork = root_forks.iter().next().unwrap();
        let Node::Fork { factors, .. } = &self.function.nodes[root_fork.idx()] else {
            panic!("Expected fork node");
        };
        let _reduces = &self.fork_reduce_map[root_fork];
        if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) {
            let fork_size = factors
                .iter()
                .map(|dc| format!("dc{}", dc.idx()))
                .collect::<Vec<_>>()
                .join(" * ");
            (root_forks, fork_size, true)
        } else {
            (root_forks, "1".to_string(), false)
        }
    }

    /*
     * If there's a block fork, then thread root forks are it's child forks. If
     * not, thread root forks are the root forks. This will be used to begin the
     * thread fork tree traversal for codegen.
     */
    fn get_thread_root_forks(
        &self,
        root_forks: &HashSet<NodeID>,
        fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
        is_block_parallel: bool,
    ) -> (NodeID, HashSet<NodeID>) {
        if is_block_parallel {
            let root_fork = root_forks.iter().next().unwrap();
            (
                *root_fork,
                fork_tree.get(&root_fork).unwrap().iter().copied().collect(),
            )
        } else {
            (NodeID::new(0), root_forks.clone())
        }
    }

    /*
     * This analysis determines the parallelization strategy within threadblocks.
     * We run post-order traversal on the fork tree to get the thread quota per
     * subtree. In particular, each fork starts with a base factor as the
     * maximum over its descendants (leafs have base 1). We traverse up (details
     * in helper) and pass the factor and a map from fork node to a tuple of
     * (max quota of its siblings (including itself), its quota, its fork factor)
     * from each node to its parents. The parent then compares the received quota
     * of its subtree vs just it's own. If it's associative, it chooses the larger
     * of the two, if not it can parallelize both if applicable and if they fit.
     *
     * Finally, the map is returned such that a node is in the map IFF it will
     * be parallelized. If not, the fork will use the parent's quota and serialize
     * over the Fork's ThreadIDs. Nodes may be removed from the map when traversing
     * up the tree due to conflicting (due to associative or limit) ancestor of
     * larger factor.
     */
    fn get_thread_quotas(
        &self,
        fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
        root_fork: NodeID,
    ) -> (HashMap<NodeID, (usize, usize, usize)>, usize) {
        let (tree_map, tree_quota, _) = self.recurse_thread_quotas(root_fork, fork_tree, true);
        (tree_map, tree_quota)
    }

    fn recurse_thread_quotas(
        &self,
        curr_fork: NodeID,
        fork_tree: &HashMap<NodeID, HashSet<NodeID>>,
        is_root: bool,
    ) -> (HashMap<NodeID, (usize, usize, usize)>, usize, bool) {
        // Subsubtree map is the union of all keys for grandchildren and lower
        // nodes. children_quota_map is a constructed map from parallelized children
        // to their quota to update the subsubtree map at grandchildren level to
        // subtreemap at children level. subtree_quota is cumulative factor of
        // subtree and is then compared to this fork's factor.
        let (mut subsubtree_map, children_quota_map, subtree_quota) = fork_tree
            .get(&curr_fork)
            .unwrap()
            .iter()
            .map(|child| (child, self.recurse_thread_quotas(*child, fork_tree, false)))
            .fold(
                (HashMap::new(), HashMap::new(), 1),
                |(mut subsubtree_map, mut children_quota_map, subtree_quota),
                 (child, (curr_map, curr_quota, use_curr))| {
                    subsubtree_map.extend(curr_map);
                    if use_curr {
                        children_quota_map.insert(child, curr_quota);
                    }
                    (
                        subsubtree_map,
                        children_quota_map,
                        subtree_quota.max(curr_quota),
                    )
                },
            );
        // First update children_quota_map items with full information and add
        // to subsubtree_map
        for (&child, quota) in children_quota_map.iter() {
            let Node::Fork { factors, .. } = &self.function.nodes[child.idx()] else {
                panic!("Expected fork node");
            };
            let fork_size = self.multiply_fork_factors(factors).unwrap();
            subsubtree_map.insert(*child, (subtree_quota, *quota, fork_size));
        }
        let subtree_map = subsubtree_map;
        if is_root {
            return (subtree_map, subtree_quota, true);
        }
        // A node can only be considered for parallelization if:
        // a) it has statically known size
        // b) the known size is less than or equal to the max_num_threads
        // c) the known size is a power of 2
        // d) all reduces are parallel-reduce or associative
        //
        // If not, just take the max cumulative factor of its subtree
        let reduces = &self.fork_reduce_map[&curr_fork];
        if let Node::Fork { factors, .. } = &self.function.nodes[curr_fork.idx()]
            && let Some(fork_size) = self.multiply_fork_factors(factors)
            && fork_size <= self.kernel_params.max_num_threads
            && fork_size.is_power_of_two()
            && reduces.iter().all(|&reduce| {
                self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
                    || self.function.schedules[reduce.idx()].contains(&Schedule::MonoidReduce)
            })
        {
            // If there's an associative Reduce, parallelize the larger factor
            // between the Fork and subtree
            // Else, all Reduces must be only parallel-reduce, so parallelize
            // both if they fit and the larger if not.
            // The reason for this distinction is that we only perform Reduces over
            // ThreadID-based values over consecutive CUDA threads, so there's no
            // opportunity for further nested parallelization. In contrast, this
            // restriction doesn't help for parallel Writes, so nested parallelization
            // is possible.
            if reduces.iter().any(|&reduce| {
                self.function.schedules[reduce.idx()].contains(&Schedule::MonoidReduce)
            }) || fork_size > self.kernel_params.max_num_threads / subtree_quota
            {
                if fork_size >= subtree_quota {
                    (HashMap::new(), fork_size, true)
                } else {
                    (subtree_map, subtree_quota, false)
                }
            } else {
                (subtree_map, fork_size * subtree_quota, true)
            }
        } else {
            (subtree_map, subtree_quota, false)
        }
    }

    fn codegen_data_control_no_forks(
        &self,
        dynamic_shared_offset: &mut String,
        thread_block_tiles: &mut String,
        gotos: &mut BTreeMap<NodeID, CudaGoto>,
    ) -> Result<(), Error> {
        (0..self.function.nodes.len())
            .filter(|idx| self.function.nodes[*idx].is_control())
            .try_for_each(|idx| -> Result<(), Error> {
                let control = NodeID::new(idx);
                let goto = gotos.get_mut(&control).unwrap();
                let init = &mut goto.init;
                let post_init = &mut goto.post_init;
                let body = &mut goto.body;
                let term = &mut goto.term;
                let mut tabs = self.codegen_control_node(
                    control,
                    None,
                    None,
                    None,
                    thread_block_tiles,
                    init,
                    post_init,
                    term,
                )?;
                for data in self.bbs.1[control.idx()].iter() {
                    self.codegen_data_node(
                        *data,
                        KernelState::OutBlock,
                        Some(false),
                        None,
                        None,
                        None,
                        false,
                        dynamic_shared_offset,
                        body,
                        &mut tabs,
                    )?;
                }
                Ok(())
            })
    }

    /*
     * Codegen for all control and data nodes.
     */
    fn codegen_data_control(
        &self,
        block_fork: Option<NodeID>,
        thread_root_forks: &HashSet<NodeID>,
        fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
        dynamic_shared_offset: &mut String,
        is_block_parallel: bool,
        num_threads: usize,
        thread_block_tiles: &mut String,
        gotos: &mut BTreeMap<NodeID, CudaGoto>,
    ) -> Result<(), Error> {
        // First emit data and control gen for each control node outside any fork.
        // Recall that this was tracked through a fake fork node with NodeID 0.
        let mut state = KernelState::OutBlock;
        for control in self.fork_control_map.get(&NodeID::new(0)).unwrap() {
            let goto = gotos.get_mut(control).unwrap();
            let init = &mut goto.init;
            let post_init = &mut goto.post_init;
            let body = &mut goto.body;
            let term = &mut goto.term;
            let mut tabs = self.codegen_control_node(
                *control,
                None,
                None,
                None,
                thread_block_tiles,
                init,
                post_init,
                term,
            )?;
            for data in self.bbs.1[control.idx()].iter() {
                self.codegen_data_node(
                    *data,
                    state,
                    Some(is_block_parallel),
                    None,
                    None,
                    None,
                    false,
                    dynamic_shared_offset,
                    body,
                    &mut tabs,
                )?;
            }
        }
        // Then generate data and control for the single block fork if it exists
        if block_fork.is_some() {
            state = KernelState::InBlock;
            for control in self.fork_control_map.get(&block_fork.unwrap()).unwrap() {
                let goto = gotos.get_mut(control).unwrap();
                let init = &mut goto.init;
                let post_init = &mut goto.post_init;
                let body = &mut goto.body;
                let term = &mut goto.term;
                let mut tabs = self.codegen_control_node(
                    *control,
                    Some(num_threads),
                    Some(num_threads),
                    Some(1),
                    thread_block_tiles,
                    init,
                    post_init,
                    term,
                )?;
                for data in self.bbs.1[control.idx()].iter() {
                    self.codegen_data_node(
                        *data,
                        state,
                        None,
                        Some(num_threads),
                        None,
                        Some(block_fork.unwrap()),
                        false,
                        dynamic_shared_offset,
                        body,
                        &mut tabs,
                    )?;
                }
            }
        }
        // Then generate for the thread fork tree through Fork node traversal.
        state = KernelState::InThread;
        for &root_fork in thread_root_forks {
            self.codegen_data_control_traverse(
                root_fork,
                state,
                fork_thread_quota_map,
                1,
                num_threads,
                dynamic_shared_offset,
                thread_block_tiles,
                gotos,
            )?;
        }
        Ok(())
    }

    /*
     * The important feature of this traversal is that we update the available
     * thread quota, use thread quota, and parallel factor for each Fork node.
     * Either this information is in the precomputed map, or we use the parent's
     * quota with no parallel factor.
     */
    fn codegen_data_control_traverse(
        &self,
        curr_fork: NodeID,
        state: KernelState,
        fork_thread_quota_map: &HashMap<NodeID, (usize, usize, usize)>,
        parent_quota: usize,
        num_threads: usize,
        dynamic_shared_offset: &mut String,
        thread_block_tiles: &mut String,
        gotos: &mut BTreeMap<NodeID, CudaGoto>,
    ) -> Result<(), Error> {
        let (available_thread_quota, use_thread_quota, parallel_factor) = fork_thread_quota_map
            .get(&curr_fork)
            .map(|(a, u, f)| (*a, *u, Some(*f)))
            .unwrap_or((parent_quota, parent_quota, None));
        let reduces = &self.fork_reduce_map[&curr_fork];
        let reducts = if parallel_factor.is_some() {
            reduces
                .iter()
                .map(|&reduce| {
                    let Node::Reduce {
                        control: _,
                        init: _,
                        reduct,
                    } = &self.function.nodes[reduce.idx()]
                    else {
                        panic!("Expected reduce node");
                    };
                    *reduct
                })
                .collect()
        } else {
            HashSet::new()
        };
        for control in self.fork_control_map.get(&curr_fork).unwrap() {
            let goto = gotos.get_mut(control).unwrap();
            let init = &mut goto.init;
            let post_init = &mut goto.post_init;
            let body = &mut goto.body;
            let term = &mut goto.term;
            let mut tabs = self.codegen_control_node(
                *control,
                Some(available_thread_quota),
                Some(use_thread_quota),
                parallel_factor,
                thread_block_tiles,
                init,
                post_init,
                term,
            )?;
            for data in self.bbs.1[control.idx()].iter() {
                self.codegen_data_node(
                    *data,
                    state,
                    None,
                    Some(use_thread_quota),
                    parallel_factor,
                    Some(curr_fork),
                    reducts.contains(data),
                    dynamic_shared_offset,
                    body,
                    &mut tabs,
                )?;
            }
        }
        for child in self.fork_tree.get(&curr_fork).unwrap() {
            self.codegen_data_control_traverse(
                *child,
                state,
                fork_thread_quota_map,
                use_thread_quota,
                num_threads,
                dynamic_shared_offset,
                thread_block_tiles,
                gotos,
            )?;
        }
        Ok(())
    }

    fn codegen_data_node(
        &self,
        id: NodeID,
        state: KernelState,
        is_block_parallel: Option<bool>,
        use_thread_quota: Option<usize>,
        parallel_factor: Option<usize>,
        nesting_fork: Option<NodeID>,
        is_special_reduct: bool,
        dynamic_shared_offset: &mut String,
        w: &mut String,
        num_tabs: &mut usize,
    ) -> Result<(), Error> {
        let define_variable = self.get_value(id, false, false).to_string();
        let tabs = "\t".repeat(*num_tabs);
        match &self.function.nodes[id.idx()] {
            Node::Phi {
                control: _,
                data: _,
            } => {
                write!(
                    w,
                    "{}{} = {}_tmp;\n",
                    tabs, define_variable, define_variable
                )?;
            }
            Node::ThreadID { control, dimension } => {
                let Node::Fork { factors, .. } = &self.function.nodes[control.idx()] else {
                    panic!("Expected ThreadID's control to be a fork node");
                };
                let divide = multiply_dcs(&factors[dimension + 1..]);
                let modulo = format!("dc{}", factors[*dimension].idx());
                match state {
                    KernelState::InBlock => {
                        write!(
                            w,
                            "{}{} = (blockIdx.x / ({})) % {};\n",
                            tabs, define_variable, divide, modulo
                        )?;
                    }
                    KernelState::InThread => {
                        if parallel_factor.is_none() {
                            // No dependence on threadIdx.x because each used thread
                            // will run this Fork serially
                            let fork_iter = self.get_fork_iter(*control, false);
                            write!(
                                w,
                                "{}{} = ({} / {}) % {};\n",
                                tabs, define_variable, fork_iter, divide, modulo
                            )?;
                        } else {
                            // We can directly use use_thread_quota and not worry about available
                            // because Fork basic block's init section already does gating
                            write!(
                                w,
                                "{}{} = (((threadIdx.x % {}) / {}) / ({})) % ({});\n",
                                tabs,
                                define_variable,
                                use_thread_quota.unwrap(),
                                use_thread_quota.unwrap() / parallel_factor.unwrap(),
                                divide,
                                modulo,
                            )?;
                        }
                    }
                    _ => {
                        panic!("Unsupported state for ThreadID")
                    }
                }
            }
            // The Reduce node only generates it's initialization, as reduct will
            // perform the update. If serialized, add gate to prevent re-assignment
            // when we hit this reduce again due to the control flow loop between
            // the Fork and Join.
            Node::Reduce {
                control: _,
                init,
                reduct: _,
            } => {
                let init_val = self.get_value(*init, false, false);
                if parallel_factor.is_none() && KernelState::InThread == state {
                    let Some(nesting_fork) = nesting_fork else {
                        panic!("Expected reduce to be nested in a fork node");
                    };
                    let fork_iter = self.get_fork_iter(nesting_fork, false);
                    write!(w, "{}if ({} == 0) {{\n", tabs, fork_iter)?;
                    write!(w, "{}\t{} = {};\n", tabs, define_variable, init_val)?;
                    write!(w, "{}}}\n", tabs)?;
                } else {
                    write!(w, "{}{} = {};\n", tabs, define_variable, init_val)?;
                }
            }
            // Parameters emitted at top
            Node::Parameter { index: _ } => {}
            // If the constant is primitive, it's stored in register so we repeat
            // for all threads. Otherwise, it can be inside or outside block fork.
            // If inside, it's stored in shared memory so we "allocate" it once
            // and parallelize memset to 0. If outside, we initialize as offset
            // to backing, but if multi-block grid, don't memset to avoid grid-
            // level sync.
            Node::Constant { id: cons_id } => {
                let is_primitive = self.types[self.typing[id.idx()].idx()].is_primitive();
                let cg_tile = match state {
                    KernelState::OutBlock | KernelState::InBlock => "block".to_string(),
                    KernelState::InThread => {
                        self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId)
                    }
                };
                if !is_primitive && state != KernelState::OutBlock {
                    write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?;
                    *num_tabs += 1;
                }
                if is_primitive || state != KernelState::OutBlock {
                    self.codegen_constant(
                        define_variable.clone(),
                        *cons_id,
                        true,
                        dynamic_shared_offset,
                        w,
                        *num_tabs,
                    )?;
                }
                if !is_primitive && state != KernelState::OutBlock {
                    write!(w, "{}}}\n", tabs)?;
                    //write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
                    *num_tabs -= 1;
                }
                if !is_primitive && state == KernelState::OutBlock {
                    assert!(self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant), "PANIC: The CUDA backend cannot lower a global memory constant that has to be reset to zero. This is because we cannot efficiently implement a memset to the underlying memory of the constant due to the need for a grid level sync. Consider floating this collection outside the CUDA function and into an AsyncRust function, or attaching the NoResetConstant schedule to indicate that no memset is semantically necessary.");
                    let (_, offsets) = &self.backing_allocation[&Device::CUDA];
                    let offset = offsets[&id].0;
                    write!(
                        w,
                        "{}{} = backing + dc{};\n",
                        tabs,
                        define_variable,
                        offset.idx()
                    )?;
                }
                if !is_primitive
                    && (state != KernelState::OutBlock || !is_block_parallel.unwrap_or(false))
                    && !self.function.schedules[id.idx()].contains(&Schedule::NoResetConstant)
                {
                    let data_size = self.get_size(self.typing[id.idx()], None);
                    write!(
                        w,
                        "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n",
                        tabs, cg_tile, data_size, cg_tile
                    )?;
                    write!(w, "{}\t*({} + i) = 0;\n", tabs, define_variable)?;
                    write!(w, "{}}}\n", tabs)?;
                    write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
                    //write!(w, "__syncthreads\n")?;
                }
            }
            // Dynamic constants emitted at top
            Node::DynamicConstant { id: _ } => {}
            Node::Unary { op, input } => match op {
                UnaryOperator::Not => match &self.types[self.typing[input.idx()].idx()] {
                    Type::Boolean => {
                        write!(
                            w,
                            "{}{} = !{};\n",
                            tabs,
                            define_variable,
                            self.get_value(*input, false, false),
                        )?;
                    }
                    ty if ty.is_fixed() => {
                        write!(
                            w,
                            "{}{} = ~{};\n",
                            tabs,
                            define_variable,
                            self.get_value(*input, false, false),
                        )?;
                    }
                    _ => panic!("Unsupported type for not operator"),
                },
                UnaryOperator::Neg => match &self.types[self.typing[input.idx()].idx()] {
                    ty if ty.is_signed() || ty.is_float() => {
                        write!(
                            w,
                            "{}{} = -{};\n",
                            tabs,
                            define_variable,
                            self.get_value(*input, false, false),
                        )?;
                    }
                    _ => {
                        panic!("Unsupported type for neg operator")
                    }
                },
                UnaryOperator::Cast(dst_ty_id) => {
                    write!(
                        w,
                        "{}{} = static_cast<{}>({});\n",
                        tabs,
                        define_variable,
                        self.get_type(*dst_ty_id, false),
                        self.get_value(*input, false, false),
                    )?;
                }
            },
            Node::Binary { op, left, right } => {
                let mut left_val = self.get_value(*left, false, false);
                let mut right_val = self.get_value(*right, false, false);
                let id_type = self.typing[id.idx()];
                if matches!(
                    op,
                    BinaryOperator::Add
                        | BinaryOperator::Or
                        | BinaryOperator::And
                        | BinaryOperator::Xor
                ) && is_special_reduct
                {
                    // For parallelized associative Reduces, use the cooperative
                    // groups reduce API. Associative multiplication is not
                    // supported. We need to use CGType::Use not CGType::UsePerId
                    // because for parallelized reduction we only have one thread
                    // per ThreadID and the reduction is over Use, not UsePerId.
                    let (reduce_val, non_reduce_val) = if let Node::Reduce {
                        control: _,
                        init: _,
                        reduct: _,
                    } = &self.function.nodes[left.idx()]
                    {
                        (left_val, right_val)
                    } else {
                        (right_val, left_val)
                    };
                    // Special reduct is only enabled for thread parallelization
                    // so don't need to worry about grid and block cases
                    let cg_tile = self.get_cg_tile(nesting_fork.unwrap(), CGType::Use);
                    #[allow(unreachable_patterns)]
                    let cg_op = match op {
                        BinaryOperator::Add => "plus",
                        BinaryOperator::Or => "bit_or",
                        BinaryOperator::And => "bit_and",
                        BinaryOperator::Xor => "bit_xor",
                        _ => unreachable!(),
                    };
                    let id_type_name = self.get_type(id_type, false);
                    write!(
                        w,
                        "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n",
                        tabs, define_variable, cg_tile, non_reduce_val, cg_op, id_type_name
                    )?;
                    // Setup binop between reduce's init and reduced reduct. Since it's associative,
                    // we can change binop ordering
                    left_val = define_variable.clone();
                    right_val = reduce_val;
                }
                match (op, &self.types[id_type.idx()]) {
                    (BinaryOperator::Or, Type::Boolean) => write!(
                        w,
                        "{}{} = {} || {};\n",
                        tabs, define_variable, left_val, right_val,
                    )?,
                    (BinaryOperator::And, Type::Boolean) => write!(
                        w,
                        "{}{} = {} && {};\n",
                        tabs, define_variable, left_val, right_val,
                    )?,
                    (BinaryOperator::Rem, Type::Float32) => write!(
                        w,
                        "{}{} = fmodf({}, {});\n",
                        tabs, define_variable, left_val, right_val,
                    )?,
                    (BinaryOperator::Rem, Type::Float64) => write!(
                        w,
                        "{}{} = fmod({}, {});\n",
                        tabs, define_variable, left_val, right_val,
                    )?,
                    (op, _) => write!(
                        w,
                        "{}{} = {} {} {};\n",
                        tabs,
                        define_variable,
                        left_val,
                        match op {
                            BinaryOperator::Add => "+",
                            BinaryOperator::Sub => "-",
                            BinaryOperator::Mul => "*",
                            BinaryOperator::Div => "/",
                            BinaryOperator::Rem => "%",
                            BinaryOperator::LT => "<",
                            BinaryOperator::LTE => "<=",
                            BinaryOperator::GT => ">",
                            BinaryOperator::GTE => ">=",
                            BinaryOperator::EQ => "==",
                            BinaryOperator::NE => "!=",
                            BinaryOperator::Or => "|",
                            BinaryOperator::And => "&",
                            BinaryOperator::Xor => "^",
                            BinaryOperator::LSh => "<<",
                            BinaryOperator::RSh => ">>",
                        },
                        right_val,
                    )?,
                };
            }
            Node::Ternary {
                op,
                first,
                second,
                third,
            } => match op {
                TernaryOperator::Select => {
                    write!(
                        w,
                        "{}{} = {} ? {} : {};\n",
                        tabs,
                        define_variable,
                        self.get_value(*first, false, false),
                        self.get_value(*second, false, false),
                        self.get_value(*third, false, false),
                    )?;
                }
            },
            Node::IntrinsicCall { intrinsic, args } => {
                let id_type = self.typing[id.idx()];
                if matches!(intrinsic, Intrinsic::Max | Intrinsic::Min) && is_special_reduct {
                    // Similar to associative Binops
                    let non_reduce_arg = if let Node::Reduce {
                        control: _,
                        init: _,
                        reduct: _,
                    } = &self.function.nodes[args[0].idx()]
                    {
                        self.get_value(args[1], false, false)
                    } else {
                        self.get_value(args[0], false, false)
                    };
                    let cg_tile = self.get_cg_tile(nesting_fork.unwrap(), CGType::Use);
                    #[allow(unreachable_patterns)]
                    let cg_op = match intrinsic {
                        Intrinsic::Max => "greater",
                        Intrinsic::Min => "less",
                        _ => unreachable!(),
                    };
                    let id_type_name = self.get_type(id_type, false);
                    write!(
                        w,
                        "{}{} = cg::reduce({}, {}, cg::{}<{}>());\n",
                        tabs, define_variable, cg_tile, non_reduce_arg, cg_op, id_type_name
                    )?;
                } else {
                    let ty = &self.types[id_type.idx()];
                    let intrinsic = self.codegen_intrinsic(intrinsic, ty);
                    let args = args
                        .iter()
                        .map(|arg| self.get_value(*arg, false, false))
                        .collect::<Vec<_>>()
                        .join(", ");
                    write!(
                        w,
                        "{}{} = {}({});\n",
                        tabs, define_variable, intrinsic, args,
                    )?;
                }
            }
            // Read of primitive requires load after pointer math.
            Node::Read { collect, indices } => {
                let collect_with_indices = self.codegen_collect(*collect, indices);
                let data_type_id = self.typing[id.idx()];
                if self.types[data_type_id.idx()].is_primitive() {
                    let type_name = self.get_type(data_type_id, true);
                    write!(
                        w,
                        "{}{} = *reinterpret_cast<{}>({});\n",
                        tabs, define_variable, type_name, collect_with_indices
                    )?;
                } else {
                    write!(
                        w,
                        "{}{} = {};\n",
                        tabs, define_variable, collect_with_indices
                    )?;
                }
            }
            // Write of primitive needs a thread rank gate for safety. Write of
            // collection is memcpy that we distribute among threads.
            Node::Write {
                collect,
                data,
                indices,
            } => {
                let collect_with_indices = self.codegen_collect(*collect, indices);
                let data_variable = self.get_value(*data, false, false);
                let data_type_id = self.typing[data.idx()];
                let cg_tile = match state {
                    KernelState::OutBlock | KernelState::InBlock => "block".to_string(),
                    KernelState::InThread => {
                        self.get_cg_tile(nesting_fork.unwrap(), CGType::UsePerId)
                    }
                };
                if self.types[data_type_id.idx()].is_primitive() {
                    write!(w, "{}if ({}.thread_rank() == 0) {{\n", tabs, cg_tile)?;
                    let type_name = self.get_type(data_type_id, true);
                    write!(
                        w,
                        "{}\t*reinterpret_cast<{}>({}) = {};\n",
                        tabs, type_name, collect_with_indices, data_variable
                    )?;
                    write!(w, "{}}}\n", tabs)?;
                } else {
                    let data_size = self.get_size(data_type_id, None);
                    write!(
                        w,
                        "{}for (int i = {}.thread_rank(); i < {}; i += {}.size()) {{\n",
                        tabs, cg_tile, data_size, cg_tile
                    )?;
                    write!(
                        w,
                        "{}\t*({} + i) = *({} + i);\n",
                        tabs, collect_with_indices, data_variable
                    )?;
                    write!(w, "{}}}\n", tabs)?;
                    write!(
                        w,
                        "{}if ({}.thread_rank() < {} % {}.size()) {{\n",
                        tabs, cg_tile, data_size, cg_tile
                    )?;
                    write!(w, "{}\t*({} + {}.size() * ({} / {}.size()) + {}.thread_rank()) = *({} + {}.size() * ({} / {}.size()) + {}.thread_rank());\n", tabs, collect_with_indices, cg_tile, data_size, cg_tile, cg_tile, data_variable, cg_tile, data_size, cg_tile, cg_tile)?;
                    write!(w, "{}}}\n", tabs)?;
                }
                //write!(w, "{}{}.sync();\n", tabs, cg_tile)?;
                let collect_variable = self.get_value(*collect, false, false);
                write!(w, "{}{} = {};\n", tabs, define_variable, collect_variable)?;
            }
            // Undef nodes never need to be assigned to.
            Node::Undef { ty: _ } => {}
            _ => {
                panic!(
                    "Unsupported data node type: {:?}",
                    self.function.nodes[id.idx()]
                )
            }
        }
        // Since reducts are responsible for updating Reduce nodes, we check and
        // emit those for each data node.
        if let Some(reduces) = self.reduct_reduce_map.get(&id) {
            let val = self.get_value(id, false, false);
            for reduce in reduces {
                let reduce_val = self.get_value(*reduce, false, false);
                write!(w, "{}{} = {};\n", tabs, reduce_val, val)?;
            }
        }
        Ok(())
    }

    fn codegen_control_node(
        &self,
        id: NodeID,
        available_thread_quota: Option<usize>,
        use_thread_quota: Option<usize>,
        parallel_factor: Option<usize>,
        thread_block_tiles: &mut String,
        w_init: &mut String,
        w_post_init: &mut String,
        w_term: &mut String,
    ) -> Result<usize, Error> {
        for (data, phi) in self.control_data_phi_map.get(&id).unwrap_or(&vec![]).iter() {
            let data = self.get_value(*data, false, false);
            let phi = self.get_value(*phi, false, false);
            write!(w_term, "\t{}_tmp = {};\n", phi, data)?;
        }
        let tabs = match &self.function.nodes[id.idx()] {
            Node::Start
            | Node::Region { preds: _ }
            | Node::ControlProjection {
                control: _,
                selection: _,
            } => {
                let succ = self.control_subgraph.succs(id).next().unwrap();
                write!(w_term, "\tgoto {};\n", self.get_block_name(succ, false))?;
                1
            }
            Node::If { control: _, cond } => {
                let mut succs = self.control_subgraph.succs(id);
                let succ1 = succs.next().unwrap();
                let succ2 = succs.next().unwrap();
                let succ1_is_true = self.function.nodes[succ1.idx()]
                    .try_control_projection(1)
                    .is_some();
                let succ1_block_name = self.get_block_name(succ1, false);
                let succ2_block_name = self.get_block_name(succ2, false);
                write!(
                    w_term,
                    "\tif ({}) {{\n",
                    self.get_value(*cond, false, false)
                )?;
                write!(
                    w_term,
                    "\t\tgoto {};\n",
                    if succ1_is_true {
                        succ1_block_name.clone()
                    } else {
                        succ2_block_name.clone()
                    }
                )?;
                write!(w_term, "\t}} else {{\n")?;
                write!(
                    w_term,
                    "\t\tgoto {};\n",
                    if succ1_is_true {
                        succ2_block_name
                    } else {
                        succ1_block_name
                    }
                )?;
                write!(w_term, "\t}}\n")?;
                1
            }
            Node::Fork {
                control: _,
                factors: _,
            } => {
                // We create a cooperative group tile for each of: used threads per
                // thread ID- for reads and writes-, used threads across all thread
                // IDs- for parallelized reductions-, and available threads- to
                // synchronize between used and unused threads. We want to create
                // these only once, so we create two goto sections for each fork-
                // one run only once for creating groups, and other may be ran
                // multiple times if the Fork is serialized and Join jumps back
                // to it.
                let cg_tile = self.get_cg_tile(id, CGType::UsePerId);
                if use_thread_quota.is_some() {
                    let use_thread_quota = use_thread_quota.unwrap();
                    let use_thread_per_id = if parallel_factor.is_some() {
                        use_thread_quota / parallel_factor.unwrap()
                    } else {
                        use_thread_quota
                    };
                    write!(
                        thread_block_tiles,
                        "\tcg::thread_block_tile<{}> {} = cge::tiled_partition<{}>(block);\n",
                        use_thread_per_id, cg_tile, use_thread_per_id
                    )?;
                    let cg_tile_use = self.get_cg_tile(id, CGType::Use);
                    write!(
                        thread_block_tiles,
                        "\tcg::thread_block_tile<{}> {} = cge::tiled_partition<{}>(block);\n",
                        use_thread_quota, cg_tile_use, use_thread_quota
                    )?;
                    let available_thread_quota = available_thread_quota.unwrap();
                    let cg_tile_available = self.get_cg_tile(id, CGType::Available);
                    write!(
                        thread_block_tiles,
                        "\tcg::thread_block_tile<{}> {} = cge::tiled_partition<{}>(block);\n",
                        available_thread_quota, cg_tile_available, available_thread_quota
                    )?;
                    if parallel_factor.is_none() {
                        write!(thread_block_tiles, "\t{};\n", self.get_fork_iter(id, true))?;
                        write!(w_init, "\t{} = 0;\n", self.get_fork_iter(id, false))?;
                        write!(w_init, "\tgoto {};\n", self.get_block_name(id, true))?;
                    }
                }
                // Fork nodes gate the used vs unused threads out of all available
                // threads. If unused, we jump straight to the Join, and if used,
                // we jump to successor like normal.
                let succ = self.control_subgraph.succs(id).next().unwrap();
                if let Some(available_thread_quota) = available_thread_quota
                    && let Some(use_thread_quota) = use_thread_quota
                    && use_thread_quota < available_thread_quota
                {
                    let w_target = if parallel_factor.is_none() {
                        w_post_init
                    } else {
                        w_init
                    };
                    write!(
                        w_target,
                        "\tif (threadIdx.x % {} < {}) {{\n",
                        available_thread_quota, use_thread_quota
                    )?;
                    write!(w_term, "\t\tgoto {};\n", self.get_block_name(succ, false))?;
                    write!(w_term, "\t}}\n")?;
                    write!(w_term, "\telse {{\n")?;
                    let join = self.fork_join_map.get(&id).unwrap();
                    write!(w_term, "\t\tgoto {};\n", self.get_block_name(*join, false))?;
                    write!(w_term, "\t}}\n")?;
                    2
                } else {
                    // Make sure post-init isn't empty so it goto header generated
                    if use_thread_quota.is_some() && parallel_factor.is_none() {
                        write!(w_post_init, " ")?;
                    }
                    write!(w_term, "\tgoto {};\n", self.get_block_name(succ, false))?;
                    1
                }
            }
            Node::Join { control: _ } => {
                // Join nodes also gate the used vs unused threads with a tile
                // sync after the body.
                let succ = self.control_subgraph.succs(id).next().unwrap();
                let has_thread_quota = available_thread_quota.is_some();
                let mut tabs = 1;
                if has_thread_quota {
                    let available_thread_quota = available_thread_quota.unwrap();
                    let use_thread_quota = use_thread_quota.unwrap();
                    if use_thread_quota < available_thread_quota {
                        write!(
                            w_init,
                            "\tif (threadIdx.x % {} < {}) {{\n",
                            available_thread_quota, use_thread_quota
                        )?;
                        write!(w_term, "\t}}\n")?;
                        tabs += 1;
                    }
                    let fork = self.join_fork_map.get(&id).unwrap();
                    let cg_tile_available = self.get_cg_tile(*fork, CGType::Available);
                    write!(w_term, "\t{}.sync();\n", cg_tile_available)?;
                    //write!(w_term, "\t__syncthreads;\n")?;
                }
                // If the Fork was parallelized, each thread or UsedPerId tile of
                // threads only runs one ThreadID, so we can jump straight to the
                // successor. Else, we jump back to the Fork until all ThreadIDs
                // or equivalently the Fork's full factor number of iterations have
                // been completed.
                if parallel_factor.is_some() {
                    write!(w_term, "\tgoto {};\n", self.get_block_name(succ, false))?;
                } else {
                    let fork = self.join_fork_map.get(&id).unwrap();
                    let Node::Fork { factors, .. } = &self.function.nodes[fork.idx()] else {
                        panic!("Expected join_fork_map to point to a fork node");
                    };
                    let fork_size = multiply_dcs(factors);
                    let fork_iter = self.get_fork_iter(*fork, false);
                    write!(w_term, "\t{} += 1;\n", fork_iter)?;
                    write!(w_term, "\tif ({} == {}) {{\n", fork_iter, fork_size)?;
                    write!(w_term, "\t\tgoto {};\n", self.get_block_name(succ, false))?;
                    write!(w_term, "\t}}\n")?;
                    write!(w_term, "\telse {{\n")?;
                    write!(w_term, "\t\tgoto {};\n", self.get_block_name(*fork, true))?;
                    write!(w_term, "\t}}\n")?;
                }
                tabs
            }
            Node::Return {
                control: _,
                ref data,
            } => {
                write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?;
                for (idx, (data, param)) in
                    data.iter().zip(self.return_parameters.iter()).enumerate()
                {
                    // For return values that are not identical to some parameter, we write it into
                    // the output struct
                    if !param.is_some() {
                        write!(
                            w_term,
                            "\t\tret->f{} = {};\n",
                            idx,
                            self.get_value(*data, false, false)
                        )?;
                    }
                }
                write!(w_term, "\t}}\n")?;
                write!(w_term, "\treturn;\n")?;
                1
            }
            _ => {
                panic!("Unsupported control node type")
            }
        };
        Ok(tabs)
    }

    /*
     * This function emits collection name + pointer math for the provided indices.
     * All collection types use char pointers.
     */
    fn codegen_collect(&self, collect: NodeID, indices: &[Index]) -> String {
        let mut index_ptr = "0".to_string();
        let mut type_id = self.typing[collect.idx()];
        for index in indices {
            match index {
                Index::Field(field) => {
                    index_ptr.push_str(&format!(" + ({})", self.get_size(type_id, Some(*field))));
                    type_id = if let Type::Product(fields) = &self.types[type_id.idx()] {
                        fields[*field]
                    } else {
                        panic!("Expected product type")
                    };
                }
                // Variants of summations have zero offset
                Index::Variant(index) => {
                    type_id = if let Type::Summation(variants) = &self.types[type_id.idx()] {
                        variants[*index]
                    } else {
                        panic!("Expected summation type")
                    };
                }
                // Convert multi-d array index to 1-d index, and optionally
                // convert to single-byte index by multiplying by element size
                Index::Position(array_indices) => {
                    let Type::Array(element_type, extents) = &self.types[type_id.idx()] else {
                        panic!("Expected array type")
                    };
                    let mut cumulative_offset = multiply_dcs(&extents[array_indices.len()..]);
                    let max_left_array_index = array_indices.len() - 1;
                    for (i, index) in array_indices.iter().rev().enumerate() {
                        cumulative_offset = format!(
                            "{} * ({}{}",
                            cumulative_offset,
                            self.get_value(*index, false, false),
                            if i != max_left_array_index {
                                format!(" + dc{}", extents[max_left_array_index - i].idx())
                            } else {
                                "".to_string()
                            }
                        );
                    }
                    index_ptr.push_str(&format!(
                        " + {}{}",
                        cumulative_offset,
                        ")".repeat(array_indices.len())
                    ));
                    let element_size = self.get_size(*element_type, None);
                    let element_align = self.get_alignment(*element_type);
                    index_ptr.push_str(&format!(
                        " * (({} + {} - 1) / {} * {})",
                        element_size, element_align, element_align, element_align
                    ));
                    type_id = *element_type;
                }
            }
        }
        let name = self.get_value(collect, false, false);
        format!("{} + {}", name, index_ptr)
    }

    /*
     * The outlined codegen for constants allows us to handle recursive initialization
     * for collections. We perform "allocation" by atomically incrementing dynamic
     * shared memory and CUDA's support for dynamic is limited to a single extern
     * array. Dynamic is required here because not all dynamic constants and therefore
     * array sizes are known. This approach will need further work, as currently
     * we keep allocating new shmem and don't reuse any old and unused. `allow_allocate`
     * prevents unnecessary shared memory allocations for nested product and summation
     * collections, since the outermost allocates everything for the full collection.
     * Since not initialized, array collections don't need to be recursed into.
     */
    fn codegen_constant(
        &self,
        name: String,
        cons_id: ConstantID,
        allow_allocate: bool,
        dynamic_shared_offset: &mut String,
        w: &mut String,
        num_tabs: usize,
    ) -> Result<(), Error> {
        let tabs = "\t".repeat(num_tabs);
        match &self.constants[cons_id.idx()] {
            Constant::Boolean(val) => write!(w, "{}{} = {};\n", tabs, name, val)?,
            Constant::Integer8(val) => write!(w, "{}{} = {};\n", tabs, name, val)?,
            Constant::UnsignedInteger8(val) => write!(w, "{}{} = {};\n", tabs, name, val)?,
            Constant::Integer16(val) => write!(w, "{}{} = {};\n", tabs, name, val)?,
            Constant::UnsignedInteger16(val) => write!(w, "{}{} = {};\n", tabs, name, val)?,
            Constant::Integer32(val) => write!(w, "{}{} = {};\n", tabs, name, val)?,
            Constant::UnsignedInteger32(val) => write!(w, "{}{} = {}ul;\n", tabs, name, val)?,
            Constant::Integer64(val) => write!(w, "{}{} = {}ll;\n", tabs, name, val)?,
            Constant::UnsignedInteger64(val) => write!(w, "{}{} = {}ull;\n", tabs, name, val)?,
            Constant::Float32(val) => {
                write!(w, "{}{} = {};\n", tabs, name, format_float(**val as f64))?
            }
            Constant::Float64(val) => write!(w, "{}{} = {};\n", tabs, name, format_float(**val))?,
            // All three following collections involve align then allocate from the
            // single dynamic shared memory buffer by using and updating the offset.
            Constant::Product(type_id, constant_fields) => {
                if allow_allocate {
                    let alignment = self.get_alignment(*type_id);
                    let size = self.get_size(*type_id, None);
                    *dynamic_shared_offset = format!(
                        "(({} + {} - 1) / {}) * {}",
                        dynamic_shared_offset, alignment, alignment, alignment
                    );
                    write!(
                        w,
                        "{}dynamic_shared_offset = {};\n",
                        tabs, dynamic_shared_offset
                    )?;
                    write!(
                        w,
                        "{}{} = dynamic_shared + dynamic_shared_offset;\n",
                        tabs, name
                    )?;
                    *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size);
                }
                let Type::Product(type_fields) = &self.types[type_id.idx()] else {
                    panic!("Product constant should have product type")
                };
                for i in 0..constant_fields.len() {
                    // For each field update offset and issue recursive call
                    let offset = self.get_size(type_fields[i], Some(i));
                    let field_constant = &self.constants[constant_fields[i].idx()];
                    if field_constant.is_scalar() {
                        let field_type = self.get_type(type_fields[i], true);
                        self.codegen_constant(
                            format!("*reinterpret_cast<{}>({}+{})", field_type, name, offset),
                            constant_fields[i],
                            false,
                            dynamic_shared_offset,
                            w,
                            num_tabs,
                        )?;
                    } else if !field_constant.is_array() {
                        self.codegen_constant(
                            format!("{}+{}", name, offset),
                            constant_fields[i],
                            false,
                            dynamic_shared_offset,
                            w,
                            num_tabs,
                        )?;
                    }
                }
            }
            Constant::Summation(type_id, variant, field) => {
                if allow_allocate {
                    let alignment = self.get_alignment(*type_id);
                    let size = self.get_size(*type_id, None);
                    *dynamic_shared_offset = format!(
                        "(({} + {} - 1) / {}) * {}",
                        dynamic_shared_offset, alignment, alignment, alignment
                    );
                    write!(
                        w,
                        "{}dynamic_shared_offset = {};\n",
                        tabs, dynamic_shared_offset
                    )?;
                    write!(
                        w,
                        "{}{} = dynamic_shared + dynamic_shared_offset;\n",
                        tabs, name
                    )?;
                    *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size);
                }
                // No offset updating needed since all variants start at 0
                let Type::Summation(variants) = &self.types[type_id.idx()] else {
                    panic!("Summation constant should have summation type")
                };
                let variant_constant = &self.constants[field.idx()];
                if variant_constant.is_scalar() {
                    let variant_type =
                        self.get_type(self.typing[variants[*variant as usize].idx()], true);
                    self.codegen_constant(
                        format!("*reinterpret_cast<{}>({})", variant_type, name),
                        *field,
                        false,
                        dynamic_shared_offset,
                        w,
                        num_tabs,
                    )?;
                } else if !variant_constant.is_array() {
                    self.codegen_constant(name, *field, false, dynamic_shared_offset, w, num_tabs)?;
                };
            }
            Constant::Array(type_id) => {
                if !allow_allocate {
                    panic!("Nested array constant should not be re-allocated");
                }
                let alignment = self.get_alignment(*type_id);
                let size = self.get_size(*type_id, None);
                *dynamic_shared_offset = format!(
                    "(({} + {} - 1) / {}) * {}",
                    dynamic_shared_offset, alignment, alignment, alignment
                );
                write!(
                    w,
                    "{}dynamic_shared_offset = {};\n",
                    tabs, dynamic_shared_offset
                )?;
                write!(
                    w,
                    "{}{} = dynamic_shared + dynamic_shared_offset;\n",
                    tabs, name
                )?;
                *dynamic_shared_offset = format!("{} + {}", dynamic_shared_offset, size);
            }
        }
        Ok(())
    }

    /*
     * Emit code to calculate data size. For Product types, setting `num_fields`
     * gives data size up to but not including that field, so = 2 gives 1st field
     * and offset to 2nd field. This is useful for constant initialization and read/write
     * index math.
     */
    fn get_size(&self, type_id: TypeID, num_fields: Option<usize>) -> String {
        match &self.types[type_id.idx()] {
            Type::Array(element_type, extents) => {
                assert!(num_fields.is_none());
                let array_size = multiply_dcs(extents);
                let elem_align = self.get_alignment(*element_type);
                format!(
                    "(({} + {} - 1) / {} * {}) * {}",
                    self.get_size(*element_type, None),
                    elem_align,
                    elem_align,
                    elem_align,
                    array_size
                )
            }
            Type::Product(fields) => {
                let num_fields = num_fields.unwrap_or(fields.len());
                fields
                    .iter()
                    .take(num_fields)
                    .map(|id| (self.get_size(*id, None), self.get_alignment(*id)))
                    .fold(String::from("0"), |acc, (size, align)| {
                        if acc == "0" {
                            size
                        } else {
                            format!(
                                "({} + {} - 1) / {} * {} + {}",
                                acc, align, align, align, size
                            )
                        }
                    })
            }
            Type::Summation(variants) => {
                assert!(num_fields.is_none());
                // The argmax variant by size is not guaranteed to be same as
                // argmax variant by alignment, eg product of 3 4-byte primitives
                // vs 1 8-byte primitive, so we need to calculate both.
                let max_size = variants.iter().map(|id| self.get_size(*id, None)).fold(
                    String::from("0"),
                    |acc, x| {
                        if acc == "0" {
                            x
                        } else {
                            format!("umax({}, {})", acc, x)
                        }
                    },
                );
                let max_alignment = variants
                    .iter()
                    .map(|id| self.get_alignment(*id))
                    .max()
                    .unwrap_or(0);
                format!(
                    "({} + {} - 1) / {} * {}",
                    max_size, max_alignment, max_alignment, max_alignment
                )
            }
            _ => {
                assert!(num_fields.is_none());
                format!("{}", self.get_alignment(type_id))
            }
        }
    }

    fn get_alignment(&self, type_id: TypeID) -> usize {
        get_type_alignment(&self.types, type_id)
    }

    fn codegen_intrinsic(&self, intrinsic: &Intrinsic, ty: &Type) -> String {
        let func_name = match intrinsic {
            Intrinsic::Abs => match ty {
                Type::Float32 => "fabsf",
                Type::Float64 => "__fabs",
                ty if ty.is_signed() => "abs",
                ty if ty.is_unsigned() => "uabs",
                _ => panic!("Unsupported type for Abs"),
            },
            Intrinsic::ACos => match ty {
                ty if ty.is_float() => "__acosf",
                _ => "acos",
            },
            Intrinsic::ASin => match ty {
                ty if ty.is_float() => "__asinf",
                _ => "asin",
            },
            Intrinsic::ATan => match ty {
                ty if ty.is_float() => "__atanf",
                _ => "atan",
            },
            Intrinsic::ATan2 => match ty {
                ty if ty.is_float() => "__atan2f",
                _ => "atan2",
            },
            Intrinsic::Ceil => match ty {
                ty if ty.is_float() => "__ceilf",
                _ => "ceil",
            },
            Intrinsic::Cos => match ty {
                ty if ty.is_float() => "__cosf",
                _ => "cos",
            },
            Intrinsic::Cosh => match ty {
                ty if ty.is_float() => "coshf",
                _ => "cosh",
            },
            Intrinsic::Exp => match ty {
                ty if ty.is_float() => "__expf",
                _ => "exp",
            },
            Intrinsic::Exp2 => match ty {
                ty if ty.is_float() => "__exp2f",
                _ => "exp2",
            },
            Intrinsic::Floor => match ty {
                ty if ty.is_float() => "__floorf",
                _ => "floor",
            },
            Intrinsic::Ln => match ty {
                ty if ty.is_float() => "__logf",
                _ => "log",
            },
            Intrinsic::Log10 => match ty {
                ty if ty.is_float() => "__log10f",
                _ => "log10",
            },
            Intrinsic::Log2 => match ty {
                ty if ty.is_float() => "__log2f",
                _ => "log2",
            },
            Intrinsic::Max => match ty {
                Type::Float32 => "fmaxf",
                Type::Float64 => "fmax",
                ty if ty.is_signed() => "smax",
                ty if ty.is_unsigned() => "umax",
                _ => "max",
            },
            Intrinsic::Min => match ty {
                Type::Float32 => "fminf",
                Type::Float64 => "fmin",
                ty if ty.is_signed() => "smin",
                ty if ty.is_unsigned() => "umin",
                _ => "min",
            },
            Intrinsic::Pow | Intrinsic::Powf => match ty {
                Type::Float32 => "__powf",
                Type::Float64 => "pow",
                _ => panic!("Unsupported type for Pow"),
            },
            Intrinsic::Powi => match ty {
                ty if ty.is_signed() || ty.is_unsigned() => "powi",
                _ => panic!("Unsupported type for Powi"),
            },
            Intrinsic::Round => match ty {
                ty if ty.is_float() => "__roundf",
                ty if ty.is_signed() || ty.is_unsigned() => "roundi",
                _ => "round",
            },
            Intrinsic::Sin => match ty {
                ty if ty.is_float() => "__sinf",
                _ => "sin",
            },
            Intrinsic::Sinh => match ty {
                ty if ty.is_float() => "sinhf",
                _ => "sinh",
            },
            Intrinsic::Sqrt => match ty {
                Type::Float32 => "sqrtf",
                ty if ty.is_signed() || ty.is_unsigned() => "isqrt",
                _ => "sqrt",
            },
            Intrinsic::Tan => match ty {
                ty if ty.is_float() => "__tanf",
                _ => "tan",
            },
            Intrinsic::Tanh => match ty {
                ty if ty.is_float() => "tanhf",
                _ => "tanh",
            },
            _ => panic!("Unsupported intrinsic {:?}", intrinsic),
        };
        func_name.to_string()
    }

    fn get_cg_tile(&self, fork: NodeID, cg_type: CGType) -> String {
        format!(
            "cg_{}{}",
            self.get_value(fork, false, false),
            if cg_type == CGType::Use {
                "_use"
            } else if cg_type == CGType::Available {
                "_available"
            } else {
                ""
            }
        )
    }

    fn get_fork_iter(&self, fork: NodeID, ty: bool) -> String {
        if ty {
            format!("unsigned int iter_{}", self.get_value(fork, false, false))
        } else {
            format!("iter_{}", self.get_value(fork, false, false))
        }
    }

    fn get_block_name(&self, id: NodeID, post: bool) -> String {
        format!(
            "bb_{}{}",
            self.get_value(id, false, false),
            if post { "_post" } else { "" }
        )
    }

    /*
     * Setting `ty = true` will return with type in declaration format. `make_pointer`
     * is only considered if `ty = true` and only relevant for primitive types-
     * otherwise it makes no difference because collections are already pointers.
     */
    fn get_value(&self, id: NodeID, ty: bool, make_pointer: bool) -> String {
        if let Node::DynamicConstant { id: dc_id } = &self.function.nodes[id.idx()] {
            if ty {
                panic!("Dynamic constants shouldn't be re-initialized")
            }
            format!("dc{}", dc_id.idx())
        } else if let Node::Parameter { index } = &self.function.nodes[id.idx()] {
            if ty {
                panic!("Parameters shouldn't be re-initialized")
            }
            format!("p{}", index)
        } else if ty {
            format!(
                "{} {}",
                self.get_type(self.typing[id.idx()], make_pointer),
                self.get_value(id, false, false)
            )
        } else {
            format!(
                "{}{}",
                self.function.nodes[id.idx()].lower_case_name(),
                id.idx()
            )
        }
    }

    fn get_type(&self, id: TypeID, make_pointer: bool) -> String {
        let ty = &self.types[id.idx()];
        if ty.is_primitive() {
            convert_type(ty, make_pointer)
        } else {
            format!("char*{}", if make_pointer { "*" } else { "" })
        }
    }

    fn multiply_fork_factors(&self, factors: &[DynamicConstantID]) -> Option<usize> {
        factors.iter().try_fold(1usize, |acc, &factor_id| {
            evaluate_dynamic_constant(factor_id, self.dynamic_constants)
                .map(|val| acc.saturating_mul(val))
        })
    }
}

fn multiply_dcs(dcs: &[DynamicConstantID]) -> String {
    if dcs.is_empty() {
        "1".to_string()
    } else {
        dcs.iter()
            .map(|dc| format!("dc{}", dc.idx()))
            .collect::<Vec<_>>()
            .join(" * ")
    }
}

fn convert_type(ty: &Type, make_pointer: bool) -> String {
    let mut result = match ty {
        Type::Boolean => "bool".to_string(),
        Type::Integer8 => "int8_t".to_string(),
        Type::UnsignedInteger8 => "uint8_t".to_string(),
        Type::Integer16 => "short".to_string(),
        Type::UnsignedInteger16 => "unsigned short".to_string(),
        Type::Integer32 => "int".to_string(),
        Type::UnsignedInteger32 => "unsigned int".to_string(),
        Type::Integer64 => "long long".to_string(),
        Type::UnsignedInteger64 => "unsigned long long".to_string(),
        Type::Float8 => "__nv_fp8_e4m3".to_string(),
        Type::BFloat16 => "nv_bfloat16".to_string(),
        Type::Float32 => "float".to_string(),
        Type::Float64 => "double".to_string(),
        _ => panic!("Unsupported type"),
    };
    if make_pointer {
        result.push('*');
    }
    result
}

fn format_float(val: f64) -> String {
    if val == f64::INFINITY {
        "INFINITY".to_string()
    } else if val == f64::NEG_INFINITY {
        "-INFINITY".to_string()
    } else {
        let mut s = val.to_string();
        if !s.contains('.') && !s.contains('e') && !s.contains('E') {
            s.push_str(".0");
        }
        s
    }
}