use std::collections::BTreeMap; use std::fmt::{Error, Write}; use std::iter::zip; use hercules_ir::*; use crate::*; /* * Entry Hercules functions are lowered to async Rust code to achieve easy task * level parallelism. This Rust is generated textually, and is included via a * procedural macro in the user's Rust code. */ pub fn rt_codegen<W: Write>( func_id: FunctionID, module: &Module, typing: &Vec<TypeID>, control_subgraph: &Subgraph, bbs: &BasicBlocks, collection_objects: &CollectionObjects, callgraph: &CallGraph, devices: &Vec<Device>, w: &mut W, ) -> Result<(), Error> { let ctx = RTContext { func_id, module, typing, control_subgraph, bbs, collection_objects, callgraph, devices, }; ctx.codegen_function(w) } struct RTContext<'a> { func_id: FunctionID, module: &'a Module, typing: &'a Vec<TypeID>, control_subgraph: &'a Subgraph, bbs: &'a BasicBlocks, collection_objects: &'a CollectionObjects, callgraph: &'a CallGraph, devices: &'a Vec<Device>, } impl<'a> RTContext<'a> { fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> { let func = &self.get_func(); // Dump the function signature. write!( w, "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe)]\nasync fn {}<'a>(", func.name )?; let mut first_param = true; // The first set of parameters are dynamic constants. for idx in 0..func.num_dynamic_constants { if first_param { first_param = false; } else { write!(w, ", ")?; } write!(w, "dc_p{}: u64", idx)?; } // The second set of parameters are normal parameters. for idx in 0..func.param_types.len() { if first_param { first_param = false; } else { write!(w, ", ")?; } if !self.module.types[func.param_types[idx].idx()].is_primitive() { write!(w, "mut ")?; } write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?; } write!(w, ") -> {} {{\n", self.get_type(func.return_type))?; // Allocate collection constants. for object in self.collection_objects[&self.func_id].iter_objects() { if let CollectionObjectOrigin::Constant(id) = self.collection_objects[&self.func_id].origin(object) { let size = self.codegen_type_size(self.typing[id.idx()]); write!( w, " let mut obj{}: ::hercules_rt::HerculesBox = unsafe {{ ::hercules_rt::HerculesBox::__zeros({}) }};\n", object.idx(), size )? } } // Dump signatures for called device functions. write!(w, " extern \"C\" {{\n")?; for callee in self.callgraph.get_callees(self.func_id) { if self.devices[callee.idx()] == Device::AsyncRust { continue; } let callee = &self.module.functions[callee.idx()]; write!(w, " fn {}(", callee.name)?; let mut first_param = true; for idx in 0..callee.num_dynamic_constants { if first_param { first_param = false; } else { write!(w, ", ")?; } write!(w, "dc{}: u64", idx)?; } for (idx, ty) in callee.param_types.iter().enumerate() { if first_param { first_param = false; } else { write!(w, ", ")?; } write!(w, "p{}: {}", idx, self.device_get_type(*ty))?; } write!(w, ") -> {};\n", self.device_get_type(callee.return_type))?; } write!(w, " }}\n")?; // Declare intermediary variables for every value. for idx in 0..func.nodes.len() { if func.nodes[idx].is_control() { continue; } write!( w, " let mut node_{}: {} = {};\n", idx, self.get_type(self.typing[idx]), if self.module.types[self.typing[idx].idx()].is_integer() { "0" } else if self.module.types[self.typing[idx].idx()].is_float() { "0.0" } else { "unsafe { ::hercules_rt::HerculesBox::__null() }" } )?; } // The core executor is a Rust loop. We literally run a "control token" // as described in the original sea of nodes paper through the basic // blocks to drive execution. write!( w, " let mut control_token: i8 = 0;\n let return_value = loop {{\n match control_token {{\n", )?; let mut blocks: BTreeMap<_, _> = (0..func.nodes.len()) .filter(|idx| func.nodes[*idx].is_control()) .map(|idx| (NodeID::new(idx), String::new())) .collect(); // Emit data flow into basic blocks. for block in self.bbs.1.iter() { for id in block { self.codegen_data_node(*id, &mut blocks)?; } } // Emit control flow into basic blocks. for id in (0..func.nodes.len()).map(NodeID::new) { if !func.nodes[id.idx()].is_control() { continue; } self.codegen_control_node(id, &mut blocks)?; } // Dump the emitted basic blocks. for (id, block) in blocks { write!( w, " {} => {{\n{} }}\n", id.idx(), block )?; } // Close the match and loop. write!(w, " _ => panic!()\n }}\n }};\n")?; // Emit the epilogue of the function. write!(w, " unsafe {{\n")?; for idx in 0..func.param_types.len() { if !self.module.types[func.param_types[idx].idx()].is_primitive() { write!(w, " p{}.__forget();\n", idx)?; } } if !self.module.types[func.return_type.idx()].is_primitive() { for object in self.collection_objects[&self.func_id].iter_objects() { if let CollectionObjectOrigin::Constant(_) = self.collection_objects[&self.func_id].origin(object) { write!( w, " if obj{}.__cmp_ids(&return_value) {{\n", object.idx() )?; write!(w, " obj{}.__forget();\n", object.idx())?; write!(w, " }}\n")?; } } } for idx in 0..func.nodes.len() { if !func.nodes[idx].is_control() && !self.module.types[self.typing[idx].idx()].is_primitive() { write!(w, " node_{}.__forget();\n", idx)?; } } write!(w, " }}\n")?; write!(w, " return_value\n")?; write!(w, "}}\n")?; Ok(()) } /* * While control nodes in Hercules IR are predecessor-centric (each take a * control input that defines the predecessor relationship), the Rust loop * we generate is successor centric. This difference requires explicit * translation. */ fn codegen_control_node( &self, id: NodeID, blocks: &mut BTreeMap<NodeID, String>, ) -> Result<(), Error> { let func = &self.get_func(); match func.nodes[id.idx()] { // Start, region, and projection control nodes all have exactly one // successor and are otherwise simple. Node::Start | Node::Region { preds: _ } | Node::Projection { control: _, selection: _, } => { let block = &mut blocks.get_mut(&id).unwrap(); let succ = self.control_subgraph.succs(id).next().unwrap(); write!(block, " control_token = {};\n", succ.idx())? } // If nodes have two successors - examine the projections to // determine which branch is which, and branch between them. Node::If { control: _, cond } => { let block = &mut blocks.get_mut(&id).unwrap(); let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); let succ1_is_true = func.nodes[succ1.idx()].try_projection(1).is_some(); write!( block, " control_token = if {} {{ {} }} else {{ {} }};\n", self.get_value(cond), if succ1_is_true { succ1 } else { succ2 }.idx(), if succ1_is_true { succ2 } else { succ1 }.idx(), )? } Node::Return { control: _, data } => { let block = &mut blocks.get_mut(&id).unwrap(); if self.module.types[self.typing[data.idx()].idx()].is_primitive() { write!(block, " break {};\n", self.get_value(data))? } else { write!( block, " break unsafe {{ {}.__clone() }};\n", self.get_value(data) )? } } _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), } Ok(()) } /* * Lower data nodes in Hercules IR into Rust statements. */ fn codegen_data_node( &self, id: NodeID, blocks: &mut BTreeMap<NodeID, String>, ) -> Result<(), Error> { let func = &self.get_func(); match func.nodes[id.idx()] { Node::Parameter { index } => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); if self.module.types[self.typing[id.idx()].idx()].is_primitive() { write!( block, " {} = p{};\n", self.get_value(id), index )? } else { write!( block, " {} = unsafe {{ p{}.__clone() }};\n", self.get_value(id), index )? } } Node::Constant { id: cons_id } => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); write!(block, " {} = ", self.get_value(id))?; match self.module.constants[cons_id.idx()] { Constant::Boolean(val) => write!(block, "{}bool", val)?, Constant::Integer8(val) => write!(block, "{}i8", val)?, Constant::Integer16(val) => write!(block, "{}i16", val)?, Constant::Integer32(val) => write!(block, "{}i32", val)?, Constant::Integer64(val) => write!(block, "{}i64", val)?, Constant::UnsignedInteger8(val) => write!(block, "{}u8", val)?, Constant::UnsignedInteger16(val) => write!(block, "{}u16", val)?, Constant::UnsignedInteger32(val) => write!(block, "{}u32", val)?, Constant::UnsignedInteger64(val) => write!(block, "{}u64", val)?, Constant::Float32(val) => write!(block, "{}f32", val)?, Constant::Float64(val) => write!(block, "{}f64", val)?, Constant::Product(_, _) | Constant::Summation(_, _, _) | Constant::Array(_) => { let objects = self.collection_objects[&self.func_id].objects(id); assert_eq!(objects.len(), 1); let object = objects[0]; write!(block, "unsafe {{ obj{}.__clone() }}", object.idx())? } } write!(block, ";\n")? } Node::Call { control: _, function: callee_id, ref dynamic_constants, ref args, } => { let device = self.devices[callee_id.idx()]; match device { // The device backends ensure that device functions have the // same C interface. Device::LLVM | Device::CUDA => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); let device = match device { Device::LLVM => "cpu", Device::CUDA => "cuda", _ => panic!(), }; // First, get the raw pointers to collections that the // device function takes as input. let callee_objs = &self.collection_objects[&callee_id]; for (idx, arg) in args.into_iter().enumerate() { if let Some(obj) = callee_objs.param_to_object(idx) { // Extract a raw pointer from the HerculesBox. if callee_objs.is_mutated(obj) { write!( block, " let arg_tmp{} = unsafe {{ {}.__{}_ptr_mut() }};\n", idx, self.get_value(*arg), device )?; } else { write!( block, " let arg_tmp{} = unsafe {{ {}.__{}_ptr() }};\n", idx, self.get_value(*arg), device )?; } } else { write!( block, " let arg_tmp{} = {};\n", idx, self.get_value(*arg) )?; } } // Emit the call. write!( block, " let call_tmp = unsafe {{ {}(", self.module.functions[callee_id.idx()].name )?; for dc in dynamic_constants { self.codegen_dynamic_constant(*dc, block)?; write!(block, ", ")?; } for idx in 0..args.len() { write!(block, "arg_tmp{}, ", idx)?; } write!(block, ") }};\n")?; // When a device function is called that returns a // collection object, that object must have come from // one of its parameters. Dynamically figure out which // one it came from, so that we can move it to the slot // of the output object. let caller_objects = self.collection_objects[&self.func_id].objects(id); if !caller_objects.is_empty() { for (idx, arg) in args.into_iter().enumerate() { if idx != 0 { write!(block, " else\n")?; } write!( block, " if call_tmp == arg_tmp{} {{\n", idx )?; write!( block, " {} = unsafe {{ {}.__clone() }};\n", self.get_value(id), self.get_value(*arg) )?; write!(block, " }}")?; } write!(block, " else {{\n")?; write!(block, " panic!(\"HERCULES PANIC: Pointer returned from device function doesn't match an argument pointer.\");\n")?; write!(block, " }}\n")?; } else { write!( block, " {} = call_tmp;\n", self.get_value(id) )?; } } Device::AsyncRust => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); write!( block, " {} = {}(", self.get_value(id), self.module.functions[callee_id.idx()].name )?; for dc in dynamic_constants { self.codegen_dynamic_constant(*dc, block)?; write!(block, ", ")?; } for arg in args { if self.module.types[self.typing[arg.idx()].idx()].is_primitive() { write!(block, "{}, ", self.get_value(*arg))?; } else { write!(block, "unsafe {{ {}.__clone() }}, ", self.get_value(*arg))?; } } write!(block, ").await;\n")?; } } } Node::Read { collect, ref indices, } => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); let collect_ty = self.typing[collect.idx()]; let out_size = self.codegen_type_size(self.typing[id.idx()]); let offset = self.codegen_index_math(collect_ty, indices)?; write!( block, " let mut read_offset_obj = unsafe {{ {}.__clone() }};\n", self.get_value(collect) )?; write!( block, " unsafe {{ read_offset_obj.__offset({}, {}) }};\n", offset, out_size, )?; if self.module.types[self.typing[id.idx()].idx()].is_primitive() { write!( block, " {} = unsafe {{ *(read_offset_obj.__cpu_ptr() as *const _) }};\n", self.get_value(id) )?; write!( block, " unsafe {{ read_offset_obj.__forget() }};\n", )?; } else { write!( block, " {} = read_offset_obj;\n", self.get_value(id) )?; } } Node::Write { collect, data, ref indices, } => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); let collect_ty = self.typing[collect.idx()]; let data_size = self.codegen_type_size(self.typing[data.idx()]); let offset = self.codegen_index_math(collect_ty, indices)?; write!( block, " let mut write_offset_obj = unsafe {{ {}.__clone() }};\n", self.get_value(collect) )?; write!(block, " let write_offset_ptr = unsafe {{ write_offset_obj.__cpu_ptr_mut().byte_add({}) }};\n", offset)?; if self.module.types[self.typing[data.idx()].idx()].is_primitive() { write!( block, " unsafe {{ *(write_offset_ptr as *mut _) = {} }};\n", self.get_value(data) )?; } else { write!( block, " unsafe {{ ::core::ptr::copy_nonoverlapping({}.__cpu_ptr(), write_offset_ptr as *mut _, {} as usize) }};\n", self.get_value(data), data_size, )?; } write!( block, " {} = write_offset_obj;\n", self.get_value(id), )?; } _ => panic!( "PANIC: Can't lower {:?} in {}.", func.nodes[id.idx()], func.name ), } Ok(()) } /* * Lower dynamic constant in Hercules IR into a Rust expression. */ fn codegen_dynamic_constant<W: Write>( &self, id: DynamicConstantID, w: &mut W, ) -> Result<(), Error> { match self.module.dynamic_constants[id.idx()] { DynamicConstant::Constant(val) => write!(w, "{}", val)?, DynamicConstant::Parameter(idx) => write!(w, "dc_p{}", idx)?, DynamicConstant::Add(left, right) => { write!(w, "(")?; self.codegen_dynamic_constant(left, w)?; write!(w, "+")?; self.codegen_dynamic_constant(right, w)?; write!(w, ")")?; } DynamicConstant::Sub(left, right) => { write!(w, "(")?; self.codegen_dynamic_constant(left, w)?; write!(w, "-")?; self.codegen_dynamic_constant(right, w)?; write!(w, ")")?; } DynamicConstant::Mul(left, right) => { write!(w, "(")?; self.codegen_dynamic_constant(left, w)?; write!(w, "*")?; self.codegen_dynamic_constant(right, w)?; write!(w, ")")?; } DynamicConstant::Div(left, right) => { write!(w, "(")?; self.codegen_dynamic_constant(left, w)?; write!(w, "/")?; self.codegen_dynamic_constant(right, w)?; write!(w, ")")?; } DynamicConstant::Rem(left, right) => { write!(w, "(")?; self.codegen_dynamic_constant(left, w)?; write!(w, "%")?; self.codegen_dynamic_constant(right, w)?; write!(w, ")")?; } DynamicConstant::Min(left, right) => { write!(w, "::core::cmp::min(")?; self.codegen_dynamic_constant(left, w)?; write!(w, ",")?; self.codegen_dynamic_constant(right, w)?; write!(w, ")")?; } DynamicConstant::Max(left, right) => { write!(w, "::core::cmp::max(")?; self.codegen_dynamic_constant(left, w)?; write!(w, ",")?; self.codegen_dynamic_constant(right, w)?; write!(w, ")")?; } } Ok(()) } /* * Emit logic to index into an collection. */ fn codegen_index_math( &self, mut collect_ty: TypeID, indices: &[Index], ) -> Result<String, Error> { let mut acc_offset = "0".to_string(); for index in indices { match index { Index::Field(idx) => { let Type::Product(ref fields) = self.module.types[collect_ty.idx()] else { panic!() }; // Get the offset of the field at index `idx` by calculating // the product's size up to field `idx`, then offseting the // base pointer by that amount. for field in &fields[..*idx] { let field_align = get_type_alignment(&self.module.types, *field); let field = self.codegen_type_size(*field); acc_offset = format!( "((({} + {}) & !{}) + {})", acc_offset, field_align - 1, field_align - 1, field ); } let last_align = get_type_alignment(&self.module.types, fields[*idx]); acc_offset = format!( "(({} + {}) & !{})", acc_offset, last_align - 1, last_align - 1 ); collect_ty = fields[*idx]; } Index::Variant(idx) => { // The tag of a summation is at the end of the summation, so // the variant pointer is just the base pointer. Do nothing. let Type::Summation(ref variants) = self.module.types[collect_ty.idx()] else { panic!() }; collect_ty = variants[*idx]; } Index::Position(ref pos) => { let Type::Array(elem, ref dims) = self.module.types[collect_ty.idx()] else { panic!() }; // The offset of the position into an array is: // // ((0 * s1 + p1) * s2 + p2) * s3 + p3 ... let elem_size = self.codegen_type_size(elem); for (p, s) in zip(pos, dims) { let p = self.get_value(*p); acc_offset = format!("{} * ", acc_offset); self.codegen_dynamic_constant(*s, &mut acc_offset)?; acc_offset = format!("({} + {})", acc_offset, p); } // Convert offset in # elements -> # bytes. acc_offset = format!("({} * {})", acc_offset, elem_size); collect_ty = elem; } } } Ok(acc_offset) } /* * Lower the size of a type into a Rust expression. */ fn codegen_type_size(&self, ty: TypeID) -> String { match self.module.types[ty.idx()] { Type::Control => panic!(), Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 => "1".to_string(), Type::Integer16 | Type::UnsignedInteger16 => "2".to_string(), Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => "4".to_string(), Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => "8".to_string(), Type::Product(ref fields) => { let fields_align = fields .into_iter() .map(|id| get_type_alignment(&self.module.types, *id)); let fields: Vec<String> = fields .into_iter() .map(|id| self.codegen_type_size(*id)) .collect(); // Emit LLVM IR to round up to the alignment of the next field, // and then add the size of that field. At the end, round up to // the alignment of the whole struct. let mut acc_size = "0".to_string(); for (field_align, field) in zip(fields_align, fields) { acc_size = format!( "(({} + {}) & !{})", acc_size, field_align - 1, field_align - 1 ); acc_size = format!("({} + {})", acc_size, field); } let total_align = get_type_alignment(&self.module.types, ty); format!( "(({} + {}) & !{})", acc_size, total_align - 1, total_align - 1 ) } Type::Summation(ref variants) => { let variants = variants.into_iter().map(|id| self.codegen_type_size(*id)); // The size of a summation is the size of the largest field, // plus 1 byte and alignment for the discriminant. let mut acc_size = "0".to_string(); for variant in variants { acc_size = format!("::core::cmp::max({}, {})", acc_size, variant); } // No alignment is necessary for the 1 byte discriminant. let total_align = get_type_alignment(&self.module.types, ty); format!( "(({} + 1 + {}) & !{})", acc_size, total_align - 1, total_align - 1 ) } Type::Array(elem, ref bounds) => { // The size of an array is the size of the element multipled by // the dynamic constant bounds. let mut acc_size = self.codegen_type_size(elem); for dc in bounds { acc_size = format!("{} * ", acc_size); self.codegen_dynamic_constant(*dc, &mut acc_size).unwrap(); } format!("({})", acc_size) } } } fn get_func(&self) -> &Function { &self.module.functions[self.func_id.idx()] } fn get_value(&self, id: NodeID) -> String { format!("node_{}", id.idx()) } fn get_type(&self, id: TypeID) -> &'static str { convert_type(&self.module.types[id.idx()]) } fn device_get_type(&self, id: TypeID) -> &'static str { device_convert_type(&self.module.types[id.idx()]) } } fn convert_type(ty: &Type) -> &'static str { match ty { Type::Boolean => "bool", Type::Integer8 => "i8", Type::Integer16 => "i16", Type::Integer32 => "i32", Type::Integer64 => "i64", Type::UnsignedInteger8 => "u8", Type::UnsignedInteger16 => "u16", Type::UnsignedInteger32 => "u32", Type::UnsignedInteger64 => "u64", Type::Float32 => "f32", Type::Float64 => "f64", Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { "::hercules_rt::HerculesBox<'a>" } _ => panic!(), } } fn device_convert_type(ty: &Type) -> &'static str { match ty { Type::Boolean => "bool", Type::Integer8 => "i8", Type::Integer16 => "i16", Type::Integer32 => "i32", Type::Integer64 => "i64", Type::UnsignedInteger8 => "u8", Type::UnsignedInteger16 => "u16", Type::UnsignedInteger32 => "u32", Type::UnsignedInteger64 => "u64", Type::Float32 => "f32", Type::Float64 => "f64", Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "*mut u8", _ => panic!(), } }