lib.rs 20.02 KiB
#![feature(iter_intersperse)]
extern crate anyhow;
extern crate hercules_cg;
extern crate hercules_ir;
extern crate postcard;
extern crate proc_macro;
use std::collections::{HashMap, HashSet};
use std::ffi::OsStr;
use std::fmt::Write;
use std::fs::File;
use std::io::prelude::*;
use std::path::Path;
use proc_macro::*;
use self::hercules_cg::*;
use self::hercules_ir::{DynamicConstant, DynamicConstantID};
/*
* Convert schedule IR types to the Rust types generated in the interface.
*/
fn generate_type_string(ty: &SType) -> String {
match ty {
SType::Boolean => "bool".to_string(),
SType::Integer8 => "i8".to_string(),
SType::Integer16 => "i16".to_string(),
SType::Integer32 => "i32".to_string(),
SType::Integer64 => "i64".to_string(),
SType::UnsignedInteger8 => "u8".to_string(),
SType::UnsignedInteger16 => "u16".to_string(),
SType::UnsignedInteger32 => "u32".to_string(),
SType::UnsignedInteger64 => "u64".to_string(),
SType::Float32 => "f32".to_string(),
SType::Float64 => "f64".to_string(),
SType::Product(fields) => {
fields.iter().fold("__Prod".to_string(), |acc, field| {
format!("{}_{}", acc, generate_type_name(field))
}) + "_"
}
SType::ArrayRef(elem) => format!("*mut {}", generate_type_string(elem)),
}
}
fn generate_type_name(ty: &SType) -> String {
match ty {
SType::Boolean
| SType::Integer8
| SType::Integer16
| SType::Integer32
| SType::Integer64
| SType::UnsignedInteger8
| SType::UnsignedInteger16
| SType::UnsignedInteger32
| SType::UnsignedInteger64
| SType::Float32
| SType::Float64 => generate_type_string(ty),
SType::Product(fields) => {
fields.iter().fold("__Prod".to_string(), |acc, field| {
format!("{}_{}", acc, generate_type_name(field))
}) + "_"
}
SType::ArrayRef(elem) => format!("ArrayRef_{}", generate_type_name(elem)),
}
}
fn compute_dynamic_constant<W: Write>(
dc: DynamicConstantID,
manifest: &Manifest,
rust_code: &mut W,
) -> Result<(), anyhow::Error> {
match manifest.dynamic_constants[dc.idx()] {
DynamicConstant::Constant(cons) => write!(rust_code, "{}", cons)?,
DynamicConstant::Parameter(idx) => write!(rust_code, "dc_{}", idx)?,
DynamicConstant::Add(left, right) => {
write!(rust_code, "(")?;
compute_dynamic_constant(left, manifest, rust_code)?;
write!(rust_code, " + ")?;
compute_dynamic_constant(right, manifest, rust_code)?;
write!(rust_code, ")")?;
}
DynamicConstant::Sub(left, right) => {
write!(rust_code, "(")?;
compute_dynamic_constant(left, manifest, rust_code)?;
write!(rust_code, " - ")?;
compute_dynamic_constant(right, manifest, rust_code)?;
write!(rust_code, ")")?;
}
DynamicConstant::Mul(left, right) => {
write!(rust_code, "(")?;
compute_dynamic_constant(left, manifest, rust_code)?;
write!(rust_code, " * ")?;
compute_dynamic_constant(right, manifest, rust_code)?;
write!(rust_code, ")")?;
}
DynamicConstant::Div(left, right) => {
write!(rust_code, "(")?;
compute_dynamic_constant(left, manifest, rust_code)?;
write!(rust_code, " / ")?;
compute_dynamic_constant(right, manifest, rust_code)?;
write!(rust_code, ")")?;
}
DynamicConstant::Rem(left, right) => {
write!(rust_code, "(")?;
compute_dynamic_constant(left, manifest, rust_code)?;
write!(rust_code, " % ")?;
compute_dynamic_constant(right, manifest, rust_code)?;
write!(rust_code, ")")?;
}
}
Ok(())
}
/*
* Generate async Rust code orchestrating partition execution.
*/
fn codegen(manifests: &HashMap<String, Manifest>) -> Result<String, anyhow::Error> {
// Write to a String containing all of the Rust code.
let mut rust_code = "".to_string();
// Rust doesn't allow you to send pointers between threads. In order to send
// pointers between threads, we need to wrap them in a struct that unsafely
// implements Send and Sync. This passes the responsibility of
// synchronization onto us, which we do by being careful with how we lower
// parallel code. Make this type generic so that we actually wrap all
// arguments in it for ease of macro codegen.
write!(
rust_code,
"#[derive(Clone, Copy, Debug)]\nstruct SendSyncWrapper<T: Copy>(T);\nunsafe impl<T: Copy> Send for SendSyncWrapper<T> {{}}\nunsafe impl<T: Copy> Sync for SendSyncWrapper<T> {{}}\n"
)?;
// Emit the product types used in this module. We can't just emit product
// types, since we need #[repr(C)] to interact with LLVM.
let all_stypes = manifests
.into_iter()
.map(|(_, manifest)| manifest.all_visible_types())
.flatten()
.collect::<HashSet<SType>>();
for stype in all_stypes.iter() {
if let Some(fields) = stype.try_product() {
write!(
rust_code,
"#[derive(Clone, Copy, Debug)]\n#[repr(C)]\nstruct {}({});\n",
generate_type_string(stype),
fields
.iter()
.map(|field| generate_type_string(field))
.intersperse(", ".to_string())
.fold("".to_string(), |acc, token| acc + &token)
)?;
}
}
// Emit the async Rust functions implementing each Hercules function.
for (function_name, manifest) in manifests.into_iter() {
// Emit the function signature.
write!(rust_code, "async unsafe fn {}(", function_name)?;
for (param_ty, param_kind) in manifest.param_types.iter() {
match param_kind {
ParameterKind::HerculesParameter(idx) => write!(rust_code, "param_{}", idx)?,
ParameterKind::DataInput(_) => panic!(
"PANIC: Parameter kind for Hercules function parameter cannot be DataInput."
),
ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}", idx)?,
ParameterKind::ArrayConstant(array_id) => {
write!(rust_code, "array_{}", array_id.idx())?
}
}
write!(rust_code, ": {}, ", generate_type_string(param_ty))?
}
write!(
rust_code,
") -> {} {{\n",
generate_type_string(&manifest.return_type)
)?;
// Compute the signature for each partition function and emit the extern
// function signatures.
write!(rust_code, " extern \"C\" {{\n")?;
for partition in manifest.partitions.iter() {
write!(rust_code, " fn {}(", partition.name)?;
// Add parameters for SFunction signature.
for (param_stype, kind) in partition.parameters.iter() {
match kind {
ParameterKind::HerculesParameter(idx) => write!(rust_code, "param_{}: ", idx)?,
ParameterKind::DataInput(id) => write!(rust_code, "data_{}: ", id.idx())?,
ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}: ", idx)?,
ParameterKind::ArrayConstant(id) => write!(rust_code, "array_{}: ", id.idx())?,
}
write!(rust_code, "{}, ", generate_type_string(param_stype))?;
}
// Add parameters for device specific lowering details.
if let DeviceManifest::CPU { parallel_launch } = &partition.device {
for parallel_launch_dim in 0..parallel_launch.len() {
write!(
rust_code,
"parallel_launch_low_{}: u64, parallel_launch_len_{}: u64, ",
parallel_launch_dim, parallel_launch_dim
)?;
}
}
// Add the return product of the SFunction signature.
let return_stype = if partition.returns.len() == 1 {
partition.returns[0].0.clone()
} else {
SType::Product(
partition
.returns
.iter()
.map(|(return_stype, _)| return_stype.clone())
.collect(),
)
};
write!(rust_code, ") -> {};\n", generate_type_string(&return_stype),)?;
}
write!(rust_code, " }}\n")?;
// Declare all of the intermediary data input / output variables. They
// are declared as MaybeUninit, since they get assigned after running a
// partition. MaybeUninits should always be defined before assume_init()
// is called on them, assuming a valid partitioning.
let mut data_inputs = HashSet::new();
let mut data_outputs = HashSet::new();
for partition in manifest.partitions.iter() {
data_inputs.extend(partition.data_inputs());
data_outputs.extend(partition.data_outputs());
}
assert_eq!(data_inputs, data_outputs);
for (node, stype) in data_inputs {
write!(rust_code, " let mut node_{}: ::core::mem::MaybeUninit<{}> = ::core::mem::MaybeUninit::uninit();\n", node.idx(), generate_type_string(stype))?;
}
// The core executor is a Rust loop. We literally run a "control token"
// as described in the original sea of nodes paper through the
// partitions to drive execution.
write!(
rust_code,
" let mut control_token: i8 = 0;\n loop {{\n",
)?;
// Match on the control token position to determine which partition to
// execute.
write!(rust_code, " match control_token {{\n")?;
// Emit the match arm per partition.
for (idx, partition) in manifest.partitions.iter().enumerate() {
// Open the arm.
write!(rust_code, " {} => {{\n", idx)?;
match partition.device {
DeviceManifest::CPU {
ref parallel_launch,
} => {
for (idx, (_, kind)) in partition.parameters.iter().enumerate() {
write!(
rust_code,
" let local_param_{} = SendSyncWrapper(",
idx
)?;
match kind {
ParameterKind::HerculesParameter(idx) => {
write!(rust_code, "param_{}", idx)?
}
ParameterKind::DataInput(id) => {
write!(rust_code, "node_{}.assume_init()", id.idx())?
}
ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}", idx)?,
ParameterKind::ArrayConstant(id) => {
write!(rust_code, "array_{}", id.idx())?
}
}
write!(rust_code, ");\n")?;
}
if parallel_launch.is_empty() {
// Call the partition function.
write!(
rust_code,
" let output = {}(",
partition.name
)?;
for idx in 0..partition.parameters.len() {
write!(rust_code, "local_param_{}.0, ", idx)?;
}
write!(rust_code, ");\n")?;
} else {
// Compute the dynamic constant bounds.
for (dim, (_, dc)) in parallel_launch.into_iter().enumerate() {
write!(rust_code, " let bound_{} = ", dim)?;
compute_dynamic_constant(*dc, manifest, &mut rust_code)?;
write!(rust_code, ";\n let low_{} = 0;\n", dim)?;
}
// Simultaneously calculate the tiles lows and lens and
// spawn the tiles. Emit the launches unrolled.
let mut tile = vec![0; parallel_launch.len()];
let total_num_tiles = parallel_launch
.into_iter()
.fold(1, |acc, (num_tiles, _)| acc * num_tiles);
for tile_num in 0..total_num_tiles {
// Calculate the lows and lens for this tile.
for (dim, tile) in tile.iter().enumerate() {
let num_tiles = parallel_launch[dim].0;
write!(
rust_code,
" let len_{} = bound_{} / {} + ({} < bound_{} % {}) as u64;\n",
dim, dim, num_tiles, tile, dim, num_tiles
)?;
}
// Spawn the tile. We need to explicitly copy the
// SendSyncWrappers, or else the path expression for
// the parameters get interpreted as what needs to
// be moved, when we want the wrapper itself to be
// what gets moved. Ugh.
write!(
rust_code,
" let tile_{} = async_std::task::spawn(async move {{ ",
tile_num,
)?;
for idx in 0..partition.parameters.len() {
write!(
rust_code,
"let local_param_{} = local_param_{}; ",
idx, idx
)?;
}
write!(rust_code, "SendSyncWrapper({}(", partition.name)?;
for idx in 0..partition.parameters.len() {
write!(rust_code, "local_param_{}.0, ", idx)?;
}
for dim in 0..parallel_launch.len() {
write!(rust_code, "low_{}, len_{}, ", dim, dim)?;
}
write!(rust_code, ")) }});\n")?;
// Go to the next tile.
for dim in (0..parallel_launch.len()).rev() {
tile[dim] += 1;
let num_tiles = parallel_launch[dim].0;
if tile[dim] < num_tiles {
write!(
rust_code,
" let low_{} = low_{} + len_{};\n",
dim, dim, dim
)?;
break;
} else {
tile[dim] = 0;
write!(rust_code, " let low_{} = 0;\n", dim)?;
}
}
}
// Join the JoinHandles, and get the output from one of
// them.
write!(
rust_code,
" let output = ::core::future::join!(",
)?;
for tile_num in 0..total_num_tiles {
write!(rust_code, "tile_{}, ", tile_num)?;
}
// join! unhelpfully returns either a tuple or a single
// value, but never a singleton tuple.
if total_num_tiles == 1 {
write!(rust_code, ").await.0;\n")?;
} else {
write!(rust_code, ").await.0.0;\n")?;
}
}
// Assign the outputs.
for (output_idx, (_, kind)) in partition.returns.iter().enumerate() {
let output_ref = if partition.returns.len() == 1 {
"output".to_string()
} else {
format!("output.{}", output_idx)
};
match kind {
ReturnKind::HerculesReturn => {
write!(rust_code, " return {};\n", output_ref)?
}
ReturnKind::DataOutput(id) => write!(
rust_code,
" node_{}.write({});\n",
id.idx(),
output_ref
)?,
ReturnKind::NextPartition => write!(
rust_code,
" control_token = {};\n",
output_ref
)?,
}
}
}
_ => todo!(),
}
// If there's only one partition successor, then an explicit
// NextPartition isn't returned - emit the new control token here.
if partition.successors.len() == 1 {
write!(
rust_code,
" control_token = {};\n",
partition.successors[0].idx()
)?;
}
// Close the arm.
write!(rust_code, " }}\n")?;
}
// Close the match, and handle invalid control token values.
write!(
rust_code,
" _ => panic!(\"PANIC: Invalid control token value.\"),\n }}\n"
)?;
// Close the loop.
write!(rust_code, " }}\n")?;
// Close the function.
write!(rust_code, "}}\n")?;
}
Ok(rust_code)
}
/*
* Generate the async Rust runtime from the manifest of a Hercules module.
*/
#[proc_macro]
pub fn use_hman(path: TokenStream) -> TokenStream {
use TokenTree::Literal;
// Get the path as a Rust path object, and make sure it's a .hman file.
let mut tokens_iter = path.into_iter();
let token = tokens_iter
.next()
.expect("Please provide a path to a .hman file to the use_hman! macro.");
assert!(tokens_iter.next().is_none(), "Too many tokens provided to the use_hman! macro. Please provide only one path to a .hman file.");
let literal = if let Literal(literal) = token {
literal
} else {
panic!("Please provide a string literal containing the path to a .hman file to the use_hman! macro.");
};
let literal_string = literal.to_string();
let path = Path::new(&literal_string[1..(literal_string.len() - 1)]);
assert_eq!(
path.extension(),
Some(OsStr::new("hman")),
"Please provide only .hman files to the use_hman! macro."
);
assert_eq!(
path.try_exists().ok(),
Some(true),
"Please provide a valid path to a .hman file to the use_hman! macro."
);
// Load manifest from path.
let mut f = File::open(path).unwrap();
let mut buffer = vec![];
f.read_to_end(&mut buffer).unwrap();
let manifests = postcard::from_bytes(&buffer).unwrap();
// Generate Rust code.
let rust_code = codegen(&manifests).unwrap();
eprintln!("{}", rust_code);
rust_code.parse().unwrap()
}