From f6674a965fc503820effb50481f266ffafc9e972 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Tue, 4 Feb 2025 20:07:08 -0600 Subject: [PATCH] Misc. preparations for multi-core support. --- hercules_cg/src/rt.rs | 72 ++++++++++ hercules_ir/src/dot.rs | 35 ++++- hercules_ir/src/einsum.rs | 6 +- hercules_ir/src/fork_join_analysis.rs | 18 +-- hercules_opt/src/ccp.rs | 10 +- hercules_opt/src/fork_transforms.rs | 14 +- hercules_opt/src/forkify.rs | 10 +- hercules_rt/Cargo.toml | 1 + hercules_rt/src/lib.rs | 87 +++++++++++- hercules_rt/src/rtdefs.cu | 12 +- juno_samples/fork_join_tests/src/cpu.sch | 30 +++- .../fork_join_tests/src/fork_join_tests.jn | 9 ++ juno_samples/fork_join_tests/src/gpu.sch | 28 +++- juno_samples/fork_join_tests/src/main.rs | 25 ++-- juno_scheduler/src/lib.rs | 2 + juno_scheduler/src/pm.rs | 133 ++++++++++++++---- 16 files changed, 403 insertions(+), 89 deletions(-) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index cbef5a00..35334a14 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -407,6 +407,78 @@ impl<'a> RTContext<'a> { write!(block, ");\n")?; } } + Node::Unary { op, input } => { + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + match op { + UnaryOperator::Not => write!( + block, + " {} = !{};\n", + self.get_value(id), + self.get_value(input) + )?, + UnaryOperator::Neg => write!( + block, + " {} = -{};\n", + self.get_value(id), + self.get_value(input) + )?, + UnaryOperator::Cast(ty) => write!( + block, + " {} = {} as {};\n", + self.get_value(id), + self.get_value(input), + self.get_type(ty) + )?, + }; + } + Node::Binary { op, left, right } => { + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let op = match op { + BinaryOperator::Add => "+", + BinaryOperator::Sub => "-", + BinaryOperator::Mul => "*", + BinaryOperator::Div => "/", + BinaryOperator::Rem => "%", + BinaryOperator::LT => "<", + BinaryOperator::LTE => "<=", + BinaryOperator::GT => ">", + BinaryOperator::GTE => ">=", + BinaryOperator::EQ => "==", + BinaryOperator::NE => "!=", + BinaryOperator::Or => "|", + BinaryOperator::And => "&", + BinaryOperator::Xor => "^", + BinaryOperator::LSh => "<<", + BinaryOperator::RSh => ">>", + }; + + write!( + block, + " {} = {} {} {};\n", + self.get_value(id), + self.get_value(left), + op, + self.get_value(right) + )?; + } + Node::Ternary { + op, + first, + second, + third, + } => { + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + match op { + TernaryOperator::Select => write!( + block, + " {} = if {} {{ {} }} else {{ {} }};\n", + self.get_value(id), + self.get_value(first), + self.get_value(second), + self.get_value(third), + )?, + }; + } Node::Read { collect, ref indices, diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 5ccda9dc..7ad8c6df 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -343,30 +343,57 @@ fn write_node<W: Write>( } let mut iter = schedules.into_iter(); - if let Some(first) = iter.next() { - let schedules = iter.fold(format!("{:?}", first), |b, i| format!("{}, {:?}", b, i)); + let schedules = if let Some(first) = iter.next() { + iter.fold(format!("{:?}", first), |b, i| format!("{}, {:?}", b, i)) + } else { + String::new() + }; + if tylabel.is_empty() && schedules.is_empty() { write!( w, - "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT><BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n", + "{}_{}_{} [xlabel={}, label=<{}>, color={}];\n", + node.lower_case_name(), + function_id.idx(), + node_id.idx(), + xlabel, + label, + color + )?; + } else if schedules.is_empty() { + write!( + w, + "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n", node.lower_case_name(), function_id.idx(), node_id.idx(), xlabel, label, tylabel, + color + )?; + } else if tylabel.is_empty() { + write!( + w, + "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n", + node.lower_case_name(), + function_id.idx(), + node_id.idx(), + xlabel, + label, schedules, color )?; } else { write!( w, - "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n", + "{}_{}_{} [xlabel={}, label=<{}<BR /><FONT POINT-SIZE=\"8\">{}</FONT><BR /><FONT POINT-SIZE=\"8\">{}</FONT>>, color={}];\n", node.lower_case_name(), function_id.idx(), node_id.idx(), xlabel, label, tylabel, + schedules, color )?; } diff --git a/hercules_ir/src/einsum.rs b/hercules_ir/src/einsum.rs index 8d3bec3a..25e15d63 100644 --- a/hercules_ir/src/einsum.rs +++ b/hercules_ir/src/einsum.rs @@ -60,7 +60,7 @@ pub fn einsum( typing: &Vec<TypeID>, fork_join_map: &HashMap<NodeID, NodeID>, fork_join_nest: &HashMap<NodeID, Vec<NodeID>>, - data_nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &HashMap<NodeID, HashSet<NodeID>>, ) -> (MathEnv, HashMap<NodeID, MathID>) { let mut env = vec![]; let mut rev_env = HashMap::new(); @@ -101,7 +101,7 @@ pub fn einsum( function, typing, constants, - data_nodes_in_fork_joins, + nodes_in_fork_joins, fork, factors, thread_ids: &thread_ids, @@ -185,7 +185,7 @@ struct EinsumContext<'a> { function: &'a Function, typing: &'a Vec<TypeID>, constants: &'a Vec<Constant>, - data_nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, + nodes_in_fork_joins: &'a HashMap<NodeID, HashSet<NodeID>>, fork: NodeID, factors: &'a [DynamicConstantID], thread_ids: &'a Vec<(NodeID, usize)>, diff --git a/hercules_ir/src/fork_join_analysis.rs b/hercules_ir/src/fork_join_analysis.rs index 7a098a35..ad3125ba 100644 --- a/hercules_ir/src/fork_join_analysis.rs +++ b/hercules_ir/src/fork_join_analysis.rs @@ -165,10 +165,9 @@ fn reduce_cycle_dfs_helper( } /* - * Top level function to calculate which data nodes are "inside" a fork-join, - * not including its reduces. + * Top level function to calculate which nodes are "inside" a fork-join. */ -pub fn data_nodes_in_fork_joins( +pub fn nodes_in_fork_joins( function: &Function, def_use: &ImmutableDefUseMap, fork_join_map: &HashMap<NodeID, NodeID>, @@ -178,22 +177,19 @@ pub fn data_nodes_in_fork_joins( for (fork, join) in fork_join_map { let mut worklist = vec![*fork]; let mut set = HashSet::new(); + set.insert(*fork); while let Some(item) = worklist.pop() { for u in def_use.get_users(item) { - if function.nodes[u.idx()].is_control() + let terminate = *u == *join || function.nodes[u.idx()] .try_reduce() .map(|(control, _, _)| control == *join) - .unwrap_or(false) - { - // Ignore control users and reduces of the fork-join. - continue; - } - if !set.contains(u) { - set.insert(*u); + .unwrap_or(false); + if !set.contains(u) && !terminate { worklist.push(*u); } + set.insert(*u); } } diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 92d52a71..9768198c 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -482,9 +482,13 @@ fn ccp_flow_function( reachability: ReachabilityLattice::bottom(), constant: ConstantLattice::Constant(editor.get_constant(*id).clone()), }, - // TODO: This should really be constant interpreted, since dynamic - // constants as values are used frequently. - Node::DynamicConstant { id: _ } => CCPLattice::bottom(), + Node::DynamicConstant { id } => match *editor.get_dynamic_constant(*id) { + DynamicConstant::Constant(value) => CCPLattice { + reachability: ReachabilityLattice::bottom(), + constant: ConstantLattice::Constant(Constant::UnsignedInteger64(value as u64)), + }, + _ => CCPLattice::bottom(), + }, // Interpret unary op on constant. Node::Unary { input, op } => { let CCPLattice { diff --git a/hercules_opt/src/fork_transforms.rs b/hercules_opt/src/fork_transforms.rs index c4a6ba7f..456f670e 100644 --- a/hercules_opt/src/fork_transforms.rs +++ b/hercules_opt/src/fork_transforms.rs @@ -541,18 +541,19 @@ pub fn fork_coalesce_helper( true } -pub fn split_all_forks( +pub fn split_any_fork( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, -) { +) -> Option<(Vec<NodeID>, Vec<NodeID>)> { for (fork, join) in fork_join_map { - if let Some((forks, _)) = split_fork(editor, *fork, *join, reduce_cycles) + if let Some((forks, joins)) = split_fork(editor, *fork, *join, reduce_cycles) && forks.len() > 1 { - break; + return Some((forks, joins)); } } + None } /* @@ -689,6 +690,7 @@ pub(crate) fn split_fork( Ok(edit) }); if success { + new_joins.reverse(); Some((new_forks, new_joins)) } else { None @@ -749,6 +751,7 @@ pub fn chunk_fork_unguarded( let new_fork = edit.add_node(new_fork); edit = edit.replace_all_uses(fork, new_fork)?; + edit.sub_edit(fork, new_fork); for (tid, node) in fork_users { let Node::ThreadID { @@ -765,6 +768,7 @@ pub fn chunk_fork_unguarded( }; let new_tid = edit.add_node(new_tid); edit = edit.replace_all_uses(tid, new_tid)?; + edit.sub_edit(tid, new_tid); edit = edit.delete_node(tid)?; } else if tid_dim == dim_idx { let tile_tid = Node::ThreadID { @@ -784,6 +788,8 @@ pub fn chunk_fork_unguarded( right: tile_tid, op: BinaryOperator::Add, }); + edit.sub_edit(tid, add); + edit.sub_edit(tid, tile_tid); edit = edit.replace_all_uses_where(tid, add, |usee| *usee != mul)?; } } diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index f6db06ca..2adfddd8 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -298,15 +298,13 @@ pub fn forkify_loop( let (_, factors) = function.nodes[fork_id.idx()].try_fork().unwrap(); let dimension = factors.len() - 1; - // Start failable edit: - let redcutionable_phis_and_init: Vec<(_, NodeID)> = reductionable_phis .iter() .map(|reduction_phi| { let LoopPHI::Reductionable { phi, data_cycle: _, - continue_latch, + continue_latch: _, is_associative: _, } = reduction_phi else { @@ -328,6 +326,7 @@ pub fn forkify_loop( }) .collect(); + // Start failable edit: editor.edit(|mut edit| { let thread_id = Node::ThreadID { control: fork_id, @@ -339,6 +338,7 @@ pub fn forkify_loop( edit = edit.replace_all_uses_where(canonical_iv.phi(), thread_id_id, |node| { loop_nodes.contains(node) })?; + edit.sub_edit(canonical_iv.phi(), thread_id_id); edit = edit.delete_node(canonical_iv.phi())?; @@ -386,12 +386,16 @@ pub fn forkify_loop( edit = edit.replace_all_uses_where(continue_latch, reduce_id, |usee| { !loop_nodes.contains(usee) && *usee != reduce_id })?; + edit.sub_edit(phi, reduce_id); edit = edit.delete_node(phi)? } edit = edit.replace_all_uses(l.header, fork_id)?; edit = edit.replace_all_uses(loop_continue_projection, fork_id)?; edit = edit.replace_all_uses(loop_exit_projection, join_id)?; + edit.sub_edit(l.header, fork_id); + edit.sub_edit(loop_continue_projection, fork_id); + edit.sub_edit(loop_exit_projection, join_id); edit = edit.delete_node(loop_continue_projection)?; edit = edit.delete_node(condition_node)?; // Might have to get rid of other users of this. diff --git a/hercules_rt/Cargo.toml b/hercules_rt/Cargo.toml index c4678b18..46886b12 100644 --- a/hercules_rt/Cargo.toml +++ b/hercules_rt/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [features] cuda = [] +debug = [] [dependencies] diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index ed5dca1d..2ad72043 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -10,29 +10,101 @@ use std::slice::{from_raw_parts, from_raw_parts_mut}; */ pub unsafe fn __cpu_alloc(size: usize) -> *mut u8 { - alloc(Layout::from_size_align(size, 16).unwrap()) + let ptr = alloc(Layout::from_size_align(size, 16).unwrap()); + if cfg!(feature = "debug") { + eprintln!("__cpu_alloc: {:?}, {}", ptr, size); + assert!(!ptr.is_null() || size == 0); + } + ptr } pub unsafe fn __cpu_dealloc(ptr: *mut u8, size: usize) { + if cfg!(feature = "debug") { + eprintln!("__cpu_dealloc: {:?}, {}", ptr, size); + assert!(!ptr.is_null() || size == 0); + } dealloc(ptr, Layout::from_size_align(size, 16).unwrap()) } pub unsafe fn __cpu_zero_mem(ptr: *mut u8, size: usize) { + if cfg!(feature = "debug") { + eprintln!("__cpu_zero_mem: {:?}, {}", ptr, size); + assert!(!ptr.is_null() || size == 0); + } write_bytes(ptr, 0, size); } pub unsafe fn __copy_cpu_to_cpu(dst: *mut u8, src: *mut u8, size: usize) { + if cfg!(feature = "debug") { + eprintln!("__copy_cpu_to_cpu: {:?}, {:?}, {}", dst, src, size); + assert!((!dst.is_null() && !src.is_null()) || size == 0); + } copy_nonoverlapping(src, dst, size); } +#[cfg(feature = "cuda")] +pub unsafe fn __cuda_alloc(size: usize) -> *mut u8 { + let ptr = ___cuda_alloc(size); + if cfg!(feature = "debug") { + eprintln!("__cuda_alloc: {:?}, {}", ptr, size); + assert!(!ptr.is_null() || size == 0); + } + ptr +} + +#[cfg(feature = "cuda")] +pub unsafe fn __cuda_dealloc(ptr: *mut u8, size: usize) { + if cfg!(feature = "debug") { + eprintln!("__cuda_dealloc: {:?}, {}", ptr, size); + assert!(!ptr.is_null() || size == 0); + } + ___cuda_dealloc(ptr, size); +} + +#[cfg(feature = "cuda")] +pub unsafe fn __cuda_zero_mem(ptr: *mut u8, size: usize) { + if cfg!(feature = "debug") { + eprintln!("__cuda_zero_mem: {:?}, {}", ptr, size); + assert!(!ptr.is_null() || size == 0); + } + ___cuda_zero_mem(ptr, size); +} + +#[cfg(feature = "cuda")] +pub unsafe fn __copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize) { + if cfg!(feature = "debug") { + eprintln!("__copy_cpu_to_cuda: {:?}, {:?}, {}", dst, src, size); + assert!((!dst.is_null() && !src.is_null()) || size == 0); + } + ___copy_cpu_to_cuda(dst, src, size); +} + +#[cfg(feature = "cuda")] +pub unsafe fn __copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize) { + if cfg!(feature = "debug") { + eprintln!("__copy_cuda_to_cpu: {:?}, {:?}, {}", dst, src, size); + assert!((!dst.is_null() && !src.is_null()) || size == 0); + } + ___copy_cuda_to_cpu(dst, src, size); +} + +#[cfg(feature = "cuda")] +pub unsafe fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize) { + if cfg!(feature = "debug") { + eprintln!("__copy_cuda_to_cuda: {:?}, {:?}, {}", dst, src, size); + assert!((!dst.is_null() && !src.is_null()) || size == 0); + } + ___copy_cuda_to_cuda(dst, src, size); +} + #[cfg(feature = "cuda")] extern "C" { - pub fn __cuda_alloc(size: usize) -> *mut u8; - pub fn __cuda_dealloc(ptr: *mut u8, size: usize); - pub fn __cuda_zero_mem(ptr: *mut u8, size: usize); - pub fn __copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); - pub fn __copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); - pub fn __copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + pub fn ___cuda_alloc(size: usize) -> *mut u8; + pub fn ___cuda_dealloc(ptr: *mut u8, size: usize); + pub fn ___cuda_zero_mem(ptr: *mut u8, size: usize); + pub fn ___copy_cpu_to_cuda(dst: *mut u8, src: *mut u8, size: usize); + pub fn ___copy_cuda_to_cpu(dst: *mut u8, src: *mut u8, size: usize); + pub fn ___copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); } #[derive(Clone, Debug)] @@ -155,6 +227,7 @@ impl<'a> HerculesCUDARef<'a> { pub fn to_cpu_ref<'b, T>(self, dst: &'b mut [T]) -> HerculesCPURefMut<'b> { unsafe { let size = self.size; + assert_eq!(size, dst.len() * size_of::<T>()); let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap(); __copy_cuda_to_cpu(ptr.as_ptr(), self.ptr.as_ptr(), size); HerculesCPURefMut { diff --git a/hercules_rt/src/rtdefs.cu b/hercules_rt/src/rtdefs.cu index 534b297d..50e11fa6 100644 --- a/hercules_rt/src/rtdefs.cu +++ b/hercules_rt/src/rtdefs.cu @@ -1,5 +1,5 @@ extern "C" { - void *__cuda_alloc(size_t size) { + void *___cuda_alloc(size_t size) { void *ptr = NULL; cudaError_t res = cudaMalloc(&ptr, size); if (res != cudaSuccess) { @@ -8,24 +8,24 @@ extern "C" { return ptr; } - void __cuda_dealloc(void *ptr, size_t size) { + void ___cuda_dealloc(void *ptr, size_t size) { (void) size; cudaFree(ptr); } - void __cuda_zero_mem(void *ptr, size_t size) { + void ___cuda_zero_mem(void *ptr, size_t size) { cudaMemset(ptr, 0, size); } - void __copy_cpu_to_cuda(void *dst, void *src, size_t size) { + void ___copy_cpu_to_cuda(void *dst, void *src, size_t size) { cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); } - void __copy_cuda_to_cpu(void *dst, void *src, size_t size) { + void ___copy_cuda_to_cpu(void *dst, void *src, size_t size) { cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); } - void __copy_cuda_to_cuda(void *dst, void *src, size_t size) { + void ___copy_cuda_to_cuda(void *dst, void *src, size_t size) { cudaMemcpy(dst, src, size, cudaMemcpyDeviceToDevice); } } diff --git a/juno_samples/fork_join_tests/src/cpu.sch b/juno_samples/fork_join_tests/src/cpu.sch index 0263c275..38010004 100644 --- a/juno_samples/fork_join_tests/src/cpu.sch +++ b/juno_samples/fork_join_tests/src/cpu.sch @@ -1,15 +1,21 @@ +no-memset(test6@const); + +ccp(*); gvn(*); phi-elim(*); dce(*); -let out = auto-outline(*); +let out = auto-outline(test1, test2, test3, test4, test5); cpu(out.test1); cpu(out.test2); cpu(out.test3); +cpu(out.test4); +cpu(out.test5); ip-sroa(*); sroa(*); dce(*); +ccp(*); gvn(*); phi-elim(*); dce(*); @@ -23,20 +29,30 @@ gvn(*); phi-elim(*); dce(*); -gvn(*); -phi-elim(*); -dce(*); - fixpoint panic after 20 { infer-schedules(*); } -fork-split(*); +fork-split(out.test1, out.test2, out.test3, out.test4, out.test5); gvn(*); phi-elim(*); dce(*); -unforkify(*); +unforkify(out.test1, out.test2, out.test3, out.test4, out.test5); +ccp(*); gvn(*); phi-elim(*); dce(*); +fork-tile[32, 0, true](test6@loop); +let out = fork-split(test6@loop); +//let out = outline(out.test6.fj1); +let out = auto-outline(test6); +cpu(out.test6); +ip-sroa(*); +sroa(*); +unforkify(out.test6); +dce(*); +ccp(*); +gvn(*); +phi-elim(*); +dce(*); gcm(*); diff --git a/juno_samples/fork_join_tests/src/fork_join_tests.jn b/juno_samples/fork_join_tests/src/fork_join_tests.jn index 6e5db4cb..806cb0f1 100644 --- a/juno_samples/fork_join_tests/src/fork_join_tests.jn +++ b/juno_samples/fork_join_tests/src/fork_join_tests.jn @@ -72,3 +72,12 @@ fn test5(input : i32) -> i32[4] { } return arr1; } + +#[entry] +fn test6(input: i32) -> i32[1024] { + @const let arr : i32[1024]; + @loop for i = 0 to 1024 { + arr[i] = i as i32 + input; + } + return arr; +} diff --git a/juno_samples/fork_join_tests/src/gpu.sch b/juno_samples/fork_join_tests/src/gpu.sch index f096ea50..f108e2c1 100644 --- a/juno_samples/fork_join_tests/src/gpu.sch +++ b/juno_samples/fork_join_tests/src/gpu.sch @@ -4,17 +4,18 @@ no-memset(test1@const); no-memset(test3@const1); no-memset(test3@const2); no-memset(test3@const3); +no-memset(test6@const); gvn(*); phi-elim(*); dce(*); -let out = auto-outline(*); -gpu(out.test1); -gpu(out.test2); -gpu(out.test3); -gpu(out.test4); -gpu(out.test5); +let auto = auto-outline(test1, test2, test3, test4, test5); +gpu(auto.test1); +gpu(auto.test2); +gpu(auto.test3); +gpu(auto.test4); +gpu(auto.test5); ip-sroa(*); sroa(*); @@ -37,5 +38,18 @@ fixpoint panic after 20 { infer-schedules(*); } -float-collections(test2, out.test2, test4, out.test4, test5, out.test5); +fork-tile[32, 0, true](test6@loop); +let out = fork-split(test6@loop); +let out = auto-outline(test6); +gpu(out.test6); +ip-sroa(*); +sroa(*); +dce(*); +ccp(*); +gvn(*); +phi-elim(*); +dce(*); +gcm(*); + +float-collections(test2, auto.test2, test4, auto.test4, test5, auto.test5); gcm(*); diff --git a/juno_samples/fork_join_tests/src/main.rs b/juno_samples/fork_join_tests/src/main.rs index 5e848ade..19838fd7 100644 --- a/juno_samples/fork_join_tests/src/main.rs +++ b/juno_samples/fork_join_tests/src/main.rs @@ -6,42 +6,47 @@ juno_build::juno!("fork_join_tests"); fn main() { #[cfg(not(feature = "cuda"))] - let assert = |correct, output: hercules_rt::HerculesCPURefMut<'_>| { - assert_eq!(output.as_slice::<i32>(), &correct); + let assert = |correct: &Vec<i32>, output: hercules_rt::HerculesCPURefMut<'_>| { + assert_eq!(output.as_slice::<i32>(), correct); }; #[cfg(feature = "cuda")] - let assert = |correct, output: hercules_rt::HerculesCUDARefMut<'_>| { - let mut dst = vec![0i32; 16]; + let assert = |correct: &Vec<i32>, output: hercules_rt::HerculesCUDARefMut<'_>| { + let mut dst = vec![0i32; correct.len()]; let output = output.to_cpu_ref(&mut dst); - assert_eq!(output.as_slice::<i32>(), &correct); + assert_eq!(output.as_slice::<i32>(), correct); }; async_std::task::block_on(async { let mut r = runner!(test1); let output = r.run(5).await; let correct = vec![5i32; 16]; - assert(correct, output); + assert(&correct, output); let mut r = runner!(test2); let output = r.run(3).await; let correct = vec![24i32; 16]; - assert(correct, output); + assert(&correct, output); let mut r = runner!(test3); let output = r.run(0).await; let correct = vec![11, 10, 9, 10, 9, 8, 9, 8, 7]; - assert(correct, output); + assert(&correct, output); let mut r = runner!(test4); let output = r.run(9).await; let correct = vec![63i32; 16]; - assert(correct, output); + assert(&correct, output); let mut r = runner!(test5); let output = r.run(4).await; let correct = vec![7i32; 4]; - assert(correct, output); + assert(&correct, output); + + let mut r = runner!(test6); + let output = r.run(73).await; + let correct = (73i32..73i32+1024i32).collect(); + assert(&correct, output); }); } diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index ad9195fb..d4ab432a 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(exact_size_is_empty)] + use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::Read; diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index b2845913..9478eb9b 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -182,7 +182,7 @@ pub struct PassManager { pub fork_trees: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub loops: Option<Vec<LoopTree>>, pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, - pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, + pub nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, pub reduce_einsums: Option<Vec<(MathEnv, HashMap<NodeID, MathID>)>>, pub no_reset_constants: Option<Vec<BTreeSet<NodeID>>>, pub collection_objects: Option<CollectionObjects>, @@ -221,7 +221,7 @@ impl PassManager { fork_trees: None, loops: None, reduce_cycles: None, - data_nodes_in_fork_joins: None, + nodes_in_fork_joins: None, reduce_einsums: None, no_reset_constants: None, collection_objects: None, @@ -408,11 +408,11 @@ impl PassManager { } } - pub fn make_data_nodes_in_fork_joins(&mut self) { - if self.data_nodes_in_fork_joins.is_none() { + pub fn make_nodes_in_fork_joins(&mut self) { + if self.nodes_in_fork_joins.is_none() { self.make_def_uses(); self.make_fork_join_maps(); - self.data_nodes_in_fork_joins = Some( + self.nodes_in_fork_joins = Some( zip( self.functions.iter(), zip( @@ -421,7 +421,7 @@ impl PassManager { ), ) .map(|(function, (def_use, fork_join_map))| { - data_nodes_in_fork_joins(function, def_use, fork_join_map) + nodes_in_fork_joins(function, def_use, fork_join_map) }) .collect(), ); @@ -434,12 +434,12 @@ impl PassManager { self.make_typing(); self.make_fork_join_maps(); self.make_fork_join_nests(); - self.make_data_nodes_in_fork_joins(); + self.make_nodes_in_fork_joins(); let def_uses = self.def_uses.as_ref().unwrap().iter(); let typing = self.typing.as_ref().unwrap().iter(); let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); let fork_join_nests = self.fork_join_nests.as_ref().unwrap().iter(); - let data_nodes_in_fork_joins = self.data_nodes_in_fork_joins.as_ref().unwrap().iter(); + let nodes_in_fork_joins = self.nodes_in_fork_joins.as_ref().unwrap().iter(); self.reduce_einsums = Some( self.functions .iter() @@ -447,11 +447,11 @@ impl PassManager { .zip(typing) .zip(fork_join_maps) .zip(fork_join_nests) - .zip(data_nodes_in_fork_joins) + .zip(nodes_in_fork_joins) .map( |( ((((function, def_use), typing), fork_join_map), fork_join_nest), - data_nodes_in_fork_joins, + nodes_in_fork_joins, )| { einsum( function, @@ -461,7 +461,7 @@ impl PassManager { typing, fork_join_map, fork_join_nest, - data_nodes_in_fork_joins, + nodes_in_fork_joins, ) }, ) @@ -579,7 +579,7 @@ impl PassManager { self.fork_trees = None; self.loops = None; self.reduce_cycles = None; - self.data_nodes_in_fork_joins = None; + self.nodes_in_fork_joins = None; self.reduce_einsums = None; self.no_reset_constants = None; self.collection_objects = None; @@ -809,6 +809,8 @@ impl PassManager { let mut nvcc_process = Command::new("nvcc") .arg("-c") .arg("-O3") + .arg("-diag-suppress") + .arg("177") .arg("-o") .arg(&cuda_object) .arg(&cuda_path) @@ -1515,6 +1517,7 @@ fn run_pass( } Pass::ForkSplit => { assert!(args.is_empty()); + let mut created_fork_joins = vec![vec![vec![]]; pm.functions.len()]; loop { let mut inner_changed = false; pm.make_fork_join_maps(); @@ -1529,17 +1532,80 @@ fn run_pass( let Some(mut func) = func else { continue; }; - split_all_forks(&mut func, fork_join_map, reduce_cycles); + if let Some((forks, joins)) = + split_any_fork(&mut func, fork_join_map, reduce_cycles) + { + let created_fork_joins = &mut created_fork_joins[func.func_id().idx()]; + if forks.len() > created_fork_joins.len() { + created_fork_joins.resize(forks.len(), vec![]); + } + for (idx, (fork, join)) in zip(forks, joins).enumerate() { + created_fork_joins[idx].push((fork, join)); + } + } changed |= func.modified(); inner_changed |= func.modified(); } - pm.delete_gravestones(); pm.clear_analyses(); if !inner_changed { break; } } + + pm.make_nodes_in_fork_joins(); + let nodes_in_fork_joins = pm.nodes_in_fork_joins.take().unwrap(); + let mut new_fork_joins = HashMap::new(); + for (mut func, created_fork_joins) in + build_editors(pm).into_iter().zip(created_fork_joins) + { + // For every function, create a label for every level of fork- + // joins resulting from the split. + let name = func.func().name.clone(); + let func_id = func.func_id(); + let labels = create_labels_for_node_sets( + &mut func, + created_fork_joins.into_iter().map(|level_fork_joins| { + level_fork_joins + .into_iter() + .map(|(fork, _)| { + nodes_in_fork_joins[func_id.idx()][&fork] + .iter() + .map(|id| *id) + }) + .flatten() + }), + ); + + // Assemble those labels into a record for this function. The + // format of the records is <function>.<fjN>, where N is the + // level of the split fork-joins being referred to. + let mut func_record = HashMap::new(); + for (idx, label) in labels { + func_record.insert( + format!("fj{}", idx), + Value::Label { + labels: vec![LabelInfo { + func: func_id, + label: label, + }], + }, + ); + } + + // Try to avoid creating unnecessary record entries. + if !func_record.is_empty() { + new_fork_joins.entry(name).insert_entry(Value::Record { + fields: func_record, + }); + } + } + + pm.delete_gravestones(); + pm.clear_analyses(); + result = Value::Record { + fields: new_fork_joins, + }; } Pass::Forkify => { assert!(args.is_empty()); @@ -1914,27 +1980,22 @@ fn run_pass( } Pass::ForkChunk => { assert_eq!(args.len(), 3); - let tile_size = args.get(0); - let dim_idx = args.get(1); - - let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { + let Some(Value::Integer { val: tile_size }) = args.get(0) else { return Err(SchedulerError::PassError { pass: "forkChunk".to_string(), - error: "expected boolean argument".to_string(), + error: "expected integer argument".to_string(), }); }; - let Some(Value::Integer { val: dim_idx }) = args.get(1) else { return Err(SchedulerError::PassError { pass: "forkChunk".to_string(), error: "expected integer argument".to_string(), }); }; - - let Some(Value::Integer { val: tile_size }) = args.get(0) else { + let Some(Value::Boolean { val: guarded_flag }) = args.get(2) else { return Err(SchedulerError::PassError { pass: "forkChunk".to_string(), - error: "expected integer argument".to_string(), + error: "expected boolean argument".to_string(), }); }; @@ -2068,9 +2129,33 @@ fn run_pass( // Put BasicBlocks back, since it's needed for Codegen. pm.bbs = bbs; } - Pass::ForkChunk => todo!(), } println!("Ran Pass: {:?}", pass); Ok((result, changed)) } + +fn create_labels_for_node_sets<I, J>( + editor: &mut FunctionEditor, + node_sets: I, +) -> Vec<(usize, LabelID)> +where + I: Iterator<Item = J>, + J: Iterator<Item = NodeID>, +{ + let mut labels = vec![]; + editor.edit(|mut edit| { + for (set_idx, node_set) in node_sets.enumerate() { + let mut node_set = node_set.peekable(); + if node_set.peek().is_some() { + let label = edit.fresh_label(); + for node in node_set { + edit = edit.add_label(node, label).unwrap(); + } + labels.push((set_idx, label)); + } + } + Ok(edit) + }); + labels +} -- GitLab