Skip to content
Snippets Groups Projects
codegen.rs 29.33 KiB
use std::collections::{HashMap, VecDeque};

use hercules_ir::ir;
use hercules_ir::ir::*;

use crate::labeled_builder::LabeledBuilder;
use crate::semant;
use crate::semant::{BinaryOp, Expr, Function, Literal, Prg, Stmt, UnaryOp};
use crate::ssa::SSA;
use crate::types::{Either, Primitive, TypeSolver, TypeSolverInst};

use juno_scheduler::labels::*;

// Loop info is a stack of the loop levels, recording the latch and exit block of each
type LoopInfo = Vec<(NodeID, NodeID)>;

pub fn codegen_program(prg: Prg) -> (Module, JunoInfo) {
    CodeGenerator::build(prg)
}

struct CodeGenerator<'a> {
    builder: LabeledBuilder<'a>,
    types: &'a TypeSolver,
    funcs: &'a Vec<Function>,
    uid: usize,
    // The function map tracks a map from function index and set of type variables to its function
    // id in the builder
    functions: HashMap<(usize, Vec<TypeID>), FunctionID>,
    // The worklist tracks a list of functions to codegen, tracking the function's id, its
    // type-solving instantiation (account for the type parameters), the function id, and the entry
    // block id
    worklist: VecDeque<(usize, TypeSolverInst<'a>, FunctionID, NodeID)>,

    // The JunoInfo needed for scheduling, which tracks the Juno function names and their
    // associated FunctionIDs.
    juno_info: JunoInfo,
}

