diff --git a/Cargo.lock b/Cargo.lock index af7902c692bbeb11c941025593928a110d91e671..e761361bee6d76f6b08d0680d4be986176121ad2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1222,7 +1222,9 @@ dependencies = [ "lrlex", "lrpar", "postcard", + "prettyplease", "serde", + "syn 2.0.96", "tempfile", ] @@ -1739,6 +1741,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +dependencies = [ + "proc-macro2", + "syn 2.0.96", +] + [[package]] name = "proc-macro-error" version = "1.0.4" diff --git a/hercules_cg/src/fork_tree.rs b/hercules_cg/src/fork_tree.rs index c048f7e388a7eb4f73c0c866896fadfa5d2556a1..5bdcdf62c8ee29f6f766eeb7439960e24c559830 100644 --- a/hercules_cg/src/fork_tree.rs +++ b/hercules_cg/src/fork_tree.rs @@ -3,11 +3,14 @@ use std::collections::{HashMap, HashSet}; use crate::*; /* - * Construct a map from fork node to all control nodes (including itself) satisfying: - * a) domination by F - * b) no domination by F's join - * c) no domination by any other fork that's also dominated by F, where we do count self-domination - * Here too we include the non-fork start node, as key for all controls outside any fork. + * Construct a map from fork node to all control nodes (including itself) + * satisfying: + * 1. Dominated by the fork. + * 2. Not dominated by the fork's join. + * 3. Not dominated by any other fork that's also dominated by the fork, where + * we do count self-domination. + * We include the non-fork start node as the key for all control nodes outside + * any fork. */ pub fn fork_control_map( fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>, @@ -23,11 +26,14 @@ pub fn fork_control_map( fork_control_map } -/* Construct a map from each fork node F to all forks satisfying: - * a) domination by F - * b) no domination by F's join - * c) no domination by any other fork that's also dominated by F, where we don't count self-domination - * Note that the fork_tree also includes the non-fork start node, as unique root node. +/* + * Construct a map from fork node to all fork nodes (including itself) + * satisfying: + * 1. Dominated by the fork. + * 2. Not dominated by the fork's join. + * 3. Not dominated by any other fork that's also dominated by the fork, where + * we do count self-domination. + * Note that the fork tree also includes the start node as the unique root node. */ pub fn fork_tree( function: &Function, @@ -44,5 +50,6 @@ pub fn fork_tree( .insert(*control); } } + fork_tree.entry(NodeID::new(0)).or_default(); fork_tree } diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 35334a147b310dabfeeb7f9f7560d361fbaefad0..2c5f7c351f28e75c0fff2483742e76c3660bcacb 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Error, Write}; use std::iter::zip; @@ -76,8 +76,13 @@ use crate::*; pub fn rt_codegen<W: Write>( func_id: FunctionID, module: &Module, + def_use: &ImmutableDefUseMap, typing: &Vec<TypeID>, control_subgraph: &Subgraph, + fork_join_map: &HashMap<NodeID, NodeID>, + fork_control_map: &HashMap<NodeID, HashSet<NodeID>>, + fork_tree: &HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, collection_objects: &CollectionObjects, callgraph: &CallGraph, devices: &Vec<Device>, @@ -86,11 +91,21 @@ pub fn rt_codegen<W: Write>( backing_allocations: &BackingAllocations, w: &mut W, ) -> Result<(), Error> { + let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map + .into_iter() + .map(|(fork, join)| (*join, *fork)) + .collect(); let ctx = RTContext { func_id, module, + def_use, typing, control_subgraph, + fork_join_map, + join_fork_map: &join_fork_map, + fork_control_map, + fork_tree, + nodes_in_fork_joins, collection_objects, callgraph, devices, @@ -104,8 +119,14 @@ pub fn rt_codegen<W: Write>( struct RTContext<'a> { func_id: FunctionID, module: &'a Module, + def_use: &'a ImmutableDefUseMap, typing: &'a Vec<TypeID>, control_subgraph: &'a Subgraph, + fork_join_map: &'a HashMap<NodeID, NodeID>, + join_fork_map: &'a HashMap<NodeID, NodeID>, + fork_control_map: &'a HashMap<NodeID, HashSet<NodeID>>, + fork_tree: &'a HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, collection_objects: &'a CollectionObjects, callgraph: &'a CallGraph, devices: &'a Vec<Device>, @@ -114,6 +135,13 @@ struct RTContext<'a> { backing_allocations: &'a BackingAllocations, } +#[derive(Debug, Clone, Default)] +struct RustBlock { + prologue: String, + data: String, + epilogue: String, +} + impl<'a> RTContext<'a> { fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> { // If this is an entry function, generate a corresponding runner object @@ -126,7 +154,7 @@ impl<'a> RTContext<'a> { // Dump the function signature. write!( w, - "#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]\nasync unsafe fn {}(", + "#[allow(unused_assignments,unused_variables,unused_mut,unused_parens,unused_unsafe,non_snake_case)]async unsafe fn {}(", func.name )?; let mut first_param = true; @@ -137,7 +165,11 @@ impl<'a> RTContext<'a> { } else { write!(w, ", ")?; } - write!(w, "backing_{}: *mut u8", device.name())?; + write!( + w, + "backing_{}: ::hercules_rt::__RawPtrSendSync", + device.name() + )?; } // The second set of parameters are dynamic constants. for idx in 0..func.num_dynamic_constants { @@ -157,20 +189,20 @@ impl<'a> RTContext<'a> { } write!(w, "p{}: {}", idx, self.get_type(func.param_types[idx]))?; } - write!(w, ") -> {} {{\n", self.get_type(func.return_type))?; + write!(w, ") -> {} {{", self.get_type(func.return_type))?; // Dump signatures for called device functions. - write!(w, " extern \"C\" {{\n")?; + write!(w, "extern \"C\" {{")?; for callee_id in self.callgraph.get_callees(self.func_id) { if self.devices[callee_id.idx()] == Device::AsyncRust { continue; } let callee = &self.module.functions[callee_id.idx()]; - write!(w, " fn {}(", callee.name)?; + write!(w, "fn {}(", callee.name)?; let mut first_param = true; if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { first_param = false; - write!(w, "backing: *mut u8")?; + write!(w, "backing: ::hercules_rt::__RawPtrSendSync")?; } for idx in 0..callee.num_dynamic_constants { if first_param { @@ -188,41 +220,18 @@ impl<'a> RTContext<'a> { } write!(w, "p{}: {}", idx, self.get_type(*ty))?; } - write!(w, ") -> {};\n", self.get_type(callee.return_type))?; + write!(w, ") -> {};", self.get_type(callee.return_type))?; } - write!(w, " }}\n")?; + write!(w, "}}")?; - // Declare intermediary variables for every value. - for idx in 0..func.nodes.len() { - if func.nodes[idx].is_control() { - continue; - } - write!( - w, - " let mut node_{}: {} = {};\n", - idx, - self.get_type(self.typing[idx]), - if self.module.types[self.typing[idx].idx()].is_integer() { - "0" - } else if self.module.types[self.typing[idx].idx()].is_float() { - "0.0" - } else { - "::core::ptr::null_mut()" - } - )?; - } - - // The core executor is a Rust loop. We literally run a "control token" - // as described in the original sea of nodes paper through the basic - // blocks to drive execution. - write!( - w, - " let mut control_token: i8 = 0;\n loop {{\n match control_token {{\n", - )?; + // Set up the root environment for the function. An environment is set + // up for every created task in async closures, and there needs to be a + // root environment corresponding to the root control node (start node). + self.codegen_open_environment(NodeID::new(0), w)?; let mut blocks: BTreeMap<_, _> = (0..func.nodes.len()) .filter(|idx| func.nodes[*idx].is_control()) - .map(|idx| (NodeID::new(idx), String::new())) + .map(|idx| (NodeID::new(idx), RustBlock::default())) .collect(); // Emit data flow into basic blocks. @@ -233,26 +242,21 @@ impl<'a> RTContext<'a> { } // Emit control flow into basic blocks. - for id in (0..func.nodes.len()).map(NodeID::new) { - if !func.nodes[id.idx()].is_control() { - continue; - } - self.codegen_control_node(id, &mut blocks)?; + let rev_po = self.control_subgraph.rev_po(NodeID::new(0)); + for id in rev_po.iter() { + self.codegen_control_node(*id, &mut blocks)?; } - // Dump the emitted basic blocks. - for (id, block) in blocks { - write!( - w, - " {} => {{\n{} }}\n", - id.idx(), - block - )?; + // Dump the emitted basic blocks. Do this in reverse postorder since + // fork and join nodes open and close environments, respectively. + for id in rev_po.iter() { + let block = &blocks[id]; + write!(w, "{}{}{}", block.prologue, block.data, block.epilogue)?; } - // Close the match and loop. - write!(w, " _ => panic!()\n }}\n }}\n")?; - write!(w, "}}\n")?; + // Close the root environment. + self.codegen_close_environment(w)?; + write!(w, "}}")?; Ok(()) } @@ -265,7 +269,7 @@ impl<'a> RTContext<'a> { fn codegen_control_node( &self, id: NodeID, - blocks: &mut BTreeMap<NodeID, String>, + blocks: &mut BTreeMap<NodeID, RustBlock>, ) -> Result<(), Error> { let func = &self.get_func(); match func.nodes[id.idx()] { @@ -277,29 +281,133 @@ impl<'a> RTContext<'a> { control: _, selection: _, } => { - let block = &mut blocks.get_mut(&id).unwrap(); + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + write!(prologue, "{} => {{", id.idx())?; + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; let succ = self.control_subgraph.succs(id).next().unwrap(); - write!(block, " control_token = {};\n", succ.idx())? + write!(epilogue, "control_token = {};}}", succ.idx())?; } // If nodes have two successors - examine the projections to // determine which branch is which, and branch between them. Node::If { control: _, cond } => { - let block = &mut blocks.get_mut(&id).unwrap(); + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + write!(prologue, "{} => {{", id.idx())?; + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); let succ1_is_true = func.nodes[succ1.idx()].try_projection(1).is_some(); write!( - block, - " control_token = if {} {{ {} }} else {{ {} }};\n", - self.get_value(cond), + epilogue, + "control_token = if {} {{{}}} else {{{}}};}}", + self.get_value(cond, id), if succ1_is_true { succ1 } else { succ2 }.idx(), if succ1_is_true { succ2 } else { succ1 }.idx(), - )? + )?; } Node::Return { control: _, data } => { - let block = &mut blocks.get_mut(&id).unwrap(); - write!(block, " return {};\n", self.get_value(data))? + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + write!(prologue, "{} => {{", id.idx())?; + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; + write!(epilogue, "return {};}}", self.get_value(data, id))?; + } + // Fork nodes open a new environment for defining an async closure. + Node::Fork { + control: _, + ref factors, + } => { + assert!(func.schedules[id.idx()].contains(&Schedule::ParallelFork)); + + // Set the outer environment control token to the join. + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + let join = self.fork_join_map[&id]; + write!( + prologue, + "{} => {{control_token = {};", + id.idx(), + join.idx() + )?; + + // Emit loops for the thread IDs. + for (idx, factor) in factors.into_iter().enumerate() { + write!(prologue, "for tid_{}_{} in 0..", id.idx(), idx)?; + self.codegen_dynamic_constant(*factor, prologue)?; + write!(prologue, " {{")?; + } + + // Spawn an async closure and push its future to a Vec. + write!( + prologue, + "fork_{}.push(::async_std::task::spawn(async move {{", + id.idx() + )?; + + // Open a new environment. + self.codegen_open_environment(id, prologue)?; + + // Open the branch inside the async closure for the fork. + let succ = self.control_subgraph.succs(id).next().unwrap(); + write!( + prologue, + "{} => {{control_token = {};", + id.idx(), + succ.idx() + )?; + + // Close the branch for the fork inside the async closure. + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; + write!(epilogue, "}}")?; + } + // Join nodes close the environment opened by its corresponding + // fork. + Node::Join { control: _ } => { + // Emit the branch for the join inside the async closure. + let prologue = &mut blocks.get_mut(&id).unwrap().prologue; + write!(prologue, "{} => {{", id.idx())?; + + // Close the branch inside the async closure. + let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; + write!(epilogue, "return;}}")?; + + // Close the fork's environment. + self.codegen_close_environment(epilogue)?; + + // Close the async closure and push statement from the + // fork. + write!(epilogue, "}}));")?; + + // Close the loops emitted by the fork node. + let fork = self.join_fork_map[&id]; + for _ in 0..func.nodes[fork.idx()].try_fork().unwrap().1.len() { + write!(epilogue, "}}")?; + } + + // Close the branch for the fork outside the async closure. + write!(epilogue, "}}")?; + + // Open the branch in the surrounding context for the join. + let succ = self.control_subgraph.succs(id).next().unwrap(); + write!(epilogue, "{} => {{", id.idx())?; + + // Emit the assignments to the reduce variables in the + // surrounding context. It's very unfortunate that we have to do + // it while lowering the join node (rather than the reduce nodes + // themselves), but this is the only place we can put these + // assignments in the correct control location. + for user in self.def_use.get_users(id) { + if let Some((_, init, _)) = func.nodes[user.idx()].try_reduce() { + write!( + epilogue, + "{} = {};", + self.get_value(*user, id), + self.get_value(init, id) + )?; + } + } + + // Branch to the successor control node in the surrounding + // context, and close the branch for the join. + write!(epilogue, "control_token = {};}}", succ.idx())?; } _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), } @@ -312,22 +420,18 @@ impl<'a> RTContext<'a> { fn codegen_data_node( &self, id: NodeID, - blocks: &mut BTreeMap<NodeID, String>, + blocks: &mut BTreeMap<NodeID, RustBlock>, ) -> Result<(), Error> { let func = &self.get_func(); + let bb = self.bbs.0[id.idx()]; match func.nodes[id.idx()] { Node::Parameter { index } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); - write!( - block, - " {} = p{};\n", - self.get_value(id), - index - )? + let block = &mut blocks.get_mut(&bb).unwrap().data; + write!(block, "{} = p{};", self.get_value(id, bb), index)? } Node::Constant { id: cons_id } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); - write!(block, " {} = ", self.get_value(id))?; + let block = &mut blocks.get_mut(&bb).unwrap().data; + write!(block, "{} = ", self.get_value(id, bb))?; let mut size_and_device = None; match self.module.constants[cons_id.idx()] { Constant::Boolean(val) => write!(block, "{}bool", val)?, @@ -357,19 +461,36 @@ impl<'a> RTContext<'a> { size_and_device = Some((self.codegen_type_size(ty), device)); } } - write!(block, ";\n")?; + write!(block, ";")?; if !func.schedules[id.idx()].contains(&Schedule::NoResetConstant) { if let Some((size, device)) = size_and_device { write!( block, - " ::hercules_rt::__{}_zero_mem({}, {} as usize);\n", + "::hercules_rt::__{}_zero_mem({}.0, {} as usize);", device.name(), - self.get_value(id), + self.get_value(id, bb), size )?; } } } + Node::ThreadID { control, dimension } => { + let block = &mut blocks.get_mut(&bb).unwrap().data; + write!( + block, + "{} = tid_{}_{};", + self.get_value(id, bb), + control.idx(), + dimension + )?; + } + Node::Reduce { + control: _, + init: _, + reduct: _, + } => { + assert!(func.schedules[id.idx()].contains(&Schedule::ParallelReduce)); + } Node::Call { control: _, function: callee_id, @@ -378,11 +499,11 @@ impl<'a> RTContext<'a> { } => { // The device backends ensure that device functions have the // same interface as AsyncRust functions. - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&bb).unwrap().data; write!( block, - " {} = {}(", - self.get_value(id), + "{} = {}(", + self.get_value(id, bb), self.module.functions[callee_id.idx()].name )?; for (device, offset) in self.backing_allocations[&self.func_id] @@ -398,41 +519,41 @@ impl<'a> RTContext<'a> { write!(block, ", ")?; } for arg in args { - write!(block, "{}, ", self.get_value(*arg))?; + write!(block, "{}, ", self.get_value(*arg, bb))?; } let device = self.devices[callee_id.idx()]; if device == Device::AsyncRust { - write!(block, ").await;\n")?; + write!(block, ").await;")?; } else { - write!(block, ");\n")?; + write!(block, ");")?; } } Node::Unary { op, input } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&bb).unwrap().data; match op { UnaryOperator::Not => write!( block, - " {} = !{};\n", - self.get_value(id), - self.get_value(input) + "{} = !{};", + self.get_value(id, bb), + self.get_value(input, bb) )?, UnaryOperator::Neg => write!( block, - " {} = -{};\n", - self.get_value(id), - self.get_value(input) + "{} = -{};", + self.get_value(id, bb), + self.get_value(input, bb) )?, UnaryOperator::Cast(ty) => write!( block, - " {} = {} as {};\n", - self.get_value(id), - self.get_value(input), + "{} = {} as {};", + self.get_value(id, bb), + self.get_value(input, bb), self.get_type(ty) )?, }; } Node::Binary { op, left, right } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&bb).unwrap().data; let op = match op { BinaryOperator::Add => "+", BinaryOperator::Sub => "-", @@ -454,11 +575,11 @@ impl<'a> RTContext<'a> { write!( block, - " {} = {} {} {};\n", - self.get_value(id), - self.get_value(left), + "{} = {} {} {};", + self.get_value(id, bb), + self.get_value(left, bb), op, - self.get_value(right) + self.get_value(right, bb) )?; } Node::Ternary { @@ -467,15 +588,15 @@ impl<'a> RTContext<'a> { second, third, } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&bb).unwrap().data; match op { TernaryOperator::Select => write!( block, - " {} = if {} {{ {} }} else {{ {} }};\n", - self.get_value(id), - self.get_value(first), - self.get_value(second), - self.get_value(third), + "{} = if {} {{{}}} else {{{}}};", + self.get_value(id, bb), + self.get_value(first, bb), + self.get_value(second, bb), + self.get_value(third, bb), )?, }; } @@ -483,10 +604,10 @@ impl<'a> RTContext<'a> { collect, ref indices, } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&bb).unwrap().data; let collect_ty = self.typing[collect.idx()]; let out_size = self.codegen_type_size(self.typing[id.idx()]); - let offset = self.codegen_index_math(collect_ty, indices)?; + let offset = self.codegen_index_math(collect_ty, indices, bb)?; todo!(); } Node::Write { @@ -494,10 +615,10 @@ impl<'a> RTContext<'a> { data, ref indices, } => { - let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let block = &mut blocks.get_mut(&bb).unwrap().data; let collect_ty = self.typing[collect.idx()]; let data_size = self.codegen_type_size(self.typing[data.idx()]); - let offset = self.codegen_index_math(collect_ty, indices)?; + let offset = self.codegen_index_math(collect_ty, indices, bb)?; todo!(); } _ => panic!( @@ -612,6 +733,7 @@ impl<'a> RTContext<'a> { &self, mut collect_ty: TypeID, indices: &[Index], + bb: NodeID, ) -> Result<String, Error> { let mut acc_offset = "0".to_string(); for index in indices { @@ -662,7 +784,7 @@ impl<'a> RTContext<'a> { // ((0 * s1 + p1) * s2 + p2) * s3 + p3 ... let elem_size = self.codegen_type_size(elem); for (p, s) in zip(pos, dims) { - let p = self.get_value(*p); + let p = self.get_value(*p, bb); acc_offset = format!("{} * ", acc_offset); self.codegen_dynamic_constant(*s, &mut acc_offset)?; acc_offset = format!("({} + {})", acc_offset, p); @@ -749,8 +871,70 @@ impl<'a> RTContext<'a> { } } + fn codegen_open_environment<W: Write>(&self, root: NodeID, w: &mut W) -> Result<(), Error> { + let func = &self.get_func(); + + // Declare intermediary variables for every value in this fork-join (or + // whole function for start) that isn't in any child fork-joins. + for idx in 0..func.nodes.len() { + let id = NodeID::new(idx); + let control = func.nodes[idx].is_control(); + let in_root = self.nodes_in_fork_joins[&root].contains(&id); + let in_child = self.fork_tree[&root] + .iter() + .any(|child| self.nodes_in_fork_joins[&child].contains(&id)); + let is_reduce_on_child = func.nodes[idx] + .try_reduce() + .map(|(control, _, _)| { + self.fork_tree[&root].contains(&self.join_fork_map[&control]) + }) + .unwrap_or(false); + if (control || !in_root || in_child) && !is_reduce_on_child { + continue; + } + + write!( + w, + "let mut {}_{}: {} = {};", + if is_reduce_on_child { "reduce" } else { "node" }, + idx, + self.get_type(self.typing[idx]), + if self.module.types[self.typing[idx].idx()].is_integer() { + "0" + } else if self.module.types[self.typing[idx].idx()].is_float() { + "0.0" + } else { + "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())" + } + )?; + } + + // Declare Vec for storing futures of fork-joins. + for fork in self.fork_tree[&root].iter() { + write!(w, "let mut fork_{} = vec![];", fork.idx())?; + } + + // The core executor is a Rust loop. We literally run a "control token" + // as described in the original sea of nodes paper through the basic + // blocks to drive execution. + write!( + w, + "let mut control_token: i8 = {};loop {{match control_token {{", + root.idx(), + )?; + + Ok(()) + } + + fn codegen_close_environment<W: Write>(&self, w: &mut W) -> Result<(), Error> { + // Close the match and loop. + write!(w, "_ => panic!()}}}}") + } + /* - * Generate a runner object for this function. + * Generate a runner object for this function. The runner object stores + * backing memory for a Hercules function and wraps calls to the Hercules + * function. */ fn codegen_runner_object<W: Write>(&self, w: &mut W) -> Result<(), Error> { // Figure out the devices for the parameters and the return value if @@ -795,29 +979,29 @@ impl<'a> RTContext<'a> { // Emit the type definition. A runner object owns its backing memory. write!( w, - "#[allow(non_camel_case_types)]\nstruct HerculesRunner_{} {{\n", + "#[allow(non_camel_case_types)]struct HerculesRunner_{} {{", func.name )?; for (device, _) in self.backing_allocations[&self.func_id].iter() { - write!(w, " backing_ptr_{}: *mut u8,\n", device.name(),)?; - write!(w, " backing_size_{}: usize,\n", device.name(),)?; + write!(w, "backing_ptr_{}: *mut u8,", device.name(),)?; + write!(w, "backing_size_{}: usize,", device.name(),)?; } - write!(w, "}}\n")?; + write!(w, "}}")?; write!( w, - "impl HerculesRunner_{} {{\n fn new() -> Self {{\n Self {{\n", + "impl HerculesRunner_{} {{fn new() -> Self {{Self {{", func.name )?; for (device, _) in self.backing_allocations[&self.func_id].iter() { write!( w, - " backing_ptr_{}: ::core::ptr::null_mut(),\n backing_size_{}: 0,\n", + "backing_ptr_{}: ::core::ptr::null_mut(),backing_size_{}: 0,", device.name(), device.name() )?; } - write!(w, " }}\n }}\n")?; - write!(w, " async fn run<'a>(&'a mut self")?; + write!(w, "}}}}")?; + write!(w, "async fn run<'a>(&'a mut self")?; for idx in 0..func.num_dynamic_constants { write!(w, ", dc_p{}: u64", idx)?; } @@ -842,7 +1026,7 @@ impl<'a> RTContext<'a> { } } if self.module.types[func.return_type.idx()].is_primitive() { - write!(w, ") -> {} {{\n", self.get_type(func.return_type))?; + write!(w, ") -> {} {{", self.get_type(func.return_type))?; } else { let device = match return_device { Some(Device::LLVM) => "CPU", @@ -852,36 +1036,52 @@ impl<'a> RTContext<'a> { let mutability = if return_mut { "Mut" } else { "" }; write!( w, - ") -> ::hercules_rt::Hercules{}Ref{}<'a> {{\n", + ") -> ::hercules_rt::Hercules{}Ref{}<'a> {{", device, mutability )?; } - write!(w, " unsafe {{\n")?; + write!(w, "unsafe {{")?; for (device, (total, _)) in self.backing_allocations[&self.func_id].iter() { - write!(w, " let size = ")?; + write!(w, "let size = ")?; self.codegen_dynamic_constant(*total, w)?; write!( w, - " as usize;\n if self.backing_size_{} < size {{\n", + " as usize;if self.backing_size_{} < size {{", + device.name() + )?; + write!( + w, + "::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});", + device.name(), + device.name(), device.name() )?; - write!(w, " ::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});\n", device.name(), device.name(), device.name())?; + write!(w, "self.backing_size_{} = size;", device.name())?; write!( w, - " self.backing_size_{} = size;\n", + "self.backing_ptr_{} = ::hercules_rt::__{}_alloc(self.backing_size_{});", + device.name(), + device.name(), device.name() )?; - write!(w, " self.backing_ptr_{} = ::hercules_rt::__{}_alloc(self.backing_size_{});\n", device.name(), device.name(), device.name())?; - write!(w, " }}\n")?; + write!(w, "}}")?; } for idx in 0..func.param_types.len() { if !self.module.types[func.param_types[idx].idx()].is_primitive() { - write!(w, " let p{} = p{}.__ptr();\n", idx, idx)?; + write!( + w, + "let p{} = ::hercules_rt::__RawPtrSendSync(p{}.__ptr());", + idx, idx + )?; } } - write!(w, " let ret = {}(", func.name)?; + write!(w, "let ret = {}(", func.name)?; for (device, _) in self.backing_allocations[&self.func_id].iter() { - write!(w, "self.backing_ptr_{}, ", device.name())?; + write!( + w, + "::hercules_rt::__RawPtrSendSync(self.backing_ptr_{}), ", + device.name() + )?; } for idx in 0..func.num_dynamic_constants { write!(w, "dc_p{}, ", idx)?; @@ -889,9 +1089,9 @@ impl<'a> RTContext<'a> { for idx in 0..func.param_types.len() { write!(w, "p{}, ", idx)?; } - write!(w, ").await;\n")?; + write!(w, ").await;")?; if self.module.types[func.return_type.idx()].is_primitive() { - write!(w, " ret\n")?; + write!(w, " ret")?; } else { let device = match return_device { Some(Device::LLVM) => "CPU", @@ -901,28 +1101,28 @@ impl<'a> RTContext<'a> { let mutability = if return_mut { "Mut" } else { "" }; write!( w, - " ::hercules_rt::Hercules{}Ref{}::__from_parts(ret, {} as usize)\n", + "::hercules_rt::Hercules{}Ref{}::__from_parts(ret.0, {} as usize)", device, mutability, self.codegen_type_size(func.return_type) )?; } - write!(w, " }}\n }}\n")?; + write!(w, "}}}}")?; write!( w, - "}}\nimpl Drop for HerculesRunner_{} {{\n #[allow(unused_unsafe)]\n fn drop(&mut self) {{\n unsafe {{\n", + "}}impl Drop for HerculesRunner_{} {{#[allow(unused_unsafe)]fn drop(&mut self) {{unsafe {{", func.name )?; for (device, _) in self.backing_allocations[&self.func_id].iter() { write!( w, - " ::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});\n", + "::hercules_rt::__{}_dealloc(self.backing_ptr_{}, self.backing_size_{});", device.name(), device.name(), device.name() )?; } - write!(w, " }}\n }}\n}}\n")?; + write!(w, "}}}}}}")?; Ok(()) } @@ -930,8 +1130,26 @@ impl<'a> RTContext<'a> { &self.module.functions[self.func_id.idx()] } - fn get_value(&self, id: NodeID) -> String { - format!("node_{}", id.idx()) + fn get_value(&self, id: NodeID, bb: NodeID) -> String { + let func = self.get_func(); + if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() + && control == bb + { + format!("reduce_{}", id.idx()) + } else if let Some((control, _, _)) = func.nodes[id.idx()].try_reduce() + && let fork = self.join_fork_map[&control] + && !self.nodes_in_fork_joins[&fork].contains(&bb) + { + // Before using the value of a reduction outside the fork-join, + // await the futures. + format!( + "{{for fut in fork_{}.drain(..) {{ fut.await; }}; reduce_{}}}", + fork.idx(), + id.idx() + ) + } else { + format!("node_{}", id.idx()) + } } fn get_type(&self, id: TypeID) -> &'static str { @@ -952,7 +1170,9 @@ fn convert_type(ty: &Type) -> &'static str { Type::UnsignedInteger64 => "u64", Type::Float32 => "f32", Type::Float64 => "f64", - Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => "*mut u8", + Type::Product(_) | Type::Summation(_) | Type::Array(_, _) => { + "::hercules_rt::__RawPtrSendSync" + } _ => panic!(), } } diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index ad3125ba6deaf17a4d53befa11c966b154030ab5..3fcc6af029a50839c0b382be371db6fa593e1119 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -67,6 +67,8 @@ pub fn compute_fork_join_nesting( .filter(|id| function.nodes[id.idx()].is_fork()) // where its corresponding join doesn't dominate the control // node (if so, then this control is after the fork-join). + // Check for strict dominance since the join itself should + // be nested in the fork. .filter(|fork_id| !dom.does_prop_dom(fork_join_map[&fork_id], id)) .collect(), ) @@ -174,6 +176,7 @@ pub fn nodes_in_fork_joins( ) -> HashMap<NodeID, HashSet<NodeID>> { let mut result = HashMap::new(); + // Iterate users of fork until reaching corresponding join or reduces. for (fork, join) in fork_join_map { let mut worklist = vec![*fork]; let mut set = HashSet::new(); @@ -196,5 +199,11 @@ pub fn nodes_in_fork_joins( result.insert(*fork, set); } + // Add an entry for the start node containing every node. + result.insert( + NodeID::new(0), + (0..function.nodes.len()).map(NodeID::new).collect(), + ); + result } diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 2ad720434c9ca514af657c90bdd132493fb67c1e..12b64fa3ce7c5f73c2c7470dc546d9913203aa3f 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -354,3 +354,16 @@ macro_rules! runner { <concat_idents!(HerculesRunner_, $x)>::new() }; } + +#[derive(Debug, Clone, Copy)] +#[repr(transparent)] +pub struct __RawPtrSendSync(pub *mut u8); + +impl __RawPtrSendSync { + pub unsafe fn byte_add(self, add: usize) -> Self { + __RawPtrSendSync(self.0.byte_add(add)) + } +} + +unsafe impl Send for __RawPtrSendSync {} +unsafe impl Sync for __RawPtrSendSync {} diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 380100044004a98dcc4ced7b26f3aac3976eb9c7..5665e1faef4f55f49a26545be73e0a8ba81f89a8 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -44,12 +44,11 @@ dce(*); fork-tile[32, 0, true](test6@loop); let out = fork-split(test6@loop); -//let out = outline(out.test6.fj1); -let out = auto-outline(test6); -cpu(out.test6); +let out = outline(out.test6.fj1); +cpu(out); ip-sroa(*); sroa(*); -unforkify(out.test6); +unforkify(out); dce(*); ccp(*); gvn(*); diff --git a/juno_scheduler/Cargo.toml b/juno_scheduler/Cargo.toml index 26055b03fc6fd94fe6440f88a95faf1e2371b404..03a18c8350ebc6a86ece1cc3e65681faaddae4e5 100644 --- a/juno_scheduler/Cargo.toml +++ b/juno_scheduler/Cargo.toml @@ -17,6 +17,8 @@ cfgrammar = "0.13" lrlex = "0.13" lrpar = "0.13" tempfile = "*" +prettyplease = "0.2.29" +syn = { version = "2.0.96", features = ["full"] } hercules_cg = { path = "../hercules_cg" } hercules_ir = { path = "../hercules_ir" } hercules_opt = { path = "../hercules_opt" } diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 9478eb9b55d69ed9c5294e60534024c23889885d..901361c6a438a90abde4ebccf8e85e2cda4bc564 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -647,14 +647,15 @@ impl PassManager { } fn codegen(mut self, output_dir: String, module_name: String) -> Result<(), SchedulerError> { + self.make_def_uses(); self.make_typing(); self.make_control_subgraphs(); - self.make_collection_objects(); - self.make_callgraph(); - self.make_def_uses(); self.make_fork_join_maps(); self.make_fork_control_maps(); self.make_fork_trees(); + self.make_nodes_in_fork_joins(); + self.make_collection_objects(); + self.make_callgraph(); self.make_devices(); let PassManager { @@ -663,14 +664,15 @@ impl PassManager { constants, dynamic_constants, labels, + def_uses: Some(def_uses), typing: Some(typing), control_subgraphs: Some(control_subgraphs), - collection_objects: Some(collection_objects), - callgraph: Some(callgraph), - def_uses: Some(def_uses), fork_join_maps: Some(fork_join_maps), fork_control_maps: Some(fork_control_maps), fork_trees: Some(fork_trees), + nodes_in_fork_joins: Some(nodes_in_fork_joins), + collection_objects: Some(collection_objects), + callgraph: Some(callgraph), devices: Some(devices), bbs: Some(bbs), node_colors: Some(node_colors), @@ -735,8 +737,13 @@ impl PassManager { Device::AsyncRust => rt_codegen( FunctionID::new(idx), &module, + &def_uses[idx], &typing[idx], &control_subgraphs[idx], + &fork_join_maps[idx], + &fork_control_maps[idx], + &fork_trees[idx], + &nodes_in_fork_joins[idx], &collection_objects, &callgraph, &devices, @@ -753,8 +760,20 @@ impl PassManager { } println!("{}", llvm_ir); println!("{}", cuda_ir); + let rust_rt = prettyplease::unparse( + &syn::parse_file(&rust_rt) + .expect(&format!("PANIC: Malformed RT Rust code: {}", rust_rt)), + ); println!("{}", rust_rt); + // Write the Rust runtime into a file. + let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); + println!("{}", output_rt); + let mut file = + File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file."); + file.write_all(rust_rt.as_bytes()) + .expect("PANIC: Unable to write output Rust runtime file contents."); + let output_archive = format!("{}/lib{}.a", output_dir, module_name); println!("{}", output_archive); @@ -844,14 +863,6 @@ impl PassManager { ); } - // Write the Rust runtime into a file. - let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); - println!("{}", output_rt); - let mut file = - File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file."); - file.write_all(rust_rt.as_bytes()) - .expect("PANIC: Unable to write output Rust runtime file contents."); - Ok(()) } }