diff --git a/.gitignore b/.gitignore index feb7ba10a50766c241d492635590ba42b5e56b3b..16e4eda72be12bf1449508941f676c9453459632 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ *.ll *.c *.o -*.hbin +*.a +*.hman .*.swp -.vscode \ No newline at end of file +.vscode diff --git a/Cargo.lock b/Cargo.lock index 6dfa8e9cf56fce92035040654966de942a727d98..768b6a6fb6f3fac6f5fa3133aa9a543d6e9dd93d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -676,6 +676,7 @@ version = "0.1.0" dependencies = [ "anyhow", "hercules_cg", + "hercules_ir", "postcard", "serde", ] diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 716613efe6ccd0b5f6ce93f35c0515d56d7cecfd..915974bc350f0ac62c86b55ab42d2345bb74e97a 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -1,6 +1,7 @@ extern crate bitvec; -use std::collections::{HashMap, VecDeque}; +use std::cell::{Cell, RefCell}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::fmt::{Error, Write}; use std::iter::once; @@ -63,7 +64,8 @@ pub fn cpu_compile<W: Write>( w: &mut W, ) -> Result<(), Error> { // Calculate basic analyses over schedule IR. - let dep_graph = sched_dependence_graph(function); + let virt_reg_to_inst_id = sched_virt_reg_to_inst_id(function); + let dep_graph = sched_dependence_graph(function, &virt_reg_to_inst_id); let svalue_types = sched_svalue_types(function); let parallel_reduce_infos = sched_parallel_reduce_sections(function); @@ -111,7 +113,18 @@ pub fn cpu_compile<W: Write>( true }; - if is_inside { + // Check if the parent is a vectorized fork-join. + let is_parent_vectorized = possible_parent + // Check if the parent fork-join has a vector width. + .map(|parent| parallel_reduce_infos[&parent].vector_width.is_some()) + // Sequential blocks are not vectorized. + .unwrap_or(false); + + // If we are inside the block's fork-join or the block's fork-join + // is vectorized, then refer to the blocks directly. Vectorized + // fork-joins have the same LLVM IR control flow as the schedule IR + // control flow. + if is_inside || is_parent_vectorized { block_names.insert((block_id, fork_join_id), format!("bb_{}", block_idx)); } else { block_names.insert( @@ -122,526 +135,1045 @@ pub fn cpu_compile<W: Write>( } } - // Generate a dummy uninitialized global - this is needed so that there'll - // be a non-empty .bss section in the ELF object file. - write!( - w, - "@dummy_{} = dso_local global i8 0, align 1\n", - manifest.name - )?; - - // Emit the partition function signature. - write!(w, "define ")?; - if function.return_types.len() == 1 { - emit_type(&function.return_types[0], w)?; - } else { - // Functions with multiple return values return said values in a struct. - emit_type(&SType::Product(function.return_types.clone().into()), w)?; - } - write!(w, " @{}(", manifest.name)?; - (0..function.param_types.len()) - .map(|param_idx| Some(SValue::VirtualRegister(param_idx))) - .intersperse(None) - .map(|token| -> Result<(), Error> { - match token { - Some(param_svalue) => { - emit_svalue(¶m_svalue, true, &svalue_types, w)?; + // Create context for emitting LLVM IR. + let ctx = CPUContext { + function, + manifest, + block: Cell::new((0, &function.blocks[0])), + + virt_reg_to_inst_id, + dep_graph, + svalue_types, + parallel_reduce_infos, + + block_names, + + vector_width: Cell::new(None), + outside_def_used_in_vector: RefCell::new(HashSet::new()), + vectors_from_parallel: RefCell::new(HashSet::new()), + vector_reduce_associative_vars: RefCell::new(HashSet::new()), + vector_reduce_cycle: Cell::new(false), + }; + ctx.emit_function(w)?; + + Ok(()) +} + +/* + * Top level structure to hold analysis info and cell-ed state. + */ +struct CPUContext<'a> { + function: &'a SFunction, + manifest: &'a PartitionManifest, + block: Cell<(usize, &'a SBlock)>, + + // Basic analyses over schedule IR. + virt_reg_to_inst_id: HashMap<usize, InstID>, + dep_graph: HashMap<InstID, Vec<InstID>>, + svalue_types: HashMap<SValue, SType>, + parallel_reduce_infos: HashMap<ForkJoinID, ParallelReduceInfo>, + + // Calculate the names of each block up front. For blocks that are the top + // or bottom blocks of sequential fork-joins, references outside the fork- + // join actually need to refer to the header block. This is a bit + // complicated to handle, and we use these names in several places, so pre- + // calculate the block names. Intuitively, if we are "inside" a sequential + // fork-join, references to the top or bottom blocks actually refer to those + // blocks, while if we are "outside" the sequential fork-join, references to + // both the top or bottom blocks actually refer to the loop header block. + // Fully vectorized fork-joins are not considered sequential. + block_names: HashMap<(BlockID, Option<ForkJoinID>), String>, + + // Track whether we are currently in a vectorized parallel section - this + // affects how we lower types, for example. + vector_width: Cell<Option<usize>>, + // Track which virtual registers are defined outside the vectorized parallel + // section and used within it. + outside_def_used_in_vector: RefCell<HashSet<usize>>, + // Track which virtual registers are defined in the vectorized parallel + // section and used in the vectorized reduce section. + vectors_from_parallel: RefCell<HashSet<usize>>, + // Track which reduction variables (store their virtual register and + // variable number) are associative in the vectorized reduce section. + vector_reduce_associative_vars: RefCell<HashSet<(usize, usize)>>, + // track whether there are any non-associative reduction variables in a + // vectorized reduce section (which corresponds to whether we need to + // generate explicit control flow or not). + vector_reduce_cycle: Cell<bool>, +} + +impl<'a> CPUContext<'a> { + fn emit_function<W: Write>(&self, w: &mut W) -> Result<(), Error> { + // Emit the partition function signature. + write!(w, "define ")?; + if self.function.return_types.len() == 1 { + self.emit_type(&self.function.return_types[0], w)?; + } else { + // Functions with multiple return values return said values in a + // struct. + self.emit_type( + &SType::Product(self.function.return_types.clone().into()), + w, + )?; + } + write!(w, " @{}(", self.manifest.name)?; + (0..self.function.param_types.len()) + .map(|param_idx| Some(SValue::VirtualRegister(param_idx))) + .intersperse(None) + .map(|token| -> Result<(), Error> { + match token { + Some(param_svalue) => { + self.emit_svalue(¶m_svalue, true, w)?; + } + None => write!(w, ", ")?, } - None => write!(w, ", ")?, + Ok(()) + }) + .collect::<Result<(), Error>>()?; + // Technically, this may fail if for some reason there's a parallel + // launch partition with no parameters. Blame LLVM for being + // unnecessarily strict about commas of all things... + for parallel_launch_dim in 0..self.manifest.device.num_parallel_launch_dims() { + write!(w, ", i64 %parallel_launch_{}_low", parallel_launch_dim)?; + write!(w, ", i64 %parallel_launch_{}_len", parallel_launch_dim)?; + } + write!(w, ") {{\n",)?; + + // Emit the function body. Emit each block, one at a time. + for (block_idx, block) in self.function.blocks.iter().enumerate() { + self.block.set((block_idx, block)); + + // For "tops" of sequential fork-joins, we emit a special top block + // to be the loop header for the fork-join loop. + if let Some(fork_join_id) = block.kind.try_parallel() + && self.parallel_reduce_infos[&fork_join_id] + .top_parallel_block + .idx() + == block_idx + && self.parallel_reduce_infos[&fork_join_id] + .vector_width + .is_none() + { + self.emit_fork_join_seq_header(fork_join_id, block_idx, w)?; } - Ok(()) - }) - .collect::<Result<(), Error>>()?; - write!(w, ") {{\n",)?; - // Emit the function body. Emit each block, one at a time. - for (block_idx, block) in function.blocks.iter().enumerate() { - // For "tops" of sequential fork-joins, we emit a special top block to - // be the loop header for the fork-join loop. - if let Some(fork_join_id) = block.kind.try_parallel() - && parallel_reduce_infos[&fork_join_id] - .top_parallel_block - .idx() - == block_idx - { - emit_fork_join_seq_header( - fork_join_id, - ¶llel_reduce_infos[&fork_join_id], - &svalue_types, - &block_names, - block_idx, + // Emit the header for the block. + write!( w, + "{}:\n", + &self.block_names[&(BlockID::new(block_idx), block.kind.try_fork_join_id())] )?; - } - // Emit the header for the block. - write!( - w, - "{}:\n", - &block_names[&(BlockID::new(block_idx), block.kind.try_fork_join_id())] - )?; + // If this block is in a vectorized parallel section, set up the + // context for vector code generation. + if let Some(fork_join_id) = block.kind.try_parallel() + && let Some(width) = self.parallel_reduce_infos[&fork_join_id].vector_width + { + self.setup_vectorized_parallel_block(width, w)?; + } - // For each basic block, emit instructions in that block. Emit using a - // worklist over the dependency graph. - let mut emitted = bitvec![u8, Lsb0; 0; block.insts.len()]; - let mut worklist = VecDeque::from((0..block.insts.len()).collect::<Vec<_>>()); - while let Some(inst_idx) = worklist.pop_front() { - let inst_id = InstID::new(block_idx, inst_idx); - let dependencies = &dep_graph[&inst_id]; - let all_uses_emitted = dependencies - .into_iter() - // Check that all used instructions in this block... - .filter(|inst_id| inst_id.idx_0() == block_idx) - // were already emitted. - .all(|inst_id| emitted[inst_id.idx_1()]); - // Phis don't need to wait for all of their uses to be emitted. - if block.insts[inst_idx].is_phi() || all_uses_emitted { - emit_inst( - block.virt_regs[inst_id.idx_1()].0, - &block.insts[inst_idx], - block.kind.try_fork_join_id(), - &block_names, - &svalue_types, - w, - )?; - emitted.set(inst_id.idx_1(), true); - } else { - worklist.push_back(inst_idx); + // If this block is in a vectorized reduce section, set up either a + // post-parallel reduction loop or a vector reduction, depending on + // whether there's an associative schedule on each reduction + // variable. + if let Some(fork_join_id) = block.kind.try_reduce() + && let Some(width) = self.parallel_reduce_infos[&fork_join_id].vector_width + { + self.setup_vectorized_reduce_block(fork_join_id, width, w)?; } - } - } - write!(w, "}}\n\n",)?; - Ok(()) -} + // For each basic block, emit instructions in that block. Emit using + // a worklist over the dependency graph. + let mut emitted = bitvec![u8, Lsb0; 0; block.insts.len()]; + let mut worklist = VecDeque::from((0..block.insts.len()).collect::<Vec<_>>()); + while let Some(inst_idx) = worklist.pop_front() { + let inst_id = InstID::new(block_idx, inst_idx); + let dependencies = &self.dep_graph[&inst_id]; + let all_uses_emitted = dependencies + .into_iter() + // Check that all used instructions in this block... + .filter(|inst_id| inst_id.idx_0() == block_idx) + // were already emitted. + .all(|inst_id| emitted[inst_id.idx_1()]); + // Phis don't need to wait for all of their uses to be emitted. + if block.insts[inst_idx].is_phi() || all_uses_emitted { + self.emit_inst( + block.virt_regs[inst_id.idx_1()].0, + &block.insts[inst_idx], + block.kind.try_fork_join_id(), + w, + )?; + emitted.set(inst_id.idx_1(), true); + } else { + worklist.push_back(inst_idx); + } + } -fn emit_type<W: Write>(stype: &SType, w: &mut W) -> Result<(), Error> { - match stype { - SType::Boolean => write!(w, "i1")?, - SType::Integer8 | SType::UnsignedInteger8 => write!(w, "i8")?, - SType::Integer16 | SType::UnsignedInteger16 => write!(w, "i16")?, - SType::Integer32 | SType::UnsignedInteger32 => write!(w, "i32")?, - SType::Integer64 | SType::UnsignedInteger64 => write!(w, "i64")?, - SType::Float32 => write!(w, "float")?, - SType::Float64 => write!(w, "double")?, - SType::Product(fields) => { - write!(w, "{{")?; - fields - .into_iter() - .map(Some) - .intersperse(None) - .map(|token| -> Result<(), Error> { - match token { - Some(field_ty) => emit_type(field_ty, w)?, - None => write!(w, ", ")?, - } - Ok(()) - }) - .collect::<Result<(), Error>>()?; - write!(w, "}}")?; + self.reset_cells(); } - SType::ArrayRef(_) => write!(w, "ptr")?, - } + write!(w, "}}\n",)?; - Ok(()) -} + Ok(()) + } -fn emit_constant<W: Write>(sconstant: &SConstant, w: &mut W) -> Result<(), Error> { - match sconstant { - SConstant::Boolean(val) => write!(w, "{}", val)?, - SConstant::Integer8(val) => write!(w, "{}", val)?, - SConstant::Integer16(val) => write!(w, "{}", val)?, - SConstant::Integer32(val) => write!(w, "{}", val)?, - SConstant::Integer64(val) => write!(w, "{}", val)?, - SConstant::UnsignedInteger8(val) => write!(w, "{}", val)?, - SConstant::UnsignedInteger16(val) => write!(w, "{}", val)?, - SConstant::UnsignedInteger32(val) => write!(w, "{}", val)?, - SConstant::UnsignedInteger64(val) => write!(w, "{}", val)?, - SConstant::Float32(val) => { - if val.fract() == 0.0 { - write!(w, "{}.0", val)? - } else { - write!(w, "{}", val)? - } + fn emit_type<W: Write>(&self, stype: &SType, w: &mut W) -> Result<(), Error> { + if let Some(width) = self.vector_width.get() { + write!(w, "<{} x ", width)?; } - SConstant::Float64(val) => { - if val.fract() == 0.0 { - write!(w, "{}.0", val)? - } else { - write!(w, "{}", val)? + + match stype { + SType::Boolean => write!(w, "i1")?, + SType::Integer8 | SType::UnsignedInteger8 => write!(w, "i8")?, + SType::Integer16 | SType::UnsignedInteger16 => write!(w, "i16")?, + SType::Integer32 | SType::UnsignedInteger32 => write!(w, "i32")?, + SType::Integer64 | SType::UnsignedInteger64 => write!(w, "i64")?, + SType::Float32 => write!(w, "float")?, + SType::Float64 => write!(w, "double")?, + SType::Product(fields) => { + write!(w, "{{")?; + fields + .into_iter() + .map(Some) + .intersperse(None) + .map(|token| -> Result<(), Error> { + match token { + Some(field_ty) => self.emit_type(field_ty, w)?, + None => write!(w, ", ")?, + } + Ok(()) + }) + .collect::<Result<(), Error>>()?; + write!(w, "}}")?; } + SType::ArrayRef(_) => write!(w, "ptr")?, } - SConstant::Product(fields) => { - write!(w, "{{")?; - fields - .into_iter() - .map(Some) - .intersperse(None) - .map(|token| -> Result<(), Error> { - match token { - Some(field_cons) => { - emit_type(&field_cons.get_type(), w)?; - write!(w, " ")?; - emit_constant(field_cons, w)?; - } - None => write!(w, ", ")?, - } - Ok(()) - }) - .collect::<Result<(), Error>>()?; - write!(w, "}}")?; + + if self.vector_width.get().is_some() { + write!(w, ">")?; } + + Ok(()) } - Ok(()) -} + fn emit_constant<W: Write>(&self, sconstant: &SConstant, w: &mut W) -> Result<(), Error> { + match sconstant { + SConstant::Boolean(val) => write!(w, "{}", val)?, + SConstant::Integer8(val) => write!(w, "{}", val)?, + SConstant::Integer16(val) => write!(w, "{}", val)?, + SConstant::Integer32(val) => write!(w, "{}", val)?, + SConstant::Integer64(val) => write!(w, "{}", val)?, + SConstant::UnsignedInteger8(val) => write!(w, "{}", val)?, + SConstant::UnsignedInteger16(val) => write!(w, "{}", val)?, + SConstant::UnsignedInteger32(val) => write!(w, "{}", val)?, + SConstant::UnsignedInteger64(val) => write!(w, "{}", val)?, + SConstant::Float32(val) => { + if val.fract() == 0.0 { + write!(w, "{}.0", val)? + } else { + write!(w, "{}", val)? + } + } + SConstant::Float64(val) => { + if val.fract() == 0.0 { + write!(w, "{}.0", val)? + } else { + write!(w, "{}", val)? + } + } + SConstant::Product(fields) => { + write!(w, "{{")?; + fields + .into_iter() + .map(Some) + .intersperse(None) + .map(|token| -> Result<(), Error> { + match token { + Some(field_cons) => { + self.emit_type(&field_cons.get_type(), w)?; + write!(w, " ")?; + self.emit_constant(field_cons, w)?; + } + None => write!(w, ", ")?, + } + Ok(()) + }) + .collect::<Result<(), Error>>()?; + write!(w, "}}")?; + } + } -fn emit_svalue<W: Write>( - svalue: &SValue, - add_ty: bool, - types: &HashMap<SValue, SType>, - w: &mut W, -) -> Result<(), Error> { - if add_ty { - emit_type(&types[svalue], w)?; - write!(w, " ")?; - } - match svalue { - SValue::VirtualRegister(virt_reg) => write!(w, "%v{}", virt_reg)?, - SValue::Constant(cons) => emit_constant(cons, w)?, + Ok(()) } - Ok(()) -} -fn emit_inst<W: Write>( - virt_reg: usize, - inst: &SInst, - location: Option<ForkJoinID>, - block_names: &HashMap<(BlockID, Option<ForkJoinID>), String>, - types: &HashMap<SValue, SType>, - w: &mut W, -) -> Result<(), Error> { - // Helper to emit the initial assignment to the destination virtual - // register, when applicable. - let self_svalue = SValue::VirtualRegister(virt_reg); - let emit_assign = |w: &mut W| -> Result<(), Error> { write!(w, "%v{} = ", virt_reg) }; - - write!(w, " ")?; - match inst { - SInst::Phi { inputs } => { - emit_assign(w)?; - write!(w, "phi ")?; - emit_type(&types[&self_svalue], w)?; + fn emit_svalue<W: Write>(&self, svalue: &SValue, add_ty: bool, w: &mut W) -> Result<(), Error> { + if add_ty { + self.emit_type(&self.svalue_types[svalue], w)?; write!(w, " ")?; - inputs - .into_iter() - .map(Some) - .intersperse(None) - .map(|token| match token { - Some((pred_block_id, svalue)) => { - write!(w, "[ ")?; - emit_svalue(svalue, false, types, w)?; - write!(w, ", %{} ]", &block_names[&(*pred_block_id, location)])?; - Ok(()) - } - None => write!(w, ", "), - }) - .collect::<Result<(), Error>>()?; } - SInst::ThreadID { - dimension, - fork_join, - } => { - emit_assign(w)?; - write!(w, "add i64 0, %thread_id_{}_{}", fork_join.idx(), dimension)?; - } - SInst::ReductionVariable { number } => { - write!(w, "; Already emitted reduction variable #{number}.")?; + if self.vector_width.get().is_some() + && svalue + .try_virt_reg() + .map(|virt_reg| self.outside_def_used_in_vector.borrow().contains(&virt_reg)) + .unwrap_or(false) + { + match svalue { + SValue::VirtualRegister(virt_reg) => { + write!(w, "%vec_{}_v{}", self.block.get().0, virt_reg)? + } + SValue::Constant(_) => todo!(), + } + } else if svalue + .try_virt_reg() + .map(|virt_reg| self.vectors_from_parallel.borrow().contains(&virt_reg)) + .unwrap_or(false) + { + match svalue { + SValue::VirtualRegister(virt_reg) => write!(w, "%extract_v{}", virt_reg)?, + SValue::Constant(_) => todo!(), + } + } else { + match svalue { + SValue::VirtualRegister(virt_reg) => write!(w, "%v{}", virt_reg)?, + SValue::Constant(cons) => self.emit_constant(cons, w)?, + } } - SInst::Jump { - target, - parallel_entry: _, - reduce_exit, - } => { - if reduce_exit.is_some() { + Ok(()) + } + + fn emit_inst<W: Write>( + &self, + virt_reg: usize, + inst: &SInst, + location: Option<ForkJoinID>, + w: &mut W, + ) -> Result<(), Error> { + // Helper to emit the initial assignment to the destination virtual + // register, when applicable. + let self_svalue = SValue::VirtualRegister(virt_reg); + let emit_assign = |w: &mut W| -> Result<(), Error> { write!(w, "%v{} = ", virt_reg) }; + + write!(w, " ")?; + match inst { + SInst::Phi { inputs } => { + emit_assign(w)?; + write!(w, "phi ")?; + self.emit_type(&self.svalue_types[&self_svalue], w)?; + write!(w, " ")?; + inputs + .into_iter() + .map(Some) + .intersperse(None) + .map(|token| match token { + Some((pred_block_id, svalue)) => { + write!(w, "[ ")?; + self.emit_svalue(svalue, false, w)?; + write!(w, ", %{} ]", &self.block_names[&(*pred_block_id, location)])?; + Ok(()) + } + None => write!(w, ", "), + }) + .collect::<Result<(), Error>>()?; + } + SInst::ThreadID { + dimension, + fork_join, + } => { + emit_assign(w)?; + if let Some(width) = self.vector_width.get() { + write!(w, "add <{} x i64> <", width)?; + for idx in 0..width { + if idx != 0 { + write!(w, ", ")?; + } + write!(w, "i64 {}", idx)?; + } + write!(w, ">, zeroinitializer")?; + } else { + write!(w, "add i64 0, %thread_id_{}_{}", fork_join.idx(), dimension)?; + } + } + SInst::ReductionVariable { number } => { + write!(w, "; Already emitted reduction variable #{number}.")?; + } + SInst::Jump { + target, + parallel_entry: _, + reduce_exit, + } => { + if reduce_exit.is_some() && self.vector_reduce_cycle.get() { + // If we're closing a non-vectorized reduction for a + // vectorized parallel, jump back to the beginning of the + // reduction, not the beginning of the parallel section. + let self_block_idx = self.block.get().0; + write!( + w, + "br label %{}", + &self.block_names[&(BlockID::new(self_block_idx), location)], + )?; + } else if reduce_exit.is_some() && self.vectors_from_parallel.borrow().is_empty() { + // If we're closing a reduction and the parallel-reduce is + // not vectorized, we need to branch back to the beginning + // of the parallel-reduce. + write!( + w, + "br label %fork_join_seq_header_{}", + location.unwrap().idx(), + )?; + } else { + // If this is a normal jump (or is closing a reduction and + // is vectorized, along with the parallel section), then + // branch to the successor as expected. + write!(w, "br label %{}", &self.block_names[&(*target, location)])?; + } + } + SInst::Branch { + cond, + false_target, + true_target, + } => { + // Branches aren't involved in any parallel-reduce shenanigans, + // so lowering them is straightforward. + write!(w, "br ")?; + self.emit_svalue(cond, true, w)?; write!( w, - "br label %fork_join_seq_header_{}", - location.unwrap().idx(), + ", label %{}, label %{}", + &self.block_names[&(*true_target, location)], + &self.block_names[&(*false_target, location)], )?; - } else { - write!(w, "br label %{}", &block_names[&(*target, location)])?; } - } - SInst::Branch { - cond, - false_target, - true_target, - } => { - write!(w, "br ")?; - emit_svalue(cond, true, types, w)?; - write!( - w, - ", label %{}, label %{}", - &block_names[&(*true_target, location)], - &block_names[&(*false_target, location)], - )?; - } - SInst::PartitionExit { data_outputs } => { - if data_outputs.len() == 0 { - write!(w, "ret {{}} zeroinitializer")?; - } else if data_outputs.len() == 1 { - write!(w, "ret ")?; - emit_svalue(&data_outputs[0], true, types, w)?; - } else { - let ret_ty = SType::Product( - data_outputs - .iter() - .map(|svalue| types[svalue].clone()) - .collect(), - ); - write!(w, "%v{}_0 = insertvalue ", virt_reg)?; - emit_type(&ret_ty, w)?; - write!(w, " undef, ")?; - emit_svalue(&data_outputs[0], true, types, w)?; - write!(w, ", 0\n")?; - for idx in 1..data_outputs.len() { - write!(w, " %v{}_{} = insertvalue ", virt_reg, idx)?; - emit_type(&ret_ty, w)?; - write!(w, " %v{}_{}, ", virt_reg, idx - 1)?; - emit_svalue(&data_outputs[idx], true, types, w)?; - write!(w, ", {}\n", idx)?; + SInst::PartitionExit { data_outputs } => { + if data_outputs.len() == 0 { + write!(w, "ret {{}} zeroinitializer")?; + } else if data_outputs.len() == 1 { + write!(w, "ret ")?; + self.emit_svalue(&data_outputs[0], true, w)?; + } else { + let ret_ty = SType::Product( + data_outputs + .iter() + .map(|svalue| self.svalue_types[svalue].clone()) + .collect(), + ); + write!(w, "%v{}_0 = insertvalue ", virt_reg)?; + self.emit_type(&ret_ty, w)?; + write!(w, " undef, ")?; + self.emit_svalue(&data_outputs[0], true, w)?; + write!(w, ", 0\n")?; + for idx in 1..data_outputs.len() { + write!(w, " %v{}_{} = insertvalue ", virt_reg, idx)?; + self.emit_type(&ret_ty, w)?; + write!(w, " %v{}_{}, ", virt_reg, idx - 1)?; + self.emit_svalue(&data_outputs[idx], true, w)?; + write!(w, ", {}\n", idx)?; + } + write!(w, " ret ")?; + self.emit_type(&ret_ty, w)?; + write!(w, " %v{}_{}", virt_reg, data_outputs.len() - 1)?; } - write!(w, " ret ")?; - emit_type(&ret_ty, w)?; - write!(w, " %v{}_{}", virt_reg, data_outputs.len() - 1)?; } - } - SInst::Return { value } => { - write!(w, "ret ")?; - emit_svalue(value, true, types, w)?; - } - SInst::Unary { input, op } => { - emit_assign(w)?; - match op { - SUnaryOperator::Not => { - write!(w, "xor ")?; - emit_svalue(input, true, types, w)?; - write!(w, ", -1")?; - } - SUnaryOperator::Neg => { - if types[input].is_float() { - write!(w, "fneg ")?; - emit_svalue(input, true, types, w)?; - } else { - write!(w, "mul ")?; - emit_svalue(input, true, types, w)?; + SInst::Return { value } => { + write!(w, "ret ")?; + self.emit_svalue(value, true, w)?; + } + SInst::Unary { input, op } => { + emit_assign(w)?; + match op { + SUnaryOperator::Not => { + write!(w, "xor ")?; + self.emit_svalue(input, true, w)?; write!(w, ", -1")?; } + SUnaryOperator::Neg => { + if self.svalue_types[input].is_float() { + write!(w, "fneg ")?; + self.emit_svalue(input, true, w)?; + } else { + write!(w, "mul ")?; + self.emit_svalue(input, true, w)?; + write!(w, ", -1")?; + } + } + SUnaryOperator::Cast(_) => todo!(), } - SUnaryOperator::Cast(_) => todo!(), } - } - SInst::Binary { left, right, op } => { - emit_assign(w)?; - let op = op.get_llvm_op(&types[left]); - write!(w, "{} ", op)?; - emit_svalue(left, true, types, w)?; - write!(w, ", ")?; - emit_svalue(right, false, types, w)?; - } - SInst::Ternary { - first, - second, - third, - op, - } => { - emit_assign(w)?; - match op { - STernaryOperator::Select => { - write!(w, "select ")?; - emit_svalue(first, true, types, w)?; + SInst::Binary { left, right, op } => { + // If we're in a vectorized reduce block and this binary + // operation is reducing over an associative reduction variable, + // then we need to emit a LLVM vector reduce intrinsic. + // Otherwise lower into a normal LLVM binary op. + let try_associative_reduction = |sval: &SValue| { + sval.try_virt_reg() + .map(|virt_reg| { + self.vector_reduce_associative_vars + .borrow() + .iter() + .filter(|(red_virt_reg, _)| *red_virt_reg == virt_reg) + .map(|(red_virt_reg, red_num)| (*red_virt_reg, *red_num)) + .next() + }) + .flatten() + }; + if let Some((red_virt_reg, red_num)) = + try_associative_reduction(left).or(try_associative_reduction(right)) + { + let left_virt_reg = left + .try_virt_reg() + .expect("PANIC: Associative reduction can't involve constants."); + let right_virt_reg = right + .try_virt_reg() + .expect("PANIC: Associative reduction can't involve constants."); + let vector_virt_reg = if left_virt_reg != red_virt_reg { + left_virt_reg + } else if right_virt_reg != red_virt_reg { + right_virt_reg + } else { + panic!("PANIC: Associative reduction can't use the reduction variable more than once."); + }; + let info = &self.parallel_reduce_infos[&location.unwrap()]; + write!(w, "%v{} = call reassoc ", red_virt_reg)?; + self.emit_type(&self.svalue_types[&self_svalue], w)?; + let op = op.get_llvm_op(&self.svalue_types[left]); + write!(w, " @llvm.vector.reduce.{}", op)?; + let width = info.vector_width.unwrap(); + self.emit_reduce_suffix(width, &self.svalue_types[&self_svalue], w)?; + write!(w, "(")?; + self.emit_svalue(&info.reduce_inits[red_num], true, w)?; write!(w, ", ")?; - emit_svalue(second, true, types, w)?; + self.vector_width.set(Some(width)); + let old_vectors_from_parallel = self.vectors_from_parallel.take(); + self.emit_svalue(&SValue::VirtualRegister(vector_virt_reg), true, w)?; + self.vector_width.set(None); + self.vectors_from_parallel + .replace(old_vectors_from_parallel); + write!(w, ")")?; + } else { + emit_assign(w)?; + let op = op.get_llvm_op(&self.svalue_types[left]); + write!(w, "{} ", op)?; + self.emit_svalue(left, true, w)?; write!(w, ", ")?; - emit_svalue(third, true, types, w)?; + self.emit_svalue(right, false, w)?; } } + SInst::Ternary { + first, + second, + third, + op, + } => { + emit_assign(w)?; + match op { + STernaryOperator::Select => { + write!(w, "select ")?; + self.emit_svalue(first, true, w)?; + write!(w, ", ")?; + self.emit_svalue(second, true, w)?; + write!(w, ", ")?; + self.emit_svalue(third, true, w)?; + } + } + } + SInst::ArrayLoad { + array, + position, + bounds, + } => { + self.emit_linear_index_calc(virt_reg, position, bounds, w)?; + write!(w, "%load_ptr_{} = getelementptr ", virt_reg)?; + let old_width = self.vector_width.take(); + self.emit_type(&self.svalue_types[&self_svalue], w)?; + self.vector_width.set(old_width); + write!(w, ", ")?; + self.emit_svalue(array, true, w)?; + write!(w, ", ")?; + self.emit_type(&self.svalue_types[&position[0]], w)?; + write!(w, " %calc_linear_idx_{}\n ", virt_reg)?; + emit_assign(w)?; + if let Some(width) = self.vector_width.get() { + write!(w, "call ")?; + self.emit_type(&self.svalue_types[&self_svalue], w)?; + write!(w, " @llvm.masked.gather")?; + self.emit_gather_scatter_suffix(width, &self.svalue_types[&self_svalue], w)?; + write!(w, "(")?; + self.emit_type(&self.svalue_types[array], w)?; + write!(w, " %load_ptr_{}, i32 8, <{} x i1> <", virt_reg, width)?; + for idx in 0..width { + if idx != 0 { + write!(w, ", ")?; + } + write!(w, "i1 true")?; + } + write!(w, ">, ")?; + self.emit_type(&self.svalue_types[&self_svalue], w)?; + write!(w, " undef)")?; + } else { + write!(w, "load ")?; + self.emit_type(&self.svalue_types[&self_svalue], w)?; + write!(w, ", ptr %load_ptr_{}", virt_reg)?; + } + } + SInst::ArrayStore { + array, + value, + position, + bounds, + } => { + self.emit_linear_index_calc(virt_reg, position, bounds, w)?; + write!(w, "%store_ptr_{} = getelementptr ", virt_reg)?; + let old_width = self.vector_width.take(); + self.emit_type(&self.svalue_types[value], w)?; + self.vector_width.set(old_width); + write!(w, ", ")?; + self.emit_svalue(array, true, w)?; + write!(w, ", ")?; + self.emit_type(&self.svalue_types[&position[0]], w)?; + write!(w, " %calc_linear_idx_{}\n ", virt_reg)?; + if let Some(width) = self.vector_width.get() { + write!(w, "call ")?; + self.emit_type(&self.svalue_types[&self_svalue], w)?; + write!(w, " @llvm.masked.scatter")?; + self.emit_gather_scatter_suffix(width, &self.svalue_types[&self_svalue], w)?; + write!(w, "(")?; + self.emit_svalue(array, true, w)?; + write!(w, ", ")?; + self.emit_type(&self.svalue_types[array], w)?; + write!(w, " %store_ptr_{}, i32 8, <{} x i1> <", virt_reg, width)?; + for idx in 0..width { + if idx != 0 { + write!(w, ", ")?; + } + write!(w, "i1 true")?; + } + write!(w, ">)")?; + } else { + write!(w, "store ")?; + self.emit_svalue(value, true, w)?; + write!(w, ", ptr %store_ptr_{}", virt_reg)?; + } + } + _ => {} } - SInst::ArrayLoad { - array, - position, - bounds, - } => { - emit_linear_index_calc(virt_reg, position, bounds, types, w)?; - write!(w, "%load_ptr_{} = getelementptr ", virt_reg)?; - emit_type(&types[&self_svalue], w)?; - write!(w, ", ")?; - emit_svalue(array, true, types, w)?; - write!(w, ", i64 %calc_linear_idx_{}\n ", virt_reg)?; - emit_assign(w)?; - write!(w, "load ")?; - emit_type(&types[&self_svalue], w)?; - write!(w, ", ptr %load_ptr_{}", virt_reg)?; - } - SInst::ArrayStore { - array, - value, - position, - bounds, - } => { - emit_linear_index_calc(virt_reg, position, bounds, types, w)?; - write!(w, "%store_ptr_{} = getelementptr ", virt_reg)?; - emit_type(&types[value], w)?; + write!(w, "\n")?; + + Ok(()) + } + + /* + * Implement the index math to convert a multi-dimensional position to a + * linear position inside an array. + */ + fn emit_linear_index_calc<W: Write>( + &self, + virt_reg: usize, + position: &[SValue], + bounds: &[SValue], + w: &mut W, + ) -> Result<(), Error> { + assert_eq!(position.len(), bounds.len()); + + if position.len() == 1 { + write!(w, "%calc_linear_idx_{} = add ", virt_reg)?; + self.emit_svalue(&position[0], true, w)?; + write!(w, ", zeroinitializer\n ")?; + } else if position.len() == 2 { + write!(w, "%calc_linear_idx_{}_0 = mul ", virt_reg)?; + self.emit_svalue(&position[0], true, w)?; write!(w, ", ")?; - emit_svalue(array, true, types, w)?; - write!(w, ", i64 %calc_linear_idx_{}\n ", virt_reg)?; - write!(w, "store ")?; - emit_svalue(value, true, types, w)?; - write!(w, ", ptr %store_ptr_{}", virt_reg)?; + self.emit_svalue(&bounds[1], false, w)?; + write!(w, "\n %calc_linear_idx_{} = add ", virt_reg)?; + self.emit_svalue(&position[1], true, w)?; + write!(w, ", %calc_linear_idx_{}_0", virt_reg)?; + write!(w, "\n ")?; + } else { + todo!("TODO: Handle the 3 or more dimensional array case.") } - _ => {} + + Ok(()) } - write!(w, "\n")?; - Ok(()) -} + /* + * LLVM intrinsics are a pain to emit textually... + */ + fn intrinsic_type_str(elem_ty: &SType) -> &'static str { + // We can't just use our previous routines for emitting types, because + // only inside intrinsics does LLVM use "f32" and "f64" properly! + match elem_ty { + SType::Boolean => "i1", + SType::Integer8 | SType::UnsignedInteger8 => "i8", + SType::Integer16 | SType::UnsignedInteger16 => "i16", + SType::Integer32 | SType::UnsignedInteger32 => "i32", + SType::Integer64 | SType::UnsignedInteger64 => "i64", + SType::Float32 => "f32", + SType::Float64 => "f64", + _ => panic!(), + } + } -/* - * Emit the loop header implementing a sequential fork-join. - */ -fn emit_fork_join_seq_header<W: Write>( - fork_join_id: ForkJoinID, - info: &ParallelReduceInfo, - types: &HashMap<SValue, SType>, - block_names: &HashMap<(BlockID, Option<ForkJoinID>), String>, - block_idx: usize, - w: &mut W, -) -> Result<(), Error> { - // Start the header of the loop. - write!(w, "fork_join_seq_header_{}:\n", fork_join_id.idx())?; - - // Emit the phis for the linear loop index variable and the reduction - // variables. - let entry_name = &block_names[&(info.predecessor, Some(fork_join_id))]; - let loop_name = &block_names[&(info.reduce_block, Some(fork_join_id))]; - write!( - w, - " %linear_{} = phi i64 [ 0, %{} ], [ %linear_{}_inc, %{} ]\n", - block_idx, entry_name, block_idx, loop_name, - )?; - for (var_num, virt_reg) in info.reduction_variables.iter() { - write!(w, " %v{} = phi ", virt_reg)?; - emit_type(&types[&SValue::VirtualRegister(*virt_reg)], w)?; - write!(w, " [ ")?; - emit_svalue(&info.reduce_inits[*var_num], false, types, w)?; - write!(w, ", %{} ], [ ", entry_name)?; - emit_svalue(&info.reduce_reducts[*var_num], false, types, w)?; - write!(w, ", %{} ]\n", loop_name)?; + fn emit_reduce_suffix<W: Write>( + &self, + width: usize, + elem_ty: &SType, + w: &mut W, + ) -> Result<(), Error> { + write!(w, ".v{}{}", width, Self::intrinsic_type_str(elem_ty))?; + Ok(()) } - // Calculate the loop bounds. - if info.thread_counts.len() == 1 { - write!(w, " %bound_{} = add i64 0, ", block_idx)?; - emit_svalue(&info.thread_counts[0], false, types, w)?; - write!(w, "\n")?; - } else if info.thread_counts.len() == 2 { - write!(w, " %bound_{} = mul ", block_idx)?; - emit_svalue(&info.thread_counts[0], true, types, w)?; - write!(w, ", ")?; - emit_svalue(&info.thread_counts[1], false, types, w)?; - write!(w, "\n")?; - } else { - todo!("TODO: Handle the 3 or more dimensional fork-join case.") + fn emit_gather_scatter_suffix<W: Write>( + &self, + width: usize, + elem_ty: &SType, + w: &mut W, + ) -> Result<(), Error> { + write!( + w, + ".v{}{}.v{}p0", + width, + Self::intrinsic_type_str(elem_ty), + width + )?; + Ok(()) } - // Calculate the multi-dimensional thread indices. - if info.thread_counts.len() == 1 { + /* + * Emit the loop header implementing a sequential fork-join. For historical + * reasons, "sequential" fork-joins are just fork-joins that are lowered to + * LLVM level loops. This includes fork-joins that end up getting + * parallelized across threads via low/high bounds. + */ + fn emit_fork_join_seq_header<W: Write>( + &self, + fork_join_id: ForkJoinID, + block_idx: usize, + w: &mut W, + ) -> Result<(), Error> { + let info = &self.parallel_reduce_infos[&fork_join_id]; + let entry_name = &self.block_names[&(info.predecessor, Some(fork_join_id))]; + let loop_name = &self.block_names[&(info.reduce_block, Some(fork_join_id))]; + let parallel_launch = self.manifest.device.num_parallel_launch_dims() > 0 && info.top_level; + + // Start the header of the loop. + write!(w, "fork_join_seq_header_{}:\n", fork_join_id.idx())?; + + // Emit the phis for the linear loop index variable and the reduction + // variables. + write!( + w, + " %linear_{} = phi i64 [ 0, %{} ], [ %linear_{}_inc, %{} ]\n", + block_idx, entry_name, block_idx, loop_name, + )?; + for (var_num, virt_reg) in info.reduction_variables.iter() { + write!(w, " %v{} = phi ", virt_reg)?; + self.emit_type(&self.svalue_types[&SValue::VirtualRegister(*virt_reg)], w)?; + write!(w, " [ ")?; + self.emit_svalue(&info.reduce_inits[*var_num], false, w)?; + write!(w, ", %{} ], [ ", entry_name)?; + self.emit_svalue(&info.reduce_reducts[*var_num], false, w)?; + write!(w, ", %{} ]\n", loop_name)?; + } + + // Calculate the loop bounds. + if info.thread_counts.len() == 1 { + write!(w, " %bound_{} = add i64 0, ", block_idx)?; + if parallel_launch { + write!(w, "%parallel_launch_0_len")?; + } else { + self.emit_svalue(&info.thread_counts[0], false, w)?; + } + write!(w, "\n")?; + } else if info.thread_counts.len() == 2 { + write!(w, " %bound_{} = mul ", block_idx)?; + if parallel_launch { + write!(w, "i64 %parallel_launch_0_len, %parallel_launch_1_len")?; + } else { + self.emit_svalue(&info.thread_counts[0], true, w)?; + write!(w, ", ")?; + self.emit_svalue(&info.thread_counts[1], false, w)?; + } + write!(w, "\n")?; + } else { + todo!("TODO: Handle the 3 or more dimensional fork-join case.") + } + + // Calculate the multi-dimensional thread indices. + if info.thread_counts.len() == 1 && parallel_launch { + write!( + w, + " %thread_id_{}_0 = add i64 %parallel_launch_0_low, %linear_{}\n", + fork_join_id.idx(), + block_idx + )?; + } else if info.thread_counts.len() == 1 { + write!( + w, + " %thread_id_{}_0 = add i64 0, %linear_{}\n", + fork_join_id.idx(), + block_idx + )?; + } else if info.thread_counts.len() == 2 && parallel_launch { + write!( + w, + " %unshifted_id_{}_0 = udiv i64 %linear_{}, %parallel_launch_1_len\n", + fork_join_id.idx(), + block_idx + )?; + write!( + w, + " %unshifted_id_{}_1 = urem i64 %linear_{}, %parallel_launch_1_len\n", + fork_join_id.idx(), + block_idx + )?; + write!( + w, + " %thread_id_{}_0 = add i64 %unshifted_id_{}_0, %parallel_launch_0_low\n", + fork_join_id.idx(), + fork_join_id.idx(), + )?; + write!( + w, + " %thread_id_{}_1 = add i64 %unshifted_id_{}_1, %parallel_launch_1_low\n", + fork_join_id.idx(), + fork_join_id.idx(), + )?; + } else if info.thread_counts.len() == 2 { + write!( + w, + " %thread_id_{}_0 = udiv i64 %linear_{}, ", + fork_join_id.idx(), + block_idx + )?; + self.emit_svalue(&info.thread_counts[1], false, w)?; + write!(w, "\n")?; + write!( + w, + " %thread_id_{}_1 = urem i64 %linear_{}, ", + fork_join_id.idx(), + block_idx + )?; + self.emit_svalue(&info.thread_counts[1], false, w)?; + write!(w, "\n")?; + } else { + todo!("TODO: Handle the 3 or more dimensional fork-join case.") + } + + // Increment the linear index. write!( w, - " %thread_id_{}_0 = add i64 0, %linear_{}\n", - fork_join_id.idx(), - block_idx + " %linear_{}_inc = add i64 %linear_{}, 1\n", + block_idx, block_idx )?; - } else if info.thread_counts.len() == 2 { + + // Emit the branch. write!( w, - " %thread_id_{}_0 = udiv i64 %linear_{}, ", - fork_join_id.idx(), - block_idx + " %cond_{} = icmp ult i64 %linear_{}, %bound_{}\n", + block_idx, block_idx, block_idx )?; - emit_svalue(&info.thread_counts[1], false, types, w)?; - write!(w, "\n")?; + let top_name = &self.block_names[&(BlockID::new(block_idx), Some(fork_join_id))]; + let succ_name = &self.block_names[&(info.successor, Some(fork_join_id))]; write!( w, - " %thread_id_{}_1 = urem i64 %linear_{}, ", - fork_join_id.idx(), - block_idx + " br i1 %cond_{}, label %{}, label %{}\n", + block_idx, top_name, succ_name )?; - emit_svalue(&info.thread_counts[1], false, types, w)?; - write!(w, "\n")?; - } else { - todo!("TODO: Handle the 3 or more dimensional fork-join case.") + + Ok(()) } - // Increment the linear index. - write!( - w, - " %linear_{}_inc = add i64 %linear_{}, 1\n", - block_idx, block_idx - )?; - - // Emit the branch. - write!( - w, - " %cond_{} = icmp ult i64 %linear_{}, %bound_{}\n", - block_idx, block_idx, block_idx - )?; - let top_name = &block_names[&(BlockID::new(block_idx), Some(fork_join_id))]; - let succ_name = &block_names[&(info.successor, Some(fork_join_id))]; - write!( - w, - " br i1 %cond_{}, label %{}, label %{}\n", - block_idx, top_name, succ_name - )?; + /* + * Calculate and emit block-level info for vectorized parallel blocks. + */ + fn setup_vectorized_parallel_block<W: Write>( + &self, + width: usize, + w: &mut W, + ) -> Result<(), Error> { + let (block_idx, block) = self.block.get(); - Ok(()) -} + // Get the uses of virtual registers defined outside the + // vectorized region. + let mut outside_def_used_in_vector = HashSet::new(); + for inst in block.insts.iter() { + for virt_reg in sched_get_uses(inst).filter_map(|svalue| svalue.try_virt_reg()) { + let outside = match self.virt_reg_to_inst_id.get(&virt_reg) { + Some(use_inst_id) => { + block.kind != self.function.blocks[use_inst_id.idx_0()].kind + } + // Parameters are always defined outside the vectorized + // region as scalars. + None => true, + }; + if outside { + outside_def_used_in_vector.insert(virt_reg); + } + } + } -/* - * Implement the index math to convert a multi-dimensional position to a linear - * position inside an array. - */ -fn emit_linear_index_calc<W: Write>( - virt_reg: usize, - position: &[SValue], - bounds: &[SValue], - types: &HashMap<SValue, SType>, - w: &mut W, -) -> Result<(), Error> { - assert_eq!(position.len(), bounds.len()); - - if position.len() == 1 { - write!(w, "%calc_linear_idx_{} = add i64 0, ", virt_reg)?; - emit_svalue(&position[0], false, types, w)?; - write!(w, "\n ")?; - } else if position.len() == 2 { - write!(w, "%calc_linear_idx_{}_0 = mul ", virt_reg)?; - emit_svalue(&position[0], true, types, w)?; - write!(w, ", ")?; - emit_svalue(&bounds[1], false, types, w)?; - write!( - w, - "\n %calc_linear_idx_{} = add i64 %calc_linear_idx_{}_0, ", - virt_reg, virt_reg - )?; - emit_svalue(&position[1], false, types, w)?; - write!(w, "\n ")?; - } else { - todo!("TODO: Handle the 3 or more dimensional array case.") + // Broadcast scalar values into vector values. The vector + // register produced needs to be indexed in name by the block + // index. This is because we may end up using the same value in + // multiple vectorized blocks, and we can't have those + // vectorized scalars have the same name. + for outside_virt_reg in outside_def_used_in_vector.iter() { + write!( + w, + " %vec1_{}_v{} = insertelement <1 x ", + block_idx, outside_virt_reg + )?; + let elem_ty = &self.svalue_types[&SValue::VirtualRegister(*outside_virt_reg)]; + self.emit_type(elem_ty, w)?; + write!(w, "> undef, ")?; + self.emit_type(elem_ty, w)?; + write!(w, " %v{}, i32 0\n", outside_virt_reg)?; + write!( + w, + " %vec_{}_v{} = shufflevector <1 x ", + block_idx, outside_virt_reg + )?; + self.emit_type(elem_ty, w)?; + write!(w, "> %vec1_{}_v{}, <1 x ", block_idx, outside_virt_reg)?; + self.emit_type(elem_ty, w)?; + write!(w, "> undef, <{} x i32> zeroinitializer\n", width)?; + } + + // Set the cell values in the context. + self.vector_width.set(Some(width)); + self.outside_def_used_in_vector + .replace(outside_def_used_in_vector); + + Ok(()) } - Ok(()) + /* + * Calculate and emit block-level info for vectorized reduce blocks. + */ + fn setup_vectorized_reduce_block<W: Write>( + &self, + fork_join_id: ForkJoinID, + width: usize, + w: &mut W, + ) -> Result<(), Error> { + let (block_idx, block) = self.block.get(); + + // Get uses of vector values defined in the parallel region. + let mut vectors_from_parallel = HashSet::new(); + for inst in block.insts.iter() { + for virt_reg in sched_get_uses(inst).filter_map(|svalue| svalue.try_virt_reg()) { + if let Some(inst_id) = self.virt_reg_to_inst_id.get(&virt_reg) + && self.function.blocks[inst_id.idx_0()].kind + == SBlockKind::Parallel(fork_join_id) + { + vectors_from_parallel.insert(virt_reg); + } + } + } + + // Each reduction may be representable by an LLVM reduction intrinsic. + // If every reduction in this reduce block is, then we don't need to + // generate an explicit loop. If any one reduction isn't representable + // as a single intrinsic, then we need to generate an explicit loop. The + // explicit loop calculates the reduction for all reductions that can't + // be represented by intrinsics, while intrinsics are still used to + // calculate reductions that can be represented by them. Currently, the + // "associative" schedule captures this info per reduction variable. + let all_intrinsic_representable = block + .insts + .iter() + .enumerate() + .filter(|(_, inst)| inst.is_reduction_variable()) + .all(|(inst_idx, _)| block.schedules[&inst_idx].contains(&SSchedule::Associative)); + if !all_intrinsic_representable { + let info = &self.parallel_reduce_infos[&fork_join_id]; + let entry_name = &self.block_names[&(info.bottom_parallel_block, Some(fork_join_id))]; + let self_name = &self.block_names[&(info.reduce_block, Some(fork_join_id))]; + let succ_name = &self.block_names[&(info.successor, Some(fork_join_id))]; + + // Emit a loop header for the reduce. + write!( + w, + " %linear_{} = phi i64 [ 0, %{} ], [ %linear_{}_inc, %{}_reduce_body ]\n", + block_idx, entry_name, block_idx, self_name, + )?; + // Emit phis for reduction variables here, since they need to be + // above everything emitted below. + for (var_num, virt_reg) in info.reduction_variables.iter() { + // Only emit phis for reduction variables that aren't + // implemented in intrinsics. + if !block.schedules[&self.virt_reg_to_inst_id[virt_reg].idx_1()] + .contains(&SSchedule::Associative) + { + write!(w, " %v{} = phi ", virt_reg)?; + self.emit_type(&self.svalue_types[&SValue::VirtualRegister(*virt_reg)], w)?; + write!(w, " [ ")?; + self.emit_svalue(&info.reduce_inits[*var_num], false, w)?; + write!(w, ", %{} ], [ ", entry_name)?; + self.emit_svalue(&info.reduce_reducts[*var_num], false, w)?; + write!(w, ", %{}_reduce_body ]\n", self_name)?; + } + } + write!( + w, + " %linear_{}_inc = add i64 %linear_{}, 1\n", + block_idx, block_idx + )?; + // The loop bound is the constant vector width. + write!( + w, + " %cond_{} = icmp ult i64 %linear_{}, {}\n", + block_idx, block_idx, width + )?; + // Branch to the reduce loop body. + write!( + w, + " br i1 %cond_{}, label %{}_reduce_body, label %{}\n", + block_idx, self_name, succ_name + )?; + // The rest of the reduce block gets put into a "body" block. + write!(w, "{}_reduce_body:\n", self_name)?; + // Extract the needed element from the used parallel vectors. + self.vector_width.set(Some(width)); + for virt_reg in vectors_from_parallel.iter() { + write!(w, " %extract_v{} = extractelement ", virt_reg)?; + self.emit_svalue(&SValue::VirtualRegister(*virt_reg), true, w)?; + write!(w, ", i64 %linear_{}\n", block_idx)?; + } + self.vector_width.set(None); + + // Signal that the terminator needs to be a conditional branch to + // close the loop. + self.vector_reduce_cycle.set(true); + } + + let vector_reduce_associative_vars = block + .insts + .iter() + .enumerate() + .filter_map(|(inst_idx, inst)| { + inst.try_reduction_variable() + .map(|num| (inst_idx, block.virt_regs[inst_idx].0, num)) + }) + .filter(|(inst_idx, _, _)| block.schedules[&inst_idx].contains(&SSchedule::Associative)) + .map(|(_, virt_reg, num)| (virt_reg, num)) + .collect(); + + self.vectors_from_parallel.replace(vectors_from_parallel); + self.vector_reduce_associative_vars + .replace(vector_reduce_associative_vars); + + Ok(()) + } + + /* + * Reset the cells storing block specific context configuration. + */ + pub fn reset_cells(&self) { + self.vector_width.take(); + self.outside_def_used_in_vector.take(); + self.vectors_from_parallel.take(); + self.vector_reduce_associative_vars.take(); + self.vector_reduce_cycle.take(); + } } impl SBinaryOperator { diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index 3ecb506a0ba9eed16865e297f5d5cb814526bc3c..88171d33923fc773bd8ba8377e6eba80986612be 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -1,4 +1,4 @@ -#![feature(let_chains, iter_intersperse)] +#![feature(let_chains, iter_intersperse, map_try_insert)] pub mod cpu; pub mod manifest; diff --git a/hercules_cg/src/manifest.rs b/hercules_cg/src/manifest.rs index 1634b87c131b1b9e4dc1217bf89e8f9118418925..7f13c4b4809ab2cb5643e7df303a8d73ea012b82 100644 --- a/hercules_cg/src/manifest.rs +++ b/hercules_cg/src/manifest.rs @@ -43,6 +43,10 @@ pub struct PartitionManifest { pub returns: Vec<(SType, ReturnKind)>, // Record the list of possible successors from this partition. pub successors: Vec<PartitionID>, + // Device specific parts of the manifest. Represents details of calling + // partition functions not present in the schedule IR type information + // (since schedule IR is target independent). + pub device: DeviceManifest, } #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialEq, Eq)] @@ -70,6 +74,17 @@ pub enum ReturnKind { NextPartition, } +#[derive(Debug, Clone, Hash, Serialize, Deserialize)] +pub enum DeviceManifest { + CPU { + // If there's a top level fork-join that we parallel launch, specify, + // for each thread dimension, how many tiles we want to spawn, and + // the thread count. The thread count is a dynamic constant. + parallel_launch: Box<[(usize, DynamicConstantID)]>, + }, + GPU, +} + impl Manifest { pub fn all_visible_types(&self) -> impl Iterator<Item = SType> + '_ { self.param_types @@ -121,3 +136,21 @@ impl PartitionManifest { }) } } + +impl DeviceManifest { + pub fn default(device: Device) -> Self { + match device { + Device::CPU => DeviceManifest::CPU { + parallel_launch: Box::new([]), + }, + Device::GPU => todo!(), + } + } + + pub fn num_parallel_launch_dims(&self) -> usize { + match self { + DeviceManifest::CPU { parallel_launch } => parallel_launch.len(), + DeviceManifest::GPU => 0, + } + } +} diff --git a/hercules_cg/src/sched_dot.rs b/hercules_cg/src/sched_dot.rs index 9f70dcff07433f32cd1a294dc10b59ad6a9944bf..b997138d2f5c7506d2bd50322ac3e852631e7f58 100644 --- a/hercules_cg/src/sched_dot.rs +++ b/hercules_cg/src/sched_dot.rs @@ -42,10 +42,11 @@ pub fn write_dot<W: Write>(module: &SModule, w: &mut W) -> std::fmt::Result { for (function_name, function) in module.functions.iter() { // Schedule the SFunction to form a linear ordering of instructions. - let dep_graph = sched_dependence_graph(function); + let virt_reg_to_inst_id = sched_virt_reg_to_inst_id(function); + let dep_graph = sched_dependence_graph(function, &virt_reg_to_inst_id); let mut block_to_inst_list = (0..function.blocks.len()) .map(|block_idx| (block_idx, vec![])) - .collect::<HashMap<usize, Vec<(&SInst, usize)>>>(); + .collect::<HashMap<usize, Vec<(&SInst, usize, Option<&Vec<SSchedule>>)>>>(); for (block_idx, block) in function.blocks.iter().enumerate() { let mut emitted = bitvec![u8, Lsb0; 0; block.insts.len()]; let mut worklist = VecDeque::from((0..block.insts.len()).collect::<Vec<_>>()); @@ -60,10 +61,11 @@ pub fn write_dot<W: Write>(module: &SModule, w: &mut W) -> std::fmt::Result { .all(|inst_id| emitted[inst_id.idx_1()]); // Phis don't need to wait for all of their uses to be added. if block.insts[inst_idx].is_phi() || all_uses_emitted { - block_to_inst_list - .get_mut(&block_idx) - .unwrap() - .push((&block.insts[inst_idx], block.virt_regs[inst_idx].0)); + block_to_inst_list.get_mut(&block_idx).unwrap().push(( + &block.insts[inst_idx], + block.virt_regs[inst_idx].0, + block.schedules.get(&inst_idx), + )); emitted.set(inst_id.idx_1(), true); } else { worklist.push_back(inst_idx); @@ -74,15 +76,11 @@ pub fn write_dot<W: Write>(module: &SModule, w: &mut W) -> std::fmt::Result { // A SFunction is a subgraph. write_subgraph_header(function_name, w)?; - // Each SBlock is a nested subgraph. + // Each SBlock is a record node. for (block_idx, block) in function.blocks.iter().enumerate() { - write_block_header(function_name, block_idx, "lightblue", w)?; - // Emit the instructions in scheduled order. write_block(function_name, block_idx, &block_to_inst_list[&block_idx], w)?; - write_graph_footer(w)?; - // Add control edges. for succ in block.successors().as_ref() { write_control_edge(function_name, block_idx, succ.idx(), w)?; @@ -110,20 +108,6 @@ fn write_subgraph_header<W: Write>(function_name: &SFunctionName, w: &mut W) -> Ok(()) } -fn write_block_header<W: Write>( - function_name: &SFunctionName, - block_idx: usize, - color: &str, - w: &mut W, -) -> std::fmt::Result { - write!(w, "subgraph {}_block_{} {{\n", function_name, block_idx)?; - write!(w, "label=\"\"\n")?; - write!(w, "style=rounded\n")?; - write!(w, "bgcolor={}\n", color)?; - write!(w, "cluster=true\n")?; - Ok(()) -} - fn write_graph_footer<W: Write>(w: &mut W) -> std::fmt::Result { write!(w, "}}\n")?; Ok(()) @@ -132,17 +116,13 @@ fn write_graph_footer<W: Write>(w: &mut W) -> std::fmt::Result { fn write_block<W: Write>( function_name: &SFunctionName, block_idx: usize, - insts: &[(&SInst, usize)], + insts: &[(&SInst, usize, Option<&Vec<SSchedule>>)], w: &mut W, ) -> std::fmt::Result { - write!( - w, - "{}_{} [xlabel={}, label=\"{{", - function_name, block_idx, block_idx - )?; + write!(w, "{}_{} [label=\"{{", function_name, block_idx,)?; for token in insts.into_iter().map(|token| Some(token)).intersperse(None) { match token { - Some((inst, virt_reg)) => { + Some((inst, virt_reg, schedules)) => { write!(w, "%{} = {}(", virt_reg, inst.upper_case_name())?; for token in sched_get_uses(inst).map(|u| Some(u)).intersperse(None) { match token { @@ -154,11 +134,26 @@ fn write_block<W: Write>( } } write!(w, ")")?; + if let Some(schedules) = schedules + && !schedules.is_empty() + { + write!(w, " [")?; + for token in schedules.into_iter().map(|s| Some(s)).intersperse(None) { + match token { + Some(schedule) => write!(w, "{:?}", schedule)?, + None => write!(w, ", ")?, + } + } + write!(w, "]")?; + } } None => write!(w, " | ")?, } } - write!(w, "}}\", shape = \"record\"];\n")?; + write!( + w, + "}}\", shape = \"Mrecord\", style = \"filled\", fillcolor = \"lightblue\"];\n" + )?; Ok(()) } diff --git a/hercules_cg/src/sched_gen.rs b/hercules_cg/src/sched_gen.rs index 586dd1486e98890a1f90483e56b2d03a3074a6b6..758eb786cc42eb423df06c2188c21e50205456a1 100644 --- a/hercules_cg/src/sched_gen.rs +++ b/hercules_cg/src/sched_gen.rs @@ -256,7 +256,7 @@ impl<'a> FunctionContext<'a> { * functions. */ fn compile_function(&self) -> (HashMap<SFunctionName, SFunction>, Manifest) { - let (manifest, array_node_to_array_id) = self.compute_manifest(); + let (mut manifest, array_node_to_array_id) = self.compute_manifest(); manifest .partitions @@ -268,10 +268,11 @@ impl<'a> FunctionContext<'a> { let partition_functions = (0..self.plan.num_partitions) .map(|partition_idx| { - ( - self.get_sfunction_name(partition_idx), - self.compile_partition(partition_idx, &manifest, &array_node_to_array_id), - ) + let name = self.get_sfunction_name(partition_idx); + let sfunction = + self.compile_partition(partition_idx, &manifest, &array_node_to_array_id); + self.update_manifest(&mut manifest.partitions[partition_idx], &sfunction); + (name, sfunction) }) .collect(); @@ -433,6 +434,7 @@ impl<'a> FunctionContext<'a> { parameters, returns, successors, + device: DeviceManifest::default(self.plan.partition_devices[partition_idx]), } }) .collect(); @@ -728,6 +730,7 @@ impl<'a> FunctionContext<'a> { let top_node = self.top_nodes[partition_idx]; let top_block = control_id_to_block_id[&top_node]; let parallel_entry = if self.function.nodes[top_node.idx()].is_fork() { + self.copy_schedules(top_node, &mut blocks[0]); Some(self.compile_parallel_entry( top_node, &data_id_to_svalue, @@ -748,6 +751,13 @@ impl<'a> FunctionContext<'a> { .push((self.make_virt_reg(partition_idx), SType::Boolean)); } + // Fifth, make sure every block's schedules map is "filled". + for block in blocks.iter_mut() { + for inst_idx in 0..block.insts.len() { + let _ = block.schedules.try_insert(inst_idx, vec![]); + } + } + SFunction { blocks, param_types: manifest.partitions[partition_idx] @@ -790,6 +800,7 @@ impl<'a> FunctionContext<'a> { // the jump if we're jumping into a parallel section or out of a // reduce section. Note that both of those may be true at once. let parallel_entry = if self.function.nodes[dst.idx()].is_fork() { + self.copy_schedules(dst, block); Some(self.compile_parallel_entry( dst, data_id_to_svalue, @@ -990,6 +1001,7 @@ impl<'a> FunctionContext<'a> { .filter(|user| self.function.nodes[user.idx()].is_reduce()) .position(|user| *user == id) .unwrap(); + self.copy_schedules(id, &mut block); block.insts.push(SInst::ReductionVariable { number }); block.virt_regs.push(( self_virt_reg(), @@ -1201,12 +1213,23 @@ impl<'a> FunctionContext<'a> { blocks[block_id.get().idx()] = block; } + /* + * Helper to copy over schedules. + */ + fn copy_schedules(&self, src: NodeID, block: &mut SBlock) { + block.schedules.insert( + block.insts.len(), + self.plan.schedules[src.idx()] + .iter() + .map(|schedule| sched_make_schedule(schedule)) + .collect(), + ); + } + /* * Compiles a reference to a dynamic constant into math to compute that * dynamic constant. We need a mutable reference to some basic block, since * we may need to generate math inline to compute the dynamic constant. - * TODO: actually implement dynamic constant math - only then will the above - * be true. */ fn compile_dynamic_constant( &self, @@ -1315,6 +1338,52 @@ impl<'a> FunctionContext<'a> { fn get_sfunction_name(&self, partition_idx: usize) -> SFunctionName { format!("{}_{}", self.function.name, partition_idx) } + + /* + * There is some information we can only add to the manifest once we've + * computed the schedule IR. + */ + fn update_manifest(&self, manifest: &mut PartitionManifest, function: &SFunction) { + let parallel_reduce_infos = sched_parallel_reduce_sections(function); + + // Add parallel launch info for CPU partitions. This relies on checking + // schedules inside the generated schedule IR. + let partition_name = manifest.name.clone(); + if let Some(tiles) = function.blocks[0].schedules[&0] + .iter() + .filter_map(|schedule| schedule.try_parallel_launch()) + .next() + && parallel_reduce_infos + .into_iter() + .any(|(_, info)| info.top_level) + && let DeviceManifest::CPU { parallel_launch } = &mut manifest.device + { + let parallel_entry = function.blocks[0].insts[0].try_jump().unwrap().1.unwrap(); + assert_eq!(tiles.len(), parallel_entry.thread_counts.len()); + let top_level_fork_id = self + .fork_join_nest + .iter() + // Find control nodes in the fork join nesting whose only nest + // it itself (is a top level fork-join). + .filter(|(id, nest)| nest.len() == 1 && nest[0] == **id) + // Only consider forks in this partition. + .filter(|(id, _)| { + self.get_sfunction_name(self.plan.partitions[id.idx()].idx()) == partition_name + }) + .next() + .unwrap() + .0; + *parallel_launch = zip( + tiles.into_iter(), + self.function.nodes[top_level_fork_id.idx()] + .try_fork() + .unwrap() + .1, + ) + .map(|(num_chunks, count_dc_id)| (*num_chunks, *count_dc_id)) + .collect(); + } + } } fn convert_unary_op(op: UnaryOperator, simple_ir_types: &[SType]) -> SUnaryOperator { diff --git a/hercules_cg/src/sched_ir.rs b/hercules_cg/src/sched_ir.rs index 3db4a9e14a897e4f09462ad8ebf47f1c899b2867..02563834b7f7d29745b793b9c709ff8dac4dbb7f 100644 --- a/hercules_cg/src/sched_ir.rs +++ b/hercules_cg/src/sched_ir.rs @@ -82,6 +82,9 @@ pub struct SBlock { // registers produced by certain instructions, like Jump or ArrayStore, is // set to SType::Boolean, but it's not meaningful. pub virt_regs: Vec<(usize, SType)>, + // Map from instruction index in the block to a list of schedules attached + // to that instruction. + pub schedules: HashMap<usize, Vec<SSchedule>>, pub kind: SBlockKind, } @@ -129,6 +132,41 @@ impl SBlockKind { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SSchedule { + ParallelLaunch(Box<[usize]>), + ParallelReduce, + Vectorizable(usize), + Associative, +} + +impl SSchedule { + pub fn try_parallel_launch(&self) -> Option<&[usize]> { + if let SSchedule::ParallelLaunch(tiles) = self { + Some(tiles) + } else { + None + } + } + + pub fn try_vectorizable(&self) -> Option<usize> { + if let SSchedule::Vectorizable(width) = self { + Some(*width) + } else { + None + } + } +} + +pub fn sched_make_schedule(schedule: &Schedule) -> SSchedule { + match schedule { + Schedule::ParallelFork(tiles) => SSchedule::ParallelLaunch(tiles.clone()), + Schedule::ParallelReduce => SSchedule::ParallelReduce, + Schedule::Vectorizable(width) => SSchedule::Vectorizable(*width), + Schedule::Associative => SSchedule::Associative, + } +} + /* * Unlike Hercules IR, we can represent a reference to an array (so that we * don't need to use an array value in this IR). This is fine, since we're not @@ -319,6 +357,14 @@ pub enum SInst { } impl SInst { + pub fn is_reduction_variable(&self) -> bool { + if let SInst::ReductionVariable { number: _ } = self { + true + } else { + false + } + } + pub fn is_phi(&self) -> bool { if let SInst::Phi { inputs: _ } = self { true diff --git a/hercules_cg/src/sched_schedule.rs b/hercules_cg/src/sched_schedule.rs index ee16f5bd0be335574cafefc64e919d64e4e83cc5..5300b990efeb2d345de1085bda746595570ecb97 100644 --- a/hercules_cg/src/sched_schedule.rs +++ b/hercules_cg/src/sched_schedule.rs @@ -79,12 +79,9 @@ pub fn sched_get_uses(inst: &SInst) -> Box<dyn Iterator<Item = &SValue> + '_> { } /* - * Build a dependency graph of instructions in an SFunction. + * Map virtual registers to corresponding instruction IDs. */ -pub fn sched_dependence_graph(function: &SFunction) -> HashMap<InstID, Vec<InstID>> { - let mut dep_graph = HashMap::new(); - - // First, map each virtual register to the instruction ID producing it. +pub fn sched_virt_reg_to_inst_id(function: &SFunction) -> HashMap<usize, InstID> { let mut virt_reg_to_inst_id = HashMap::new(); for block_idx in 0..function.blocks.len() { let block = &function.blocks[block_idx]; @@ -92,9 +89,22 @@ pub fn sched_dependence_graph(function: &SFunction) -> HashMap<InstID, Vec<InstI let virt_reg = block.virt_regs[inst_idx].0; let inst_id = InstID::new(block_idx, inst_idx); virt_reg_to_inst_id.insert(virt_reg, inst_id); - dep_graph.insert(inst_id, vec![]); } } + virt_reg_to_inst_id +} + +/* + * Build a dependency graph of instructions in an SFunction. + */ +pub fn sched_dependence_graph( + function: &SFunction, + virt_reg_to_inst_id: &HashMap<usize, InstID>, +) -> HashMap<InstID, Vec<InstID>> { + let mut dep_graph = HashMap::new(); + for inst_id in virt_reg_to_inst_id.values() { + dep_graph.insert(*inst_id, vec![]); + } // Process the dependencies in each block. This includes inter-block // dependencies for normal def-use edges. @@ -249,6 +259,19 @@ pub struct ParallelReduceInfo { // the parent's ForkJoinID. Parallel-reduce sections in an SFunction form a // forest. pub parent_fork_join_id: Option<ForkJoinID>, + + // Information about how this fork-join should be scheduled. Collecting this + // info here just makes writing the backends more convenient. + pub vector_width: Option<usize>, + // For each reduction variable, track if its associative or parallel + // individually. + pub associative_reduce: HashMap<usize, bool>, + pub parallel_reduce: HashMap<usize, bool>, + // Track if this is a "top-level" parallel-reduce. That is, the parallel- + // reduce is the "only thing" inside this partition function. Only these + // parallel-reduces can be parallelized on the CPU, even if this parallel- + // reduce has a parallel schedule on the entry jump. + pub top_level: bool, } /* @@ -262,7 +285,7 @@ pub fn sched_parallel_reduce_sections( for (block_idx, block) in function.blocks.iter().enumerate() { // Start by identifying a jump into a parallel section. - for inst in block.insts.iter() { + for (inst_idx, inst) in block.insts.iter().enumerate() { if let SInst::Jump { target, parallel_entry, @@ -275,6 +298,10 @@ pub fn sched_parallel_reduce_sections( thread_counts, reduce_inits, } = parallel_entry.clone(); + let vector_width = block.schedules[&inst_idx] + .iter() + .filter_map(|schedule| schedule.try_vectorizable()) + .next(); // The jump target is the top of the parallel section. Get the // fork-join ID from that block. @@ -332,13 +359,22 @@ pub fn sched_parallel_reduce_sections( } // Find the reduction variable instructions. + let mut associative_reduce = HashMap::new(); + let mut parallel_reduce = HashMap::new(); + let reduce_sblock = &function.blocks[reduce_block.idx()]; let reduction_variables = zip( - function.blocks[reduce_block.idx()].insts.iter(), - function.blocks[reduce_block.idx()].virt_regs.iter(), + reduce_sblock.insts.iter().enumerate(), + reduce_sblock.virt_regs.iter(), ) - .filter_map(|(inst, (virt_reg, _))| { - inst.try_reduction_variable() - .map(|number| (number, *virt_reg)) + .filter_map(|((inst_idx, inst), (virt_reg, _))| { + inst.try_reduction_variable().map(|number| { + let schedules = &reduce_sblock.schedules[&inst_idx]; + associative_reduce + .insert(number, schedules.contains(&SSchedule::Associative)); + parallel_reduce + .insert(number, schedules.contains(&SSchedule::ParallelReduce)); + (number, *virt_reg) + }) }) .collect(); @@ -359,12 +395,35 @@ pub fn sched_parallel_reduce_sections( reduction_variables, parent_fork_join_id: None, + vector_width, + associative_reduce, + parallel_reduce, + + top_level: false, }; result.insert(fork_join_id, info); } } } + // Figure out if any parallel-reduces are top level - that is, they are the + // "only thing" in the partition function. + for (_, parallel_reduce_info) in result.iter_mut() { + // A parallel-reduce is top-level if its predecessor is the entry block + // containing only a jump and its successor is an exit block containing + // just a function terminator. + let pred_block = &function.blocks[parallel_reduce_info.predecessor.idx()]; + let succ_block = &function.blocks[parallel_reduce_info.successor.idx()]; + if parallel_reduce_info.predecessor == BlockID::new(0) + && pred_block.insts.len() == 1 + && pred_block.insts[0].is_jump() + && succ_block.insts.len() == 1 + && (succ_block.insts[0].is_partition_exit() || succ_block.insts[0].is_return()) + { + parallel_reduce_info.top_level = true; + } + } + // Compute the parallel-reduce forest last, since this requires some info we // just computed above. let mut parents = HashMap::new(); diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index d369138c670ee39cb3525de6d980a86d1c88110b..9b946d7a41c19a5b4732c3e2a299d4ccbbdba24e 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -113,7 +113,7 @@ pub enum DynamicConstant { // The usize here is an index (which dynamic constant parameter of a // function is this). Parameter(usize), - + // Supported integer operations on dynamic constants. Add(DynamicConstantID, DynamicConstantID), Sub(DynamicConstantID, DynamicConstantID), Mul(DynamicConstantID, DynamicConstantID), @@ -417,7 +417,7 @@ impl Module { write!(w, ", ")?; self.write_dynamic_constant(*y, w)?; write!(w, ")") - }, + } }?; Ok(()) @@ -848,6 +848,31 @@ impl DynamicConstant { } } +pub fn evaluate_dynamic_constant( + cons: DynamicConstantID, + dcs: &Vec<DynamicConstant>, +) -> Option<usize> { + match dcs[cons.idx()] { + DynamicConstant::Constant(cons) => Some(cons), + DynamicConstant::Parameter(_) => None, + DynamicConstant::Add(left, right) => { + Some(evaluate_dynamic_constant(left, dcs)? + evaluate_dynamic_constant(right, dcs)?) + } + DynamicConstant::Sub(left, right) => { + Some(evaluate_dynamic_constant(left, dcs)? - evaluate_dynamic_constant(right, dcs)?) + } + DynamicConstant::Mul(left, right) => { + Some(evaluate_dynamic_constant(left, dcs)? * evaluate_dynamic_constant(right, dcs)?) + } + DynamicConstant::Div(left, right) => { + Some(evaluate_dynamic_constant(left, dcs)? / evaluate_dynamic_constant(right, dcs)?) + } + DynamicConstant::Rem(left, right) => { + Some(evaluate_dynamic_constant(left, dcs)? % evaluate_dynamic_constant(right, dcs)?) + } + } +} + /* * Simple predicate functions on nodes take a lot of space, so use a macro. */ diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index cb888154b23235c084b2202ab325e38f03f01dd0..5a16aced6f5eab1f62e46de121c09e631cb2fc00 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -892,7 +892,7 @@ fn parse_dynamic_constant_id<'a>( fn parse_dynamic_constant<'a>( ir_text: &'a str, - context : &RefCell<Context<'a>>, + context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, DynamicConstant> { let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, dc) = nom::branch::alt(( @@ -911,16 +911,20 @@ fn parse_dynamic_constant<'a>( // Dynamic constant math is written using a prefix function nom::combinator::map( nom::sequence::tuple(( - nom::character::complete::one_of("+-*/%"), - parse_tuple2(|x| parse_dynamic_constant_id(x, context), - |x| parse_dynamic_constant_id(x, context)))), - |(op, (x, y))| - match op { '+' => DynamicConstant::Add(x, y), - '-' => DynamicConstant::Sub(x, y), - '*' => DynamicConstant::Mul(x, y), - '/' => DynamicConstant::Div(x, y), - '%' => DynamicConstant::Rem(x, y), - _ => panic!("Invalid parse") } + nom::character::complete::one_of("+-*/%"), + parse_tuple2( + |x| parse_dynamic_constant_id(x, context), + |x| parse_dynamic_constant_id(x, context), + ), + )), + |(op, (x, y))| match op { + '+' => DynamicConstant::Add(x, y), + '-' => DynamicConstant::Sub(x, y), + '*' => DynamicConstant::Mul(x, y), + '/' => DynamicConstant::Div(x, y), + '%' => DynamicConstant::Rem(x, y), + _ => panic!("Invalid parse"), + }, ), ))(ir_text)?; Ok((ir_text, dc)) diff --git a/hercules_ir/src/schedule.rs b/hercules_ir/src/schedule.rs index 86f18424fb603e443650ba3b4d131723921c9737..ea2c6153e40b9c0bfa28c7cc84a69d41c493ad39 100644 --- a/hercules_ir/src/schedule.rs +++ b/hercules_ir/src/schedule.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; use std::collections::VecDeque; -use std::iter::zip; +use std::iter::{repeat, zip}; use crate::*; @@ -11,8 +11,22 @@ use crate::*; */ #[derive(Debug, Clone, PartialEq, Eq)] pub enum Schedule { + // This fork can be run in parallel and has a "natural" tiling, which may or + // may not be respected by certain backends. The field stores at least how + // many parallel tiles should be run concurrently, along each dimension. + // Some backends (such as GPU) may spawn more parallel tiles (each tile + // being a single thread in that case) along each axis. + ParallelFork(Box<[usize]>), + // This reduce can be "run in parallel" - conceptually, the `reduct` + // backedge can be removed, and the reduce code can be merged into the + // parallel code. ParallelReduce, - Vectorize, + // This fork-join has no impeding control flow. The field stores the vector + // width. + Vectorizable(usize), + // This reduce can be re-associated. This may lower a sequential dependency + // chain into a reduction tree. + Associative, } /* @@ -20,7 +34,7 @@ pub enum Schedule { * refers to a specific backend, so difference "devices" may refer to the same * "kind" of hardware. */ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub enum Device { CPU, GPU, @@ -469,8 +483,11 @@ impl Plan { */ pub fn default_plan( function: &Function, + dynamic_constants: &Vec<DynamicConstant>, + def_use: &ImmutableDefUseMap, reverse_postorder: &Vec<NodeID>, fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nesting: &HashMap<NodeID, Vec<NodeID>>, bbs: &Vec<NodeID>, ) -> Plan { // Start by creating a completely bare-bones plan doing nothing interesting. @@ -483,7 +500,15 @@ pub fn default_plan( // Infer schedules. infer_parallel_reduce(function, fork_join_map, &mut plan); - infer_vectorize(function, fork_join_map, &mut plan); + infer_parallel_fork( + function, + def_use, + fork_join_map, + fork_join_nesting, + &mut plan, + ); + infer_vectorizable(function, dynamic_constants, fork_join_map, &mut plan); + infer_associative(function, &mut plan); // Infer a partitioning. partition_out_forks(function, reverse_postorder, fork_join_map, bbs, &mut plan); @@ -493,6 +518,38 @@ pub fn default_plan( plan } +/* + * Infer parallel fork-joins. These are fork-joins with only parallel reduction + * variables and no parent fork-joins. + */ +pub fn infer_parallel_fork( + function: &Function, + def_use: &ImmutableDefUseMap, + fork_join_map: &HashMap<NodeID, NodeID>, + fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, + plan: &mut Plan, +) { + for id in (0..function.nodes.len()).map(NodeID::new) { + let Node::Fork { + control: _, + ref factors, + } = function.nodes[id.idx()] + else { + continue; + }; + let join_id = fork_join_map[&id]; + let all_parallel_reduce = def_use.get_users(join_id).as_ref().into_iter().all(|user| { + plan.schedules[user.idx()].contains(&Schedule::ParallelReduce) + || function.nodes[user.idx()].is_control() + }); + let top_level = fork_join_nest[&id].len() == 1 && fork_join_nest[&id][0] == id; + if all_parallel_reduce && top_level { + let tiling = repeat(4).take(factors.len()).collect(); + plan.schedules[id.idx()].push(Schedule::ParallelFork(tiling)); + } + } +} + /* * Infer parallel reductions consisting of a simple cycle between a Reduce node * and a Write node, where indices of the Write are position indices using the @@ -585,10 +642,11 @@ pub fn infer_parallel_reduce( /* * Infer vectorizable fork-joins. Just check that there are no control nodes - * between a fork and its join. + * between a fork and its join and the factor is a constant. */ -pub fn infer_vectorize( +pub fn infer_vectorizable( function: &Function, + dynamic_constants: &Vec<DynamicConstant>, fork_join_map: &HashMap<NodeID, NodeID>, plan: &mut Plan, ) { @@ -600,7 +658,39 @@ pub fn infer_vectorize( if let Some(join) = fork_join_map.get(&u) && *join == id { - plan.schedules[u.idx()].push(Schedule::Vectorize); + let factors = function.nodes[u.idx()].try_fork().unwrap().1; + if factors.len() == 1 + && let Some(width) = evaluate_dynamic_constant(factors[0], dynamic_constants) + { + plan.schedules[u.idx()].push(Schedule::Vectorizable(width)); + } + } + } +} + +/* + * Infer associative reduction loops. + */ +pub fn infer_associative(function: &Function, plan: &mut Plan) { + let is_associative = |op| match op { + BinaryOperator::Add + | BinaryOperator::Mul + | BinaryOperator::Or + | BinaryOperator::And + | BinaryOperator::Xor => true, + _ => false, + }; + + for (id, reduct) in (0..function.nodes.len()).map(NodeID::new).filter_map(|id| { + function.nodes[id.idx()] + .try_reduce() + .map(|(_, _, reduct)| (id, reduct)) + }) { + if let Node::Binary { left, right, op } = function.nodes[reduct.idx()] + && (left == id || right == id) + && is_associative(op) + { + plan.schedules[id.idx()].push(Schedule::Associative); } } } diff --git a/hercules_ir/src/subgraph.rs b/hercules_ir/src/subgraph.rs index 62ef123e7528cf5526de59017cb31503113456c5..f1466ce3ab91770c5c2130c00cb9a234dcfd928b 100644 --- a/hercules_ir/src/subgraph.rs +++ b/hercules_ir/src/subgraph.rs @@ -271,10 +271,17 @@ pub fn partition_graph(function: &Function, def_use: &ImmutableDefUseMap, plan: // Record the source of the edges (the current partition). let old_num_edges = subgraph.backward_edges.len(); subgraph.first_backward_edges.push(old_num_edges as u32); - for node in partition.iter() { - // Look at all the uses from nodes in that partition. + for node in partition + .iter() + .filter(|id| function.nodes[id.idx()].is_control()) + { + // Look at all the control uses of control nodes in that partition. let uses = get_uses(&function.nodes[node.idx()]); - for use_id in uses.as_ref() { + for use_id in uses + .as_ref() + .iter() + .filter(|id| function.nodes[id.idx()].is_control()) + { // Add a backward edge to any different partition we are using // and don't add duplicate backward edges. if plan.partitions[use_id.idx()] != plan.partitions[node.idx()] @@ -294,10 +301,17 @@ pub fn partition_graph(function: &Function, def_use: &ImmutableDefUseMap, plan: // Record the source of the edges (the current partition). let old_num_edges = subgraph.forward_edges.len(); subgraph.first_forward_edges.push(old_num_edges as u32); - for node in partition.iter() { - // Look at all the uses from nodes in that partition. + for node in partition + .iter() + .filter(|id| function.nodes[id.idx()].is_control()) + { + // Look at all the control users of control nodes in that partition. let users = def_use.get_users(*node); - for user_id in users.as_ref() { + for user_id in users + .as_ref() + .iter() + .filter(|id| function.nodes[id.idx()].is_control()) + { // Add a forward edge to any different partition that we are a // user of and don't add duplicate forward edges. if plan.partitions[user_id.idx()] != plan.partitions[node.idx()] diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 505b6d2424c1c1de993ecb84ad94521820f92963..b6b9b2811f5f1e1323876f3138c907f26b9c1e8e 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -99,12 +99,12 @@ pub fn typecheck( .enumerate() .map(|(idx, ty)| (ty.clone(), TypeID::new(idx))) .collect(); - let mut reverse_dynamic_constant_map: HashMap<DynamicConstant, DynamicConstantID> - = dynamic_constants - .iter() - .enumerate() - .map(|(idx, ty)| (ty.clone(), DynamicConstantID::new(idx))) - .collect(); + let mut reverse_dynamic_constant_map: HashMap<DynamicConstant, DynamicConstantID> = + dynamic_constants + .iter() + .enumerate() + .map(|(idx, ty)| (ty.clone(), DynamicConstantID::new(idx))) + .collect(); // Step 2: run dataflow. This is an occurrence of dataflow where the flow // function performs a non-associative operation on the predecessor "out" @@ -159,7 +159,7 @@ fn typeflow( constants: &Vec<Constant>, dynamic_constants: &mut Vec<DynamicConstant>, reverse_type_map: &mut HashMap<Type, TypeID>, - reverse_dynamic_constant_map : &mut HashMap<DynamicConstant, DynamicConstantID>, + reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, ) -> TypeSemilattice { // Whenever we want to reference a specific type (for example, for the // start node), we need to get its type ID. This helper function gets the @@ -193,9 +193,10 @@ fn typeflow( | DynamicConstant::Sub(x, y) | DynamicConstant::Mul(x, y) | DynamicConstant::Div(x, y) - | DynamicConstant::Rem(x, y) - => check_dynamic_constants(x, dynamic_constants, num_parameters) - && check_dynamic_constants(y, dynamic_constants, num_parameters), + | DynamicConstant::Rem(x, y) => { + check_dynamic_constants(x, dynamic_constants, num_parameters) + && check_dynamic_constants(y, dynamic_constants, num_parameters) + } } } @@ -757,8 +758,14 @@ fn typeflow( } } - Concrete(type_subst(types, dynamic_constants, reverse_type_map, - reverse_dynamic_constant_map, dc_args, callee.return_type)) + Concrete(type_subst( + types, + dynamic_constants, + reverse_type_map, + reverse_dynamic_constant_map, + dc_args, + callee.return_type, + )) } Node::IntrinsicCall { intrinsic, args: _ } => { let num_params = match intrinsic { @@ -1119,34 +1126,45 @@ pub fn cast_compatible(src_ty: &Type, dst_ty: &Type) -> bool { * dynamic constants are substituted in for the dynamic constants used in the * parameter type. */ -fn types_match(types: &Vec<Type>, dynamic_constants: &Vec<DynamicConstant>, - dc_args : &Box<[DynamicConstantID]>, param : TypeID, input : TypeID) -> bool { +fn types_match( + types: &Vec<Type>, + dynamic_constants: &Vec<DynamicConstant>, + dc_args: &Box<[DynamicConstantID]>, + param: TypeID, + input: TypeID, +) -> bool { // Note that we can't just check whether the type ids are equal since them // being equal does not mean they match when we properly substitute in the // dynamic constant arguments match (&types[param.idx()], &types[input.idx()]) { - (Type::Control, Type::Control) | (Type::Boolean, Type::Boolean) - | (Type::Integer8, Type::Integer8) | (Type::Integer16, Type::Integer16) - | (Type::Integer32, Type::Integer32) | (Type::Integer64, Type::Integer64) + (Type::Control, Type::Control) + | (Type::Boolean, Type::Boolean) + | (Type::Integer8, Type::Integer8) + | (Type::Integer16, Type::Integer16) + | (Type::Integer32, Type::Integer32) + | (Type::Integer64, Type::Integer64) | (Type::UnsignedInteger8, Type::UnsignedInteger8) | (Type::UnsignedInteger16, Type::UnsignedInteger16) | (Type::UnsignedInteger32, Type::UnsignedInteger32) | (Type::UnsignedInteger64, Type::UnsignedInteger64) - | (Type::Float32, Type::Float32) | (Type::Float64, Type::Float64) - => true, - (Type::Product(ps), Type::Product(is)) - | (Type::Summation(ps), Type::Summation(is)) => { + | (Type::Float32, Type::Float32) + | (Type::Float64, Type::Float64) => true, + (Type::Product(ps), Type::Product(is)) | (Type::Summation(ps), Type::Summation(is)) => { ps.len() == is.len() - && ps.iter().zip(is.iter()) - .all(|(p, i)| types_match(types, dynamic_constants, dc_args, *p, *i)) - }, + && ps + .iter() + .zip(is.iter()) + .all(|(p, i)| types_match(types, dynamic_constants, dc_args, *p, *i)) + } (Type::Array(p, pds), Type::Array(i, ids)) => { types_match(types, dynamic_constants, dc_args, *p, *i) - && pds.len() == ids.len() - && pds.iter().zip(ids.iter()) - .all(|(pd, id)| dyn_consts_match(dynamic_constants, dc_args, *pd, *id)) - }, + && pds.len() == ids.len() + && pds + .iter() + .zip(ids.iter()) + .all(|(pd, id)| dyn_consts_match(dynamic_constants, dc_args, *pd, *id)) + } (_, _) => false, } } @@ -1156,21 +1174,26 @@ fn types_match(types: &Vec<Type>, dynamic_constants: &Vec<DynamicConstant>, * constants when the provided dynamic constants are substituted in for the * dynamic constants used in the parameter's dynamic constant */ -fn dyn_consts_match(dynamic_constants: &Vec<DynamicConstant>, - dc_args: &Box<[DynamicConstantID]>, param: DynamicConstantID, - input: DynamicConstantID) -> bool { - match (&dynamic_constants[param.idx()], &dynamic_constants[input.idx()]) { - (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) - => x == y, - (DynamicConstant::Parameter(i), _) - => input == dc_args[*i], +fn dyn_consts_match( + dynamic_constants: &Vec<DynamicConstant>, + dc_args: &Box<[DynamicConstantID]>, + param: DynamicConstantID, + input: DynamicConstantID, +) -> bool { + match ( + &dynamic_constants[param.idx()], + &dynamic_constants[input.idx()], + ) { + (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => x == y, + (DynamicConstant::Parameter(i), _) => input == dc_args[*i], (DynamicConstant::Add(pl, pr), DynamicConstant::Add(il, ir)) | (DynamicConstant::Sub(pl, pr), DynamicConstant::Sub(il, ir)) | (DynamicConstant::Mul(pl, pr), DynamicConstant::Mul(il, ir)) | (DynamicConstant::Div(pl, pr), DynamicConstant::Div(il, ir)) - | (DynamicConstant::Rem(pl, pr), DynamicConstant::Rem(il, ir)) - => dyn_consts_match(dynamic_constants, dc_args, *pl, *il) - && dyn_consts_match(dynamic_constants, dc_args, *pr, *ir), + | (DynamicConstant::Rem(pl, pr), DynamicConstant::Rem(il, ir)) => { + dyn_consts_match(dynamic_constants, dc_args, *pl, *il) + && dyn_consts_match(dynamic_constants, dc_args, *pr, *ir) + } (_, _) => false, } } @@ -1180,13 +1203,19 @@ fn dyn_consts_match(dynamic_constants: &Vec<DynamicConstant>, * returns the appropriate typeID (potentially creating new types and dynamic * constants in the process) */ -fn type_subst(types: &mut Vec<Type>, dynamic_constants: &mut Vec<DynamicConstant>, - reverse_type_map: &mut HashMap<Type, TypeID>, - reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, - dc_args: &Box<[DynamicConstantID]>, typ : TypeID) -> TypeID { - - fn intern_type(ty : Type, types: &mut Vec<Type>, reverse_type_map: &mut HashMap<Type, TypeID>) - -> TypeID { +fn type_subst( + types: &mut Vec<Type>, + dynamic_constants: &mut Vec<DynamicConstant>, + reverse_type_map: &mut HashMap<Type, TypeID>, + reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, + dc_args: &Box<[DynamicConstantID]>, + typ: TypeID, +) -> TypeID { + fn intern_type( + ty: Type, + types: &mut Vec<Type>, + reverse_type_map: &mut HashMap<Type, TypeID>, + ) -> TypeID { if let Some(id) = reverse_type_map.get(&ty) { *id } else { @@ -1198,50 +1227,85 @@ fn type_subst(types: &mut Vec<Type>, dynamic_constants: &mut Vec<DynamicConstant } match &types[typ.idx()] { - Type::Control | Type::Boolean | Type::Integer8 | Type::Integer16 - | Type::Integer32 | Type::Integer64 | Type::UnsignedInteger8 - | Type::UnsignedInteger16 | Type::UnsignedInteger32 - | Type::UnsignedInteger64 | Type::Float32 | Type::Float64 - => typ, + Type::Control + | Type::Boolean + | Type::Integer8 + | Type::Integer16 + | Type::Integer32 + | Type::Integer64 + | Type::UnsignedInteger8 + | Type::UnsignedInteger16 + | Type::UnsignedInteger32 + | Type::UnsignedInteger64 + | Type::Float32 + | Type::Float64 => typ, Type::Product(ts) => { let mut new_ts = vec![]; for t in ts.clone().iter() { - new_ts.push(type_subst(types, dynamic_constants, reverse_type_map, - reverse_dynamic_constant_map, dc_args, *t)); + new_ts.push(type_subst( + types, + dynamic_constants, + reverse_type_map, + reverse_dynamic_constant_map, + dc_args, + *t, + )); } intern_type(Type::Product(new_ts.into()), types, reverse_type_map) - }, + } Type::Summation(ts) => { let mut new_ts = vec![]; for t in ts.clone().iter() { - new_ts.push(type_subst(types, dynamic_constants, reverse_type_map, - reverse_dynamic_constant_map, dc_args, *t)); + new_ts.push(type_subst( + types, + dynamic_constants, + reverse_type_map, + reverse_dynamic_constant_map, + dc_args, + *t, + )); } intern_type(Type::Summation(new_ts.into()), types, reverse_type_map) - }, + } Type::Array(elem, dims) => { let ds = dims.clone(); - let new_elem = type_subst(types, dynamic_constants, reverse_type_map, - reverse_dynamic_constant_map, dc_args, *elem); + let new_elem = type_subst( + types, + dynamic_constants, + reverse_type_map, + reverse_dynamic_constant_map, + dc_args, + *elem, + ); let mut new_dims = vec![]; for d in ds.iter() { - new_dims.push(dyn_const_subst(dynamic_constants, - reverse_dynamic_constant_map, - dc_args, *d)); + new_dims.push(dyn_const_subst( + dynamic_constants, + reverse_dynamic_constant_map, + dc_args, + *d, + )); } - intern_type(Type::Array(new_elem, new_dims.into()), types, reverse_type_map) - }, + intern_type( + Type::Array(new_elem, new_dims.into()), + types, + reverse_type_map, + ) + } } } -fn dyn_const_subst(dynamic_constants: &mut Vec<DynamicConstant>, - reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, - dc_args: &Box<[DynamicConstantID]>, dyn_const : DynamicConstantID) - -> DynamicConstantID { - - fn intern_dyn_const(dc: DynamicConstant, dynamic_constants: &mut Vec<DynamicConstant>, - reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>) - -> DynamicConstantID { +fn dyn_const_subst( + dynamic_constants: &mut Vec<DynamicConstant>, + reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, + dc_args: &Box<[DynamicConstantID]>, + dyn_const: DynamicConstantID, +) -> DynamicConstantID { + fn intern_dyn_const( + dc: DynamicConstant, + dynamic_constants: &mut Vec<DynamicConstant>, + reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, + ) -> DynamicConstantID { if let Some(id) = reverse_dynamic_constant_map.get(&dc) { *id } else { @@ -1258,52 +1322,57 @@ fn dyn_const_subst(dynamic_constants: &mut Vec<DynamicConstant>, DynamicConstant::Add(l, r) => { let x = *l; let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, y); - intern_dyn_const(DynamicConstant::Add(sx, sy), dynamic_constants, - reverse_dynamic_constant_map) - }, + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); + intern_dyn_const( + DynamicConstant::Add(sx, sy), + dynamic_constants, + reverse_dynamic_constant_map, + ) + } DynamicConstant::Sub(l, r) => { let x = *l; let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, y); - intern_dyn_const(DynamicConstant::Sub(sx, sy), dynamic_constants, - reverse_dynamic_constant_map) - }, + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); + intern_dyn_const( + DynamicConstant::Sub(sx, sy), + dynamic_constants, + reverse_dynamic_constant_map, + ) + } DynamicConstant::Mul(l, r) => { let x = *l; let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, y); - intern_dyn_const(DynamicConstant::Mul(sx, sy), dynamic_constants, - reverse_dynamic_constant_map) - }, + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); + intern_dyn_const( + DynamicConstant::Mul(sx, sy), + dynamic_constants, + reverse_dynamic_constant_map, + ) + } DynamicConstant::Div(l, r) => { let x = *l; let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, y); - intern_dyn_const(DynamicConstant::Div(sx, sy), dynamic_constants, - reverse_dynamic_constant_map) - }, + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); + intern_dyn_const( + DynamicConstant::Div(sx, sy), + dynamic_constants, + reverse_dynamic_constant_map, + ) + } DynamicConstant::Rem(l, r) => { let x = *l; let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, - dc_args, y); - intern_dyn_const(DynamicConstant::Rem(sx, sy), dynamic_constants, - reverse_dynamic_constant_map) - }, + let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); + let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); + intern_dyn_const( + DynamicConstant::Rem(sx, sy), + dynamic_constants, + reverse_dynamic_constant_map, + ) + } } } diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs index 92b24279999d8605cc65219ad4aee17d7a812313..c61177a014c8661865adcd03b1b5736196be33b6 100644 --- a/hercules_opt/src/pass.rs +++ b/hercules_opt/src/pass.rs @@ -5,9 +5,11 @@ extern crate serde; extern crate take_mut; use std::collections::HashMap; +use std::env::temp_dir; use std::fs::File; use std::io::Write; use std::iter::zip; +use std::path::Path; use std::process::{Command, Stdio}; use self::serde::Deserialize; @@ -35,8 +37,7 @@ pub enum Pass { // Useful to set to false if displaying a potentially broken module. Xdot(bool), SchedXdot, - // Parameterized by output file name. - Codegen(String), + Codegen, } /* @@ -257,26 +258,49 @@ impl PassManager { pub fn make_plans(&mut self) { if self.plans.is_none() { + self.make_def_uses(); self.make_reverse_postorders(); self.make_fork_join_maps(); + self.make_fork_join_nests(); self.make_bbs(); + let def_uses = self.def_uses.as_ref().unwrap().iter(); let reverse_postorders = self.reverse_postorders.as_ref().unwrap().iter(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); + let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); let bbs = self.bbs.as_ref().unwrap().iter(); self.plans = Some( zip( self.module.functions.iter(), - zip(reverse_postorders, zip(fork_join_maps, bbs)), + zip( + def_uses, + zip( + reverse_postorders, + zip(fork_join_maps, zip(fork_join_nests, bbs)), + ), + ), + ) + .map( + |( + function, + (def_use, (reverse_postorder, (fork_join_map, (fork_join_nest, bb)))), + )| { + default_plan( + function, + &self.module.dynamic_constants, + def_use, + reverse_postorder, + fork_join_map, + fork_join_nest, + bb, + ) + }, ) - .map(|(function, (reverse_postorder, (fork_join_map, bb)))| { - default_plan(function, reverse_postorder, fork_join_map, bb) - }) .collect(), ); } } - pub fn run_passes(&mut self) { + pub fn run_passes(&mut self, input_file: &Path) { for pass in self.passes.clone().iter() { match pass { Pass::DCE => { @@ -473,7 +497,7 @@ impl PassManager { // Xdot doesn't require clearing analysis results. continue; } - Pass::Codegen(output_file_name) => { + Pass::Codegen => { self.make_def_uses(); self.make_typing(); self.make_control_subgraphs(); @@ -504,30 +528,43 @@ impl PassManager { } println!("{}", llvm_ir); - // Compile LLVM IR into ELF object. - let llc_process = Command::new("llc") - .arg("-filetype=obj") + // Write the LLVM IR into a temporary file. + let output_file_prefix = input_file.file_stem().unwrap().to_str().unwrap(); + let mut tmp_path = temp_dir(); + tmp_path.push(format!("{}.ll", output_file_prefix)); + let mut file = File::create(&tmp_path) + .expect("PANIC: Unable to open output LLVM IR file."); + file.write_all(llvm_ir.as_bytes()) + .expect("PANIC: Unable to write output LLVM IR file contents."); + + // Compile LLVM IR into an ELF object file. + let mut clang_process = Command::new("clang") + .arg(&tmp_path) + .arg("--emit-static-lib") .arg("-O3") + .arg("-march=native") + .arg("-o") + .arg({ + let mut lib_path = input_file.parent().unwrap().to_path_buf(); + lib_path.push(format!("lib{}.a", output_file_prefix)); + lib_path + }) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn() .unwrap(); - llc_process - .stdin - .as_ref() - .unwrap() - .write(llvm_ir.as_bytes()) - .unwrap(); - let elf_object = llc_process.wait_with_output().unwrap().stdout; - - // Package manifest and ELF object into the same file. - let hbin_module = (smodule.manifests, elf_object); - let hbin_contents: Vec<u8> = postcard::to_allocvec(&hbin_module).unwrap(); - - let mut file = - File::create(output_file_name).expect("PANIC: Unable to open output file."); - file.write_all(&hbin_contents) - .expect("PANIC: Unable to write output file contents."); + assert!(clang_process.wait().unwrap().success()); + + // Package manifest into a file. + let hman_contents: Vec<u8> = postcard::to_allocvec(&smodule.manifests).unwrap(); + let mut file = File::create({ + let mut hman_path = input_file.parent().unwrap().to_path_buf(); + hman_path.push(format!("{}.hman", output_file_prefix)); + hman_path + }) + .expect("PANIC: Unable to open output manifest file."); + file.write_all(&hman_contents) + .expect("PANIC: Unable to write output manifest file contents."); // Codegen doesn't require clearing analysis results. continue; diff --git a/hercules_rt/src/elf.rs b/hercules_rt/src/elf.rs deleted file mode 100644 index 9a7aac0d87d1712b5b37dd7803d881329f2b472a..0000000000000000000000000000000000000000 --- a/hercules_rt/src/elf.rs +++ /dev/null @@ -1,223 +0,0 @@ -extern crate libc; - -use std::ffi::CStr; -use std::mem::size_of; -use std::ptr::copy_nonoverlapping; -use std::ptr::null_mut; -use std::ptr::read_unaligned; -use std::str::from_utf8; - -use self::libc::*; - -/* - * The libc crate doesn't have everything from elf.h, so these things need to be - * manually defined. - */ -#[repr(C)] -#[derive(Debug)] -struct Elf64_Rela { - r_offset: Elf64_Addr, - r_info: Elf64_Xword, - r_addend: Elf64_Sxword, -} - -const R_X86_64_PC32: u64 = 2; -const R_X86_64_PLT32: u64 = 4; -const STT_FUNC: u8 = 2; - -/* - * Holds a mmaped copy of .text + .bss for direct execution, plus metadata for - * each function. The .bss section holds a table storing addresses to internal - * runtime functions, since this is literally easier than patching the object - * code to directly jump to those runtime functions. - */ -#[derive(Debug)] -pub(crate) struct Elf { - pub(crate) function_names: Vec<String>, - pub(crate) function_pointers: Vec<isize>, - pub(crate) program_section: *mut u8, - pub(crate) program_size: usize, -} - -/* - * Mmaps are visible to all threads and are thread safe, so we can share the ELF - * across threads. - */ -unsafe impl Send for Elf {} -unsafe impl Sync for Elf {} - -impl Drop for Elf { - fn drop(&mut self) { - unsafe { munmap(self.program_section as *mut _, self.program_size) }; - } -} - -/* - * Function for parsing our internal memory representation of an ELF file from - * the raw bytes of an ELF file. This includes creating a executable section of - * code, and relocating function calls and global variables. This whole thing is - * very unsafe, and is predicated on the elf parameter referencing properly - * formatted bytes. - */ -pub(crate) unsafe fn parse_elf(elf: &[u8]) -> Elf { - fn page_align(n: usize) -> usize { - (n + (4096 - 1)) & !(4096 - 1) - } - - // read_unaligned corresponds to memcpys in C - we need to memcpy structs - // out of the file's bytes, since they may be stored without proper - // alignment. - let header: Elf64_Ehdr = read_unaligned(elf.as_ptr() as *const _); - assert!(header.e_shentsize as usize == size_of::<Elf64_Shdr>()); - let section_header_table: Box<[_]> = (0..header.e_shnum) - .map(|idx| { - read_unaligned( - (elf.as_ptr().offset(header.e_shoff as isize) as *const Elf64_Shdr) - .offset(idx as isize), - ) - }) - .collect(); - - // Look for the .symtab, .strtab, .text, .bss, and .rela.text sections. Only - // the .rela.text section is not necessary. - let mut symtab_ndx = -1; - let mut strtab_ndx = -1; - let mut text_ndx = -1; - let mut bss_ndx = -1; - let mut rela_text_ndx = -1; - let shstrtab = &elf[section_header_table[header.e_shstrndx as usize].sh_offset as usize..]; - for i in 0..header.e_shnum as usize { - let section_name = &shstrtab[section_header_table[i].sh_name as usize..]; - let null_position = section_name - .iter() - .position(|&c| c == b'\0') - .unwrap_or(section_name.len()); - let name_str = from_utf8(§ion_name[..null_position]).unwrap(); - if name_str == ".symtab" { - symtab_ndx = i as i32; - } else if name_str == ".strtab" { - strtab_ndx = i as i32; - } else if name_str == ".text" { - text_ndx = i as i32; - } else if name_str == ".bss" { - bss_ndx = i as i32; - } else if name_str == ".rela.text" { - rela_text_ndx = i as i32; - } - } - assert!(symtab_ndx != -1); - assert!(strtab_ndx != -1); - assert!(text_ndx != -1); - assert!(bss_ndx != -1); - - // Get the headers for the required sections. - let symtab_hdr = section_header_table[symtab_ndx as usize]; - let strtab_hdr = section_header_table[strtab_ndx as usize]; - let text_hdr = section_header_table[text_ndx as usize]; - let bss_hdr = section_header_table[bss_ndx as usize]; - - // Collect the symbols in the symbol table. - assert!(symtab_hdr.sh_entsize as usize == size_of::<Elf64_Sym>()); - let num_symbols = symtab_hdr.sh_size as usize / size_of::<Elf64_Sym>(); - let symbol_table: Box<[_]> = (0..num_symbols) - .map(|idx| { - read_unaligned( - (elf.as_ptr().offset(symtab_hdr.sh_offset as isize) as *const Elf64_Sym) - .offset(idx as isize), - ) - }) - .collect(); - - // The mmaped region includes both the .text and .bss sections. - let program_size = page_align(text_hdr.sh_size as usize) + page_align(bss_hdr.sh_size as usize); - let program_base = mmap( - null_mut(), - program_size, - PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, - -1, - 0, - ) as *mut u8; - let text_base = program_base; - let bss_base = text_base.offset(page_align(text_hdr.sh_size as usize) as isize); - - // Copy the object code into the mmaped region. - copy_nonoverlapping( - elf.as_ptr().offset(text_hdr.sh_offset as isize), - text_base, - text_hdr.sh_size as usize, - ); - - // If there are relocations, we process them here. - if rela_text_ndx != -1 { - let rela_text_hdr = section_header_table[rela_text_ndx as usize]; - let num_relocations = rela_text_hdr.sh_size / rela_text_hdr.sh_entsize; - - // We only iterate the relocations in order, so no need to collect. - let relocations = (0..num_relocations).map(|idx| { - read_unaligned( - (elf.as_ptr().offset(rela_text_hdr.sh_offset as isize) as *const Elf64_Rela) - .offset(idx as isize), - ) - }); - for relocation in relocations { - let symbol_idx = relocation.r_info >> 32; - let ty = relocation.r_info & 0xFFFFFFFF; - let patch_offset = text_base.offset(relocation.r_offset as isize); - - // We support PLT32 relocations only in the .text section, and PC32 - // relocations only in the .bss section. - match ty { - R_X86_64_PLT32 => { - let symbol_address = - text_base.offset(symbol_table[symbol_idx as usize].st_value as isize); - let patch = symbol_address - .offset(relocation.r_addend as isize) - .offset_from(patch_offset); - (patch_offset as *mut u32).write_unaligned(patch as u32); - } - R_X86_64_PC32 => { - let symbol_address = - bss_base.offset(symbol_table[symbol_idx as usize].st_value as isize); - let patch = symbol_address - .offset(relocation.r_addend as isize) - .offset_from(patch_offset); - (patch_offset as *mut u32).write_unaligned(patch as u32); - } - _ => panic!("ERROR: Unrecognized relocation type: {}.", ty), - } - } - } - - // Make the .text section readable and executable. The .bss section should - // still be readable and writable. - mprotect( - text_base as *mut c_void, - page_align(text_hdr.sh_size as usize), - PROT_READ | PROT_EXEC, - ); - - // Construct the final in-memory ELF representation. Look up the names of - // function symbols in the string table. - let strtab = &elf[strtab_hdr.sh_offset as usize..]; - let mut elf = Elf { - function_names: vec![], - function_pointers: vec![], - program_section: program_base, - program_size, - }; - for i in 0..num_symbols { - if symbol_table[i].st_info & 0xF == STT_FUNC { - let function_name_base = &strtab[symbol_table[i].st_name as usize..]; - let function_name = CStr::from_ptr(function_name_base.as_ptr() as *const _) - .to_str() - .unwrap() - .to_owned(); - elf.function_names.push(function_name); - elf.function_pointers - .push(symbol_table[i].st_value as isize); - } - } - - elf -} diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index c8a9c729ac1d10af3c5ddc01c10adc3889cfb223..425e6bab50f49e072a4391ad33254073b3001248 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -1,35 +1,3 @@ extern crate hercules_rt_proc; -pub(crate) mod elf; -pub(crate) use crate::elf::*; -pub use hercules_rt_proc::use_hbin; - -#[derive(Debug)] -pub struct Module { - elf: Elf, -} - -impl Module { - /* - * Parse ELF object from ELF bytes buffer. - */ - pub fn new(buffer: &[u8]) -> Self { - let elf = unsafe { parse_elf(buffer) }; - Module { elf } - } - - /* - * Get the function pointer corresponding to a function name. Panic if not - * found. - */ - pub unsafe fn get_function_ptr(&self, name: &str) -> *mut u8 { - self.elf.program_section.offset( - self.elf.function_pointers[self - .elf - .function_names - .iter() - .position(|s| s == name) - .unwrap()], - ) - } -} +pub use hercules_rt_proc::use_hman; diff --git a/hercules_rt_proc/Cargo.toml b/hercules_rt_proc/Cargo.toml index bbf52c6ba75b93db67bcd878f57cd5d475401591..c6e55af051cd412aa5930a8239a933fa64fcb828 100644 --- a/hercules_rt_proc/Cargo.toml +++ b/hercules_rt_proc/Cargo.toml @@ -11,4 +11,5 @@ proc-macro = true postcard = { version = "*", features = ["alloc"] } serde = { version = "*", features = ["derive"] } hercules_cg = { path = "../hercules_cg" } +hercules_ir = { path = "../hercules_ir" } anyhow = "*" \ No newline at end of file diff --git a/hercules_rt_proc/src/lib.rs b/hercules_rt_proc/src/lib.rs index 69625c19b658ae868029fe5b579d96927e215f15..047548c1ed1283421542f64680dd7c2a060d61c6 100644 --- a/hercules_rt_proc/src/lib.rs +++ b/hercules_rt_proc/src/lib.rs @@ -2,6 +2,7 @@ extern crate anyhow; extern crate hercules_cg; +extern crate hercules_ir; extern crate postcard; extern crate proc_macro; @@ -9,20 +10,13 @@ use std::collections::{HashMap, HashSet}; use std::ffi::OsStr; use std::fmt::Write; use std::fs::File; -use std::hash::{DefaultHasher, Hash, Hasher}; use std::io::prelude::*; use std::path::Path; use proc_macro::*; use self::hercules_cg::*; - -/* - * Parse manifest from header of .hbin file. - */ -fn manifests_and_module_bytes(buffer: &[u8]) -> (HashMap<String, Manifest>, Vec<u8>) { - postcard::from_bytes(buffer).unwrap() -} +use self::hercules_ir::{DynamicConstant, DynamicConstantID}; /* * Convert schedule IR types to the Rust types generated in the interface. @@ -71,40 +65,70 @@ fn generate_type_name(ty: &SType) -> String { } } +fn compute_dynamic_constant<W: Write>( + dc: DynamicConstantID, + manifest: &Manifest, + rust_code: &mut W, +) -> Result<(), anyhow::Error> { + match manifest.dynamic_constants[dc.idx()] { + DynamicConstant::Constant(cons) => write!(rust_code, "{}", cons)?, + DynamicConstant::Parameter(idx) => write!(rust_code, "dc_{}", idx)?, + DynamicConstant::Add(left, right) => { + write!(rust_code, "(")?; + compute_dynamic_constant(left, manifest, rust_code)?; + write!(rust_code, " + ")?; + compute_dynamic_constant(right, manifest, rust_code)?; + write!(rust_code, ")")?; + } + DynamicConstant::Sub(left, right) => { + write!(rust_code, "(")?; + compute_dynamic_constant(left, manifest, rust_code)?; + write!(rust_code, " - ")?; + compute_dynamic_constant(right, manifest, rust_code)?; + write!(rust_code, ")")?; + } + DynamicConstant::Mul(left, right) => { + write!(rust_code, "(")?; + compute_dynamic_constant(left, manifest, rust_code)?; + write!(rust_code, " * ")?; + compute_dynamic_constant(right, manifest, rust_code)?; + write!(rust_code, ")")?; + } + DynamicConstant::Div(left, right) => { + write!(rust_code, "(")?; + compute_dynamic_constant(left, manifest, rust_code)?; + write!(rust_code, " / ")?; + compute_dynamic_constant(right, manifest, rust_code)?; + write!(rust_code, ")")?; + } + DynamicConstant::Rem(left, right) => { + write!(rust_code, "(")?; + compute_dynamic_constant(left, manifest, rust_code)?; + write!(rust_code, " % ")?; + compute_dynamic_constant(right, manifest, rust_code)?; + write!(rust_code, ")")?; + } + } + Ok(()) +} + /* * Generate async Rust code orchestrating partition execution. */ -fn codegen(manifests: &HashMap<String, Manifest>, elf: &[u8]) -> Result<String, anyhow::Error> { +fn codegen(manifests: &HashMap<String, Manifest>) -> Result<String, anyhow::Error> { // Write to a String containing all of the Rust code. let mut rust_code = "".to_string(); - // Emit the ELF bytes as a static byte string constant. Construct the module - // object on first access with LazyLock. + // Rust doesn't allow you to send pointers between threads. In order to send + // pointers between threads, we need to wrap them in a struct that unsafely + // implements Send and Sync. This passes the responsibility of + // synchronization onto us, which we do by being careful with how we lower + // parallel code. Make this type generic so that we actually wrap all + // arguments in it for ease of macro codegen. write!( rust_code, - "const __HERCULES_ELF_OBJ: &[u8] = {};\n", - Literal::byte_string(elf) + "#[derive(Clone, Copy, Debug)]\nstruct SendSyncWrapper<T: Copy>(T);\nunsafe impl<T: Copy> Send for SendSyncWrapper<T> {{}}\nunsafe impl<T: Copy> Sync for SendSyncWrapper<T> {{}}\n" )?; - write!( - rust_code, - "static __HERCULES_MODULE_OBJ: ::std::sync::LazyLock<::hercules_rt::Module> = ::std::sync::LazyLock::new(|| {{\n", - )?; - // Check that the ELF got embedded properly. - let hash = { - let mut s = DefaultHasher::new(); - elf.hash(&mut s); - s.finish() - }; - write!( - rust_code, - " use std::hash::{{DefaultHasher, Hash, Hasher}};\n debug_assert_eq!({{let mut s = DefaultHasher::new(); __HERCULES_ELF_OBJ.hash(&mut s); s.finish()}}, {});\n", - hash - )?; - write!( - rust_code, - " ::hercules_rt::Module::new(__HERCULES_ELF_OBJ)\n" - )?; - write!(rust_code, "}});\n")?; // Emit the product types used in this module. We can't just emit product // types, since we need #[repr(C)] to interact with LLVM. @@ -151,17 +175,35 @@ fn codegen(manifests: &HashMap<String, Manifest>, elf: &[u8]) -> Result<String, generate_type_string(&manifest.return_type) )?; - // Compute the signature for each partition function and emit the - // function pointers. - for (partition_idx, partition) in manifest.partitions.iter().enumerate() { - write!( - rust_code, - " let fn_ptr_part_{}: extern \"C\" fn(", - partition_idx - )?; - for (param_stype, _) in partition.parameters.iter() { + // Compute the signature for each partition function and emit the extern + // function signatures. + write!(rust_code, " extern \"C\" {{\n")?; + for partition in manifest.partitions.iter() { + write!(rust_code, " fn {}(", partition.name)?; + + // Add parameters for SFunction signature. + for (param_stype, kind) in partition.parameters.iter() { + match kind { + ParameterKind::HerculesParameter(idx) => write!(rust_code, "param_{}: ", idx)?, + ParameterKind::DataInput(id) => write!(rust_code, "data_{}: ", id.idx())?, + ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}: ", idx)?, + ParameterKind::ArrayConstant(id) => write!(rust_code, "array_{}: ", id.idx())?, + } write!(rust_code, "{}, ", generate_type_string(param_stype))?; } + + // Add parameters for device specific lowering details. + if let DeviceManifest::CPU { parallel_launch } = &partition.device { + for parallel_launch_dim in 0..parallel_launch.len() { + write!( + rust_code, + "parallel_launch_low_{}: u64, parallel_launch_len_{}: u64, ", + parallel_launch_dim, parallel_launch_dim + )?; + } + } + + // Add the return product of the SFunction signature. let return_stype = if partition.returns.len() == 1 { partition.returns[0].0.clone() } else { @@ -173,13 +215,9 @@ fn codegen(manifests: &HashMap<String, Manifest>, elf: &[u8]) -> Result<String, .collect(), ) }; - write!( - rust_code, - ") -> {} = ::core::mem::transmute(__HERCULES_MODULE_OBJ.get_function_ptr(\"{}\"));\n", - generate_type_string(&return_stype), - partition.name, - )?; + write!(rust_code, ") -> {};\n", generate_type_string(&return_stype),)?; } + write!(rust_code, " }}\n")?; // Declare all of the intermediary data input / output variables. They // are declared as MaybeUninit, since they get assigned after running a @@ -193,7 +231,7 @@ fn codegen(manifests: &HashMap<String, Manifest>, elf: &[u8]) -> Result<String, } assert_eq!(data_inputs, data_outputs); for (node, stype) in data_inputs { - write!(rust_code, " let mut node_{}: ::std::mem::MaybeUninit<{}> = ::std::mem::MaybeUninit::uninit();\n", node.idx(), generate_type_string(stype))?; + write!(rust_code, " let mut node_{}: ::core::mem::MaybeUninit<{}> = ::core::mem::MaybeUninit::uninit();\n", node.idx(), generate_type_string(stype))?; } // The core executor is a Rust loop. We literally run a "control token" @@ -213,47 +251,155 @@ fn codegen(manifests: &HashMap<String, Manifest>, elf: &[u8]) -> Result<String, // Open the arm. write!(rust_code, " {} => {{\n", idx)?; - // Call the partition function. - write!( - rust_code, - " let output = fn_ptr_part_{}(", - idx - )?; - for (_, kind) in partition.parameters.iter() { - match kind { - ParameterKind::HerculesParameter(idx) => write!(rust_code, "param_{}, ", idx)?, - ParameterKind::DataInput(id) => { - write!(rust_code, "node_{}.assume_init(), ", id.idx())? + match partition.device { + DeviceManifest::CPU { + ref parallel_launch, + } => { + for (idx, (_, kind)) in partition.parameters.iter().enumerate() { + write!( + rust_code, + " let local_param_{} = SendSyncWrapper(", + idx + )?; + match kind { + ParameterKind::HerculesParameter(idx) => { + write!(rust_code, "param_{}", idx)? + } + ParameterKind::DataInput(id) => { + write!(rust_code, "node_{}.assume_init()", id.idx())? + } + ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}", idx)?, + ParameterKind::ArrayConstant(id) => { + write!(rust_code, "array_{}", id.idx())? + } + } + write!(rust_code, ");\n")?; } - ParameterKind::DynamicConstant(idx) => write!(rust_code, "dc_{}, ", idx)?, - ParameterKind::ArrayConstant(id) => write!(rust_code, "array_{}, ", id.idx())?, - } - } - write!(rust_code, ");\n")?; - - // Assign the outputs. - for (output_idx, (_, kind)) in partition.returns.iter().enumerate() { - let output_ref = if partition.returns.len() == 1 { - "output".to_string() - } else { - format!("output.{}", output_idx) - }; - match kind { - ReturnKind::HerculesReturn => { - write!(rust_code, " return {};\n", output_ref)? + + if parallel_launch.is_empty() { + // Call the partition function. + write!( + rust_code, + " let output = {}(", + partition.name + )?; + for idx in 0..partition.parameters.len() { + write!(rust_code, "local_param_{}.0, ", idx)?; + } + write!(rust_code, ");\n")?; + } else { + // Compute the dynamic constant bounds. + for (dim, (_, dc)) in parallel_launch.into_iter().enumerate() { + write!(rust_code, " let bound_{} = ", dim)?; + compute_dynamic_constant(*dc, manifest, &mut rust_code)?; + write!(rust_code, ";\n let low_{} = 0;\n", dim)?; + } + + // Simultaneously calculate the tiles lows and lens and + // spawn the tiles. Emit the launches unrolled. + let mut tile = vec![0; parallel_launch.len()]; + let total_num_tiles = parallel_launch + .into_iter() + .fold(1, |acc, (num_tiles, _)| acc * num_tiles); + for tile_num in 0..total_num_tiles { + // Calculate the lows and lens for this tile. + for (dim, tile) in tile.iter().enumerate() { + let num_tiles = parallel_launch[dim].0; + write!( + rust_code, + " let len_{} = bound_{} / {} + ({} < bound_{} % {}) as u64;\n", + dim, dim, num_tiles, tile, dim, num_tiles + )?; + } + + // Spawn the tile. We need to explicitly copy the + // SendSyncWrappers, or else the path expression for + // the parameters get interpreted as what needs to + // be moved, when we want the wrapper itself to be + // what gets moved. Ugh. + write!( + rust_code, + " let tile_{} = async_std::task::spawn(async move {{ ", + tile_num, + )?; + for idx in 0..partition.parameters.len() { + write!( + rust_code, + "let local_param_{} = local_param_{}; ", + idx, idx + )?; + } + write!(rust_code, "SendSyncWrapper({}(", partition.name)?; + for idx in 0..partition.parameters.len() { + write!(rust_code, "local_param_{}.0, ", idx)?; + } + for dim in 0..parallel_launch.len() { + write!(rust_code, "low_{}, len_{}, ", dim, dim)?; + } + write!(rust_code, ")) }});\n")?; + + // Go to the next tile. + for dim in (0..parallel_launch.len()).rev() { + tile[dim] += 1; + let num_tiles = parallel_launch[dim].0; + if tile[dim] < num_tiles { + write!( + rust_code, + " let low_{} = low_{} + len_{};\n", + dim, dim, dim + )?; + break; + } else { + tile[dim] = 0; + write!(rust_code, " let low_{} = 0;\n", dim)?; + } + } + } + + // Join the JoinHandles, and get the output from one of + // them. + write!( + rust_code, + " let output = ::core::future::join!(", + )?; + for tile_num in 0..total_num_tiles { + write!(rust_code, "tile_{}, ", tile_num)?; + } + // join! unhelpfully returns either a tuple or a single + // value, but never a singleton tuple. + if total_num_tiles == 1 { + write!(rust_code, ").await.0;\n")?; + } else { + write!(rust_code, ").await.0.0;\n")?; + } + } + + // Assign the outputs. + for (output_idx, (_, kind)) in partition.returns.iter().enumerate() { + let output_ref = if partition.returns.len() == 1 { + "output".to_string() + } else { + format!("output.{}", output_idx) + }; + match kind { + ReturnKind::HerculesReturn => { + write!(rust_code, " return {};\n", output_ref)? + } + ReturnKind::DataOutput(id) => write!( + rust_code, + " node_{}.write({});\n", + id.idx(), + output_ref + )?, + ReturnKind::NextPartition => write!( + rust_code, + " control_token = {};\n", + output_ref + )?, + } } - ReturnKind::DataOutput(id) => write!( - rust_code, - " node_{}.write({});\n", - id.idx(), - output_ref - )?, - ReturnKind::NextPartition => write!( - rust_code, - " control_token = {};l\n", - output_ref - )?, } + _ => todo!(), } // If there's only one partition successor, then an explicit @@ -287,44 +433,44 @@ fn codegen(manifests: &HashMap<String, Manifest>, elf: &[u8]) -> Result<String, } /* - * Load a Hercules compiled module from a .hbin file. + * Generate the async Rust runtime from the manifest of a Hercules module. */ #[proc_macro] -pub fn use_hbin(path: TokenStream) -> TokenStream { +pub fn use_hman(path: TokenStream) -> TokenStream { use TokenTree::Literal; - // Get the path as a Rust path object, and make sure it's a .hbin file. + // Get the path as a Rust path object, and make sure it's a .hman file. let mut tokens_iter = path.into_iter(); let token = tokens_iter .next() - .expect("Please provide a path to a .hbin file to the use_hbin! macro."); - assert!(tokens_iter.next().is_none(), "Too many tokens provided to the use_hbin! macro. Please provide only a path to a .hbin file."); + .expect("Please provide a path to a .hman file to the use_hman! macro."); + assert!(tokens_iter.next().is_none(), "Too many tokens provided to the use_hman! macro. Please provide only one path to a .hman file."); let literal = if let Literal(literal) = token { literal } else { - panic!("Please provide a string literal containing the path to a .hbin file to the use_hbin! macro."); + panic!("Please provide a string literal containing the path to a .hman file to the use_hman! macro."); }; let literal_string = literal.to_string(); let path = Path::new(&literal_string[1..(literal_string.len() - 1)]); assert_eq!( path.extension(), - Some(OsStr::new("hbin")), - "Please provide only .hbin files to the use_hbin! macro." + Some(OsStr::new("hman")), + "Please provide only .hman files to the use_hman! macro." ); assert_eq!( path.try_exists().ok(), Some(true), - "Please provide a valid path to a .hbin file to the use_hbin! macro." + "Please provide a valid path to a .hman file to the use_hman! macro." ); // Load manifest from path. let mut f = File::open(path).unwrap(); let mut buffer = vec![]; f.read_to_end(&mut buffer).unwrap(); - let (manifests, elf) = manifests_and_module_bytes(&buffer); + let manifests = postcard::from_bytes(&buffer).unwrap(); // Generate Rust code. - let rust_code = codegen(&manifests, &elf).unwrap(); + let rust_code = codegen(&manifests).unwrap(); eprintln!("{}", rust_code); rust_code.parse().unwrap() } diff --git a/hercules_samples/dot/build.rs b/hercules_samples/dot/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..47fb55e82fc0266b03c6a7ba94717738038b489c --- /dev/null +++ b/hercules_samples/dot/build.rs @@ -0,0 +1,12 @@ +use std::env::current_dir; + +fn main() { + println!( + "cargo::rustc-link-search=native={}", + current_dir().unwrap().display() + ); + println!("cargo::rustc-link-lib=static=dot"); + + println!("cargo::rerun-if-changed=dot.hman"); + println!("cargo::rerun-if-changed=libdot.a"); +} diff --git a/hercules_samples/dot/dot.hir b/hercules_samples/dot/dot.hir index 63790a343e9ec8dcbcac0829dea082a6e4b60e9c..6f484462bd3df796a5334f2a8f2dbce80ee4a572 100644 --- a/hercules_samples/dot/dot.hir +++ b/hercules_samples/dot/dot.hir @@ -1,11 +1,19 @@ fn dot<1>(a: array(f32, #0), b: array(f32, #0)) -> f32 zero = constant(f32, 0.0) - fork = fork(start, #0) - id = thread_id(fork, 0) - join = join(fork) - r = return(join, dot_red) - aval = read(a, position(id)) - bval = read(b, position(id)) - mul = mul(aval, bval) - dot = add(mul, dot_red) - dot_red = reduce(join, zero, dot) + eight = constant(u64, 8) + outer_fork = fork(start, /(#0, 8)) + outer_id = thread_id(outer_fork, 0) + vector_fork = fork(outer_fork, 8) + vector_id = thread_id(vector_fork, 0) + outer_idx = mul(outer_id, eight) + vector_idx = add(outer_idx, vector_id) + vector_aval = read(a, position(vector_idx)) + vector_bval = read(b, position(vector_idx)) + vector_mul = mul(vector_aval, vector_bval) + vector_add = add(vector_mul, vector_dot) + vector_dot = reduce(vector_join, zero, vector_add) + vector_join = join(vector_fork) + outer_add = add(vector_dot, outer_dot) + outer_dot = reduce(outer_join, zero, outer_add) + outer_join = join(vector_join) + return = return(outer_join, outer_dot) diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs index 3aea6409dd08eb6c8b8157877b55ad57cc811b80..624c193761b2c1996dcc91eee3e182775103ddaa 100644 --- a/hercules_samples/dot/src/main.rs +++ b/hercules_samples/dot/src/main.rs @@ -3,16 +3,16 @@ extern crate clap; extern crate hercules_rt; // To compile currently, run from the Hercules project root directory: -// cargo run --bin hercules_driver hercules_samples/dot/dot.hir "Codegen(\"dot.hbin\")" +// cargo run --bin hercules_driver hercules_samples/dot/dot.hir "Codegen" // Then, you can execute this example with: // cargo run --bin hercules_dot -hercules_rt::use_hbin!("dot.hbin"); +hercules_rt::use_hman!("hercules_samples/dot/dot.hman"); fn main() { async_std::task::block_on(async { - let mut a = vec![1.0, 2.0, 3.0, 4.0]; - let mut b = vec![5.0, 6.0, 7.0, 8.0]; - let c = unsafe { dot(a.as_mut_ptr(), b.as_mut_ptr(), 4).await }; + let mut a = vec![0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; + let mut b = vec![0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; + let c = unsafe { dot(a.as_mut_ptr(), b.as_mut_ptr(), 8).await }; println!("{}", c,); }); } diff --git a/hercules_samples/matmul/build.rs b/hercules_samples/matmul/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..aa4cd97f33d5f06754a3306f2d824e0cf38eb4ea --- /dev/null +++ b/hercules_samples/matmul/build.rs @@ -0,0 +1,12 @@ +use std::env::current_dir; + +fn main() { + println!( + "cargo::rustc-link-search=native={}", + current_dir().unwrap().display() + ); + println!("cargo::rustc-link-lib=static=matmul"); + + println!("cargo::rerun-if-changed=matmul.hman"); + println!("cargo::rerun-if-changed=libmatmul.a"); +} diff --git a/hercules_samples/matmul/matmul.hir b/hercules_samples/matmul/matmul.hir index 6207eb8848a3b3e9d94b5da438c40ff7ed3e4993..8bbccfdfc012d09610c21fe024d9b709374d18de 100644 --- a/hercules_samples/matmul/matmul.hir +++ b/hercules_samples/matmul/matmul.hir @@ -1,24 +1,21 @@ fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2) c = constant(array(f32, #0, #2), []) - i_ctrl = fork(start, #0) - i_idx = thread_id(i_ctrl, 0) - j_ctrl = fork(i_ctrl, #2) - j_idx = thread_id(j_ctrl, 0) - k_ctrl = fork(j_ctrl, #1) + i_j_ctrl = fork(start, #0, #2) + i_idx = thread_id(i_j_ctrl, 0) + j_idx = thread_id(i_j_ctrl, 1) + k_ctrl = fork(i_j_ctrl, #1) k_idx = thread_id(k_ctrl, 0) k_join_ctrl = join(k_ctrl) - j_join_ctrl = join(k_join_ctrl) - i_join_ctrl = join(j_join_ctrl) - r = return(i_join_ctrl, update_i_c) + i_j_join_ctrl = join(k_join_ctrl) + r = return(i_j_join_ctrl, update_i_j_c) zero = constant(f32, 0) a_val = read(a, position(i_idx, k_idx)) b_val = read(b, position(k_idx, j_idx)) mul = mul(a_val, b_val) add = add(mul, dot) dot = reduce(k_join_ctrl, zero, add) - updated_c = write(update_j_c, dot, position(i_idx, j_idx)) - update_j_c = reduce(j_join_ctrl, update_i_c, updated_c) - update_i_c = reduce(i_join_ctrl, c, update_j_c) + update_c = write(update_i_j_c, dot, position(i_idx, j_idx)) + update_i_j_c = reduce(i_j_join_ctrl, c, update_c) diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index baeab2685e1d74ea85b314f9048e3533c5b65f92..9a84c6ef89cc7b007eb680d891934367f1ef631d 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -1,12 +1,14 @@ +#![feature(future_join)] + extern crate async_std; extern crate clap; extern crate hercules_rt; // To compile currently, run from the Hercules project root directory: -// cargo run --bin hercules_driver hercules_samples/matmul/matmul.hir "Codegen(\"matmul.hbin\")" +// cargo run --bin hercules_driver hercules_samples/matmul/matmul.hir "Codegen" // Then, you can execute this example with: // cargo run --bin hercules_matmul -hercules_rt::use_hbin!("matmul.hbin"); +hercules_rt::use_hman!("hercules_samples/matmul/matmul.hman"); fn main() { async_std::task::block_on(async { @@ -16,9 +18,6 @@ fn main() { unsafe { matmul(a.as_mut_ptr(), b.as_mut_ptr(), c.as_mut_ptr(), 2, 2, 2).await; } - println!( - "[[{}, {}], [{}, {}]]", - c[0], c[1], c[2], c[3] - ); + println!("[[{}, {}], [{}, {}]]", c[0], c[1], c[2], c[3]); }); } diff --git a/hercules_tools/hercules_driver/src/main.rs b/hercules_tools/hercules_driver/src/main.rs index 17be3596292b912769cfc785add7e502987e37cf..70a7bafb6621de8fa5f54a33e8ec1cd030eaa433 100644 --- a/hercules_tools/hercules_driver/src/main.rs +++ b/hercules_tools/hercules_driver/src/main.rs @@ -2,6 +2,7 @@ extern crate clap; use std::fs::File; use std::io::prelude::*; +use std::path::Path; use clap::Parser; @@ -18,7 +19,8 @@ fn main() { eprintln!("WARNING: Running hercules_driver on a file without a .hir extension - interpreting as a textual Hercules IR file."); } - let mut file = File::open(args.hir_file).expect("PANIC: Unable to open input file."); + let hir_path = Path::new(&args.hir_file); + let mut file = File::open(hir_path).expect("PANIC: Unable to open input file."); let mut contents = String::new(); file.read_to_string(&mut contents) .expect("PANIC: Unable to read input file contents."); @@ -40,5 +42,5 @@ fn main() { for pass in passes { pm.add_pass(pass); } - pm.run_passes(); + pm.run_passes(hir_path); }