Skip to content
Snippets Groups Projects
lib.rs 20.02 KiB
#![feature(iter_intersperse)]

extern crate anyhow;
extern crate hercules_cg;
extern crate hercules_ir;
extern crate postcard;
extern crate proc_macro;

use std::collections::{HashMap, HashSet};
use std::ffi::OsStr;
use std::fmt::Write;
use std::fs::File;
use std::io::prelude::*;
use std::path::Path;

use proc_macro::*;

use self::hercules_cg::*;
use self::hercules_ir::{DynamicConstant, DynamicConstantID};

/*
 * Convert schedule IR types to the Rust types generated in the interface.
 */
fn generate_type_string(ty: &SType) -> String {
    match ty {
        SType::Boolean => "bool".to_string(),
        SType::Integer8 => "i8".to_string(),
        SType::Integer16 => "i16".to_string(),
        SType::Integer32 => "i32".to_string(),
        SType::Integer64 => "i64".to_string(),
        SType::UnsignedInteger8 => "u8".to_string(),
        SType::UnsignedInteger16 => "u16".to_string(),
        SType::UnsignedInteger32 => "u32".to_string(),
        SType::UnsignedInteger64 => "u64".to_string(),
        SType::Float32 => "f32".to_string(),
        SType::Float64 => "f64".to_string(),
        SType::Product(fields) => {
            fields.iter().fold("__Prod".to_string(), |acc, field| {
                format!("{}_{}", acc, generate_type_name(field))
            }) + "_"
        }
        SType::ArrayRef(elem) => format!("*mut {}", generate_type_string(elem)),
    }
}

fn generate_type_name(ty: &SType) -> String {
    match ty {
        SType::Boolean
        | SType::Integer8
        | SType::Integer16
        | SType::Integer32
        | SType::Integer64
        | SType::UnsignedInteger8
        | SType::UnsignedInteger16
        | SType::UnsignedInteger32
        | SType::UnsignedInteger64
        | SType::Float32
        | SType::Float64 => generate_type_string(ty),
        SType::Product(fields) => {
            fields.iter().fold("__Prod".to_string(), |acc, field| {
                format!("{}_{}", acc, generate_type_name(field))
            }) + "_"
        }
        SType::ArrayRef(elem) => format!("ArrayRef_{}", generate_type_name(elem)),
    }
}

fn compute_dynamic_constant<W: Write>(
    dc: DynamicConstantID,
    manifest: &Manifest,
    rust_code: &mut W,
) -> Result<(), anyhow::Error> {
    match manifest.dynamic_constants[dc.idx()] {
        DynamicConstant::Constant(cons) => write!(rust_code, "{}", cons)?,
        DynamicConstant::Parameter(idx) => write!(rust_code, "dc_{}", idx)?,
        DynamicConstant::Add(left, right) => {
            write!(rust_code, "(")?;
            compute_dynamic_constant(left, manifest, rust_code)?;
            write!(rust_code, " + ")?;
            compute_dynamic_constant(right, manifest, rust_code)?;
            write!(rust_code, ")")?;
        }
        DynamicConstant::Sub(left, right) => {
            write!(rust_code, "(")?;
            compute_dynamic_constant(left, manifest, rust_code)?;
            write!(rust_code, " - ")?;
            compute_dynamic_constant(right, manifest, rust_code)?;
            write!(rust_code, ")")?;
        }
        DynamicConstant::Mul(left, right) => {
            write!(rust_code, "(")?;
            compute_dynamic_constant(left, manifest, rust_code)?;
            write!(rust_code, " * ")?;
            compute_dynamic_constant(right, manifest, rust_code)?;
            write!(rust_code, ")")?;
        }
        DynamicConstant::Div(left, right) => {
            write!(rust_code, "(")?;
            compute_dynamic_constant(left, manifest, rust_code)?;
            write!(rust_code, " / ")?;
            compute_dynamic_constant(right, manifest, rust_code)?;
            write!(rust_code, ")")?;
        }
        DynamicConstant::Rem(left, right) => {
            write!(rust_code, "(")?;
            compute_dynamic_constant(left, manifest, rust_code)?;
            write!(rust_code, " % ")?;
            compute_dynamic_constant(right, manifest, rust_code)?;
            write!(rust_code, ")")?;
        }
    }
    Ok(())
}

