Skip to content
Snippets Groups Projects
Commit 260f6f47 authored by Aaron Councilman's avatar Aaron Councilman
Browse files

Fixes for cpu and rt

parent 44f38882
No related branches found
No related tags found
1 merge request!196Multi return
......@@ -88,6 +88,11 @@ impl<'a> CPUContext<'a> {
.collect::<Vec<_>>()
.join(", "),
)?;
write!(
w,
"define dso_local void @{}(",
self.function.name,
)?;
}
let mut first_param = true;
// The first parameter is a pointer to CPU backing memory, if it's
......
......@@ -211,7 +211,7 @@ impl<'a> RTContext<'a> {
if is_single_return {
write!(w, "extern \"C\" {{")?;
}
self.write_device_signature_async(w, *callee_id)?;
self.write_device_signature_async(w, *callee_id, !is_single_return)?;
if is_single_return {
write!(w, ";}}")?;
} else {
......@@ -1200,9 +1200,9 @@ impl<'a> RTContext<'a> {
// references.
let func = self.get_func();
let param_devices = &self.node_colors.1;
let return_device = self.node_colors.2;
let return_devices = &self.node_colors.2;
let mut param_muts = vec![false; func.param_types.len()];
let mut return_mut = true;
let mut return_muts = vec![true; func.return_types.len()];
let objects = &self.collection_objects[&self.func_id];
for idx in 0..func.param_types.len() {
if let Some(object) = objects.param_to_object(idx)
......@@ -1211,11 +1211,14 @@ impl<'a> RTContext<'a> {
param_muts[idx] = true;
}
}
for object in objects.returned_objects() {
if let Some(idx) = objects.origin(*object).try_parameter()
&& !param_muts[idx]
{
return_mut = false;
let num_returns = func.return_types.len();
for idx in 0..num_returns {
for object in objects.returned_objects(idx) {
if let Some(param_idx) = objects.origin(*object).try_parameter()
&& !param_muts[param_idx]
{
return_muts[idx] = false;
}
}
}
......@@ -1245,27 +1248,38 @@ impl<'a> RTContext<'a> {
}
write!(w, "}}}}")?;
// Every reference that may be returned has the same lifetime. Every
// other reference gets its own unique lifetime.
let returned_origins: HashSet<_> = self.collection_objects[&self.func_id]
.returned_objects()
.into_iter()
.map(|obj| self.collection_objects[&self.func_id].origin(*obj))
.collect();
write!(w, "async fn run<'runner, 'returned")?;
for idx in 0..func.param_types.len() {
write!(w, ", 'p{}", idx)?;
// Each returned reference, input reference, and the runner will have
// its own lifetime. We use lifetime bounds to ensure that the runner
// and parameters are borrowed for the lifetimes needed by the outputs
let returned_origins: Vec<HashSet<_>> = (0..num_returns)
.map(|idx| objects.returned_objects(idx)
.iter()
.map(|obj| objects.origin(*obj))
.collect()
).collect();
write!(w, "async fn run<'runner:")?;
for (ret_idx, origins) in returned_origins.iter().enumerate() {
if origins.iter().any(|origin| !origin.is_parameter()) {
write!(w, " 'r{} +", ret_idx)?;
}
}
write!(
w,
">(&'{} mut self",
if returned_origins.iter().any(|origin| !origin.is_parameter()) {
"returned"
} else {
"runner"
for idx in 0..num_returns {
write!(w, ", 'r{}", idx)?;
}
for idx in 0..func.param_types.len() {
write!(w, ", 'p{}:", idx)?;
for (ret_idx, origins) in returned_origins.iter().enumerate() {
if origins.iter().any(|origin| origin
.try_parameter()
.map(|oidx| idx == oidx)
.unwrap_or(false))
{
write!(w, " 'r{} +", ret_idx)?;
}
}
)?;
}
write!(w, ">(&'runner mut self")?;
for idx in 0..func.num_dynamic_constants {
write!(w, ", dc_p{}: u64", idx)?;
}
......@@ -1281,37 +1295,35 @@ impl<'a> RTContext<'a> {
let mutability = if param_muts[idx] { "Mut" } else { "" };
write!(
w,
", p{}: ::hercules_rt::Hercules{}Ref{}<'{}>",
", p{}: ::hercules_rt::Hercules{}Ref{}<'p{}>",
idx,
device,
mutability,
if returned_origins.iter().any(|origin| origin
.try_parameter()
.map(|oidx| idx == oidx)
.unwrap_or(false))
{
"returned".to_string()
} else {
format!("p{}", idx)
}
idx,
)?;
}
}
if self.module.types[func.return_type.idx()].is_primitive() {
write!(w, ") -> {} {{", self.get_type(func.return_type))?;
} else {
let device = match return_device {
Some(Device::LLVM) | None => "CPU",
Some(Device::CUDA) => "CUDA",
_ => panic!(),
};
let mutability = if return_mut { "Mut" } else { "" };
write!(
w,
") -> ::hercules_rt::Hercules{}Ref{}<'returned> {{",
device, mutability
)?;
}
write!(w, ") -> {}{}{} {{",
if num_returns != 1 { "(" } else { "" },
func.return_types.iter().enumerate()
.map(|(ret_idx, typ)|
if self.module.types[typ.idx()].is_primitive() {
self.get_type(*typ)
} else {
let device = match return_devices[ret_idx] {
Some(Device::LLVM) | None => "CPU",
Some(Device::CUDA) => "CUDA",
_ => panic!(),
};
let mutability = if return_muts[ret_idx] { "Mut" } else { "" };
format!("::hercules_rt::Hercules{}Ref{}<'r{}>",
device, mutability, ret_idx)
}
)
.collect::<Vec<_>>()
.join(", "),
if num_returns != 1 { ")" } else { "" },
)?;
// Start with possibly re-allocating the backing memory if it's not
// large enough.
......@@ -1367,22 +1379,48 @@ impl<'a> RTContext<'a> {
write!(w, "p{}, ", idx)?;
}
write!(w, ").await;")?;
if self.module.types[func.return_type.idx()].is_primitive() {
write!(w, " ret")?;
// Return the result, appropriately wrapping pointers
if num_returns == 1 {
if self.module.types[func.return_types[0].idx()].is_primitive() {
write!(w, "ret")?;
} else {
let device = match return_devices[0] {
Some(Device::LLVM) | None => "CPU",
Some(Device::CUDA) => "CUDA",
_ => panic!(),
};
let mutability = if return_muts[0] { "Mut" } else { "" };
write!(
w,
"::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)",
device,
mutability,
self.codegen_type_size(func.return_types[0])
)?;
}
} else {
let device = match return_device {
Some(Device::LLVM) | None => "CPU",
Some(Device::CUDA) => "CUDA",
_ => panic!(),
};
let mutability = if return_mut { "Mut" } else { "" };
write!(
w,
"::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)",
device,
mutability,
self.codegen_type_size(func.return_type)
)?;
write!(w, "(")?;
for (idx, typ) in func.return_types.iter().enumerate() {
if self.module.types[typ.idx()].is_primitive() {
write!(w, "ret.{},", idx)?;
} else {
let device = match return_devices[idx] {
Some(Device::LLVM) | None => "CPU",
Some(Device::CUDA) => "CUDA",
_ => panic!(),
};
let mutability = if return_muts[idx] { "Mut" } else { "" };
write!(
w,
"::hercules_rt::Hercules{}Ref{}::__from_parts(ret.{}.0, {} as usize),",
device,
mutability,
idx,
self.codegen_type_size(func.return_types[idx]),
)?;
}
}
write!(w, ")")?;
}
write!(w, "}}}}")?;
......@@ -1476,9 +1514,9 @@ impl<'a> RTContext<'a> {
// this means that if the function is multi-return it will return a product in the produced
// Rust code
// Writes from the "fn" keyword up to the end of the return type
fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID) -> Result<(), Error> {
fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID, is_unsafe: bool) -> Result<(), Error> {
let func = &self.module.functions[func_id.idx()];
write!(w, "fn {}(", func.name)?;
write!(w, "{}fn {}(", if is_unsafe { "unsafe " } else { "" }, func.name)?;
let mut first_param = true;
if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) {
first_param = false;
......
......@@ -33,4 +33,3 @@ ccp(*);
dce(*);
gcm(*);
xdot[true](*);
#![feature(concat_idents)]
juno_build::juno!("median");
juno_build::juno!("multi_return");
use hercules_rt::{runner, HerculesImmBox, HerculesImmBoxTo};
fn main() {
let m = vec![
86, 72, 14, 5, 55, 25, 98, 89, 3, 66, 44, 81, 27, 3, 40, 18, 4, 57, 93, 34, 70, 50, 50, 18,
34,
];
let m = HerculesImmBox::from(m.as_slice());
const N: usize = 32;
let a: Box<[f32]> = (1..=N).map(|i| i as f32).collect();
let a = HerculesImmBox::from(a.as_ref());
let mut r = runner!(rolling_sum_prod);
let (sums, prods) = async_std::task::block_on(async { r.run(N as u64, a.to()).await });
let mut r = runner!(median_window);
let res = async_std::task::block_on(async { r.run(m.to()).await });
assert_eq!(res, 57);
println!("Partial Sums: {:?}", sums.as_slice::<f32>());
println!("Partial Prods: {:?}", prods.as_slice::<f32>());
}
#[test]
fn test_median_window() {
fn test_multi_return() {
main()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment