From a55ce6342e805467a3366b9fcfe16b163a0e636b Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Sat, 23 Nov 2024 00:49:41 -0600 Subject: [PATCH] Fix schedule gen to panic only if non-root array types/constants are used --- Cargo.lock | 10 ++ Cargo.toml | 2 +- hercules_cg/src/sched_gen.rs | 234 +++++++++++++++-------------- juno_samples/matmul/src/main.rs | 13 +- juno_samples/matmul/src/matmul.sch | 2 +- 5 files changed, 141 insertions(+), 120 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 440cdd3c..37216ee3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -729,6 +729,16 @@ dependencies = [ "phf", ] +[[package]] +name = "juno_matmul" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_scheduler" version = "0.0.1" diff --git a/Cargo.toml b/Cargo.toml index 8fa3079b..c6b2468f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,6 @@ members = [ "juno_scheduler", "juno_build", - #"juno_samples/matmul", + "juno_samples/matmul", "juno_samples/simple3", ] diff --git a/hercules_cg/src/sched_gen.rs b/hercules_cg/src/sched_gen.rs index f065340c..e5378b9c 100644 --- a/hercules_cg/src/sched_gen.rs +++ b/hercules_cg/src/sched_gen.rs @@ -24,9 +24,7 @@ pub fn sched_compile( bbs: &Vec<Vec<NodeID>>, plans: &Vec<Plan>, ) -> SModule { - verify_types_well_formed_for_sched_ir(&module.types); let stypes = convert_to_sched_ir_types(&module.types); - verify_constants_well_formed_for_sched_ir(&module.constants); let sconstants = convert_to_sched_ir_constants(&module.constants); let function_names: HashMap<FunctionID, String> = module .functions @@ -67,111 +65,99 @@ pub fn sched_compile( } } -/* - * Checks for the following conditions: - * 1. Array types are only ever root types. - * 2. Summation types do not exist - TODO: properly support summation types. - */ -fn verify_types_well_formed_for_sched_ir(types: &Vec<Type>) { - for ty in types.iter() { - match ty { - Type::Product(fields) => { - let any_array = fields - .iter() - .any(|field_ty| types[field_ty.idx()].is_array()); - assert!(!any_array, "PANIC: Found non-root array type."); - } - Type::Summation(_) => panic!("PANIC: Can't lower summations to schedule IR yet."), - Type::Array(elem, _) => assert!( - !types[elem.idx()].is_array(), - "PANIC: Found non-root array type." - ), - _ => {} - } - } -} - -fn convert_to_sched_ir_types(types: &Vec<Type>) -> Vec<SType> { - let mut stypes = vec![SType::Boolean; types.len()]; +fn convert_to_sched_ir_types(types: &Vec<Type>) -> Vec<Option<SType>> { + let mut stypes = vec![None; types.len()]; for id in types_bottom_up(types) { stypes[id.idx()] = match &types[id.idx()] { - Type::Control => SType::Boolean, - Type::Boolean => SType::Boolean, - Type::Integer8 => SType::Integer8, - Type::Integer16 => SType::Integer16, - Type::Integer32 => SType::Integer32, - Type::Integer64 => SType::Integer64, - Type::UnsignedInteger8 => SType::UnsignedInteger8, - Type::UnsignedInteger16 => SType::UnsignedInteger16, - Type::UnsignedInteger32 => SType::UnsignedInteger32, - Type::UnsignedInteger64 => SType::UnsignedInteger64, - Type::Float32 => SType::Float32, - Type::Float64 => SType::Float64, + Type::Control => None, + Type::Boolean => Some(SType::Boolean), + Type::Integer8 => Some(SType::Integer8), + Type::Integer16 => Some(SType::Integer16), + Type::Integer32 => Some(SType::Integer32), + Type::Integer64 => Some(SType::Integer64), + Type::UnsignedInteger8 => Some(SType::UnsignedInteger8), + Type::UnsignedInteger16 => Some(SType::UnsignedInteger16), + Type::UnsignedInteger32 => Some(SType::UnsignedInteger32), + Type::UnsignedInteger64 => Some(SType::UnsignedInteger64), + Type::Float32 => Some(SType::Float32), + Type::Float64 => Some(SType::Float64), Type::Product(fields) => { - SType::Product(fields.iter().map(|id| stypes[id.idx()].clone()).collect()) + let mut typs = vec![]; + let mut res_none = false; + for id in fields { + if types[id.idx()].is_array() { + res_none = true; + break; + } else { + match &stypes[id.idx()] { + None => { + res_none = true; + break; + } + Some(t) => typs.push(t.clone()), + } + } + } + if res_none { + None + } else { + Some(SType::Product(typs.into())) + } } Type::Summation(_) => todo!(), - Type::Array(elem_ty, _) => SType::ArrayRef(Box::new(stypes[elem_ty.idx()].clone())), + Type::Array(elem_ty, _) => match &stypes[elem_ty.idx()] { + None => None, + Some(t) => Some(SType::ArrayRef(Box::new(t.clone()))), + }, }; } stypes } -/* - * Checks for the following conditions: - * 1. Array constants are only ever root constants. - * 2. Summation constants do not exist - TODO: properly support summation - * constants. - */ -fn verify_constants_well_formed_for_sched_ir(constants: &Vec<Constant>) { - for ty in constants.iter() { - match ty { - Constant::Product(_, fields) => { - let any_array = fields - .iter() - .any(|field_ty| constants[field_ty.idx()].is_array()); - assert!(!any_array, "PANIC: Found non-root array constant."); - } - Constant::Summation(_, _, _) => { - panic!("PANIC: Can't lower summations to schedule IR yet.") - } - // We don't need to check array constants explicitly, since they - // explicitly store their type, and nothing else - if an array - // constant were invalid by condition #1, then its corresponding - // type would be invalid, which would get caught by - // `verify_types_well_formed_for_sched_ir`. - _ => {} - } - } -} - -fn convert_to_sched_ir_constants(constants: &Vec<Constant>) -> Vec<SConstant> { - let mut sconstants = vec![SConstant::Boolean(false); constants.len()]; +fn convert_to_sched_ir_constants(constants: &Vec<Constant>) -> Vec<Option<SConstant>> { + let mut sconstants = vec![None; constants.len()]; for id in constants_bottom_up(constants) { sconstants[id.idx()] = match &constants[id.idx()] { - Constant::Boolean(val) => SConstant::Boolean(*val), - Constant::Integer8(val) => SConstant::Integer8(*val), - Constant::Integer16(val) => SConstant::Integer16(*val), - Constant::Integer32(val) => SConstant::Integer32(*val), - Constant::Integer64(val) => SConstant::Integer64(*val), - Constant::UnsignedInteger8(val) => SConstant::UnsignedInteger8(*val), - Constant::UnsignedInteger16(val) => SConstant::UnsignedInteger16(*val), - Constant::UnsignedInteger32(val) => SConstant::UnsignedInteger32(*val), - Constant::UnsignedInteger64(val) => SConstant::UnsignedInteger64(*val), - Constant::Float32(val) => SConstant::Float32(*val), - Constant::Float64(val) => SConstant::Float64(*val), - Constant::Product(_, fields) => SConstant::Product( - fields - .iter() - .map(|id| sconstants[id.idx()].clone()) - .collect(), - ), + Constant::Boolean(val) => Some(SConstant::Boolean(*val)), + Constant::Integer8(val) => Some(SConstant::Integer8(*val)), + Constant::Integer16(val) => Some(SConstant::Integer16(*val)), + Constant::Integer32(val) => Some(SConstant::Integer32(*val)), + Constant::Integer64(val) => Some(SConstant::Integer64(*val)), + Constant::UnsignedInteger8(val) => Some(SConstant::UnsignedInteger8(*val)), + Constant::UnsignedInteger16(val) => Some(SConstant::UnsignedInteger16(*val)), + Constant::UnsignedInteger32(val) => Some(SConstant::UnsignedInteger32(*val)), + Constant::UnsignedInteger64(val) => Some(SConstant::UnsignedInteger64(*val)), + Constant::Float32(val) => Some(SConstant::Float32(*val)), + Constant::Float64(val) => Some(SConstant::Float64(*val)), + Constant::Product(_, fields) => { + let mut consts = vec![]; + let mut res_none = false; + for id in fields { + if constants[id.idx()].is_array() { + res_none = true; + break; + } else { + match &sconstants[id.idx()] { + None => { + res_none = true; + break; + } + Some(c) => consts.push(c.clone()), + } + } + } + if res_none { + None + } else { + Some(SConstant::Product(consts.into())) + } + } Constant::Summation(_, _, _) => todo!(), // Array constants are never generated inline schedule IR. - Constant::Array(_) => SConstant::Boolean(false), + Constant::Array(_) => None, }; } @@ -195,8 +181,8 @@ struct FunctionContext<'a> { antideps: &'a Vec<(NodeID, NodeID)>, bbs: &'a Vec<NodeID>, plan: &'a Plan, - stypes: &'a Vec<SType>, - sconstants: &'a Vec<SConstant>, + stypes: &'a Vec<Option<SType>>, + sconstants: &'a Vec<Option<SConstant>>, function_names: &'a HashMap<FunctionID, String>, top_nodes: Vec<NodeID>, @@ -222,8 +208,8 @@ impl<'a> FunctionContext<'a> { antideps: &'a Vec<(NodeID, NodeID)>, bbs: &'a Vec<NodeID>, plan: &'a Plan, - stypes: &'a Vec<SType>, - sconstants: &'a Vec<SConstant>, + stypes: &'a Vec<Option<SType>>, + sconstants: &'a Vec<Option<SConstant>>, function_names: &'a HashMap<FunctionID, String>, ) -> Self { let inverted_partition_map = plan.invert_partition_map(); @@ -348,7 +334,7 @@ impl<'a> FunctionContext<'a> { parameters.extend(self.function.param_types.iter().enumerate().map( |(param_idx, ty_id)| { ( - self.stypes[ty_id.idx()].clone(), + self.stypes[ty_id.idx()].clone().unwrap(), ParameterKind::HerculesParameter(param_idx), ) }, @@ -356,7 +342,9 @@ impl<'a> FunctionContext<'a> { } else { parameters.extend(self.data_inputs[partition_idx].iter().map(|node_id| { ( - self.stypes[self.typing[node_id.idx()].idx()].clone(), + self.stypes[self.typing[node_id.idx()].idx()] + .clone() + .unwrap(), ParameterKind::DataInput(*node_id), ) })) @@ -378,7 +366,9 @@ impl<'a> FunctionContext<'a> { .filter_map(|use_id| { if let Some(array_id) = array_node_to_array_id.get(use_id) { Some(( - self.stypes[self.typing[use_id.idx()].idx()].clone(), + self.stypes[self.typing[use_id.idx()].idx()] + .clone() + .unwrap(), ParameterKind::ArrayConstant(*array_id), )) } else { @@ -423,14 +413,18 @@ impl<'a> FunctionContext<'a> { { assert_eq!(successors.len(), 0); returns.push(( - self.stypes[self.function.return_type.idx()].clone(), + self.stypes[self.function.return_type.idx()] + .clone() + .unwrap(), ReturnKind::HerculesReturn, )); } else { assert!(successors.len() > 0); returns.extend(self.data_outputs[partition_idx].iter().map(|node_id| { ( - self.stypes[self.typing[node_id.idx()].idx()].clone(), + self.stypes[self.typing[node_id.idx()].idx()] + .clone() + .unwrap(), ReturnKind::DataOutput(*node_id), ) })); @@ -461,14 +455,16 @@ impl<'a> FunctionContext<'a> { param_types.extend(self.function.param_types.iter().enumerate().map( |(param_idx, ty_id)| { ( - self.stypes[ty_id.idx()].clone(), + self.stypes[ty_id.idx()].clone().unwrap(), ParameterKind::HerculesParameter(param_idx), ) }, )); param_types.extend(array_node_to_array_id.iter().map(|(node_id, array_id)| { ( - self.stypes[self.typing[node_id.idx()].idx()].clone(), + self.stypes[self.typing[node_id.idx()].idx()] + .clone() + .unwrap(), ParameterKind::ArrayConstant(*array_id), ) })); @@ -481,7 +477,9 @@ impl<'a> FunctionContext<'a> { // The return type is just the schedule IR type corresponding to the // Hercules function's return type. - let return_type = self.stypes[self.function.return_type.idx()].clone(); + let return_type = self.stypes[self.function.return_type.idx()] + .clone() + .unwrap(); let manifest = Manifest { param_types, @@ -647,7 +645,7 @@ impl<'a> FunctionContext<'a> { .unwrap(), ) } else { - SValue::Constant(self.sconstants[id.idx()].clone()) + SValue::Constant(self.sconstants[id.idx()].clone().unwrap()) }; Some((*use_id, svalue)) } else { @@ -998,7 +996,7 @@ impl<'a> FunctionContext<'a> { block.insts.push(SInst::Phi { inputs }); block.virt_regs.push(( self_virt_reg(), - self.stypes[self.typing[id.idx()].idx()].clone(), + self.stypes[self.typing[id.idx()].idx()].clone().unwrap(), )); } } @@ -1031,7 +1029,7 @@ impl<'a> FunctionContext<'a> { block.insts.push(SInst::ReductionVariable { number }); block.virt_regs.push(( self_virt_reg(), - self.stypes[self.typing[id.idx()].idx()].clone(), + self.stypes[self.typing[id.idx()].idx()].clone().unwrap(), )); } @@ -1042,7 +1040,7 @@ impl<'a> FunctionContext<'a> { }); block.virt_regs.push(( self_virt_reg(), - self.stypes[self.typing[id.idx()].idx()].clone(), + self.stypes[self.typing[id.idx()].idx()].clone().unwrap(), )); } Node::Binary { left, right, op } => { @@ -1053,7 +1051,7 @@ impl<'a> FunctionContext<'a> { }); block.virt_regs.push(( self_virt_reg(), - self.stypes[self.typing[id.idx()].idx()].clone(), + self.stypes[self.typing[id.idx()].idx()].clone().unwrap(), )); } Node::Ternary { @@ -1070,7 +1068,7 @@ impl<'a> FunctionContext<'a> { }); block.virt_regs.push(( self_virt_reg(), - self.stypes[self.typing[id.idx()].idx()].clone(), + self.stypes[self.typing[id.idx()].idx()].clone().unwrap(), )); } @@ -1097,8 +1095,10 @@ impl<'a> FunctionContext<'a> { // Array loads need the dynamic constant bounds for indexing // math. let bounds = lower_extents(collect, &mut block); - let load_ty = if let SType::ArrayRef(elem_ty) = - self.stypes[self.typing[collect.idx()].idx()].clone() + let load_ty = if let SType::ArrayRef(elem_ty) = self.stypes + [self.typing[collect.idx()].idx()] + .clone() + .unwrap() { *elem_ty } else { @@ -1129,7 +1129,7 @@ impl<'a> FunctionContext<'a> { }); block.virt_regs.push(( self_virt_reg(), - self.stypes[self.typing[id.idx()].idx()].clone(), + self.stypes[self.typing[id.idx()].idx()].clone().unwrap(), )); } } @@ -1175,8 +1175,10 @@ impl<'a> FunctionContext<'a> { // Load the product. let load_virt_reg = self.make_virt_reg(partition_idx); - let load_ty = if let SType::ArrayRef(elem_ty) = - self.stypes[self.typing[collect.idx()].idx()].clone() + let load_ty = if let SType::ArrayRef(elem_ty) = self.stypes + [self.typing[collect.idx()].idx()] + .clone() + .unwrap() { *elem_ty } else { @@ -1227,7 +1229,7 @@ impl<'a> FunctionContext<'a> { // they create a new product value. block.virt_regs.push(( self_virt_reg(), - self.stypes[self.typing[id.idx()].idx()].clone(), + self.stypes[self.typing[id.idx()].idx()].clone().unwrap(), )); } } @@ -1412,11 +1414,11 @@ impl<'a> FunctionContext<'a> { } } -fn convert_unary_op(op: UnaryOperator, simple_ir_types: &[SType]) -> SUnaryOperator { +fn convert_unary_op(op: UnaryOperator, simple_ir_types: &[Option<SType>]) -> SUnaryOperator { match op { UnaryOperator::Not => SUnaryOperator::Not, UnaryOperator::Neg => SUnaryOperator::Neg, - UnaryOperator::Cast(ty) => SUnaryOperator::Cast(simple_ir_types[ty.idx()].clone()), + UnaryOperator::Cast(ty) => SUnaryOperator::Cast(simple_ir_types[ty.idx()].clone().unwrap()), } } diff --git a/juno_samples/matmul/src/main.rs b/juno_samples/matmul/src/main.rs index e6a73a3f..6d1867ac 100644 --- a/juno_samples/matmul/src/main.rs +++ b/juno_samples/matmul/src/main.rs @@ -4,7 +4,7 @@ extern crate async_std; extern crate juno_build; extern crate hercules_rt; -juno_build::juno!("matmul.jn"); +juno_build::juno!("matmul"); fn main() { async_std::task::block_on(async { @@ -12,8 +12,17 @@ fn main() { let mut b = vec![5.0, 6.0, 7.0, 8.0]; let mut c = vec![0.0, 0.0, 0.0, 0.0]; unsafe { - matmul(a.as_mut_prt(), b.as_mut_ptr(), c.as_mut_ptr(), 2, 2, 2).await; + matmul(a.as_mut_ptr(), b.as_mut_ptr(), c.as_mut_ptr(), 2, 2, 2).await; } println!("[[{}, {}], [{}, {}]]", c[0], c[1], c[2], c[3]); + assert_eq!(c[0], 19.0); + assert_eq!(c[1], 22.0); + assert_eq!(c[2], 43.0); + assert_eq!(c[3], 50.0); }); } + +#[test] +fn matmul_test() { + main(); +} diff --git a/juno_samples/matmul/src/matmul.sch b/juno_samples/matmul/src/matmul.sch index bbc7ed0e..847a9121 100644 --- a/juno_samples/matmul/src/matmul.sch +++ b/juno_samples/matmul/src/matmul.sch @@ -1,5 +1,5 @@ function matmul { - partition { @outer, @middle, @inner } on gpu + partition { @outer, @middle, @inner } on cpu //gpu partition @exit on cpu parallelize @outer -- GitLab