/*
 * Generate async Rust code orchestrating partition execution.
 */
fn codegen(manifests: &HashMap<String, Manifest>) -> Result<String, anyhow::Error> {
    // Write to a String containing all of the Rust code.
    let mut rust_code = "".to_string();

    // Rust doesn't allow you to send pointers between threads. In order to send
    // pointers between threads, we need to wrap them in a struct that unsafely
    // implements Send and Sync. This passes the responsibility of
    // synchronization onto us, which we do by being careful with how we lower
    // parallel code. Make this type generic so that we actually wrap all
    // arguments in it for ease of macro codegen.
    write!(
        rust_code,
        "#[derive(Clone, Copy, Debug)]\nstruct SendSyncWrapper<T: Copy>(T);\nunsafe impl<T: Copy> Send for SendSyncWrapper<T> {{}}\nunsafe impl<T: Copy> Sync for SendSyncWrapper<T> {{}}\n"
    )?;

    // Emit the product types used in this module. We can't just emit product
    // types, since we need #[repr(C)] to interact with LLVM.
    let all_stypes = manifests
        .into_iter()
        .map(|(_, manifest)| manifest.all_visible_types())
        .flatten()
        .collect::<HashSet<SType>>();
    for stype in all_stypes.iter() {
        if let Some(fields) = stype.try_product() {
            write!(
                rust_code,
                "#[derive(Clone, Copy, Debug)]\n#[repr(C)]\nstruct {}({});\n",
                generate_type_string(stype),
                fields
                    .iter()
                    .map(|field| generate_type_string(field))
                    .intersperse(", ".to_string())
                    .fold("".to_string(), |acc, token| acc + &token)
            )?;
        }
    }

    // Emit the async Rust functions implementing each Hercules function.
    for (function_name, manifest) in manifests.into_iter() {
        // Emit the function signature.
        write!(rust_code, "async unsafe fn {}(", function_name)?;
        for (param_ty, param_kind) in manifest.param_types.iter() {
            match param_kind {
                ParameterKind::HerculesParameter(idx) => write!(rust_code, "param_{}", idx)?,
                ParameterKind::DataInput(_) => panic!(
                    "PANIC: Parameter kind for Hercules function parameter cannot be DataInput."
                ),
                ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}", idx)?,
                ParameterKind::ArrayConstant(array_id) => {
                    write!(rust_code, "array_{}", array_id.idx())?
                }
            }
            write!(rust_code, ": {}, ", generate_type_string(param_ty))?
        }
        write!(
            rust_code,
            ") -> {} {{\n",
            generate_type_string(&manifest.return_type)
        )?;

        // Compute the signature for each partition function and emit the extern
        // function signatures.
        write!(rust_code, "    extern \"C\" {{\n")?;
        for partition in manifest.partitions.iter() {
            write!(rust_code, "        fn {}(", partition.name)?;

            // Add parameters for SFunction signature.
            for (param_stype, kind) in partition.parameters.iter() {
                match kind {
                    ParameterKind::HerculesParameter(idx) => write!(rust_code, "param_{}: ", idx)?,
                    ParameterKind::DataInput(id) => write!(rust_code, "data_{}: ", id.idx())?,
                    ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}: ", idx)?,
                    ParameterKind::ArrayConstant(id) => write!(rust_code, "array_{}: ", id.idx())?,
                }
                write!(rust_code, "{}, ", generate_type_string(param_stype))?;
            }

            // Add parameters for device specific lowering details.
            if let DeviceManifest::CPU { parallel_launch } = &partition.device {
                for parallel_launch_dim in 0..parallel_launch.len() {
                    write!(
                        rust_code,
                        "parallel_launch_low_{}: u64, parallel_launch_len_{}: u64, ",
                        parallel_launch_dim, parallel_launch_dim
                    )?;
                }
            }

            // Add the return product of the SFunction signature.
            let return_stype = if partition.returns.len() == 1 {
                partition.returns[0].0.clone()
            } else {
                SType::Product(
                    partition
                        .returns
                        .iter()
                        .map(|(return_stype, _)| return_stype.clone())
                        .collect(),
                )
            };
            write!(rust_code, ") -> {};\n", generate_type_string(&return_stype),)?;
        }
        write!(rust_code, "    }}\n")?;

        // Declare all of the intermediary data input / output variables. They
        // are declared as MaybeUninit, since they get assigned after running a
        // partition. MaybeUninits should always be defined before assume_init()
        // is called on them, assuming a valid partitioning.
        let mut data_inputs = HashSet::new();
        let mut data_outputs = HashSet::new();
        for partition in manifest.partitions.iter() {
            data_inputs.extend(partition.data_inputs());
            data_outputs.extend(partition.data_outputs());
        }
        assert_eq!(data_inputs, data_outputs);
        for (node, stype) in data_inputs {
            write!(rust_code, "    let mut node_{}: ::core::mem::MaybeUninit<{}> = ::core::mem::MaybeUninit::uninit();\n", node.idx(), generate_type_string(stype))?;
        }

        // The core executor is a Rust loop. We literally run a "control token"
        // as described in the original sea of nodes paper through the
        // partitions to drive execution.
        write!(
            rust_code,
            "    let mut control_token: i8 = 0;\n    loop {{\n",
        )?;

        // Match on the control token position to determine which partition to
        // execute.
        write!(rust_code, "        match control_token {{\n")?;

        // Emit the match arm per partition.
        for (idx, partition) in manifest.partitions.iter().enumerate() {
            // Open the arm.
            write!(rust_code, "            {} => {{\n", idx)?;

            match partition.device {
                DeviceManifest::CPU {
                    ref parallel_launch,
                } => {
                    for (idx, (_, kind)) in partition.parameters.iter().enumerate() {
                        write!(
                            rust_code,
                            "                let local_param_{} = SendSyncWrapper(",
                            idx
                        )?;
                        match kind {
                            ParameterKind::HerculesParameter(idx) => {
                                write!(rust_code, "param_{}", idx)?
                            }
                            ParameterKind::DataInput(id) => {
                                write!(rust_code, "node_{}.assume_init()", id.idx())?
                            }
                            ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}", idx)?,
                            ParameterKind::ArrayConstant(id) => {
                                write!(rust_code, "array_{}", id.idx())?
                            }
                        }
                        write!(rust_code, ");\n")?;
                    }

                    if parallel_launch.is_empty() {
                        // Call the partition function.
                        write!(
                            rust_code,
                            "                let output = {}(",
                            partition.name
                        )?;
                        for idx in 0..partition.parameters.len() {
                            write!(rust_code, "local_param_{}.0, ", idx)?;
                        }
                        write!(rust_code, ");\n")?;
                    } else {
                        // Compute the dynamic constant bounds.
                        for (dim, (_, dc)) in parallel_launch.into_iter().enumerate() {
                            write!(rust_code, "                let bound_{} = ", dim)?;
                            compute_dynamic_constant(*dc, manifest, &mut rust_code)?;
                            write!(rust_code, ";\n                let low_{} = 0;\n", dim)?;
                        }

                        // Simultaneously calculate the tiles lows and lens and
                        // spawn the tiles. Emit the launches unrolled.
                        let mut tile = vec![0; parallel_launch.len()];
                        let total_num_tiles = parallel_launch
                            .into_iter()
                            .fold(1, |acc, (num_tiles, _)| acc * num_tiles);
                        for tile_num in 0..total_num_tiles {
                            // Calculate the lows and lens for this tile.
                            for (dim, tile) in tile.iter().enumerate() {
                                let num_tiles = parallel_launch[dim].0;
                                write!(
                                    rust_code,
                                    "                let len_{} = bound_{} / {} + ({} < bound_{} % {}) as u64;\n",
                                    dim, dim, num_tiles, tile, dim, num_tiles
                                )?;
                            }

                            // Spawn the tile. We need to explicitly copy the
                            // SendSyncWrappers, or else the path expression for
                            // the parameters get interpreted as what needs to
                            // be moved, when we want the wrapper itself to be
                            // what gets moved. Ugh.
                            write!(
                                rust_code,
                                "                let tile_{} = async_std::task::spawn(async move {{ ",
                                tile_num,
                            )?;
                            for idx in 0..partition.parameters.len() {
                                write!(
                                    rust_code,
                                    "let local_param_{} = local_param_{}; ",
                                    idx, idx
                                )?;
                            }
                            write!(rust_code, "SendSyncWrapper({}(", partition.name)?;
                            for idx in 0..partition.parameters.len() {
                                write!(rust_code, "local_param_{}.0, ", idx)?;
                            }
                            for dim in 0..parallel_launch.len() {
                                write!(rust_code, "low_{}, len_{}, ", dim, dim)?;
                            }
                            write!(rust_code, ")) }});\n")?;

                            // Go to the next tile.
                            for dim in (0..parallel_launch.len()).rev() {
                                tile[dim] += 1;
                                let num_tiles = parallel_launch[dim].0;
                                if tile[dim] < num_tiles {
                                    write!(
                                        rust_code,
                                        "                let low_{} = low_{} + len_{};\n",
                                        dim, dim, dim
                                    )?;
                                    break;
                                } else {
                                    tile[dim] = 0;
                                    write!(rust_code, "                let low_{} = 0;\n", dim)?;
                                }
                            }
                        }

                        // Join the JoinHandles, and get the output from one of
                        // them.
                        write!(
                            rust_code,
                            "                let output = ::core::future::join!(",
                        )?;
                        for tile_num in 0..total_num_tiles {
                            write!(rust_code, "tile_{}, ", tile_num)?;
                        }
                        // join! unhelpfully returns either a tuple or a single
                        // value, but never a singleton tuple.
                        if total_num_tiles == 1 {
                            write!(rust_code, ").await.0;\n")?;
                        } else {
                            write!(rust_code, ").await.0.0;\n")?;
                        }
                    }

                    // Assign the outputs.
                    for (output_idx, (_, kind)) in partition.returns.iter().enumerate() {
                        let output_ref = if partition.returns.len() == 1 {
                            "output".to_string()
                        } else {
                            format!("output.{}", output_idx)
                        };
                        match kind {
                            ReturnKind::HerculesReturn => {
                                write!(rust_code, "                return {};\n", output_ref)?
                            }
                            ReturnKind::DataOutput(id) => write!(
                                rust_code,
                                "                node_{}.write({});\n",
                                id.idx(),
                                output_ref
                            )?,
                            ReturnKind::NextPartition => write!(
                                rust_code,
                                "                control_token = {};\n",
                                output_ref
                            )?,
                        }
                    }
                }
                _ => todo!(),
            }

            // If there's only one partition successor, then an explicit
            // NextPartition isn't returned - emit the new control token here.
            if partition.successors.len() == 1 {
                write!(
                    rust_code,
                    "                control_token = {};\n",
                    partition.successors[0].idx()
                )?;
            }

            // Close the arm.
            write!(rust_code, "            }}\n")?;
        }

        // Close the match, and handle invalid control token values.
        write!(
            rust_code,
            "            _ => panic!(\"PANIC: Invalid control token value.\"),\n        }}\n"
        )?;

        // Close the loop.
        write!(rust_code, "    }}\n")?;

        // Close the function.
        write!(rust_code, "}}\n")?;
    }

    Ok(rust_code)
}

