rt.rs 31.91 KiB
use std::collections::BTreeMap;
use std::fmt::{Error, Write};
use std::iter::zip;
use hercules_ir::*;
use crate::*;
/*
* Entry Hercules functions are lowered to async Rust code to achieve easy task
* level parallelism. This Rust is generated textually, and is included via a
* procedural macro in the user's Rust code.
*/
pub fn rt_codegen<W: Write>(
func_id: FunctionID,
module: &Module,
typing: &Vec<TypeID>,
control_subgraph: &Subgraph,
bbs: &BasicBlocks,
collection_objects: &CollectionObjects,
callgraph: &CallGraph,
devices: &Vec<Device>,
w: &mut W,
) -> Result<(), Error> {
let ctx = RTContext {
func_id,
module,
typing,
control_subgraph,
bbs,
collection_objects,
callgraph,
devices,
};
ctx.codegen_function(w)
}
struct RTContext<'a> {
func_id: FunctionID,
module: &'a Module,
typing: &'a Vec<TypeID>,
control_subgraph: &'a Subgraph,
bbs: &'a BasicBlocks,
collection_objects: &'a CollectionObjects,
callgraph: &'a CallGraph,
devices: &'a Vec<Device>,
}
impl<'a> RTContext<'a> {
fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
let func = &self.get_func();
// Dump the function signature.
write!(
w,
"#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe)]\nasync fn {}<'a>(",
func.name
)?;
let mut first_param = true;
// The first set of parameters are dynamic constants.
for idx in 0..func.num_dynamic_constants {
if first_param {
first_param = false;
} else {
write!(w, ", ")?;
}
write!(w, "dc_p{}: u64", idx)?;
}
// The second set of parameters are normal parameters.
for idx in 0..func.param_types.len() {
if first_param {
first_param = false;
} else {
write!(w, ", ")?;
}
if !self.module.types[func.param_types[idx].idx()].is_primitive() {
write!(w, "mut ")?;
}
write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?;
}
write!(w, ") -> {} {{\n", self.get_type(func.return_type))?;
// Allocate collection constants.
for object in self.collection_objects[&self.func_id].iter_objects() {
if let CollectionObjectOrigin::Constant(id) =
self.collection_objects[&self.func_id].origin(object)
{
let size = self.codegen_type_size(self.typing[id.idx()]);
write!(
w,
" let mut obj{}: ::hercules_rt::HerculesBox = unsafe {{ ::hercules_rt::HerculesBox::__zeros({}) }};\n",
object.idx(),
size
)?
}
}
// Dump signatures for called device functions.
write!(w, " extern \"C\" {{\n")?;
for callee in self.callgraph.get_callees(self.func_id) {
if self.devices[callee.idx()] == Device::AsyncRust {
continue;
}
let callee = &self.module.functions[callee.idx()];
write!(w, " fn {}(", callee.name)?;
let mut first_param = true;
for idx in 0..callee.num_dynamic_constants {
if first_param {
first_param = false;
} else {
write!(w, ", ")?;
}
write!(w, "dc{}: u64", idx)?;
}
for (idx, ty) in callee.param_types.iter().enumerate() {
if first_param {
first_param = false;
} else {
write!(w, ", ")?;
}
write!(w, "p{}: {}", idx, self.device_get_type(*ty))?;
}
write!(w, ") -> {};\n", self.device_get_type(callee.return_type))?;
}
write!(w, " }}\n")?;
// Declare intermediary variables for every value.
for idx in 0..func.nodes.len() {
if func.nodes[idx].is_control() {
continue;
}
write!(
w,
" let mut node_{}: {} = {};\n",
idx,
self.get_type(self.typing[idx]),
if self.module.types[self.typing[idx].idx()].is_integer() {
"0"
} else if self.module.types[self.typing[idx].idx()].is_float() {
"0.0"
} else {
"unsafe { ::hercules_rt::HerculesBox::__null() }"
}
)?;
}
// The core executor is a Rust loop. We literally run a "control token"
// as described in the original sea of nodes paper through the basic
// blocks to drive execution.
write!(
w,
" let mut control_token: i8 = 0;\n let return_value = loop {{\n match control_token {{\n",
)?;
let mut blocks: BTreeMap<_, _> = (0..func.nodes.len())
.filter(|idx| func.nodes[*idx].is_control())
.map(|idx| (NodeID::new(idx), String::new()))
.collect();
// Emit data flow into basic blocks.
for block in self.bbs.1.iter() {
for id in block {
self.codegen_data_node(*id, &mut blocks)?;
}
}
// Emit control flow into basic blocks.
for id in (0..func.nodes.len()).map(NodeID::new) {
if !func.nodes[id.idx()].is_control() {
continue;
}
self.codegen_control_node(id, &mut blocks)?;
}
// Dump the emitted basic blocks.
for (id, block) in blocks {
write!(
w,
" {} => {{\n{} }}\n",
id.idx(),
block
)?;
}
// Close the match and loop.
write!(w, " _ => panic!()\n }}\n }};\n")?;
// Emit the epilogue of the function.
write!(w, " unsafe {{\n")?;
for idx in 0..func.param_types.len() {
if !self.module.types[func.param_types[idx].idx()].is_primitive() {
write!(w, " p{}.__forget();\n", idx)?;
}
}
if !self.module.types[func.return_type.idx()].is_primitive() {
for object in self.collection_objects[&self.func_id].iter_objects() {
if let CollectionObjectOrigin::Constant(_) =
self.collection_objects[&self.func_id].origin(object)
{
write!(
w,
" if obj{}.__cmp_ids(&return_value) {{\n",
object.idx()
)?;
write!(w, " obj{}.__forget();\n", object.idx())?;
write!(w, " }}\n")?;
}
}
}
for idx in 0..func.nodes.len() {
if !func.nodes[idx].is_control()
&& !self.module.types[self.typing[idx].idx()].is_primitive()
{
write!(w, " node_{}.__forget();\n", idx)?;
}
}
write!(w, " }}\n")?;
write!(w, " return_value\n")?;
write!(w, "}}\n")?;
Ok(())
}
/*
* While control nodes in Hercules IR are predecessor-centric (each take a
* control input that defines the predecessor relationship), the Rust loop
* we generate is successor centric. This difference requires explicit
* translation.
*/
fn codegen_control_node(
&self,
id: NodeID,
blocks: &mut BTreeMap<NodeID, String>,
) -> Result<(), Error> {
let func = &self.get_func();
match func.nodes[id.idx()] {
// Start, region, and projection control nodes all have exactly one
// successor and are otherwise simple.
Node::Start
| Node::Region { preds: _ }
| Node::Projection {
control: _,
selection: _,
} => {
let block = &mut blocks.get_mut(&id).unwrap();
let succ = self.control_subgraph.succs(id).next().unwrap();
write!(block, " control_token = {};\n", succ.idx())?
}
// If nodes have two successors - examine the projections to
// determine which branch is which, and branch between them.
Node::If { control: _, cond } => {
let block = &mut blocks.get_mut(&id).unwrap();
let mut succs = self.control_subgraph.succs(id);
let succ1 = succs.next().unwrap();
let succ2 = succs.next().unwrap();
let succ1_is_true = func.nodes[succ1.idx()].try_projection(1).is_some();
write!(
block,
" control_token = if {} {{ {} }} else {{ {} }};\n",
self.get_value(cond),
if succ1_is_true { succ1 } else { succ2 }.idx(),
if succ1_is_true { succ2 } else { succ1 }.idx(),
)?
}
Node::Return { control: _, data } => {
let block = &mut blocks.get_mut(&id).unwrap();
if self.module.types[self.typing[data.idx()].idx()].is_primitive() {
write!(block, " break {};\n", self.get_value(data))?
} else {
write!(
block,
" break unsafe {{ {}.__clone() }};\n",
self.get_value(data)
)?
}
}
_ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
}
Ok(())
}
/*
* Lower data nodes in Hercules IR into Rust statements.
*/
fn codegen_data_node(
&self,
id: NodeID,
blocks: &mut BTreeMap<NodeID, String>,
) -> Result<(), Error> {
let func = &self.get_func();
match func.nodes[id.idx()] {
Node::Parameter { index } => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
if self.module.types[self.typing[id.idx()].idx()].is_primitive() {
write!(
block,
" {} = p{};\n",
self.get_value(id),
index
)?
} else {
write!(
block,
" {} = unsafe {{ p{}.__clone() }};\n",
self.get_value(id),
index
)?
}
}
Node::Constant { id: cons_id } => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
write!(block, " {} = ", self.get_value(id))?;
match self.module.constants[cons_id.idx()] {
Constant::Boolean(val) => write!(block, "{}bool", val)?,
Constant::Integer8(val) => write!(block, "{}i8", val)?,
Constant::Integer16(val) => write!(block, "{}i16", val)?,
Constant::Integer32(val) => write!(block, "{}i32", val)?,
Constant::Integer64(val) => write!(block, "{}i64", val)?,
Constant::UnsignedInteger8(val) => write!(block, "{}u8", val)?,
Constant::UnsignedInteger16(val) => write!(block, "{}u16", val)?,
Constant::UnsignedInteger32(val) => write!(block, "{}u32", val)?,
Constant::UnsignedInteger64(val) => write!(block, "{}u64", val)?,
Constant::Float32(val) => write!(block, "{}f32", val)?,
Constant::Float64(val) => write!(block, "{}f64", val)?,
Constant::Product(_, _) | Constant::Summation(_, _, _) | Constant::Array(_) => {
let objects = self.collection_objects[&self.func_id].objects(id);
assert_eq!(objects.len(), 1);
let object = objects[0];
write!(block, "unsafe {{ obj{}.__clone() }}", object.idx())?
}
}
write!(block, ";\n")?
}
Node::Call {
control: _,
function: callee_id,
ref dynamic_constants,
ref args,
} => {
let device = self.devices[callee_id.idx()];
match device {
// The device backends ensure that device functions have the
// same C interface.
Device::LLVM | Device::CUDA => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
let device = match device {
Device::LLVM => "cpu",
Device::CUDA => "cuda",
_ => panic!(),
};
// First, get the raw pointers to collections that the
// device function takes as input.
let callee_objs = &self.collection_objects[&callee_id];
for (idx, arg) in args.into_iter().enumerate() {
if let Some(obj) = callee_objs.param_to_object(idx) {
// Extract a raw pointer from the HerculesBox.
if callee_objs.is_mutated(obj) {
write!(
block,
" let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n",
idx,
self.get_value(*arg),
device
)?;
} else {
write!(
block,
" let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n",
idx,
self.get_value(*arg),
device
)?;
}
} else {
write!(
block,
" let arg_tmp{} = {};\n",
idx,
self.get_value(*arg)
)?;
}
}
// Emit the call.
write!(
block,
" let call_tmp = unsafe {{ {}(",
self.module.functions[callee_id.idx()].name
)?;
for dc in dynamic_constants {
self.codegen_dynamic_constant(*dc, block)?;
write!(block, ", ")?;
}
for idx in 0..args.len() {
write!(block, "arg_tmp{}, ", idx)?;
}
write!(block, ") }};\n")?;
// When a device function is called that returns a
// collection object, that object must have come from
// one of its parameters. Dynamically figure out which
// one it came from, so that we can move it to the slot
// of the output object.
let caller_objects = self.collection_objects[&self.func_id].objects(id);
if !caller_objects.is_empty() {
for (idx, arg) in args.into_iter().enumerate() {
if idx != 0 {
write!(block, " else\n")?;
}
write!(
block,
" if call_tmp == arg_tmp{} {{\n",
idx
)?;
write!(
block,
" {} = unsafe {{ {}.__clone() }};\n",
self.get_value(id),
self.get_value(*arg)
)?;
write!(block, " }}")?;
}
write!(block, " else {{\n")?;
write!(block, " panic!(\"HERCULES PANIC: Pointer returned from device function doesn't match an argument pointer.\");\n")?;
write!(block, " }}\n")?;
} else {
write!(
block,
" {} = call_tmp;\n",
self.get_value(id)
)?;
}
}
Device::AsyncRust => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
write!(
block,
" {} = {}(",
self.get_value(id),
self.module.functions[callee_id.idx()].name
)?;
for dc in dynamic_constants {
self.codegen_dynamic_constant(*dc, block)?;
write!(block, ", ")?;
}
for arg in args {
if self.module.types[self.typing[arg.idx()].idx()].is_primitive() {
write!(block, "{}, ", self.get_value(*arg))?;
} else {
write!(block, "unsafe {{ {}.__clone() }}, ", self.get_value(*arg))?;
}
}
write!(block, ").await;\n")?;
}
}
}
Node::Read {
collect,
ref indices,
} => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
let collect_ty = self.typing[collect.idx()];
let out_size = self.codegen_type_size(self.typing[id.idx()]);
let offset = self.codegen_index_math(collect_ty, indices)?;
write!(
block,
" let mut read_offset_obj = unsafe {{ {}.__clone() }};\n",
self.get_value(collect)
)?;
write!(
block,
" unsafe {{ read_offset_obj.__offset({}, {}) }};\n",
offset, out_size,
)?;
if self.module.types[self.typing[id.idx()].idx()].is_primitive() {
write!(
block,
" {} = unsafe {{ *(read_offset_obj.__cpu_ptr() as *const _) }};\n",
self.get_value(id)
)?;
write!(
block,
" unsafe {{ read_offset_obj.__forget() }};\n",
)?;
} else {
write!(
block,
" {} = read_offset_obj;\n",
self.get_value(id)
)?;
}
}
Node::Write {
collect,
data,
ref indices,
} => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
let collect_ty = self.typing[collect.idx()];
let data_size = self.codegen_type_size(self.typing[data.idx()]);
let offset = self.codegen_index_math(collect_ty, indices)?;
write!(
block,
" let mut write_offset_obj = unsafe {{ {}.__clone() }};\n",
self.get_value(collect)
)?;
write!(block, " let write_offset_ptr = unsafe {{ write_offset_obj.__cpu_ptr_mut().byte_add({}) }};\n", offset)?;
if self.module.types[self.typing[data.idx()].idx()].is_primitive() {
write!(
block,
" unsafe {{ *(write_offset_ptr as *mut _) = {} }};\n",
self.get_value(data)
)?;
} else {
write!(
block,
" unsafe {{ ::core::ptr::copy_nonoverlapping({}.__cpu_ptr(), write_offset_ptr as *mut _, {} as usize) }};\n",
self.get_value(data),
data_size,
)?;
}
write!(
block,
" {} = write_offset_obj;\n",
self.get_value(id),
)?;
}
_ => panic!(
"PANIC: Can't lower {:?} in {}.",
func.nodes[id.idx()],
func.name
),
}
Ok(())
}
/*
* Lower dynamic constant in Hercules IR into a Rust expression.
*/
fn codegen_dynamic_constant<W: Write>(
&self,
id: DynamicConstantID,
w: &mut W,
) -> Result<(), Error> {
match self.module.dynamic_constants[id.idx()] {
DynamicConstant::Constant(val) => write!(w, "{}", val)?,
DynamicConstant::Parameter(idx) => write!(w, "dc_p{}", idx)?,
DynamicConstant::Add(left, right) => {
write!(w, "(")?;
self.codegen_dynamic_constant(left, w)?;
write!(w, "+")?;
self.codegen_dynamic_constant(right, w)?;
write!(w, ")")?;
}
DynamicConstant::Sub(left, right) => {
write!(w, "(")?;
self.codegen_dynamic_constant(left, w)?;
write!(w, "-")?;
self.codegen_dynamic_constant(right, w)?;
write!(w, ")")?;
}
DynamicConstant::Mul(left, right) => {
write!(w, "(")?;
self.codegen_dynamic_constant(left, w)?;
write!(w, "*")?;
self.codegen_dynamic_constant(right, w)?;
write!(w, ")")?;
}
DynamicConstant::Div(left, right) => {
write!(w, "(")?;
self.codegen_dynamic_constant(left, w)?;
write!(w, "/")?;
self.codegen_dynamic_constant(right, w)?;
write!(w, ")")?;
}
DynamicConstant::Rem(left, right) => {
write!(w, "(")?;
self.codegen_dynamic_constant(left, w)?;
write!(w, "%")?;
self.codegen_dynamic_constant(right, w)?;
write!(w, ")")?;
}
DynamicConstant::Min(left, right) => {
write!(w, "::core::cmp::min(")?;
self.codegen_dynamic_constant(left, w)?;
write!(w, ",")?;
self.codegen_dynamic_constant(right, w)?;
write!(w, ")")?;
}
DynamicConstant::Max(left, right) => {
write!(w, "::core::cmp::max(")?;
self.codegen_dynamic_constant(left, w)?;
write!(w, ",")?;
self.codegen_dynamic_constant(right, w)?;
write!(w, ")")?;
}
}
Ok(())
}
/*
* Emit logic to index into an collection.
*/
fn codegen_index_math(
&self,
mut collect_ty: TypeID,
indices: &[Index],
) -> Result<String, Error> {
let mut acc_offset = "0".to_string();
for index in indices {
match index {
Index::Field(idx) => {
let Type::Product(ref fields) = self.module.types[collect_ty.idx()] else {
panic!()
};
// Get the offset of the field at index `idx` by calculating
// the product's size up to field `idx`, then offseting the
// base pointer by that amount.
for field in &fields[..*idx] {
let field_align = get_type_alignment(&self.module.types, *field);
let field = self.codegen_type_size(*field);
acc_offset = format!(
"((({} + {}) & !{}) + {})",
acc_offset,
field_align - 1,
field_align - 1,
field
);
}
let last_align = get_type_alignment(&self.module.types, fields[*idx]);
acc_offset = format!(
"(({} + {}) & !{})",
acc_offset,
last_align - 1,
last_align - 1
);
collect_ty = fields[*idx];
}
Index::Variant(idx) => {
// The tag of a summation is at the end of the summation, so
// the variant pointer is just the base pointer. Do nothing.
let Type::Summation(ref variants) = self.module.types[collect_ty.idx()] else {
panic!()
};
collect_ty = variants[*idx];
}
Index::Position(ref pos) => {
let Type::Array(elem, ref dims) = self.module.types[collect_ty.idx()] else {
panic!()
};
// The offset of the position into an array is:
//
// ((0 * s1 + p1) * s2 + p2) * s3 + p3 ...
let elem_size = self.codegen_type_size(elem);
for (p, s) in zip(pos, dims) {
let p = self.get_value(*p);
acc_offset = format!("{} * ", acc_offset);
self.codegen_dynamic_constant(*s, &mut acc_offset)?;
acc_offset = format!("({} + {})", acc_offset, p);
}
// Convert offset in # elements -> # bytes.
acc_offset = format!("({} * {})", acc_offset, elem_size);
collect_ty = elem;
}
}
}
Ok(acc_offset)
}
/*
* Lower the size of a type into a Rust expression.
*/
fn codegen_type_size(&self, ty: TypeID) -> String {
match self.module.types[ty.idx()] {
Type::Control => panic!(),
Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => "1".to_string(),
Type::Integer16 | Type::UnsignedInteger16 => "2".to_string(),
Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => "4".to_string(),
Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => "8".to_string(),
Type::Product(ref fields) => {
let fields_align = fields
.into_iter()
.map(|id| get_type_alignment(&self.module.types, *id));
let fields: Vec<String> = fields
.into_iter()
.map(|id| self.codegen_type_size(*id))
.collect();
// Emit LLVM IR to round up to the alignment of the next field,
// and then add the size of that field. At the end, round up to
// the alignment of the whole struct.
let mut acc_size = "0".to_string();
for (field_align, field) in zip(fields_align, fields) {
acc_size = format!(
"(({} + {}) & !{})",
acc_size,
field_align - 1,
field_align - 1
);
acc_size = format!("({} + {})", acc_size, field);
}
let total_align = get_type_alignment(&self.module.types, ty);
format!(
"(({} + {}) & !{})",
acc_size,
total_align - 1,
total_align - 1
)
}
Type::Summation(ref variants) => {
let variants = variants.into_iter().map(|id| self.codegen_type_size(*id));
// The size of a summation is the size of the largest field,
// plus 1 byte and alignment for the discriminant.
let mut acc_size = "0".to_string();
for variant in variants {
acc_size = format!("::core::cmp::max({}, {})", acc_size, variant);
}
// No alignment is necessary for the 1 byte discriminant.
let total_align = get_type_alignment(&self.module.types, ty);
format!(
"(({} + 1 + {}) & !{})",
acc_size,
total_align - 1,
total_align - 1
)
}
Type::Array(elem, ref bounds) => {
// The size of an array is the size of the element multipled by
// the dynamic constant bounds.
let mut acc_size = self.codegen_type_size(elem);
for dc in bounds {
acc_size = format!("{} * ", acc_size);
self.codegen_dynamic_constant(*dc, &mut acc_size).unwrap();
}
format!("({})", acc_size)
}
}
}
fn get_func(&self) -> &Function {
&self.module.functions[self.func_id.idx()]
}
fn get_value(&self, id: NodeID) -> String {
format!("node_{}", id.idx())
}
fn get_type(&self, id: TypeID) -> &'static str {
convert_type(&self.module.types[id.idx()])
}
fn device_get_type(&self, id: TypeID) -> &'static str {
device_convert_type(&self.module.types[id.idx()])
}
}
fn convert_type(ty: &Type) -> &'static str {
match ty {
Type::Boolean => "bool",
Type::Integer8 => "i8",
Type::Integer16 => "i16",
Type::Integer32 => "i32",
Type::Integer64 => "i64",
Type::UnsignedInteger8 => "u8",
Type::UnsignedInteger16 => "u16",
Type::UnsignedInteger32 => "u32",
Type::UnsignedInteger64 => "u64",
Type::Float32 => "f32",
Type::Float64 => "f64",
Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => {
"::hercules_rt::HerculesBox<'a>"
}
_ => panic!(),
}
}
fn device_convert_type(ty: &Type) -> &'static str {
match ty {
Type::Boolean => "bool",
Type::Integer8 => "i8",
Type::Integer16 => "i16",
Type::Integer32 => "i32",
Type::Integer64 => "i64",
Type::UnsignedInteger8 => "u8",
Type::UnsignedInteger16 => "u16",
Type::UnsignedInteger32 => "u32",
Type::UnsignedInteger64 => "u64",
Type::Float32 => "f32",
Type::Float64 => "f64",
Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "*mut u8",
_ => panic!(),
}
}