From cf7e2a0334fa413427ab0897c2e6766d3d82e8d5 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 4 Feb 2025 22:26:25 -0600 Subject: [PATCH 1/7] wrap pointers internally in RT functions in a send+sync wrapper --- hercules_cg/src/rt.rs | 30 ++++++++++++++++++++++-------- hercules_rt/src/lib.rs | 13 +++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 35334a14..1d96cff6 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -137,7 +137,11 @@ impl<'a> RTContext<'a> { } else { write!(w, ", ")?; } - write!(w, "backing_{}: *mut u8", device.name())?; + write!( + w, + "backing_{}: ::hercules_rt::__RawPtrSendSync", + device.name() + )?; } // The second set of parameters are dynamic constants. for idx in 0..func.num_dynamic_constants { @@ -170,7 +174,7 @@ impl<'a> RTContext<'a> { let mut first_param = true; if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { first_param = false; - write!(w, "backing: *mut u8")?; + write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?; } for idx in 0..callee.num_dynamic_constants { if first_param { @@ -207,7 +211,7 @@ impl<'a> RTContext<'a> { } else if self.module.types[self.typing[idx].idx()].is_float() { "0.0" } else { - "::core::ptr::null_mut()" + "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" } )?; } @@ -362,7 +366,7 @@ impl<'a> RTContext<'a> { if let Some((size, device)) = size_and_device { write!( block, - " ::hercules_rt::__{}_zero_mem({}, {} as usize);\n", + " ::hercules_rt::__{}_zero_mem({}.0, {} as usize);\n", device.name(), self.get_value(id), size @@ -876,12 +880,20 @@ impl<'a> RTContext<'a> { } for idx in 0..func.param_types.len() { if !self.module.types[func.param_types[idx].idx()].is_primitive() { - write!(w, " let p{} = p{}.__ptr();\n", idx, idx)?; + write!( + w, + " let p{} = ::hercules_rt::__RawPtrSendSync(p{}.__ptr());\n", + idx, idx + )?; } } write!(w, " let ret = {}(", func.name)?; for (device, _) in self.backing_allocations[&self.func_id].iter() { - write!(w, "self.backing_ptr_{}, ", device.name())?; + write!( + w, + "::hercules_rt::__RawPtrSendSync(self.backing_ptr_{}), ", + device.name() + )?; } for idx in 0..func.num_dynamic_constants { write!(w, "dc_p{}, ", idx)?; @@ -901,7 +913,7 @@ impl<'a> RTContext<'a> { let mutability = if return_mut { "Mut" } else { "" }; write!( w, - " ::hercules_rt::Hercules{}Ref{}::__from_parts(ret, {} as usize)\n", + " ::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)\n", device, mutability, self.codegen_type_size(func.return_type) @@ -952,7 +964,9 @@ fn convert_type(ty: &Type) -> &'static str { Type::UnsignedInteger64 => "u64", Type::Float32 => "f32", Type::Float64 => "f64", - Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "*mut u8", + Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { + "::hercules_rt::__RawPtrSendSync" + } _ => panic!(), } } diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 2ad72043..12b64fa3 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -354,3 +354,16 @@ macro_rules! runner { <concat_idents!(HerculesRunner_, $x)>::new() }; } + +#[derive(Debug, Clone, Copy)] +#[repr(transparent)] +pub struct __RawPtrSendSync(pub *mut u8); + +impl __RawPtrSendSync { + pub unsafe fn byte_add(self, add: usize) -> Self { + __RawPtrSendSync(self.0.byte_add(add)) + } +} + +unsafe impl Send for __RawPtrSendSync {} +unsafe impl Sync for __RawPtrSendSync {} -- GitLab From 44c419c0cbb849edfb9bea65b09c046c7ee04e9e Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Tue, 4 Feb 2025 22:53:23 -0600 Subject: [PATCH 2/7] Simplify RT backend by using prettyplease for rust formatting --- Cargo.lock | 12 ++++ hercules_cg/src/rt.rs | 122 +++++++++++++++++++------------------- juno_scheduler/Cargo.toml | 2 + juno_scheduler/src/pm.rs | 17 +++--- 4 files changed, 83 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index af7902c6..e761361b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1222,7 +1222,9 @@ dependencies = [ "lrlex", "lrpar", "postcard", + "prettyplease", "serde", + "syn 2.0.96", "tempfile", ] @@ -1739,6 +1741,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +dependencies = [ + "proc-macro2", + "syn 2.0.96", +] + [[package]] name = "proc-macro-error" version = "1.0.4" diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 1d96cff6..34c5acda 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -126,7 +126,7 @@ impl<'a> RTContext<'a> { // Dump the function signature. write!( w, - "#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]\nasync unsafe fn {}(", + "#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]async unsafe fn {}(", func.name )?; let mut first_param = true; @@ -161,16 +161,16 @@ impl<'a> RTContext<'a> { } write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?; } - write!(w, ") -> {} {{\n", self.get_type(func.return_type))?; + write!(w, ") -> {} {{", self.get_type(func.return_type))?; // Dump signatures for called device functions. - write!(w, " extern \"C\" {{\n")?; + write!(w, "extern \"C\" {{")?; for callee_id in self.callgraph.get_callees(self.func_id) { if self.devices[callee_id.idx()] == Device::AsyncRust { continue; } let callee = &self.module.functions[callee_id.idx()]; - write!(w, " fn {}(", callee.name)?; + write!(w, "fn {}(", callee.name)?; let mut first_param = true; if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { first_param = false; @@ -192,9 +192,9 @@ impl<'a> RTContext<'a> { } write!(w, "p{}: {}", idx, self.get_type(*ty))?; } - write!(w, ") -> {};\n", self.get_type(callee.return_type))?; + write!(w, ") -> {};", self.get_type(callee.return_type))?; } - write!(w, " }}\n")?; + write!(w, "}}")?; // Declare intermediary variables for every value. for idx in 0..func.nodes.len() { @@ -203,7 +203,7 @@ impl<'a> RTContext<'a> { } write!( w, - " let mut node_{}: {} = {};\n", + "let mut node_{}: {} = {};", idx, self.get_type(self.typing[idx]), if self.module.types[self.typing[idx].idx()].is_integer() { @@ -221,7 +221,7 @@ impl<'a> RTContext<'a> { // blocks to drive execution. write!( w, - " let mut control_token: i8 = 0;\n loop {{\n match control_token {{\n", + "let mut control_token: i8 = 0;loop {{match control_token {{", )?; let mut blocks: BTreeMap<_, _> = (0..func.nodes.len()) @@ -246,17 +246,12 @@ impl<'a> RTContext<'a> { // Dump the emitted basic blocks. for (id, block) in blocks { - write!( - w, - " {} => {{\n{} }}\n", - id.idx(), - block - )?; + write!(w, "{} => {{{}}}", id.idx(), block)?; } // Close the match and loop. - write!(w, " _ => panic!()\n }}\n }}\n")?; - write!(w, "}}\n")?; + write!(w, "_ => panic!()}}}}")?; + write!(w, "}}")?; Ok(()) } @@ -283,7 +278,7 @@ impl<'a> RTContext<'a> { } => { let block = &mut blocks.get_mut(&id).unwrap(); let succ = self.control_subgraph.succs(id).next().unwrap(); - write!(block, " control_token = {};\n", succ.idx())? + write!(block, "control_token = {};", succ.idx())? } // If nodes have two successors - examine the projections to // determine which branch is which, and branch between them. @@ -295,7 +290,7 @@ impl<'a> RTContext<'a> { let succ1_is_true = func.nodes[succ1.idx()].try_projection(1).is_some(); write!( block, - " control_token = if {} {{ {} }} else {{ {} }};\n", + "control_token = if {} {{{}}} else {{{}}};", self.get_value(cond), if succ1_is_true { succ1 } else { succ2 }.idx(), if succ1_is_true { succ2 } else { succ1 }.idx(), @@ -303,7 +298,7 @@ impl<'a> RTContext<'a> { } Node::Return { control: _, data } => { let block = &mut blocks.get_mut(&id).unwrap(); - write!(block, " return {};\n", self.get_value(data))? + write!(block, "return {};", self.get_value(data))? } _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), } @@ -322,16 +317,11 @@ impl<'a> RTContext<'a> { match func.nodes[id.idx()] { Node::Parameter { index } => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); - write!( - block, - " {} = p{};\n", - self.get_value(id), - index - )? + write!(block, "{} = p{};", self.get_value(id), index)? } Node::Constant { id: cons_id } => { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); - write!(block, " {} = ", self.get_value(id))?; + write!(block, "{} = ", self.get_value(id))?; let mut size_and_device = None; match self.module.constants[cons_id.idx()] { Constant::Boolean(val) => write!(block, "{}bool", val)?, @@ -361,12 +351,12 @@ impl<'a> RTContext<'a> { size_and_device = Some((self.codegen_type_size(ty), device)); } } - write!(block, ";\n")?; + write!(block, ";")?; if !func.schedules[id.idx()].contains(&Schedule::NoResetConstant) { if let Some((size, device)) = size_and_device { write!( block, - " ::hercules_rt::__{}_zero_mem({}.0, {} as usize);\n", + "::hercules_rt::__{}_zero_mem({}.0, {} as usize);", device.name(), self.get_value(id), size @@ -385,7 +375,7 @@ impl<'a> RTContext<'a> { let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); write!( block, - " {} = {}(", + "{} = {}(", self.get_value(id), self.module.functions[callee_id.idx()].name )?; @@ -406,9 +396,9 @@ impl<'a> RTContext<'a> { } let device = self.devices[callee_id.idx()]; if device == Device::AsyncRust { - write!(block, ").await;\n")?; + write!(block, ").await;")?; } else { - write!(block, ");\n")?; + write!(block, ");")?; } } Node::Unary { op, input } => { @@ -416,19 +406,19 @@ impl<'a> RTContext<'a> { match op { UnaryOperator::Not => write!( block, - " {} = !{};\n", + "{} = !{};", self.get_value(id), self.get_value(input) )?, UnaryOperator::Neg => write!( block, - " {} = -{};\n", + "{} = -{};", self.get_value(id), self.get_value(input) )?, UnaryOperator::Cast(ty) => write!( block, - " {} = {} as {};\n", + "{} = {} as {};", self.get_value(id), self.get_value(input), self.get_type(ty) @@ -458,7 +448,7 @@ impl<'a> RTContext<'a> { write!( block, - " {} = {} {} {};\n", + "{} = {} {} {};", self.get_value(id), self.get_value(left), op, @@ -475,7 +465,7 @@ impl<'a> RTContext<'a> { match op { TernaryOperator::Select => write!( block, - " {} = if {} {{ {} }} else {{ {} }};\n", + "{} = if {} {{{}}} else {{{}}};", self.get_value(id), self.get_value(first), self.get_value(second), @@ -799,29 +789,29 @@ impl<'a> RTContext<'a> { // Emit the type definition. A runner object owns its backing memory. write!( w, - "#[allow(non_camel_case_types)]\nstruct HerculesRunner_{} {{\n", + "#[allow(non_camel_case_types)]struct HerculesRunner_{} {{", func.name )?; for (device, _) in self.backing_allocations[&self.func_id].iter() { - write!(w, " backing_ptr_{}: *mut u8,\n", device.name(),)?; - write!(w, " backing_size_{}: usize,\n", device.name(),)?; + write!(w, "backing_ptr_{}: *mut u8,", device.name(),)?; + write!(w, "backing_size_{}: usize,", device.name(),)?; } - write!(w, "}}\n")?; + write!(w, "}}")?; write!( w, - "impl HerculesRunner_{} {{\n fn new() -> Self {{\n Self {{\n", + "impl HerculesRunner_{} {{fn new() -> Self {{Self {{", func.name )?; for (device, _) in self.backing_allocations[&self.func_id].iter() { write!( w, - " backing_ptr_{}: ::core::ptr::null_mut(),\n backing_size_{}: 0,\n", + "backing_ptr_{}: ::core::ptr::null_mut(),backing_size_{}: 0,", device.name(), device.name() )?; } - write!(w, " }}\n }}\n")?; - write!(w, " async fn run<'a>(&'a mut self")?; + write!(w, "}}}}")?; + write!(w, "async fn run<'a>(&'a mut self")?; for idx in 0..func.num_dynamic_constants { write!(w, ", dc_p{}: u64", idx)?; } @@ -846,7 +836,7 @@ impl<'a> RTContext<'a> { } } if self.module.types[func.return_type.idx()].is_primitive() { - write!(w, ") -> {} {{\n", self.get_type(func.return_type))?; + write!(w, ") -> {} {{", self.get_type(func.return_type))?; } else { let device = match return_device { Some(Device::LLVM) => "CPU", @@ -856,38 +846,46 @@ impl<'a> RTContext<'a> { let mutability = if return_mut { "Mut" } else { "" }; write!( w, - ") -> ::hercules_rt::Hercules{}Ref{}<'a> {{\n", + ") -> ::hercules_rt::Hercules{}Ref{}<'a> {{", device, mutability )?; } - write!(w, " unsafe {{\n")?; + write!(w, "unsafe {{")?; for (device, (total, _)) in self.backing_allocations[&self.func_id].iter() { - write!(w, " let size = ")?; + write!(w, "let size = ")?; self.codegen_dynamic_constant(*total, w)?; write!( w, - " as usize;\n if self.backing_size_{} < size {{\n", + " as usize;if self.backing_size_{} < size {{", device.name() )?; - write!(w, " ::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});\n", device.name(), device.name(), device.name())?; write!( w, - " self.backing_size_{} = size;\n", + "::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});", + device.name(), + device.name(), + device.name() + )?; + write!(w, "self.backing_size_{} = size;", device.name())?; + write!( + w, + "self.backing_ptr_{} = ::hercules_rt::__{}_alloc(self.backing_size_{});", + device.name(), + device.name(), device.name() )?; - write!(w, " self.backing_ptr_{} = ::hercules_rt::__{}_alloc(self.backing_size_{});\n", device.name(), device.name(), device.name())?; - write!(w, " }}\n")?; + write!(w, "}}")?; } for idx in 0..func.param_types.len() { if !self.module.types[func.param_types[idx].idx()].is_primitive() { write!( w, - " let p{} = ::hercules_rt::__RawPtrSendSync(p{}.__ptr());\n", + "let p{} = ::hercules_rt::__RawPtrSendSync(p{}.__ptr());", idx, idx )?; } } - write!(w, " let ret = {}(", func.name)?; + write!(w, "let ret = {}(", func.name)?; for (device, _) in self.backing_allocations[&self.func_id].iter() { write!( w, @@ -901,9 +899,9 @@ impl<'a> RTContext<'a> { for idx in 0..func.param_types.len() { write!(w, "p{}, ", idx)?; } - write!(w, ").await;\n")?; + write!(w, ").await;")?; if self.module.types[func.return_type.idx()].is_primitive() { - write!(w, " ret\n")?; + write!(w, " ret")?; } else { let device = match return_device { Some(Device::LLVM) => "CPU", @@ -913,28 +911,28 @@ impl<'a> RTContext<'a> { let mutability = if return_mut { "Mut" } else { "" }; write!( w, - " ::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)\n", + "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)", device, mutability, self.codegen_type_size(func.return_type) )?; } - write!(w, " }}\n }}\n")?; + write!(w, "}}}}")?; write!( w, - "}}\nimpl Drop for HerculesRunner_{} {{\n #[allow(unused_unsafe)]\n fn drop(&mut self) {{\n unsafe {{\n", + "}}impl Drop for HerculesRunner_{} {{#[allow(unused_unsafe)]fn drop(&mut self) {{unsafe {{", func.name )?; for (device, _) in self.backing_allocations[&self.func_id].iter() { write!( w, - " ::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});\n", + "::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});", device.name(), device.name(), device.name() )?; } - write!(w, " }}\n }}\n}}\n")?; + write!(w, "}}}}}}")?; Ok(()) } diff --git a/juno_scheduler/Cargo.toml b/juno_scheduler/Cargo.toml index 26055b03..03a18c83 100644 --- a/juno_scheduler/Cargo.toml +++ b/juno_scheduler/Cargo.toml @@ -17,6 +17,8 @@ cfgrammar = "0.13" lrlex = "0.13" lrpar = "0.13" tempfile = "*" +prettyplease = "0.2.29" +syn = { version = "2.0.96", features = ["full"] } hercules_cg = { path = "../hercules_cg" } hercules_ir = { path = "../hercules_ir" } hercules_opt = { path = "../hercules_opt" } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 9478eb9b..1bb3e536 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -753,8 +753,17 @@ impl PassManager { } println!("{}", llvm_ir); println!("{}", cuda_ir); + let rust_rt = prettyplease::unparse(&syn::parse_file(&rust_rt).unwrap()); println!("{}", rust_rt); + // Write the Rust runtime into a file. + let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); + println!("{}", output_rt); + let mut file = + File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file."); + file.write_all(rust_rt.as_bytes()) + .expect("PANIC: Unable to write output Rust runtime file contents."); + let output_archive = format!("{}/lib{}.a", output_dir, module_name); println!("{}", output_archive); @@ -844,14 +853,6 @@ impl PassManager { ); } - // Write the Rust runtime into a file. - let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); - println!("{}", output_rt); - let mut file = - File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file."); - file.write_all(rust_rt.as_bytes()) - .expect("PANIC: Unable to write output Rust runtime file contents."); - Ok(()) } } -- GitLab From b870d03bb90785237b2e82c1057c331c5e907f50 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 5 Feb 2025 10:21:38 -0600 Subject: [PATCH 3/7] Some refactoring in RT backend --- hercules_cg/src/fork_tree.rs | 27 +++--- hercules_cg/src/rt.rs | 121 ++++++++++++++++++-------- hercules_ir/src/fork_join_analysis.rs | 9 ++ juno_scheduler/src/pm.rs | 17 ++-- 4 files changed, 120 insertions(+), 54 deletions(-) diff --git a/hercules_cg/src/fork_tree.rs b/hercules_cg/src/fork_tree.rs index c048f7e3..5bdcdf62 100644 --- a/hercules_cg/src/fork_tree.rs +++ b/hercules_cg/src/fork_tree.rs @@ -3,11 +3,14 @@ use std::collections::{HashMap, HashSet}; use crate::*; /* - * Construct a map from fork node to all control nodes (including itself) satisfying: - * a) domination by F - * b) no domination by F's join - * c) no domination by any other fork that's also dominated by F, where we do count self-domination - * Here too we include the non-fork start node, as key for all controls outside any fork. + * Construct a map from fork node to all control nodes (including itself) + * satisfying: + * 1. Dominated by the fork. + * 2. Not dominated by the fork's join. + * 3. Not dominated by any other fork that's also dominated by the fork, where + * we do count self-domination. + * We include the non-fork start node as the key for all control nodes outside + * any fork. */ pub fn fork_control_map( fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>, @@ -23,11 +26,14 @@ pub fn fork_control_map( fork_control_map } -/* Construct a map from each fork node F to all forks satisfying: - * a) domination by F - * b) no domination by F's join - * c) no domination by any other fork that's also dominated by F, where we don't count self-domination - * Note that the fork_tree also includes the non-fork start node, as unique root node. +/* + * Construct a map from fork node to all fork nodes (including itself) + * satisfying: + * 1. Dominated by the fork. + * 2. Not dominated by the fork's join. + * 3. Not dominated by any other fork that's also dominated by the fork, where + * we do count self-domination. + * Note that the fork tree also includes the start node as the unique root node. */ pub fn fork_tree( function: &Function, @@ -44,5 +50,6 @@ pub fn fork_tree( .insert(*control); } } + fork_tree.entry(NodeID::new(0)).or_default(); fork_tree } diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 34c5acda..ea5c46b7 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Error, Write}; use std::iter::zip; @@ -78,6 +78,9 @@ pub fn rt_codegen<W: Write>( module: &Module, typing: &Vec<TypeID>, control_subgraph: &Subgraph, + fork_control_map: &HashMap<NodeID, HashSet<NodeID>>, + fork_tree: &HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, collection_objects: &CollectionObjects, callgraph: &CallGraph, devices: &Vec<Device>, @@ -91,6 +94,9 @@ pub fn rt_codegen<W: Write>( module, typing, control_subgraph, + fork_control_map, + fork_tree, + nodes_in_fork_joins, collection_objects, callgraph, devices, @@ -106,6 +112,9 @@ struct RTContext<'a> { module: &'a Module, typing: &'a Vec<TypeID>, control_subgraph: &'a Subgraph, + fork_control_map: &'a HashMap<NodeID, HashSet<NodeID>>, + fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, collection_objects: &'a CollectionObjects, callgraph: &'a CallGraph, devices: &'a Vec<Device>, @@ -196,33 +205,10 @@ impl<'a> RTContext<'a> { } write!(w, "}}")?; - // Declare intermediary variables for every value. - for idx in 0..func.nodes.len() { - if func.nodes[idx].is_control() { - continue; - } - write!( - w, - "let mut node_{}: {} = {};", - idx, - self.get_type(self.typing[idx]), - if self.module.types[self.typing[idx].idx()].is_integer() { - "0" - } else if self.module.types[self.typing[idx].idx()].is_float() { - "0.0" - } else { - "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" - } - )?; - } - - // The core executor is a Rust loop. We literally run a "control token" - // as described in the original sea of nodes paper through the basic - // blocks to drive execution. - write!( - w, - "let mut control_token: i8 = 0;loop {{match control_token {{", - )?; + // Set up the root environment for the function. An environment is set + // up for every created task in async closures, and there needs to be a + // root environment corresponding to the root control node (start node). + self.codegen_open_environment(NodeID::new(0), w)?; let mut blocks: BTreeMap<_, _> = (0..func.nodes.len()) .filter(|idx| func.nodes[*idx].is_control()) @@ -237,20 +223,20 @@ impl<'a> RTContext<'a> { } // Emit control flow into basic blocks. - for id in (0..func.nodes.len()).map(NodeID::new) { - if !func.nodes[id.idx()].is_control() { - continue; - } - self.codegen_control_node(id, &mut blocks)?; + let rev_po = self.control_subgraph.rev_po(NodeID::new(0)); + for id in rev_po.iter() { + self.codegen_control_node(*id, &mut blocks)?; } - // Dump the emitted basic blocks. - for (id, block) in blocks { + // Dump the emitted basic blocks. Do this in reverse postorder since + // fork and join nodes open and close environments, respectively. + for id in rev_po.iter() { + let block = &blocks[id]; write!(w, "{} => {{{}}}", id.idx(), block)?; } - // Close the match and loop. - write!(w, "_ => panic!()}}}}")?; + // Close the root environment. + self.codegen_close_environment(w)?; write!(w, "}}")?; Ok(()) } @@ -364,6 +350,16 @@ impl<'a> RTContext<'a> { } } } + Node::ThreadID { control, dimension } => { + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + write!( + block, + "{} = tid_{}_{};", + self.get_value(id), + control.idx(), + dimension + )?; + } Node::Call { control: _, function: callee_id, @@ -743,8 +739,57 @@ impl<'a> RTContext<'a> { } } + fn codegen_open_environment<W: Write>(&self, root: NodeID, w: &mut W) -> Result<(), Error> { + let func = &self.get_func(); + + // Declare intermediary variables for every value in this fork-join (or + // whole function for start) that isn't in any child fork-joins. + for idx in 0..func.nodes.len() { + let id = NodeID::new(idx); + let control = func.nodes[idx].is_control(); + let in_root = self.nodes_in_fork_joins[&root].contains(&id); + let in_child = self.fork_tree[&root] + .iter() + .any(|child| self.nodes_in_fork_joins[&child].contains(&id)); + if control || !in_root || in_child { + continue; + } + + write!( + w, + "let mut node_{}: {} = {};", + idx, + self.get_type(self.typing[idx]), + if self.module.types[self.typing[idx].idx()].is_integer() { + "0" + } else if self.module.types[self.typing[idx].idx()].is_float() { + "0.0" + } else { + "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" + } + )?; + } + + // The core executor is a Rust loop. We literally run a "control token" + // as described in the original sea of nodes paper through the basic + // blocks to drive execution. + write!( + w, + "let mut control_token: i8 = 0;loop {{match control_token {{", + )?; + + Ok(()) + } + + fn codegen_close_environment<W: Write>(&self, w: &mut W) -> Result<(), Error> { + // Close the match and loop. + write!(w, "_ => panic!()}}}}") + } + /* - * Generate a runner object for this function. + * Generate a runner object for this function. The runner object stores + * backing memory for a Hercules function and wraps calls to the Hercules + * function. */ fn codegen_runner_object<W: Write>(&self, w: &mut W) -> Result<(), Error> { // Figure out the devices for the parameters and the return value if diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index ad3125ba..3fcc6af0 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -67,6 +67,8 @@ pub fn compute_fork_join_nesting( .filter(|id| function.nodes[id.idx()].is_fork()) // where its corresponding join doesn't dominate the control // node (if so, then this control is after the fork-join). + // Check for strict dominance since the join itself should + // be nested in the fork. .filter(|fork_id| !dom.does_prop_dom(fork_join_map[&fork_id], id)) .collect(), ) @@ -174,6 +176,7 @@ pub fn nodes_in_fork_joins( ) -> HashMap<NodeID, HashSet<NodeID>> { let mut result = HashMap::new(); + // Iterate users of fork until reaching corresponding join or reduces. for (fork, join) in fork_join_map { let mut worklist = vec![*fork]; let mut set = HashSet::new(); @@ -196,5 +199,11 @@ pub fn nodes_in_fork_joins( result.insert(*fork, set); } + // Add an entry for the start node containing every node. + result.insert( + NodeID::new(0), + (0..function.nodes.len()).map(NodeID::new).collect(), + ); + result } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 1bb3e536..8f67d664 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -647,14 +647,15 @@ impl PassManager { } fn codegen(mut self, output_dir: String, module_name: String) -> Result<(), SchedulerError> { + self.make_def_uses(); self.make_typing(); self.make_control_subgraphs(); - self.make_collection_objects(); - self.make_callgraph(); - self.make_def_uses(); self.make_fork_join_maps(); self.make_fork_control_maps(); self.make_fork_trees(); + self.make_nodes_in_fork_joins(); + self.make_collection_objects(); + self.make_callgraph(); self.make_devices(); let PassManager { @@ -663,14 +664,15 @@ impl PassManager { constants, dynamic_constants, labels, + def_uses: Some(def_uses), typing: Some(typing), control_subgraphs: Some(control_subgraphs), - collection_objects: Some(collection_objects), - callgraph: Some(callgraph), - def_uses: Some(def_uses), fork_join_maps: Some(fork_join_maps), fork_control_maps: Some(fork_control_maps), fork_trees: Some(fork_trees), + nodes_in_fork_joins: Some(nodes_in_fork_joins), + collection_objects: Some(collection_objects), + callgraph: Some(callgraph), devices: Some(devices), bbs: Some(bbs), node_colors: Some(node_colors), @@ -737,6 +739,9 @@ impl PassManager { &module, &typing[idx], &control_subgraphs[idx], + &fork_control_maps[idx], + &fork_trees[idx], + &nodes_in_fork_joins[idx], &collection_objects, &callgraph, &devices, -- GitLab From 6c9158979bf6fcf906a1490aabc2d144b47612ad Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 5 Feb 2025 17:14:31 -0600 Subject: [PATCH 4/7] Basic emit fork-joins in RT --- hercules_cg/src/rt.rs | 104 ++++++++++++++++++++++++++++++--------- juno_scheduler/src/pm.rs | 1 + 2 files changed, 83 insertions(+), 22 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index ea5c46b7..aa01ff8b 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -78,6 +78,7 @@ pub fn rt_codegen<W: Write>( module: &Module, typing: &Vec<TypeID>, control_subgraph: &Subgraph, + fork_join_map: &HashMap<NodeID, NodeID>, fork_control_map: &HashMap<NodeID, HashSet<NodeID>>, fork_tree: &HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, @@ -89,11 +90,17 @@ pub fn rt_codegen<W: Write>( backing_allocations: &BackingAllocations, w: &mut W, ) -> Result<(), Error> { + let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); let ctx = RTContext { func_id, module, typing, control_subgraph, + fork_join_map, + join_fork_map: &join_fork_map, fork_control_map, fork_tree, nodes_in_fork_joins, @@ -112,6 +119,8 @@ struct RTContext<'a> { module: &'a Module, typing: &'a Vec<TypeID>, control_subgraph: &'a Subgraph, + fork_join_map: &'a HashMap<NodeID, NodeID>, + join_fork_map: &'a HashMap<NodeID, NodeID>, fork_control_map: &'a HashMap<NodeID, HashSet<NodeID>>, fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>, nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, @@ -123,6 +132,13 @@ struct RTContext<'a> { backing_allocations: &'a BackingAllocations, } +#[derive(Debug, Clone, Default)] +struct RustBlock { + prologue: String, + data: String, + epilogue: String, +} + impl<'a> RTContext<'a> { fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> { // If this is an entry function, generate a corresponding runner object @@ -212,7 +228,7 @@ impl<'a> RTContext<'a> { let mut blocks: BTreeMap<_, _> = (0..func.nodes.len()) .filter(|idx| func.nodes[*idx].is_control()) - .map(|idx| (NodeID::new(idx), String::new())) + .map(|idx| (NodeID::new(idx), RustBlock::default())) .collect(); // Emit data flow into basic blocks. @@ -232,7 +248,7 @@ impl<'a> RTContext<'a> { // fork and join nodes open and close environments, respectively. for id in rev_po.iter() { let block = &blocks[id]; - write!(w, "{} => {{{}}}", id.idx(), block)?; + write!(w, "{}{}{}", block.prologue, block.data, block.epilogue)?; } // Close the root environment. @@ -250,7 +266,7 @@ impl<'a> RTContext<'a> { fn codegen_control_node( &self, id: NodeID, - blocks: &mut BTreeMap<NodeID, String>, + blocks: &mut BTreeMap<NodeID, RustBlock>, ) -> Result<(), Error> { let func = &self.get_func(); match func.nodes[id.idx()] { @@ -262,29 +278,67 @@ impl<'a> RTContext<'a> { control: _, selection: _, } => { - let block = &mut blocks.get_mut(&id).unwrap(); + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + write!(prologue, "{} => {{", id.idx())?; + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; let succ = self.control_subgraph.succs(id).next().unwrap(); - write!(block, "control_token = {};", succ.idx())? + write!(epilogue, "control_token = {};}}", succ.idx())?; } // If nodes have two successors - examine the projections to // determine which branch is which, and branch between them. Node::If { control: _, cond } => { - let block = &mut blocks.get_mut(&id).unwrap(); + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + write!(prologue, "{} => {{", id.idx())?; + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); let succ1_is_true = func.nodes[succ1.idx()].try_projection(1).is_some(); write!( - block, - "control_token = if {} {{{}}} else {{{}}};", + epilogue, + "control_token = if {} {{{}}} else {{{}}};}}", self.get_value(cond), if succ1_is_true { succ1 } else { succ2 }.idx(), if succ1_is_true { succ2 } else { succ1 }.idx(), - )? + )?; } Node::Return { control: _, data } => { - let block = &mut blocks.get_mut(&id).unwrap(); - write!(block, "return {};", self.get_value(data))? + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + write!(prologue, "{} => {{", id.idx())?; + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; + write!(epilogue, "return {};}}", self.get_value(data))?; + } + // Fork nodes open a new environment for defining an async closure. + Node::Fork { + control: _, + ref factors, + } => { + // First, set the outer environment control token to the join. + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + let join = self.fork_join_map[&id]; + write!(prologue, "control_token = {};", join.idx())?; + + // Second, emit loops for the thread IDs. + for (idx, factor) in factors.into_iter().enumerate() { + write!(prologue, "for tid_{}_{} in 0..", id.idx(), idx)?; + self.codegen_dynamic_constant(*factor, prologue)?; + write!(prologue, " {{")?; + } + } + // Join nodes close the environment opened by its corresponding + // fork. + Node::Join { control: _ } => { + // First, close the loops emitted by the fork node. + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; + let fork = self.join_fork_map[&id]; + for _ in 0..func.nodes[fork.idx()].try_fork().unwrap().1.len() { + write!(epilogue, "}}")?; + } + + // Second, branch to the successor control node. + let succ = self.control_subgraph.succs(id).next().unwrap(); + write!(epilogue, "{} => {{", id.idx())?; + write!(epilogue, "control_token = {};}}", succ.idx())?; } _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), } @@ -297,16 +351,16 @@ impl<'a> RTContext<'a> { fn codegen_data_node( &self, id: NodeID, - blocks: &mut BTreeMap<NodeID, String>, + blocks: &mut BTreeMap<NodeID, RustBlock>, ) -> Result<(), Error> { let func = &self.get_func(); match func.nodes[id.idx()] { Node::Parameter { index } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; write!(block, "{} = p{};", self.get_value(id), index)? } Node::Constant { id: cons_id } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; write!(block, "{} = ", self.get_value(id))?; let mut size_and_device = None; match self.module.constants[cons_id.idx()] { @@ -351,7 +405,7 @@ impl<'a> RTContext<'a> { } } Node::ThreadID { control, dimension } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; write!( block, "{} = tid_{}_{};", @@ -368,7 +422,7 @@ impl<'a> RTContext<'a> { } => { // The device backends ensure that device functions have the // same interface as AsyncRust functions. - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; write!( block, "{} = {}(", @@ -398,7 +452,7 @@ impl<'a> RTContext<'a> { } } Node::Unary { op, input } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; match op { UnaryOperator::Not => write!( block, @@ -422,7 +476,7 @@ impl<'a> RTContext<'a> { }; } Node::Binary { op, left, right } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; let op = match op { BinaryOperator::Add => "+", BinaryOperator::Sub => "-", @@ -457,7 +511,7 @@ impl<'a> RTContext<'a> { second, third, } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; match op { TernaryOperator::Select => write!( block, @@ -473,7 +527,7 @@ impl<'a> RTContext<'a> { collect, ref indices, } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; let collect_ty = self.typing[collect.idx()]; let out_size = self.codegen_type_size(self.typing[id.idx()]); let offset = self.codegen_index_math(collect_ty, indices)?; @@ -484,7 +538,7 @@ impl<'a> RTContext<'a> { data, ref indices, } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; let collect_ty = self.typing[collect.idx()]; let data_size = self.codegen_type_size(self.typing[data.idx()]); let offset = self.codegen_index_math(collect_ty, indices)?; @@ -770,12 +824,18 @@ impl<'a> RTContext<'a> { )?; } + // Declare Vec for storing futures of fork-joins. + for fork in self.fork_tree[&root].iter() { + write!(w, "let mut fork_{} = vec![];", fork.idx())?; + } + // The core executor is a Rust loop. We literally run a "control token" // as described in the original sea of nodes paper through the basic // blocks to drive execution. write!( w, - "let mut control_token: i8 = 0;loop {{match control_token {{", + "let mut control_token: i8 = {};loop {{match control_token {{", + root.idx(), )?; Ok(()) diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 8f67d664..91883d0e 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -739,6 +739,7 @@ impl PassManager { &module, &typing[idx], &control_subgraphs[idx], + &fork_join_maps[idx], &fork_control_maps[idx], &fork_trees[idx], &nodes_in_fork_joins[idx], -- GitLab From f053dc86911042ccb8f749e218cb43e84476b4a9 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 5 Feb 2025 17:29:42 -0600 Subject: [PATCH 5/7] Progress --- hercules_cg/src/rt.rs | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index aa01ff8b..c71dc0fa 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -324,18 +324,35 @@ impl<'a> RTContext<'a> { self.codegen_dynamic_constant(*factor, prologue)?; write!(prologue, " {{")?; } + + // Third, spawn an async closure and push its future to a Vec. + write!( + prologue, + "fork_{}.push(::async_std::task::spawn(async {{", + id.idx() + )?; + + // Fourth, open a new environment. + self.codegen_open_environment(id, prologue)?; } // Join nodes close the environment opened by its corresponding // fork. Node::Join { control: _ } => { - // First, close the loops emitted by the fork node. + // First, close the fork's environment. let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; + self.codegen_close_environment(epilogue)?; + + // Second, close the async closure and push statement from the + // fork. + write!(epilogue, "}}));")?; + + // Third, close the loops emitted by the fork node. let fork = self.join_fork_map[&id]; for _ in 0..func.nodes[fork.idx()].try_fork().unwrap().1.len() { write!(epilogue, "}}")?; } - // Second, branch to the successor control node. + // Fourth, branch to the successor control node. let succ = self.control_subgraph.succs(id).next().unwrap(); write!(epilogue, "{} => {{", id.idx())?; write!(epilogue, "control_token = {};}}", succ.idx())?; -- GitLab From 260bf5ccfb66ebc208229b778190a866fd460342 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 5 Feb 2025 21:31:12 -0600 Subject: [PATCH 6/7] It works, but doesn't sync properly yet --- hercules_cg/src/rt.rs | 89 +++++++++++++++++++++--- juno_samples/fork_join_tests/src/cpu.sch | 7 +- juno_scheduler/src/pm.rs | 6 +- 3 files changed, 86 insertions(+), 16 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index c71dc0fa..e06906a6 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -76,6 +76,7 @@ use crate::*; pub fn rt_codegen<W: Write>( func_id: FunctionID, module: &Module, + def_use: &ImmutableDefUseMap, typing: &Vec<TypeID>, control_subgraph: &Subgraph, fork_join_map: &HashMap<NodeID, NodeID>, @@ -97,6 +98,7 @@ pub fn rt_codegen<W: Write>( let ctx = RTContext { func_id, module, + def_use, typing, control_subgraph, fork_join_map, @@ -117,6 +119,7 @@ pub fn rt_codegen<W: Write>( struct RTContext<'a> { func_id: FunctionID, module: &'a Module, + def_use: &'a ImmutableDefUseMap, typing: &'a Vec<TypeID>, control_subgraph: &'a Subgraph, fork_join_map: &'a HashMap<NodeID, NodeID>, @@ -313,48 +316,97 @@ impl<'a> RTContext<'a> { control: _, ref factors, } => { - // First, set the outer environment control token to the join. + assert!(func.schedules[id.idx()].contains(&Schedule::ParallelFork)); + + // Set the outer environment control token to the join. let prologue = &mut blocks.get_mut(&id).unwrap().prologue; let join = self.fork_join_map[&id]; - write!(prologue, "control_token = {};", join.idx())?; + write!( + prologue, + "{} => {{control_token = {};", + id.idx(), + join.idx() + )?; - // Second, emit loops for the thread IDs. + // Emit loops for the thread IDs. for (idx, factor) in factors.into_iter().enumerate() { write!(prologue, "for tid_{}_{} in 0..", id.idx(), idx)?; self.codegen_dynamic_constant(*factor, prologue)?; write!(prologue, " {{")?; } - // Third, spawn an async closure and push its future to a Vec. + // Spawn an async closure and push its future to a Vec. write!( prologue, - "fork_{}.push(::async_std::task::spawn(async {{", + "fork_{}.push(::async_std::task::spawn(async move {{", id.idx() )?; - // Fourth, open a new environment. + // Open a new environment. self.codegen_open_environment(id, prologue)?; + + // Open the branch inside the async closure for the fork. + let succ = self.control_subgraph.succs(id).next().unwrap(); + write!( + prologue, + "{} => {{control_token = {};", + id.idx(), + succ.idx() + )?; + + // Close the branch for the fork inside the async closure. + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; + write!(epilogue, "}}")?; } // Join nodes close the environment opened by its corresponding // fork. Node::Join { control: _ } => { - // First, close the fork's environment. + // Emit the branch for the join inside the async closure. + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + write!(prologue, "{} => {{", id.idx())?; + + // Close the branch inside the async closure. let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; + write!(epilogue, "return;}}")?; + + // Close the fork's environment. self.codegen_close_environment(epilogue)?; - // Second, close the async closure and push statement from the + // Close the async closure and push statement from the // fork. write!(epilogue, "}}));")?; - // Third, close the loops emitted by the fork node. + // Close the loops emitted by the fork node. let fork = self.join_fork_map[&id]; for _ in 0..func.nodes[fork.idx()].try_fork().unwrap().1.len() { write!(epilogue, "}}")?; } - // Fourth, branch to the successor control node. + // Close the branch for the fork outside the async closure. + write!(epilogue, "}}")?; + + // Open the branch in the surrounding context for the join. let succ = self.control_subgraph.succs(id).next().unwrap(); write!(epilogue, "{} => {{", id.idx())?; + + // Emit the assignments to the reduce variables in the + // surrounding context. It's very unfortunate that we have to do + // it while lowering the join node (rather than the reduce nodes + // themselves), but this is the only place we can put these + // assignments in the correct control location. + for user in self.def_use.get_users(id) { + if let Some((_, init, _)) = func.nodes[user.idx()].try_reduce() { + write!( + epilogue, + "{} = {};", + self.get_value(*user), + self.get_value(init) + )?; + } + } + + // Branch to the successor control node in the surrounding + // context, and close the branch for the join. write!(epilogue, "control_token = {};}}", succ.idx())?; } _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), @@ -431,6 +483,15 @@ impl<'a> RTContext<'a> { dimension )?; } + Node::Reduce { + control: _, + init, + reduct: _, + } => { + assert!(func.schedules[id.idx()].contains(&Schedule::ParallelReduce)); + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; + write!(block, "{} = {};", self.get_value(id), self.get_value(init))?; + } Node::Call { control: _, function: callee_id, @@ -822,7 +883,13 @@ impl<'a> RTContext<'a> { let in_child = self.fork_tree[&root] .iter() .any(|child| self.nodes_in_fork_joins[&child].contains(&id)); - if control || !in_root || in_child { + let is_reduce_on_child = func.nodes[idx] + .try_reduce() + .map(|(control, _, _)| { + self.fork_tree[&root].contains(&self.join_fork_map[&control]) + }) + .unwrap_or(false); + if (control || !in_root || in_child) && !is_reduce_on_child { continue; } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 38010004..5665e1fa 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -44,12 +44,11 @@ dce(*); fork-tile[32, 0, true](test6@loop); let out = fork-split(test6@loop); -//let out = outline(out.test6.fj1); -let out = auto-outline(test6); -cpu(out.test6); +let out = outline(out.test6.fj1); +cpu(out); ip-sroa(*); sroa(*); -unforkify(out.test6); +unforkify(out); dce(*); ccp(*); gvn(*); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 91883d0e..901361c6 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -737,6 +737,7 @@ impl PassManager { Device::AsyncRust => rt_codegen( FunctionID::new(idx), &module, + &def_uses[idx], &typing[idx], &control_subgraphs[idx], &fork_join_maps[idx], @@ -759,7 +760,10 @@ impl PassManager { } println!("{}", llvm_ir); println!("{}", cuda_ir); - let rust_rt = prettyplease::unparse(&syn::parse_file(&rust_rt).unwrap()); + let rust_rt = prettyplease::unparse( + &syn::parse_file(&rust_rt) + .expect(&format!("PANIC: Malformed RT Rust code: {}", rust_rt)), + ); println!("{}", rust_rt); // Write the Rust runtime into a file. -- GitLab From a02f005b60b058a9fb0ed7de7be3ded3963bbd51 Mon Sep 17 00:00:00 2001 From: Russel Arbore <russel.jma@gmail.com> Date: Wed, 5 Feb 2025 21:58:07 -0600 Subject: [PATCH 7/7] Properly await on futures --- hercules_cg/src/rt.rs | 101 +++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index e06906a6..2c5f7c35 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -300,7 +300,7 @@ impl<'a> RTContext<'a> { write!( epilogue, "control_token = if {} {{{}}} else {{{}}};}}", - self.get_value(cond), + self.get_value(cond, id), if succ1_is_true { succ1 } else { succ2 }.idx(), if succ1_is_true { succ2 } else { succ1 }.idx(), )?; @@ -309,7 +309,7 @@ impl<'a> RTContext<'a> { let prologue = &mut blocks.get_mut(&id).unwrap().prologue; write!(prologue, "{} => {{", id.idx())?; let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; - write!(epilogue, "return {};}}", self.get_value(data))?; + write!(epilogue, "return {};}}", self.get_value(data, id))?; } // Fork nodes open a new environment for defining an async closure. Node::Fork { @@ -399,8 +399,8 @@ impl<'a> RTContext<'a> { write!( epilogue, "{} = {};", - self.get_value(*user), - self.get_value(init) + self.get_value(*user, id), + self.get_value(init, id) )?; } } @@ -423,14 +423,15 @@ impl<'a> RTContext<'a> { blocks: &mut BTreeMap<NodeID, RustBlock>, ) -> Result<(), Error> { let func = &self.get_func(); + let bb = self.bbs.0[id.idx()]; match func.nodes[id.idx()] { Node::Parameter { index } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; - write!(block, "{} = p{};", self.get_value(id), index)? + let block = &mut blocks.get_mut(&bb).unwrap().data; + write!(block, "{} = p{};", self.get_value(id, bb), index)? } Node::Constant { id: cons_id } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; - write!(block, "{} = ", self.get_value(id))?; + let block = &mut blocks.get_mut(&bb).unwrap().data; + write!(block, "{} = ", self.get_value(id, bb))?; let mut size_and_device = None; match self.module.constants[cons_id.idx()] { Constant::Boolean(val) => write!(block, "{}bool", val)?, @@ -467,30 +468,28 @@ impl<'a> RTContext<'a> { block, "::hercules_rt::__{}_zero_mem({}.0, {} as usize);", device.name(), - self.get_value(id), + self.get_value(id, bb), size )?; } } } Node::ThreadID { control, dimension } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; + let block = &mut blocks.get_mut(&bb).unwrap().data; write!( block, "{} = tid_{}_{};", - self.get_value(id), + self.get_value(id, bb), control.idx(), dimension )?; } Node::Reduce { control: _, - init, + init: _, reduct: _, } => { assert!(func.schedules[id.idx()].contains(&Schedule::ParallelReduce)); - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; - write!(block, "{} = {};", self.get_value(id), self.get_value(init))?; } Node::Call { control: _, @@ -500,11 +499,11 @@ impl<'a> RTContext<'a> { } => { // The device backends ensure that device functions have the // same interface as AsyncRust functions. - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; + let block = &mut blocks.get_mut(&bb).unwrap().data; write!( block, "{} = {}(", - self.get_value(id), + self.get_value(id, bb), self.module.functions[callee_id.idx()].name )?; for (device, offset) in self.backing_allocations[&self.func_id] @@ -520,7 +519,7 @@ impl<'a> RTContext<'a> { write!(block, ", ")?; } for arg in args { - write!(block, "{}, ", self.get_value(*arg))?; + write!(block, "{}, ", self.get_value(*arg, bb))?; } let device = self.devices[callee_id.idx()]; if device == Device::AsyncRust { @@ -530,31 +529,31 @@ impl<'a> RTContext<'a> { } } Node::Unary { op, input } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; + let block = &mut blocks.get_mut(&bb).unwrap().data; match op { UnaryOperator::Not => write!( block, "{} = !{};", - self.get_value(id), - self.get_value(input) + self.get_value(id, bb), + self.get_value(input, bb) )?, UnaryOperator::Neg => write!( block, "{} = -{};", - self.get_value(id), - self.get_value(input) + self.get_value(id, bb), + self.get_value(input, bb) )?, UnaryOperator::Cast(ty) => write!( block, "{} = {} as {};", - self.get_value(id), - self.get_value(input), + self.get_value(id, bb), + self.get_value(input, bb), self.get_type(ty) )?, }; } Node::Binary { op, left, right } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; + let block = &mut blocks.get_mut(&bb).unwrap().data; let op = match op { BinaryOperator::Add => "+", BinaryOperator::Sub => "-", @@ -577,10 +576,10 @@ impl<'a> RTContext<'a> { write!( block, "{} = {} {} {};", - self.get_value(id), - self.get_value(left), + self.get_value(id, bb), + self.get_value(left, bb), op, - self.get_value(right) + self.get_value(right, bb) )?; } Node::Ternary { @@ -589,15 +588,15 @@ impl<'a> RTContext<'a> { second, third, } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; + let block = &mut blocks.get_mut(&bb).unwrap().data; match op { TernaryOperator::Select => write!( block, "{} = if {} {{{}}} else {{{}}};", - self.get_value(id), - self.get_value(first), - self.get_value(second), - self.get_value(third), + self.get_value(id, bb), + self.get_value(first, bb), + self.get_value(second, bb), + self.get_value(third, bb), )?, }; } @@ -605,10 +604,10 @@ impl<'a> RTContext<'a> { collect, ref indices, } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; + let block = &mut blocks.get_mut(&bb).unwrap().data; let collect_ty = self.typing[collect.idx()]; let out_size = self.codegen_type_size(self.typing[id.idx()]); - let offset = self.codegen_index_math(collect_ty, indices)?; + let offset = self.codegen_index_math(collect_ty, indices, bb)?; todo!(); } Node::Write { @@ -616,10 +615,10 @@ impl<'a> RTContext<'a> { data, ref indices, } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap().data; + let block = &mut blocks.get_mut(&bb).unwrap().data; let collect_ty = self.typing[collect.idx()]; let data_size = self.codegen_type_size(self.typing[data.idx()]); - let offset = self.codegen_index_math(collect_ty, indices)?; + let offset = self.codegen_index_math(collect_ty, indices, bb)?; todo!(); } _ => panic!( @@ -734,6 +733,7 @@ impl<'a> RTContext<'a> { &self, mut collect_ty: TypeID, indices: &[Index], + bb: NodeID, ) -> Result<String, Error> { let mut acc_offset = "0".to_string(); for index in indices { @@ -784,7 +784,7 @@ impl<'a> RTContext<'a> { // ((0 * s1 + p1) * s2 + p2) * s3 + p3 ... let elem_size = self.codegen_type_size(elem); for (p, s) in zip(pos, dims) { - let p = self.get_value(*p); + let p = self.get_value(*p, bb); acc_offset = format!("{} * ", acc_offset); self.codegen_dynamic_constant(*s, &mut acc_offset)?; acc_offset = format!("({} + {})", acc_offset, p); @@ -895,7 +895,8 @@ impl<'a> RTContext<'a> { write!( w, - "let mut node_{}: {} = {};", + "let mut {}_{}: {} = {};", + if is_reduce_on_child { "reduce" } else { "node" }, idx, self.get_type(self.typing[idx]), if self.module.types[self.typing[idx].idx()].is_integer() { @@ -1129,8 +1130,26 @@ impl<'a> RTContext<'a> { &self.module.functions[self.func_id.idx()] } - fn get_value(&self, id: NodeID) -> String { - format!("node_{}", id.idx()) + fn get_value(&self, id: NodeID, bb: NodeID) -> String { + let func = self.get_func(); + if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() + && control == bb + { + format!("reduce_{}", id.idx()) + } else if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() + && let fork = self.join_fork_map[&control] + && !self.nodes_in_fork_joins[&fork].contains(&bb) + { + // Before using the value of a reduction outside the fork-join, + // await the futures. + format!( + "{{for fut in fork_{}.drain(..) {{ fut.await; }}; reduce_{}}}", + fork.idx(), + id.idx() + ) + } else { + format!("node_{}", id.idx()) + } } fn get_type(&self, id: TypeID) -> &'static str { -- GitLab