/*
 * Generate the async Rust runtime from the manifest of a Hercules module.
 */
#[proc_macro]
pub fn use_hman(path: TokenStream) -> TokenStream {
    use TokenTree::Literal;

    // Get the path as a Rust path object, and make sure it's a .hman file.
    let mut tokens_iter = path.into_iter();
    let token = tokens_iter
        .next()
        .expect("Please provide a path to a .hman file to the use_hman! macro.");
    assert!(tokens_iter.next().is_none(), "Too many tokens provided to the use_hman! macro. Please provide only one path to a .hman file.");
    let literal = if let Literal(literal) = token {
        literal
    } else {
        panic!("Please provide a string literal containing the path to a .hman file to the use_hman! macro.");
    };
    let literal_string = literal.to_string();
    let path = Path::new(&literal_string[1..(literal_string.len() - 1)]);
    assert_eq!(
        path.extension(),
        Some(OsStr::new("hman")),
        "Please provide only .hman files to the use_hman! macro."
    );
    assert_eq!(
        path.try_exists().ok(),
        Some(true),
        "Please provide a valid path to a .hman file to the use_hman! macro."
    );

    // Load manifest from path.
    let mut f = File::open(path).unwrap();
    let mut buffer = vec![];
    f.read_to_end(&mut buffer).unwrap();
    let manifests = postcard::from_bytes(&buffer).unwrap();

    // Generate Rust code.
    let rust_code = codegen(&manifests).unwrap();
    eprintln!("{}", rust_code);
    rust_code.parse().unwrap()
}