impl CodeGenerator<'_> {
    fn build(
        Prg {
            types,
            funcs,
            labels,
        }: Prg,
    ) -> (Module, JunoInfo) {
        // Identify the functions (by index) which have no type arguments, these are the ones we
        // ask for code to be generated for
        let func_idx =
            funcs
                .iter()
                .enumerate()
                .filter_map(|(i, f)| if f.num_type_args == 0 { Some(i) } else { None });

        let juno_info = JunoInfo::new(funcs.iter().map(|f| f.name.clone()));

        let mut codegen = CodeGenerator {
            builder: LabeledBuilder::create(labels),
            types: &types,
            funcs: &funcs,
            uid: 0,
            functions: HashMap::new(),
            worklist: VecDeque::new(),
            juno_info,
        };

        // Add the identifed functions to the list to code-gen
        func_idx.for_each(|i| {
            let _ = codegen.get_function(i, vec![]);
        });

        codegen.finish()
    }

    fn finish(mut self) -> (Module, JunoInfo) {
        while !self.worklist.is_empty() {
            let (idx, mut type_inst, func, entry) = self.worklist.pop_front().unwrap();
            self.builder.set_function(func);
            self.codegen_function(&self.funcs[idx], &mut type_inst, entry);
        }

        let CodeGenerator {
            builder,
            types: _,
            funcs,
            uid: _,
            functions,
            worklist: _,
            juno_info,
        } = self;

        (builder.finish(), juno_info)
    }

    fn get_function(&mut self, func_idx: usize, ty_args: Vec<TypeID>) -> FunctionID {
        let func_info = (func_idx, ty_args);
        match self.functions.get(&func_info) {
            Some(func_id) => *func_id,
            None => {
                let ty_args = func_info.1;

                let func = &self.funcs[func_idx];
                let mut solver_inst = self.types.create_instance(ty_args.clone());

                // TODO: Ideally we would write out the type arguments, but now that they're
                // lowered to TypeID we can't do that as far as I can tell
                let name =
                    // For entry functions, we preserve the name, which is safe
                    // since they have no type variables and this is necessary
                    // to ensure easy ingestion into Rust
                    if func.entry { func.name.clone() }
                    else { format!("_{}_{}", self.uid, func.name) };
                self.uid += 1;

                let mut param_types = vec![];
                for (_, ty) in func.arguments.iter() {
                    // Because we're building types, we can extract the builder
                    param_types.push(solver_inst.lower_type(&mut self.builder.builder, *ty));
                }

                let return_types =
                    func.return_types
                        .iter()
                        .map(|t| solver_inst.lower_type(&mut self.builder.builder, *t))
                        .collect::<Vec<_>>();

                let (func_id, entry) = self
                    .builder
                    .create_function(
                        &name,
                        param_types,
                        return_types,
                        func.num_dyn_consts as u32,
                        func.entry,
                    )
                    .unwrap();

                self.juno_info.func_info.func_ids[func_idx].push(func_id);
                self.functions.insert((func_idx, ty_args), func_id);
                self.worklist
                    .push_back((func_idx, solver_inst, func_id, entry));
                func_id
            }
        }
    }

    fn codegen_function(&mut self, func: &Function, types: &mut TypeSolverInst, entry: NodeID) {
        // Setup the SSA construction data structure
        let mut ssa = SSA::new(entry);

        // Create nodes for the arguments
        for (idx, (var, _)) in func.arguments.iter().enumerate() {
            let mut node_builder = self.builder.allocate_node();
            ssa.write_variable(*var, entry, node_builder.id());
            node_builder.build_parameter(idx);
            self.builder.add_node(node_builder);
        }

        // Generate code for the body
        let None = self.codegen_stmt(&func.body, types, &mut ssa, entry, &mut vec![]) else {
            panic!("Generated code for a function missing a return")
        };
    }

    fn codegen_stmt(
        &mut self,
        stmt: &Stmt,
        types: &mut TypeSolverInst,
        ssa: &mut SSA,
        cur_block: NodeID,
        loops: &mut LoopInfo,
    ) -> Option<NodeID> {
        match stmt {
            Stmt::AssignStmt { var, val } => {
                let (val, block) = self.codegen_expr(val, types, ssa, cur_block);
                ssa.write_variable(*var, block, val);
                Some(block)
            }
            Stmt::IfStmt { cond, thn, els } => {
                let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, cur_block);
                let (mut if_node, block_then, block_else) =
                    ssa.create_cond(&mut self.builder, block_cond);

                let then_end = self.codegen_stmt(thn, types, ssa, block_then, loops);
                let else_end = match els {
                    None => Some(block_else),
                    Some(els_stmt) => self.codegen_stmt(els_stmt, types, ssa, block_else, loops),
                };

                let if_id = if_node.id();
                if_node.build_if(block_cond, val_cond);
                self.builder.add_node(if_node);

                match (then_end, else_end) {
                    (None, els) => els,
                    (thn, None) => thn,
                    (Some(then_term), Some(else_term)) => {
                        let block_join = ssa.create_block(&mut self.builder);
                        ssa.add_pred(block_join, then_term);
                        ssa.add_pred(block_join, else_term);
                        ssa.seal_block(block_join, &mut self.builder);
                        Some(block_join)
                    }
                }
            }
            Stmt::LoopStmt { cond, update, body } => {
                // We generate guarded loops, so the first step is to create
                // a conditional branch, branching on the condition
                let (val_guard, block_guard) = self.codegen_expr(cond, types, ssa, cur_block);
                let (mut if_node, true_guard, false_proj) =
                    ssa.create_cond(&mut self.builder, block_guard);

                if_node.build_if(block_guard, val_guard);
                self.builder.add_node(if_node);

                // We then create a region for the exit (since there may be breaks)
                let block_exit = ssa.create_block(&mut self.builder);
                ssa.add_pred(block_exit, false_proj);

                // Now, create a block for the loop's latch, we don't (currently) know any of its
                // predecessors
                let block_latch = ssa.create_block(&mut self.builder);

                // Code-gen any update into the latch and then code-gen the condition
                let block_updated = match update {
                    None => block_latch,
                    Some(stmt) => self
                        .codegen_stmt(stmt, types, ssa, block_latch, loops)
                        .expect("Loop update should return control"),
                };
                let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, block_updated);

                let (mut if_node, true_proj, false_proj) =
                    ssa.create_cond(&mut self.builder, block_cond);
                if_node.build_if(block_cond, val_cond);
                self.builder.add_node(if_node);

                // Add the false projection from the latch as a predecessor of the exit
                ssa.add_pred(block_exit, false_proj);

                // Create a block for the loop header, and add the true branches from the guard and
                // latch as its only predecessors
                let body_block = ssa.create_block(&mut self.builder);
                ssa.add_pred(body_block, true_guard);
                ssa.add_pred(body_block, true_proj);
                ssa.seal_block(body_block, &mut self.builder);

                // Generate code for the body
                loops.push((block_latch, block_exit));
                let body_res = self.codegen_stmt(body, types, ssa, body_block, loops);
                loops.pop();

                // If the body of the loop can reach some block, we add that block as a predecessor
                // of the latch
                match body_res {
                    None => {}
                    Some(block) => {
                        ssa.add_pred(block_latch, block);
                    }
                }

                // Seal remaining open blocks
                ssa.seal_block(block_exit, &mut self.builder);
                ssa.seal_block(block_latch, &mut self.builder);

                // It is always assumed a loop may be skipped and so control can reach after the
                // loop
                Some(block_exit)
            }
            Stmt::ReturnStmt { exprs } => {
                let mut vals = vec![];
                let mut block = cur_block;
                for expr in exprs {
                    let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, block);
                    vals.push(val_ret);
                    block = block_ret;
                }
                let mut return_node = self.builder.allocate_node();
                return_node.build_return(block, vals);
                self.builder.add_node(return_node);
                None
            }
            Stmt::BreakStmt {} => {
                let last_loop = loops.len() - 1;
                let (_latch, exit) = loops[last_loop];
                ssa.add_pred(exit, cur_block); // The block that contains this break now leads to
                                               // the exit
                None
            }
            Stmt::ContinueStmt {} => {
                let last_loop = loops.len() - 1;
                let (latch, _exit) = loops[last_loop];
                ssa.add_pred(latch, cur_block); // The block that contains this continue now leads
                                                // to the latch
                None
            }
            Stmt::BlockStmt { body } => {
                let mut block = Some(cur_block);
                for stmt in body.iter() {
                    let new_block = self.codegen_stmt(stmt, types, ssa, block.unwrap(), loops);
                    block = new_block;
                }
                block
            }
            Stmt::ExprStmt { expr } => {
                let (_val, block) = self.codegen_expr(expr, types, ssa, cur_block);
                Some(block)
            }
            Stmt::LabeledStmt { label, stmt } => {
                self.builder.push_label(*label);
                let res = self.codegen_stmt(&*stmt, types, ssa, cur_block, loops);
                self.builder.pop_label();
                res
            }
        }
    }

    // The codegen_expr function returns a pair of node IDs, the first is the node whose value is
    // the given expression and the second is the node of a control node at which the value is
    // available
    fn codegen_expr(
        &mut self,
        expr: &Expr,
        types: &mut TypeSolverInst,
        ssa: &mut SSA,
        cur_block: NodeID,
    ) -> (NodeID, NodeID) {
        match expr {
            Expr::Variable { var, .. } => (
                ssa.read_variable(*var, cur_block, &mut self.builder),
                cur_block,
            ),
            Expr::DynConst { val, .. } => {
                let mut node = self.builder.allocate_node();
                let node_id = node.id();
                let dyn_const = val.build(&mut self.builder.builder);
                node.build_dynamicconstant(dyn_const);
                self.builder.add_node(node);
                (node_id, cur_block)
            }
            Expr::Read { index, val, .. } => {
                let (collection, block) = self.codegen_expr(val, types, ssa, cur_block);
                let (indices, end_block) = self.codegen_indices(index, types, ssa, block);

                let mut node = self.builder.allocate_node();
                let node_id = node.id();
                node.build_read(collection, indices.into());
                self.builder.add_node(node);
                (node_id, end_block)
            }
            Expr::Write {
                index, val, rep, ..
            } => {
                let (collection, block) = self.codegen_expr(val, types, ssa, cur_block);
                let (indices, idx_block) = self.codegen_indices(index, types, ssa, block);
                let (replace, end_block) = self.codegen_expr(rep, types, ssa, idx_block);

                let mut node = self.builder.allocate_node();
                let node_id = node.id();
                node.build_write(collection, replace, indices.into());
                self.builder.add_node(node);
                (node_id, end_block)
            }
            Expr::Tuple { vals, typ } => {
                let mut block = cur_block;
                let mut values = vec![];
                for expr in vals {
                    let (val_expr, block_expr) = self.codegen_expr(expr, types, ssa, block);
                    block = block_expr;
                    values.push(val_expr);
                }

                let tuple_type = types.lower_type(&mut self.builder.builder, *typ);
                (self.build_tuple(values, tuple_type), block)
            }
            Expr::Union { tag, val, typ } => {
                let (value, block) = self.codegen_expr(val, types, ssa, cur_block);

                let union_type = types.lower_type(&mut self.builder.builder, *typ);
                (self.build_union(*tag, value, union_type), block)
            }
            Expr::Constant { val, .. } => {
                let const_id = self.build_constant(val, types);

                let mut val = self.builder.allocate_node();
                let val_node = val.id();
                val.build_constant(const_id);
                self.builder.add_node(val);

                (val_node, cur_block)
            }
            Expr::Zero { typ } => {
                let type_id = types.lower_type(&mut self.builder.builder, *typ);
                let zero_const = self.builder.builder.create_constant_zero(type_id);
                let mut zero = self.builder.allocate_node();
                let zero_val = zero.id();
                zero.build_constant(zero_const);
                self.builder.add_node(zero);

                (zero_val, cur_block)
            }
            Expr::UnaryExp { op, expr, .. } => {
                let (val, block) = self.codegen_expr(expr, types, ssa, cur_block);

                let mut expr = self.builder.allocate_node();
                let expr_id = expr.id();
                expr.build_unary(
                    val,
                    match op {
                        UnaryOp::Negation => UnaryOperator::Neg,
                        UnaryOp::BitwiseNot => UnaryOperator::Not,
                    },
                );
                self.builder.add_node(expr);

                (expr_id, block)
            }
            Expr::BinaryExp { op, lhs, rhs, .. } => {
                let (val_lhs, block_lhs) = self.codegen_expr(lhs, types, ssa, cur_block);
                let (val_rhs, block_rhs) = self.codegen_expr(rhs, types, ssa, block_lhs);

                let mut expr = self.builder.allocate_node();
                let expr_id = expr.id();
                expr.build_binary(
                    val_lhs,
                    val_rhs,
                    match op {
                        BinaryOp::Add => BinaryOperator::Add,
                        BinaryOp::Sub => BinaryOperator::Sub,
                        BinaryOp::Mul => BinaryOperator::Mul,
                        BinaryOp::Div => BinaryOperator::Div,
                        BinaryOp::Mod => BinaryOperator::Rem,
                        BinaryOp::BitAnd => BinaryOperator::And,
                        BinaryOp::BitOr => BinaryOperator::Or,
                        BinaryOp::Xor => BinaryOperator::Xor,
                        BinaryOp::Lt => BinaryOperator::LT,
                        BinaryOp::Le => BinaryOperator::LTE,
                        BinaryOp::Gt => BinaryOperator::GT,
                        BinaryOp::Ge => BinaryOperator::GTE,
                        BinaryOp::Eq => BinaryOperator::EQ,
                        BinaryOp::Neq => BinaryOperator::NE,
                        BinaryOp::LShift => BinaryOperator::LSh,
                        BinaryOp::RShift => BinaryOperator::RSh,
                    },
                );
                let _ = self.builder.add_node(expr);

                (expr_id, block_rhs)
            }
            Expr::CastExpr { expr, typ } => {
                let type_id = types.lower_type(&mut self.builder.builder, *typ);
                let (val, block) = self.codegen_expr(expr, types, ssa, cur_block);

                let mut expr = self.builder.allocate_node();
                let expr_id = expr.id();
                expr.build_unary(val, UnaryOperator::Cast(type_id));
                self.builder.add_node(expr);

                (expr_id, block)
            }
            Expr::CondExpr { cond, thn, els, .. } => {
                // Code-gen the condition
                let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, cur_block);

                // Create the if
                let (mut if_builder, then_block, else_block) =
                    ssa.create_cond(&mut self.builder, block_cond);
                if_builder.build_if(block_cond, val_cond);
                self.builder.add_node(if_builder);

                // Code-gen the branches
                let (then_val, block_then) = self.codegen_expr(thn, types, ssa, then_block);
                let (else_val, block_else) = self.codegen_expr(els, types, ssa, else_block);

                // Create the join in the control-flow
                let join = ssa.create_block(&mut self.builder);
                ssa.add_pred(join, block_then);
                ssa.add_pred(join, block_else);
                ssa.seal_block(join, &mut self.builder);

                // Create a phi that joins the two branches
                let mut phi = self.builder.allocate_node();
                let phi_id = phi.id();
                phi.build_phi(join, vec![then_val, else_val].into());
                self.builder.add_node(phi);

                (phi_id, join)
            }
            Expr::CallExpr {
                func,
                ty_args,
                dyn_consts,
                args,
                num_returns, // number of non-inout returns (which are first)
                ..
            } => {
                // We start by lowering the type arguments to TypeIDs
                let mut type_params = vec![];
                for typ in ty_args {
                    type_params.push(types.lower_type(&mut self.builder.builder, *typ));
                }

                // With the type arguments, we can now lookup the function
                let call_func = self.get_function(*func, type_params);

                // We then build the dynamic constants
                let dynamic_constants =
                    TypeSolverInst::build_dyn_consts(&mut self.builder.builder, dyn_consts);

                // Code gen for each argument in order
                // For inouts, this becomes an ssa.read_variable
                // We also record the variables which are our inouts
                let mut block = cur_block;
                let mut arg_vals = vec![];
                let mut inouts = vec![];
                for arg in args {
                    match arg {
                        Either::Left(exp) => {
                            let (val, new_block) = self.codegen_expr(exp, types, ssa, block);
                            block = new_block;
                            arg_vals.push(val);
                        }
                        Either::Right(var) => {
                            inouts.push(*var);
                            arg_vals.push(ssa.read_variable(*var, block, &mut self.builder));
                        }
                    }
                }

                // Create the call expression, a region specifically for it, and a region after that.
                let call_region = ssa.create_block(&mut self.builder);
                ssa.add_pred(call_region, block);
                ssa.seal_block(call_region, &mut self.builder);

                let after_call_region = ssa.create_block(&mut self.builder);
                ssa.add_pred(after_call_region, call_region);
                ssa.seal_block(after_call_region, &mut self.builder);

                let mut call = self.builder.allocate_node();
                let call_id = call.id();

                call.build_call(
                    call_region,
                    call_func,
                    dynamic_constants.into(),
                    arg_vals.into(),
                );
                let _ = self.builder.add_node(call);

                block = after_call_region;

                // Read each of the "inout values" and perform the SSA update
                let has_inouts = !inouts.is_empty();
                for (idx, var) in inouts.into_iter().enumerate() {
                    let index = self.builder.builder.create_field_index(num_returns + idx);
                    let mut proj = self.builder.allocate_node();
                    let proj_id = proj.id();
                    proj.build_data_projection(call_id, index);
                    self.builder.add_node(proj);
                    ssa.write_variable(var, block, proj_id);
                }

                (call_id, block)
            }
            Expr::CallExtract { call, index, .. } => {
                let (call, block) = self.codegen_expr(call, types, ssa, cur_block);

                let mut proj = self.builder.allocate_node();
                let proj_id = proj.id();
                proj.build_data_projection(call, index);
                self.builder.add_node(proj);

                (proj_id, block)
            }
            Expr::Intrinsic {
                id,
                ty_args: _,
                args,
                ..
            } => {
                // Code gen for each argument in order
                let mut block = cur_block;
                let mut arg_vals = vec![];
                for arg in args {
                    let (val, new_block) = self.codegen_expr(arg, types, ssa, block);
                    block = new_block;
                    arg_vals.push(val);
                }

                // Create the intrinsic call expression
                let mut call = self.builder.allocate_node();
                let call_id = call.id();
                call.build_intrinsic(*id, arg_vals.into());
                let _ = self.builder.add_node(call);

                (call_id, block)
            }
        }
    }

    // Convert a list of Index from the semantic analysis into a list of indices for the builder.
    // Note that this takes and returns a block since expressions may involve control flow
    fn codegen_indices(
        &mut self,
        index: &Vec<semant::Index>,
        types: &mut TypeSolverInst,
        ssa: &mut SSA,
        cur_block: NodeID,
    ) -> (Vec<ir::Index>, NodeID) {
        let mut block = cur_block;
        let mut built_index = vec![];
        for idx in index {
            match idx {
                semant::Index::Field(idx) => {
                    built_index.push(self.builder.builder.create_field_index(*idx));
                }
                semant::Index::Variant(idx) => {
                    built_index.push(self.builder.builder.create_variant_index(*idx));
                }
                semant::Index::Array(exps) => {
                    let mut expr_vals = vec![];
                    for exp in exps {
                        let (val, new_block) = self.codegen_expr(exp, types, ssa, block);
                        block = new_block;
                        expr_vals.push(val);
                    }
                    built_index.push(self.builder.builder.create_position_index(expr_vals.into()));
                }
            }
        }

        (built_index, block)
    }

    fn build_tuple(&mut self, exprs: Vec<NodeID>, typ: TypeID) -> NodeID {
        let zero_const = self.builder.builder.create_constant_zero(typ);

        let mut zero = self.builder.allocate_node();
        let zero_val = zero.id();
        zero.build_constant(zero_const);
        self.builder.add_node(zero);

        let mut val = zero_val;
        for (idx, exp) in exprs.into_iter().enumerate() {
            let mut write = self.builder.allocate_node();
            let write_id = write.id();
            let index = self.builder.builder.create_field_index(idx);

            write.build_write(val, exp, vec![index].into());
            self.builder.add_node(write);
            val = write_id;
        }

        val
    }

    fn build_union(&mut self, tag: usize, val: NodeID, typ: TypeID) -> NodeID {
        let zero_const = self.builder.builder.create_constant_zero(typ);

        let mut zero = self.builder.allocate_node();
        let zero_val = zero.id();
        zero.build_constant(zero_const);
        self.builder.add_node(zero);

        let mut write = self.builder.allocate_node();
        let write_id = write.id();
        let index = self.builder.builder.create_variant_index(tag);

        write.build_write(zero_val, val, vec![index].into());
        self.builder.add_node(write);

        write_id
    }

    fn build_constant<'a>(
        &mut self,
        (lit, typ): &semant::Constant,
        types: &mut TypeSolverInst<'a>,
    ) -> ConstantID {
        match lit {
            Literal::Unit => self.builder.builder.create_constant_prod(vec![].into()),
            Literal::Bool(val) => self.builder.builder.create_constant_bool(*val),
            Literal::Integer(val) => {
                let p = types.as_numeric_type(&mut self.builder.builder, *typ);
                match p {
                    Primitive::I8 => self.builder.builder.create_constant_i8(*val as i8),
                    Primitive::I16 => self.builder.builder.create_constant_i16(*val as i16),
                    Primitive::I32 => self.builder.builder.create_constant_i32(*val as i32),
                    Primitive::I64 => self.builder.builder.create_constant_i64(*val as i64),
                    Primitive::U8 => self.builder.builder.create_constant_u8(*val as u8),
                    Primitive::U16 => self.builder.builder.create_constant_u16(*val as u16),
                    Primitive::U32 => self.builder.builder.create_constant_u32(*val as u32),
                    Primitive::U64 => self.builder.builder.create_constant_u64(*val as u64),
                    Primitive::F32 => self.builder.builder.create_constant_f32(*val as f32),
                    Primitive::F64 => self.builder.builder.create_constant_f64(*val as f64),
                    _ => panic!("Internal error in build_constant for integer"),
                }
            }
            Literal::Float(val) => {
                let p = types.as_numeric_type(&mut self.builder.builder, *typ);
                match p {
                    Primitive::F32 => self.builder.builder.create_constant_f32(*val as f32),
                    Primitive::F64 => self.builder.builder.create_constant_f64(*val as f64),
                    _ => panic!("Internal error in build_constant for float"),
                }
            }
            Literal::Tuple(vals) => {
                let mut constants = vec![];
                for val in vals {
                    constants.push(self.build_constant(val, types));
                }
                self.builder.builder.create_constant_prod(constants.into())
            }
            Literal::Sum(tag, val) => {
                let constant = self.build_constant(val, types);
                let type_id = types.lower_type(&mut self.builder.builder, *typ);
                self.builder
                    .builder
                    .create_constant_sum(type_id, *tag as u32, constant)
                    .unwrap()
            }
        }
    }
}