Skip to content
Snippets Groups Projects
Commit d391d163 authored by prrathi's avatar prrathi
Browse files

mm dot works

parent aa1cf408
No related branches found
No related tags found
1 merge request!115GPU backend
Pipeline #201216 failed
...@@ -79,10 +79,6 @@ pub fn gpu_codegen<W: Write>( ...@@ -79,10 +79,6 @@ pub fn gpu_codegen<W: Write>(
* - Add float8, float16, bfloat16 dtypes if they come * - Add float8, float16, bfloat16 dtypes if they come
*/ */
// Temporary for matmul (both true) and dot (thread true) test while we don't have schedule annotations
let block_parallel_override = false;
let thread_parallel_override = false;
let reduce_nodes: Vec<NodeID> = (0..function.nodes.len()) let reduce_nodes: Vec<NodeID> = (0..function.nodes.len())
.filter(|idx| function.nodes[*idx].is_reduce()) .filter(|idx| function.nodes[*idx].is_reduce())
.map(NodeID::new) .map(NodeID::new)
...@@ -164,8 +160,6 @@ pub fn gpu_codegen<W: Write>( ...@@ -164,8 +160,6 @@ pub fn gpu_codegen<W: Write>(
threads_per_warp: 32, threads_per_warp: 32,
}; };
std::fs::write("out.txt", "debug\n\n").unwrap();
let ctx = GPUContext { let ctx = GPUContext {
function, function,
types, types,
...@@ -185,8 +179,6 @@ pub fn gpu_codegen<W: Write>( ...@@ -185,8 +179,6 @@ pub fn gpu_codegen<W: Write>(
control_data_phi_map, control_data_phi_map,
return_parameter, return_parameter,
kernel_params, kernel_params,
block_parallel_override,
thread_parallel_override,
}; };
ctx.codegen_function(w) ctx.codegen_function(w)
} }
...@@ -215,8 +207,6 @@ struct GPUContext<'a> { ...@@ -215,8 +207,6 @@ struct GPUContext<'a> {
control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>, control_data_phi_map: HashMap<NodeID, Vec<(NodeID, NodeID)>>,
return_parameter: Option<usize>, return_parameter: Option<usize>,
kernel_params: &'a GPUKernelParams, kernel_params: &'a GPUKernelParams,
block_parallel_override: bool,
thread_parallel_override: bool,
} }
/* /*
...@@ -265,8 +255,6 @@ enum CGType { ...@@ -265,8 +255,6 @@ enum CGType {
impl GPUContext<'_> { impl GPUContext<'_> {
fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> { fn codegen_function<W: Write>(&self, w: &mut W) -> Result<(), Error> {
let mut file = OpenOptions::new().append(true).open("out.txt").unwrap();
// Emit all code up to the "goto" to Start's block // Emit all code up to the "goto" to Start's block
let mut top = String::new(); let mut top = String::new();
self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?; self.codegen_kernel_begin(self.return_parameter.is_none(), &mut top)?;
...@@ -289,19 +277,14 @@ impl GPUContext<'_> { ...@@ -289,19 +277,14 @@ impl GPUContext<'_> {
// If there are no forks, fast forward to single-block, single-thread codegen // If there are no forks, fast forward to single-block, single-thread codegen
let (num_blocks, num_threads) = if self.fork_join_map.is_empty() { let (num_blocks, num_threads) = if self.fork_join_map.is_empty() {
writeln!(file, "shortcut to 1b1t").unwrap();
self.codegen_data_control_no_forks(&HashSet::new(), &mut dynamic_shared_offset, &mut gotos)?; self.codegen_data_control_no_forks(&HashSet::new(), &mut dynamic_shared_offset, &mut gotos)?;
("1".to_string(), "1".to_string()) ("1".to_string(), "1".to_string())
} else { } else {
writeln!(file, "no shortcut! fork tree: {:?}", self.fork_tree).unwrap();
// Create structures and determine block and thread parallelization strategy // Create structures and determine block and thread parallelization strategy
let (root_forks, num_blocks, is_block_parallel) = let (root_forks, num_blocks, is_block_parallel) =
self.get_root_forks_and_num_blocks(&self.fork_tree); self.get_root_forks_and_num_blocks(self.fork_tree);
writeln!(file, "is_block_parallel: {}", is_block_parallel).unwrap(); let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, self.fork_tree, is_block_parallel);
let (thread_root_root_fork, thread_root_forks) = self.get_thread_root_forks(&root_forks, &self.fork_tree, is_block_parallel); let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(self.fork_tree, thread_root_root_fork);
writeln!(file, "thread_root_root_fork: {:?}", thread_root_root_fork).unwrap();
let (fork_thread_quota_map, num_threads) = self.get_thread_quotas(&self.fork_tree, thread_root_root_fork);
writeln!(file, "fork_thread_quota_map: {:?}", fork_thread_quota_map).unwrap();
// TODO: Uncomment and adjust once we know logic of extra dim // TODO: Uncomment and adjust once we know logic of extra dim
// let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map); // let extra_dim_collects = self.get_extra_dim_collects(&fork_control_map, &fork_thread_quota_map);
let extra_dim_collects = HashSet::new(); let extra_dim_collects = HashSet::new();
...@@ -590,7 +573,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -590,7 +573,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
self.collection_objects.origin(*object).try_parameter().is_some() self.collection_objects.origin(*object).try_parameter().is_some()
}) })
}), "All collection reduces in block fork must originate from parameters"); }), "All collection reduces in block fork must originate from parameters");
if self.block_parallel_override || self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork) if self.function.schedules[root_fork.idx()].contains(&Schedule::ParallelFork)
{ {
let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * "); let fork_size = factors.iter().map(|dc| format!("dc{}", dc.idx())).collect::<Vec<_>>().join(" * ");
(root_forks, fork_size, true) (root_forks, fork_size, true)
...@@ -697,7 +680,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -697,7 +680,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
&& fork_size.is_power_of_two() && fork_size.is_power_of_two()
&& reduces.iter().all(|&reduce| { && reduces.iter().all(|&reduce| {
self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce) self.function.schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
|| self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
}) })
{ {
// If there's an associative Reduce, parallelize the larger factor // If there's an associative Reduce, parallelize the larger factor
...@@ -710,7 +693,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; ...@@ -710,7 +693,7 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?;
// restriction doesn't help for parallel Writes, so nested parallelization // restriction doesn't help for parallel Writes, so nested parallelization
// is possible. // is possible.
if reduces.iter().any(|&reduce| { if reduces.iter().any(|&reduce| {
self.thread_parallel_override || self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative) self.function.schedules[reduce.idx()].contains(&Schedule::TightAssociative)
}) || fork_size > self.kernel_params.max_num_threads / subtree_quota { }) || fork_size > self.kernel_params.max_num_threads / subtree_quota {
if fork_size >= subtree_quota { if fork_size >= subtree_quota {
(HashMap::new(), fork_size, true) (HashMap::new(), fork_size, true)
......
...@@ -10,13 +10,11 @@ juno_build::juno!("matmul"); ...@@ -10,13 +10,11 @@ juno_build::juno!("matmul");
fn main() { fn main() {
async_std::task::block_on(async { async_std::task::block_on(async {
const I: usize = 4; const I: usize = 256;
const J: usize = 2; const J: usize = 8; // hardcoded constant in matmul.hir
const K: usize = 8; const K: usize = 128;
// let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect(); let mut a: Box<[i32]> = (0..I * J).map(|_| random::<i32>() % 100).collect();
// let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect(); let mut b: Box<[i32]> = (0..J * K).map(|_| random::<i32>() % 100).collect();
let mut a: Box<[i32]> = (0..I * J).map(|i| (i as i32) % 100).collect();
let mut b: Box<[i32]> = (0..J * K).map(|i| (i as i32) % 100).collect();
let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect(); let mut correct_c: Box<[i32]> = (0..I * K).map(|_| 0).collect();
for i in 0..I { for i in 0..I {
for k in 0..K { for k in 0..K {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment