From 54bf7312724cab4ae2353ff7736b54f4b77d47b4 Mon Sep 17 00:00:00 2001 From: rarbore2 <rarbore2@illinois.edu> Date: Wed, 14 Feb 2024 15:45:28 -0600 Subject: [PATCH] Add zero initializer constants --- Cargo.lock | 1 + hercules_cg/src/cpu_beta.rs | 38 ++++++++++++++-- hercules_ir/src/ir.rs | 21 ++++++++- hercules_ir/src/parse.rs | 20 ++++++--- hercules_ir/src/typecheck.rs | 2 + hercules_opt/Cargo.toml | 1 + hercules_opt/src/ccp.rs | 59 ++++++++++++++++++------- hercules_samples/matmul/matmul.hir | 3 +- hercules_samples/matmul/src/main.rs | 4 +- hercules_tools/hercules_cpu/src/main.rs | 1 + hercules_tools/hercules_dot/src/main.rs | 1 + 11 files changed, 122 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bfcc04d0..5cc9664f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -205,6 +205,7 @@ name = "hercules_opt" version = "0.1.0" dependencies = [ "hercules_ir", + "ordered-float", ] [[package]] diff --git a/hercules_cg/src/cpu_beta.rs b/hercules_cg/src/cpu_beta.rs index e6c858fd..2c5897d2 100644 --- a/hercules_cg/src/cpu_beta.rs +++ b/hercules_cg/src/cpu_beta.rs @@ -106,7 +106,36 @@ pub fn cpu_beta_codegen<W: Write>( // Step 2: render constants into LLVM IR. This is done in a very similar // manner as types. - let mut llvm_constants = vec!["".to_string(); types.len()]; + let mut llvm_constants = vec!["".to_string(); constants.len()]; + fn render_zero_constant(cons_id: ConstantID, ty_id: TypeID, types: &Vec<Type>) -> String { + match &types[ty_id.idx()] { + Type::Control(_) => panic!(), + Type::Boolean => "false".to_string(), + Type::Integer8 + | Type::Integer16 + | Type::Integer32 + | Type::Integer64 + | Type::UnsignedInteger8 + | Type::UnsignedInteger16 + | Type::UnsignedInteger32 + | Type::UnsignedInteger64 => "0".to_string(), + Type::Float32 | Type::Float64 => "0.0".to_string(), + Type::Product(fields) => { + let mut iter = fields.iter(); + if let Some(first) = iter.next() { + iter.fold( + "{".to_string() + &render_zero_constant(cons_id, *first, types), + |s, f| s + ", " + &render_zero_constant(cons_id, *f, types), + ) + "}" + } else { + "{}".to_string() + } + } + Type::Summation(_) => todo!(), + Type::Array(_, _) => format!("%arr.{}", cons_id.idx()), + } + } + for id in module.constants_bottom_up() { match &constants[id.idx()] { Constant::Boolean(val) => { @@ -151,6 +180,9 @@ pub fn cpu_beta_codegen<W: Write>( } Constant::Array(_, _) => llvm_constants[id.idx()] = format!("%arr.{}", id.idx()), Constant::Summation(_, _, _) => todo!(), + Constant::Zero(ty_id) => { + llvm_constants[id.idx()] = render_zero_constant(id, *ty_id, types) + } } } @@ -190,8 +222,8 @@ pub fn cpu_beta_codegen<W: Write>( .chain((0..function.num_dynamic_constants).map(|idx| format!("i64 %dc{}", idx))) .chain( (0..constants.len()) - .filter(|idx| constants[*idx].is_array()) - .map(|idx| format!("%arr.{}", idx)), + .filter(|idx| module.is_array_constant(ConstantID::new(*idx))) + .map(|idx| format!("ptr %arr.{}", idx)), ); write!(w, "define {} @{}(", llvm_ret_type, function.name)?; if let Some(first) = llvm_params.next() { diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index ecc7eecf..8e5997bf 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -77,7 +77,9 @@ pub enum Type { * interning constants during IR construction). Product, summation, and array * constants all contain their own type. This is only strictly necessary for * summation types, but provides a nice mechanism for sanity checking for - * product and array types as well. + * product and array types as well. There is also a zero initializer constant, + * which stores its own type as well. The zero value of a summation is defined + * as the zero value of the first variant. */ #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Constant { @@ -95,6 +97,7 @@ pub enum Constant { Product(TypeID, Box<[ConstantID]>), Summation(TypeID, u32, ConstantID), Array(TypeID, Box<[ConstantID]>), + Zero(TypeID), } /* @@ -374,6 +377,7 @@ impl Module { } write!(w, "]") } + Constant::Zero(_) => write!(w, "zero"), }?; Ok(()) @@ -509,6 +513,18 @@ impl Module { coroutine: Box::new(coroutine), } } + + /* + * Unfortunately, determining if a constant is an array requires both + * knowledge of constants and types, due to zero initializer constants. + */ + pub fn is_array_constant(&self, cons_id: ConstantID) -> bool { + if let Constant::Zero(ty_id) = self.constants[cons_id.idx()] { + self.types[ty_id.idx()].is_array() + } else { + self.constants[cons_id.idx()].is_strictly_array() + } + } } struct CoroutineIterator<G, I> @@ -674,7 +690,7 @@ pub fn element_type(mut ty: TypeID, types: &Vec<Type>) -> TypeID { } impl Constant { - pub fn is_array(&self) -> bool { + pub fn is_strictly_array(&self) -> bool { if let Constant::Array(_, _) = self { true } else { @@ -697,6 +713,7 @@ impl Constant { Constant::UnsignedInteger64(0) => true, Constant::Float32(ord) => *ord == ordered_float::OrderedFloat::<f32>(0.0), Constant::Float64(ord) => *ord == ordered_float::OrderedFloat::<f64>(0.0), + Constant::Zero(_) => true, _ => false, } } diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 3354cbb2..a00da548 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -876,7 +876,16 @@ fn parse_constant<'a>( ty: Type, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Constant> { - let (ir_text, constant) = match ty.clone() { + let ty_id = context.borrow_mut().get_type_id(ty.clone()); + let (ir_text, maybe_constant) = nom::combinator::opt(nom::combinator::map( + nom::bytes::complete::tag("zero"), + |_| Constant::Zero(ty_id), + ))(ir_text)?; + if let Some(cons) = maybe_constant { + return Ok((ir_text, cons)); + } + + let (ir_text, constant) = match ty { // There are not control constants. Type::Control(_) => Err(nom::Err::Error(nom::error::Error { input: ir_text, @@ -893,13 +902,13 @@ fn parse_constant<'a>( Type::UnsignedInteger64 => parse_unsigned_integer64(ir_text)?, Type::Float32 => parse_float32(ir_text)?, Type::Float64 => parse_float64(ir_text)?, - Type::Product(tys) => parse_product_constant( + Type::Product(ref tys) => parse_product_constant( ir_text, context.borrow_mut().get_type_id(ty.clone()), tys, context, )?, - Type::Summation(tys) => parse_summation_constant( + Type::Summation(ref tys) => parse_summation_constant( ir_text, context.borrow_mut().get_type_id(ty.clone()), tys, @@ -912,7 +921,6 @@ fn parse_constant<'a>( context, )?, }; - context.borrow_mut().get_type_id(ty); Ok((ir_text, constant)) } @@ -997,7 +1005,7 @@ fn parse_float64<'a>(ir_text: &'a str) -> nom::IResult<&'a str, Constant> { fn parse_product_constant<'a>( ir_text: &'a str, prod_ty: TypeID, - tys: Box<[TypeID]>, + tys: &Box<[TypeID]>, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Constant> { let ir_text = nom::character::complete::multispace0(ir_text)?.0; @@ -1032,7 +1040,7 @@ fn parse_product_constant<'a>( fn parse_summation_constant<'a>( ir_text: &'a str, sum_ty: TypeID, - tys: Box<[TypeID]>, + tys: &Box<[TypeID]>, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Constant> { let ir_text = nom::character::complete::multispace0(ir_text)?.0; diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 1a183411..eabe45d7 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -579,6 +579,8 @@ fn typeflow( )) } } + // Zero constants need to store their type, and we trust it. + Constant::Zero(id) => Concrete(id), } } Node::DynamicConstant { id } => { diff --git a/hercules_opt/Cargo.toml b/hercules_opt/Cargo.toml index 47bd9bd5..bc305405 100644 --- a/hercules_opt/Cargo.toml +++ b/hercules_opt/Cargo.toml @@ -4,4 +4,5 @@ version = "0.1.0" authors = ["Russel Arbore <rarbore2@illinois.edu>"] [dependencies] +ordered-float = "*" hercules_ir = { path = "../hercules_ir" } diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index c061a877..1381a2f1 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -133,13 +133,14 @@ impl Semilattice for ConstantLattice { */ pub fn ccp( function: &mut Function, + types: &Vec<Type>, constants: &mut Vec<Constant>, def_use: &ImmutableDefUseMap, reverse_postorder: &Vec<NodeID>, ) { // Step 1: run ccp analysis to understand the function. let result = dataflow_global(&function, reverse_postorder, |inputs, node_id| { - ccp_flow_function(inputs, node_id, &function, &constants) + ccp_flow_function(inputs, node_id, &function, &types, &constants) }); // Step 2: update uses of constants. Any node that doesn't produce a @@ -371,6 +372,7 @@ fn ccp_flow_function( inputs: &[CCPLattice], node_id: NodeID, function: &Function, + types: &Vec<Type>, old_constants: &Vec<Constant>, ) -> CCPLattice { let node = &function.nodes[node_id.idx()]; @@ -415,7 +417,7 @@ fn ccp_flow_function( // TODO: This should produce a constant zero if the dynamic constant for // for the corresponding fork is one. Node::ThreadID { control } => inputs[control.idx()].clone(), - // TODO: At least for now, collect nodes always produce unknown values. + // TODO: At least for now, reduce nodes always produce unknown values. Node::Reduce { control, init: _, @@ -454,6 +456,7 @@ fn ccp_flow_function( (UnaryOperator::Neg, Constant::Integer64(val)) => Constant::Integer64(-val), (UnaryOperator::Neg, Constant::Float32(val)) => Constant::Float32(-val), (UnaryOperator::Neg, Constant::Float64(val)) => Constant::Float64(-val), + (UnaryOperator::Neg, Constant::Zero(id)) => Constant::Zero(*id), _ => panic!("Unsupported combination of unary operation and constant value. Did typechecking succeed?") }; ConstantLattice::Constant(new_cons) @@ -482,6 +485,32 @@ fn ccp_flow_function( ConstantLattice::Constant(right_cons), ) = (left_constant, right_constant) { + let type_to_zero_cons = |ty_id: TypeID| { + match types[ty_id.idx()] { + Type::Boolean => Constant::Boolean(false), + Type::Integer8 => Constant::Integer8(0), + Type::Integer16 => Constant::Integer16(0), + Type::Integer32 => Constant::Integer32(0), + Type::Integer64 => Constant::Integer64(0), + Type::UnsignedInteger8 => Constant::UnsignedInteger8(0), + Type::UnsignedInteger16 => Constant::UnsignedInteger16(0), + Type::UnsignedInteger32 => Constant::UnsignedInteger32(0), + Type::UnsignedInteger64 => Constant::UnsignedInteger64(0), + Type::Float32 => Constant::Float32(ordered_float::OrderedFloat::<f32>(0.0)), + Type::Float64 => Constant::Float64(ordered_float::OrderedFloat::<f64>(0.0)), + _ => panic!("Unsupported combination of binary operation and constant values. Did typechecking succeed?") + } + }; + let left_cons = if let Constant::Zero(id) = left_cons { + type_to_zero_cons(*id) + } else { + left_cons.clone() + }; + let right_cons = if let Constant::Zero(id) = right_cons { + type_to_zero_cons(*id) + } else { + right_cons.clone() + }; let new_cons = match (op, left_cons, right_cons) { (BinaryOperator::Add, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val + right_val), (BinaryOperator::Add, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val + right_val), @@ -491,8 +520,8 @@ fn ccp_flow_function( (BinaryOperator::Add, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val + right_val), (BinaryOperator::Add, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val + right_val), (BinaryOperator::Add, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val + right_val), - (BinaryOperator::Add, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val + *right_val), - (BinaryOperator::Add, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val + *right_val), + (BinaryOperator::Add, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val + right_val), + (BinaryOperator::Add, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val + right_val), (BinaryOperator::Sub, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val - right_val), (BinaryOperator::Sub, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val - right_val), (BinaryOperator::Sub, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val - right_val), @@ -501,8 +530,8 @@ fn ccp_flow_function( (BinaryOperator::Sub, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val - right_val), (BinaryOperator::Sub, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val - right_val), (BinaryOperator::Sub, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val - right_val), - (BinaryOperator::Sub, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val - *right_val), - (BinaryOperator::Sub, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val - *right_val), + (BinaryOperator::Sub, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val - right_val), + (BinaryOperator::Sub, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val - right_val), (BinaryOperator::Mul, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val * right_val), (BinaryOperator::Mul, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val * right_val), (BinaryOperator::Mul, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val * right_val), @@ -511,8 +540,8 @@ fn ccp_flow_function( (BinaryOperator::Mul, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val * right_val), (BinaryOperator::Mul, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val * right_val), (BinaryOperator::Mul, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val * right_val), - (BinaryOperator::Mul, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val * *right_val), - (BinaryOperator::Mul, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val * *right_val), + (BinaryOperator::Mul, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val * right_val), + (BinaryOperator::Mul, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val * right_val), (BinaryOperator::Div, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val / right_val), (BinaryOperator::Div, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val / right_val), (BinaryOperator::Div, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val / right_val), @@ -521,8 +550,8 @@ fn ccp_flow_function( (BinaryOperator::Div, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val / right_val), (BinaryOperator::Div, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val / right_val), (BinaryOperator::Div, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val / right_val), - (BinaryOperator::Div, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val / *right_val), - (BinaryOperator::Div, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val / *right_val), + (BinaryOperator::Div, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val / right_val), + (BinaryOperator::Div, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val / right_val), (BinaryOperator::Rem, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val % right_val), (BinaryOperator::Rem, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val % right_val), (BinaryOperator::Rem, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val % right_val), @@ -531,8 +560,8 @@ fn ccp_flow_function( (BinaryOperator::Rem, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val % right_val), (BinaryOperator::Rem, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val % right_val), (BinaryOperator::Rem, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val % right_val), - (BinaryOperator::Rem, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(*left_val % *right_val), - (BinaryOperator::Rem, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(*left_val % *right_val), + (BinaryOperator::Rem, Constant::Float32(left_val), Constant::Float32(right_val)) => Constant::Float32(left_val % right_val), + (BinaryOperator::Rem, Constant::Float64(left_val), Constant::Float64(right_val)) => Constant::Float64(left_val % right_val), (BinaryOperator::LT, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Boolean(left_val < right_val), (BinaryOperator::LT, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Boolean(left_val < right_val), (BinaryOperator::LT, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Boolean(left_val < right_val), @@ -577,7 +606,7 @@ fn ccp_flow_function( // need to unpack the constants. (BinaryOperator::EQ, left_val, right_val) => Constant::Boolean(left_val == right_val), (BinaryOperator::NE, left_val, right_val) => Constant::Boolean(left_val != right_val), - (BinaryOperator::Or, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(*left_val || *right_val), + (BinaryOperator::Or, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(left_val || right_val), (BinaryOperator::Or, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val | right_val), (BinaryOperator::Or, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val | right_val), (BinaryOperator::Or, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val | right_val), @@ -586,7 +615,7 @@ fn ccp_flow_function( (BinaryOperator::Or, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val | right_val), (BinaryOperator::Or, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val | right_val), (BinaryOperator::Or, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val | right_val), - (BinaryOperator::And, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(*left_val && *right_val), + (BinaryOperator::And, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(left_val && right_val), (BinaryOperator::And, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val & right_val), (BinaryOperator::And, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val & right_val), (BinaryOperator::And, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val & right_val), @@ -595,7 +624,7 @@ fn ccp_flow_function( (BinaryOperator::And, Constant::UnsignedInteger16(left_val), Constant::UnsignedInteger16(right_val)) => Constant::UnsignedInteger16(left_val & right_val), (BinaryOperator::And, Constant::UnsignedInteger32(left_val), Constant::UnsignedInteger32(right_val)) => Constant::UnsignedInteger32(left_val & right_val), (BinaryOperator::And, Constant::UnsignedInteger64(left_val), Constant::UnsignedInteger64(right_val)) => Constant::UnsignedInteger64(left_val & right_val), - (BinaryOperator::Xor, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(*left_val ^ *right_val), + (BinaryOperator::Xor, Constant::Boolean(left_val), Constant::Boolean(right_val)) => Constant::Boolean(left_val ^ right_val), (BinaryOperator::Xor, Constant::Integer8(left_val), Constant::Integer8(right_val)) => Constant::Integer8(left_val ^ right_val), (BinaryOperator::Xor, Constant::Integer16(left_val), Constant::Integer16(right_val)) => Constant::Integer16(left_val ^ right_val), (BinaryOperator::Xor, Constant::Integer32(left_val), Constant::Integer32(right_val)) => Constant::Integer32(left_val ^ right_val), diff --git a/hercules_samples/matmul/matmul.hir b/hercules_samples/matmul/matmul.hir index 65e0cb10..2f0fb67a 100644 --- a/hercules_samples/matmul/matmul.hir +++ b/hercules_samples/matmul/matmul.hir @@ -1,4 +1,5 @@ -fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2), c: array(f32, #0, #2)) -> array(f32, #0, #2) +fn matmul<3>(a: array(f32, #0, #1), b: array(f32, #1, #2)) -> array(f32, #0, #2) + c = constant(array(f32, #0, #2), zero) i_ctrl = fork(start, #0) i_idx = thread_id(i_ctrl) j_ctrl = fork(i_ctrl, #2) diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 02aeb3a4..2a28cd80 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -10,10 +10,10 @@ fn main() { "matmul", *const f32, *const f32, - *mut f32, u64, u64, u64, + *mut f32, => *const f32 ); @@ -24,10 +24,10 @@ fn main() { matmul( std::mem::transmute(a.as_ptr()), std::mem::transmute(b.as_ptr()), - std::mem::transmute(c.as_mut_ptr()), 2, 2, 2, + std::mem::transmute(c.as_mut_ptr()), ) }; println!("{} {}\n{} {}", c[0][0], c[0][1], c[1][0], c[1][1]); diff --git a/hercules_tools/hercules_cpu/src/main.rs b/hercules_tools/hercules_cpu/src/main.rs index c6b077f9..209d8619 100644 --- a/hercules_tools/hercules_cpu/src/main.rs +++ b/hercules_tools/hercules_cpu/src/main.rs @@ -34,6 +34,7 @@ fn main() { |(mut function, id), (types, mut constants, dynamic_constants)| { hercules_opt::ccp::ccp( &mut function, + &types, &mut constants, &def_uses[id.idx()], &reverse_postorders[id.idx()], diff --git a/hercules_tools/hercules_dot/src/main.rs b/hercules_tools/hercules_dot/src/main.rs index a24285eb..198a73b7 100644 --- a/hercules_tools/hercules_dot/src/main.rs +++ b/hercules_tools/hercules_dot/src/main.rs @@ -42,6 +42,7 @@ fn main() { |(mut function, id), (types, mut constants, dynamic_constants)| { hercules_opt::ccp::ccp( &mut function, + &types, &mut constants, &def_uses[id.idx()], &reverse_postorders[id.idx()], -- GitLab