-
Aaron Councilman authoredAaron Councilman authored
codegen.rs 29.33 KiB
use std::collections::{HashMap, VecDeque};
use hercules_ir::ir;
use hercules_ir::ir::*;
use crate::labeled_builder::LabeledBuilder;
use crate::semant;
use crate::semant::{BinaryOp, Expr, Function, Literal, Prg, Stmt, UnaryOp};
use crate::ssa::SSA;
use crate::types::{Either, Primitive, TypeSolver, TypeSolverInst};
use juno_scheduler::labels::*;
// Loop info is a stack of the loop levels, recording the latch and exit block of each
type LoopInfo = Vec<(NodeID, NodeID)>;
pub fn codegen_program(prg: Prg) -> (Module, JunoInfo) {
CodeGenerator::build(prg)
}
struct CodeGenerator<'a> {
builder: LabeledBuilder<'a>,
types: &'a TypeSolver,
funcs: &'a Vec<Function>,
uid: usize,
// The function map tracks a map from function index and set of type variables to its function
// id in the builder
functions: HashMap<(usize, Vec<TypeID>), FunctionID>,
// The worklist tracks a list of functions to codegen, tracking the function's id, its
// type-solving instantiation (account for the type parameters), the function id, and the entry
// block id
worklist: VecDeque<(usize, TypeSolverInst<'a>, FunctionID, NodeID)>,
// The JunoInfo needed for scheduling, which tracks the Juno function names and their
// associated FunctionIDs.
juno_info: JunoInfo,
}
impl CodeGenerator<'_> {
fn build(
Prg {
types,
funcs,
labels,
}: Prg,
) -> (Module, JunoInfo) {
// Identify the functions (by index) which have no type arguments, these are the ones we
// ask for code to be generated for
let func_idx =
funcs
.iter()
.enumerate()
.filter_map(|(i, f)| if f.num_type_args == 0 { Some(i) } else { None });
let juno_info = JunoInfo::new(funcs.iter().map(|f| f.name.clone()));
let mut codegen = CodeGenerator {
builder: LabeledBuilder::create(labels),
types: &types,
funcs: &funcs,
uid: 0,
functions: HashMap::new(),
worklist: VecDeque::new(),
juno_info,
};
// Add the identifed functions to the list to code-gen
func_idx.for_each(|i| {
let _ = codegen.get_function(i, vec![]);
});
codegen.finish()
}
fn finish(mut self) -> (Module, JunoInfo) {
while !self.worklist.is_empty() {
let (idx, mut type_inst, func, entry) = self.worklist.pop_front().unwrap();
self.builder.set_function(func);
self.codegen_function(&self.funcs[idx], &mut type_inst, entry);
}
let CodeGenerator {
builder,
types: _,
funcs,
uid: _,
functions,
worklist: _,
juno_info,
} = self;
(builder.finish(), juno_info)
}
fn get_function(&mut self, func_idx: usize, ty_args: Vec<TypeID>) -> FunctionID {
let func_info = (func_idx, ty_args);
match self.functions.get(&func_info) {
Some(func_id) => *func_id,
None => {
let ty_args = func_info.1;
let func = &self.funcs[func_idx];
let mut solver_inst = self.types.create_instance(ty_args.clone());
// TODO: Ideally we would write out the type arguments, but now that they're
// lowered to TypeID we can't do that as far as I can tell
let name =
// For entry functions, we preserve the name, which is safe
// since they have no type variables and this is necessary
// to ensure easy ingestion into Rust
if func.entry { func.name.clone() }
else { format!("_{}_{}", self.uid, func.name) };
self.uid += 1;
let mut param_types = vec![];
for (_, ty) in func.arguments.iter() {
// Because we're building types, we can extract the builder
param_types.push(solver_inst.lower_type(&mut self.builder.builder, *ty));
}
let return_types =
func.return_types
.iter()
.map(|t| solver_inst.lower_type(&mut self.builder.builder, *t))
.collect::<Vec<_>>();
let (func_id, entry) = self
.builder
.create_function(
&name,
param_types,
return_types,
func.num_dyn_consts as u32,
func.entry,
)
.unwrap();
self.juno_info.func_info.func_ids[func_idx].push(func_id);
self.functions.insert((func_idx, ty_args), func_id);
self.worklist
.push_back((func_idx, solver_inst, func_id, entry));
func_id
}
}
}
fn codegen_function(&mut self, func: &Function, types: &mut TypeSolverInst, entry: NodeID) {
// Setup the SSA construction data structure
let mut ssa = SSA::new(entry);
// Create nodes for the arguments
for (idx, (var, _)) in func.arguments.iter().enumerate() {
let mut node_builder = self.builder.allocate_node();
ssa.write_variable(*var, entry, node_builder.id());
node_builder.build_parameter(idx);
self.builder.add_node(node_builder);
}
// Generate code for the body
let None = self.codegen_stmt(&func.body, types, &mut ssa, entry, &mut vec![]) else {
panic!("Generated code for a function missing a return")
};
}
fn codegen_stmt(
&mut self,
stmt: &Stmt,
types: &mut TypeSolverInst,
ssa: &mut SSA,
cur_block: NodeID,
loops: &mut LoopInfo,
) -> Option<NodeID> {
match stmt {
Stmt::AssignStmt { var, val } => {
let (val, block) = self.codegen_expr(val, types, ssa, cur_block);
ssa.write_variable(*var, block, val);
Some(block)
}
Stmt::IfStmt { cond, thn, els } => {
let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, cur_block);
let (mut if_node, block_then, block_else) =
ssa.create_cond(&mut self.builder, block_cond);
let then_end = self.codegen_stmt(thn, types, ssa, block_then, loops);
let else_end = match els {
None => Some(block_else),
Some(els_stmt) => self.codegen_stmt(els_stmt, types, ssa, block_else, loops),
};
let if_id = if_node.id();
if_node.build_if(block_cond, val_cond);
self.builder.add_node(if_node);
match (then_end, else_end) {
(None, els) => els,
(thn, None) => thn,
(Some(then_term), Some(else_term)) => {
let block_join = ssa.create_block(&mut self.builder);
ssa.add_pred(block_join, then_term);
ssa.add_pred(block_join, else_term);
ssa.seal_block(block_join, &mut self.builder);
Some(block_join)
}
}
}
Stmt::LoopStmt { cond, update, body } => {
// We generate guarded loops, so the first step is to create
// a conditional branch, branching on the condition
let (val_guard, block_guard) = self.codegen_expr(cond, types, ssa, cur_block);
let (mut if_node, true_guard, false_proj) =
ssa.create_cond(&mut self.builder, block_guard);
if_node.build_if(block_guard, val_guard);
self.builder.add_node(if_node);
// We then create a region for the exit (since there may be breaks)
let block_exit = ssa.create_block(&mut self.builder);
ssa.add_pred(block_exit, false_proj);
// Now, create a block for the loop's latch, we don't (currently) know any of its
// predecessors
let block_latch = ssa.create_block(&mut self.builder);
// Code-gen any update into the latch and then code-gen the condition
let block_updated = match update {
None => block_latch,
Some(stmt) => self
.codegen_stmt(stmt, types, ssa, block_latch, loops)
.expect("Loop update should return control"),
};
let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, block_updated);
let (mut if_node, true_proj, false_proj) =
ssa.create_cond(&mut self.builder, block_cond);
if_node.build_if(block_cond, val_cond);
self.builder.add_node(if_node);
// Add the false projection from the latch as a predecessor of the exit
ssa.add_pred(block_exit, false_proj);
// Create a block for the loop header, and add the true branches from the guard and
// latch as its only predecessors
let body_block = ssa.create_block(&mut self.builder);
ssa.add_pred(body_block, true_guard);
ssa.add_pred(body_block, true_proj);
ssa.seal_block(body_block, &mut self.builder);
// Generate code for the body
loops.push((block_latch, block_exit));
let body_res = self.codegen_stmt(body, types, ssa, body_block, loops);
loops.pop();
// If the body of the loop can reach some block, we add that block as a predecessor
// of the latch
match body_res {
None => {}
Some(block) => {
ssa.add_pred(block_latch, block);
}
}
// Seal remaining open blocks
ssa.seal_block(block_exit, &mut self.builder);
ssa.seal_block(block_latch, &mut self.builder);
// It is always assumed a loop may be skipped and so control can reach after the
// loop
Some(block_exit)
}
Stmt::ReturnStmt { exprs } => {
let mut vals = vec![];
let mut block = cur_block;
for expr in exprs {
let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, block);
vals.push(val_ret);
block = block_ret;
}
let mut return_node = self.builder.allocate_node();
return_node.build_return(block, vals);
self.builder.add_node(return_node);
None
}
Stmt::BreakStmt {} => {
let last_loop = loops.len() - 1;
let (_latch, exit) = loops[last_loop];
ssa.add_pred(exit, cur_block); // The block that contains this break now leads to
// the exit
None
}
Stmt::ContinueStmt {} => {
let last_loop = loops.len() - 1;
let (latch, _exit) = loops[last_loop];
ssa.add_pred(latch, cur_block); // The block that contains this continue now leads
// to the latch
None
}
Stmt::BlockStmt { body } => {
let mut block = Some(cur_block);
for stmt in body.iter() {
let new_block = self.codegen_stmt(stmt, types, ssa, block.unwrap(), loops);
block = new_block;
}
block
}
Stmt::ExprStmt { expr } => {
let (_val, block) = self.codegen_expr(expr, types, ssa, cur_block);
Some(block)
}
Stmt::LabeledStmt { label, stmt } => {
self.builder.push_label(*label);
let res = self.codegen_stmt(&*stmt, types, ssa, cur_block, loops);
self.builder.pop_label();
res
}
}
}
// The codegen_expr function returns a pair of node IDs, the first is the node whose value is
// the given expression and the second is the node of a control node at which the value is
// available
fn codegen_expr(
&mut self,
expr: &Expr,
types: &mut TypeSolverInst,
ssa: &mut SSA,
cur_block: NodeID,
) -> (NodeID, NodeID) {
match expr {
Expr::Variable { var, .. } => (
ssa.read_variable(*var, cur_block, &mut self.builder),
cur_block,
),
Expr::DynConst { val, .. } => {
let mut node = self.builder.allocate_node();
let node_id = node.id();
let dyn_const = val.build(&mut self.builder.builder);
node.build_dynamicconstant(dyn_const);
self.builder.add_node(node);
(node_id, cur_block)
}
Expr::Read { index, val, .. } => {
let (collection, block) = self.codegen_expr(val, types, ssa, cur_block);
let (indices, end_block) = self.codegen_indices(index, types, ssa, block);
let mut node = self.builder.allocate_node();
let node_id = node.id();
node.build_read(collection, indices.into());
self.builder.add_node(node);
(node_id, end_block)
}
Expr::Write {
index, val, rep, ..
} => {
let (collection, block) = self.codegen_expr(val, types, ssa, cur_block);
let (indices, idx_block) = self.codegen_indices(index, types, ssa, block);
let (replace, end_block) = self.codegen_expr(rep, types, ssa, idx_block);
let mut node = self.builder.allocate_node();
let node_id = node.id();
node.build_write(collection, replace, indices.into());
self.builder.add_node(node);
(node_id, end_block)
}
Expr::Tuple { vals, typ } => {
let mut block = cur_block;
let mut values = vec![];
for expr in vals {
let (val_expr, block_expr) = self.codegen_expr(expr, types, ssa, block);
block = block_expr;
values.push(val_expr);
}
let tuple_type = types.lower_type(&mut self.builder.builder, *typ);
(self.build_tuple(values, tuple_type), block)
}
Expr::Union { tag, val, typ } => {
let (value, block) = self.codegen_expr(val, types, ssa, cur_block);
let union_type = types.lower_type(&mut self.builder.builder, *typ);
(self.build_union(*tag, value, union_type), block)
}
Expr::Constant { val, .. } => {
let const_id = self.build_constant(val, types);
let mut val = self.builder.allocate_node();
let val_node = val.id();
val.build_constant(const_id);
self.builder.add_node(val);
(val_node, cur_block)
}
Expr::Zero { typ } => {
let type_id = types.lower_type(&mut self.builder.builder, *typ);
let zero_const = self.builder.builder.create_constant_zero(type_id);
let mut zero = self.builder.allocate_node();
let zero_val = zero.id();
zero.build_constant(zero_const);
self.builder.add_node(zero);
(zero_val, cur_block)
}
Expr::UnaryExp { op, expr, .. } => {
let (val, block) = self.codegen_expr(expr, types, ssa, cur_block);
let mut expr = self.builder.allocate_node();
let expr_id = expr.id();
expr.build_unary(
val,
match op {
UnaryOp::Negation => UnaryOperator::Neg,
UnaryOp::BitwiseNot => UnaryOperator::Not,
},
);
self.builder.add_node(expr);
(expr_id, block)
}
Expr::BinaryExp { op, lhs, rhs, .. } => {
let (val_lhs, block_lhs) = self.codegen_expr(lhs, types, ssa, cur_block);
let (val_rhs, block_rhs) = self.codegen_expr(rhs, types, ssa, block_lhs);
let mut expr = self.builder.allocate_node();
let expr_id = expr.id();
expr.build_binary(
val_lhs,
val_rhs,
match op {
BinaryOp::Add => BinaryOperator::Add,
BinaryOp::Sub => BinaryOperator::Sub,
BinaryOp::Mul => BinaryOperator::Mul,
BinaryOp::Div => BinaryOperator::Div,
BinaryOp::Mod => BinaryOperator::Rem,
BinaryOp::BitAnd => BinaryOperator::And,
BinaryOp::BitOr => BinaryOperator::Or,
BinaryOp::Xor => BinaryOperator::Xor,
BinaryOp::Lt => BinaryOperator::LT,
BinaryOp::Le => BinaryOperator::LTE,
BinaryOp::Gt => BinaryOperator::GT,
BinaryOp::Ge => BinaryOperator::GTE,
BinaryOp::Eq => BinaryOperator::EQ,
BinaryOp::Neq => BinaryOperator::NE,
BinaryOp::LShift => BinaryOperator::LSh,
BinaryOp::RShift => BinaryOperator::RSh,
},
);
let _ = self.builder.add_node(expr);
(expr_id, block_rhs)
}
Expr::CastExpr { expr, typ } => {
let type_id = types.lower_type(&mut self.builder.builder, *typ);
let (val, block) = self.codegen_expr(expr, types, ssa, cur_block);
let mut expr = self.builder.allocate_node();
let expr_id = expr.id();
expr.build_unary(val, UnaryOperator::Cast(type_id));
self.builder.add_node(expr);
(expr_id, block)
}
Expr::CondExpr { cond, thn, els, .. } => {
// Code-gen the condition
let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, cur_block);
// Create the if
let (mut if_builder, then_block, else_block) =
ssa.create_cond(&mut self.builder, block_cond);
if_builder.build_if(block_cond, val_cond);
self.builder.add_node(if_builder);
// Code-gen the branches
let (then_val, block_then) = self.codegen_expr(thn, types, ssa, then_block);
let (else_val, block_else) = self.codegen_expr(els, types, ssa, else_block);
// Create the join in the control-flow
let join = ssa.create_block(&mut self.builder);
ssa.add_pred(join, block_then);
ssa.add_pred(join, block_else);
ssa.seal_block(join, &mut self.builder);
// Create a phi that joins the two branches
let mut phi = self.builder.allocate_node();
let phi_id = phi.id();
phi.build_phi(join, vec![then_val, else_val].into());
self.builder.add_node(phi);
(phi_id, join)
}
Expr::CallExpr {
func,
ty_args,
dyn_consts,
args,
num_returns, // number of non-inout returns (which are first)
..
} => {
// We start by lowering the type arguments to TypeIDs
let mut type_params = vec![];
for typ in ty_args {
type_params.push(types.lower_type(&mut self.builder.builder, *typ));
}
// With the type arguments, we can now lookup the function
let call_func = self.get_function(*func, type_params);
// We then build the dynamic constants
let dynamic_constants =
TypeSolverInst::build_dyn_consts(&mut self.builder.builder, dyn_consts);
// Code gen for each argument in order
// For inouts, this becomes an ssa.read_variable
// We also record the variables which are our inouts
let mut block = cur_block;
let mut arg_vals = vec![];
let mut inouts = vec![];
for arg in args {
match arg {
Either::Left(exp) => {
let (val, new_block) = self.codegen_expr(exp, types, ssa, block);
block = new_block;
arg_vals.push(val);
}
Either::Right(var) => {
inouts.push(*var);
arg_vals.push(ssa.read_variable(*var, block, &mut self.builder));
}
}
}
// Create the call expression, a region specifically for it, and a region after that.
let call_region = ssa.create_block(&mut self.builder);
ssa.add_pred(call_region, block);
ssa.seal_block(call_region, &mut self.builder);
let after_call_region = ssa.create_block(&mut self.builder);
ssa.add_pred(after_call_region, call_region);
ssa.seal_block(after_call_region, &mut self.builder);
let mut call = self.builder.allocate_node();
let call_id = call.id();
call.build_call(
call_region,
call_func,
dynamic_constants.into(),
arg_vals.into(),
);
let _ = self.builder.add_node(call);
block = after_call_region;
// Read each of the "inout values" and perform the SSA update
let has_inouts = !inouts.is_empty();
for (idx, var) in inouts.into_iter().enumerate() {
let index = self.builder.builder.create_field_index(num_returns + idx);
let mut proj = self.builder.allocate_node();
let proj_id = proj.id();
proj.build_data_projection(call_id, index);
self.builder.add_node(proj);
ssa.write_variable(var, block, proj_id);
}
(call_id, block)
}
Expr::CallExtract { call, index, .. } => {
let (call, block) = self.codegen_expr(call, types, ssa, cur_block);
let mut proj = self.builder.allocate_node();
let proj_id = proj.id();
proj.build_data_projection(call, index);
self.builder.add_node(proj);
(proj_id, block)
}
Expr::Intrinsic {
id,
ty_args: _,
args,
..
} => {
// Code gen for each argument in order
let mut block = cur_block;
let mut arg_vals = vec![];
for arg in args {
let (val, new_block) = self.codegen_expr(arg, types, ssa, block);
block = new_block;
arg_vals.push(val);
}
// Create the intrinsic call expression
let mut call = self.builder.allocate_node();
let call_id = call.id();
call.build_intrinsic(*id, arg_vals.into());
let _ = self.builder.add_node(call);
(call_id, block)
}
}
}
// Convert a list of Index from the semantic analysis into a list of indices for the builder.
// Note that this takes and returns a block since expressions may involve control flow
fn codegen_indices(
&mut self,
index: &Vec<semant::Index>,
types: &mut TypeSolverInst,
ssa: &mut SSA,
cur_block: NodeID,
) -> (Vec<ir::Index>, NodeID) {
let mut block = cur_block;
let mut built_index = vec![];
for idx in index {
match idx {
semant::Index::Field(idx) => {
built_index.push(self.builder.builder.create_field_index(*idx));
}
semant::Index::Variant(idx) => {
built_index.push(self.builder.builder.create_variant_index(*idx));
}
semant::Index::Array(exps) => {
let mut expr_vals = vec![];
for exp in exps {
let (val, new_block) = self.codegen_expr(exp, types, ssa, block);
block = new_block;
expr_vals.push(val);
}
built_index.push(self.builder.builder.create_position_index(expr_vals.into()));
}
}
}
(built_index, block)
}
fn build_tuple(&mut self, exprs: Vec<NodeID>, typ: TypeID) -> NodeID {
let zero_const = self.builder.builder.create_constant_zero(typ);
let mut zero = self.builder.allocate_node();
let zero_val = zero.id();
zero.build_constant(zero_const);
self.builder.add_node(zero);
let mut val = zero_val;
for (idx, exp) in exprs.into_iter().enumerate() {
let mut write = self.builder.allocate_node();
let write_id = write.id();
let index = self.builder.builder.create_field_index(idx);
write.build_write(val, exp, vec![index].into());
self.builder.add_node(write);
val = write_id;
}
val
}
fn build_union(&mut self, tag: usize, val: NodeID, typ: TypeID) -> NodeID {
let zero_const = self.builder.builder.create_constant_zero(typ);
let mut zero = self.builder.allocate_node();
let zero_val = zero.id();
zero.build_constant(zero_const);
self.builder.add_node(zero);
let mut write = self.builder.allocate_node();
let write_id = write.id();
let index = self.builder.builder.create_variant_index(tag);
write.build_write(zero_val, val, vec![index].into());
self.builder.add_node(write);
write_id
}
fn build_constant<'a>(
&mut self,
(lit, typ): &semant::Constant,
types: &mut TypeSolverInst<'a>,
) -> ConstantID {
match lit {
Literal::Unit => self.builder.builder.create_constant_prod(vec![].into()),
Literal::Bool(val) => self.builder.builder.create_constant_bool(*val),
Literal::Integer(val) => {
let p = types.as_numeric_type(&mut self.builder.builder, *typ);
match p {
Primitive::I8 => self.builder.builder.create_constant_i8(*val as i8),
Primitive::I16 => self.builder.builder.create_constant_i16(*val as i16),
Primitive::I32 => self.builder.builder.create_constant_i32(*val as i32),
Primitive::I64 => self.builder.builder.create_constant_i64(*val as i64),
Primitive::U8 => self.builder.builder.create_constant_u8(*val as u8),
Primitive::U16 => self.builder.builder.create_constant_u16(*val as u16),
Primitive::U32 => self.builder.builder.create_constant_u32(*val as u32),
Primitive::U64 => self.builder.builder.create_constant_u64(*val as u64),
Primitive::F32 => self.builder.builder.create_constant_f32(*val as f32),
Primitive::F64 => self.builder.builder.create_constant_f64(*val as f64),
_ => panic!("Internal error in build_constant for integer"),
}
}
Literal::Float(val) => {
let p = types.as_numeric_type(&mut self.builder.builder, *typ);
match p {
Primitive::F32 => self.builder.builder.create_constant_f32(*val as f32),
Primitive::F64 => self.builder.builder.create_constant_f64(*val as f64),
_ => panic!("Internal error in build_constant for float"),
}
}
Literal::Tuple(vals) => {
let mut constants = vec![];
for val in vals {
constants.push(self.build_constant(val, types));
}
self.builder.builder.create_constant_prod(constants.into())
}
Literal::Sum(tag, val) => {
let constant = self.build_constant(val, types);
let type_id = types.lower_type(&mut self.builder.builder, *typ);
self.builder
.builder
.create_constant_sum(type_id, *tag as u32, constant)
.unwrap()
}
}
}
}