diff --git a/hercules_cg/src/cpu_beta.rs b/hercules_cg/src/cpu_beta.rs index 2c5897d23c689817f4b359ea2a5d787ebfc5dbc2..9a1c6741009e121a6b9fce71a939e733ec0f4ca0 100644 --- a/hercules_cg/src/cpu_beta.rs +++ b/hercules_cg/src/cpu_beta.rs @@ -726,6 +726,25 @@ fn emit_llvm_for_node( virtual_register(right), ); } + Node::Ternary { + first, + second, + third, + op, + } => { + let opcode = match op { + TernaryOperator::Select => "select", + }; + + llvm_bbs.get_mut(&bb[id.idx()]).unwrap().data += &format!( + " {} = {} {}, {}, {}\n", + virtual_register(id), + opcode, + normal_value(first), + normal_value(second), + normal_value(third), + ); + } Node::Read { collect, ref indices, diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 3c37ecb86d59e87f08a8b2a8f5caaa80152ba8f3..4345f7125f7da274d7f9f19d1d0709661dde6a29 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -498,6 +498,21 @@ impl NodeBuilder { self.node = Node::Binary { left, right, op }; } + pub fn build_ternary( + &mut self, + first: NodeID, + second: NodeID, + third: NodeID, + op: TernaryOperator, + ) { + self.node = Node::Ternary { + first, + second, + third, + op, + }; + } + pub fn build_call( &mut self, function: FunctionID, diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index a8fd5c17f6799887364f3cf3ef62b36e6e1953e1..0116bb4e841d891f8e7609941b3271e91cde3ca5 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -152,6 +152,12 @@ pub fn get_uses<'a>(node: &'a Node) -> NodeUses<'a> { Node::DynamicConstant { id: _ } => NodeUses::One([NodeID::new(0)]), Node::Unary { input, op: _ } => NodeUses::One([*input]), Node::Binary { left, right, op: _ } => NodeUses::Two([*left, *right]), + Node::Ternary { + first, + second, + third, + op: _, + } => NodeUses::Three([*first, *second, *third]), Node::Call { function: _, dynamic_constants: _, @@ -227,6 +233,12 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { Node::DynamicConstant { id: _ } => NodeUsesMut::Zero, Node::Unary { input, op: _ } => NodeUsesMut::One([input]), Node::Binary { left, right, op: _ } => NodeUsesMut::Two([left, right]), + Node::Ternary { + first, + second, + third, + op: _, + } => NodeUsesMut::Three([first, second, third]), Node::Call { function: _, dynamic_constants: _, diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 8e5997bf2aa90f242b4452a2f112181f54c0539e..b3df19d9e3603de1004ed1815c71b3a15c98d588 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -199,6 +199,12 @@ pub enum Node { right: NodeID, op: BinaryOperator, }, + Ternary { + first: NodeID, + second: NodeID, + third: NodeID, + op: TernaryOperator, + }, Call { function: FunctionID, dynamic_constants: Box<[DynamicConstantID]>, @@ -241,6 +247,11 @@ pub enum BinaryOperator { RSh, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum TernaryOperator { + Select, +} + impl Module { /* * There are many transformations that need to iterate over the functions @@ -889,6 +900,12 @@ impl Node { right: _, op, } => op.upper_case_name(), + Node::Ternary { + first: _, + second: _, + third: _, + op, + } => op.upper_case_name(), Node::Call { function: _, dynamic_constants: _, @@ -943,6 +960,12 @@ impl Node { right: _, op, } => op.lower_case_name(), + Node::Ternary { + first: _, + second: _, + third: _, + op, + } => op.lower_case_name(), Node::Call { function: _, dynamic_constants: _, @@ -1037,6 +1060,20 @@ impl BinaryOperator { } } +impl TernaryOperator { + pub fn upper_case_name(&self) -> &'static str { + match self { + TernaryOperator::Select => "Select", + } + } + + pub fn lower_case_name(&self) -> &'static str { + match self { + TernaryOperator::Select => "select", + } + } +} + /* * Rust things to make newtyped IDs usable. */ diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index a00da548fa8bfd668445a504afdc93493eb2fcbe..63107c81d863934f101e859e7985b6585baee1e1 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -308,6 +308,7 @@ fn parse_node<'a>( "xor" => parse_binary(ir_text, context, BinaryOperator::Xor)?, "lsh" => parse_binary(ir_text, context, BinaryOperator::LSh)?, "rsh" => parse_binary(ir_text, context, BinaryOperator::RSh)?, + "select" => parse_ternary(ir_text, context, TernaryOperator::Select)?, "call" => parse_call(ir_text, context)?, "read" => parse_read(ir_text, context)?, "write" => parse_write(ir_text, context)?, @@ -485,6 +486,27 @@ fn parse_binary<'a>( Ok((ir_text, Node::Binary { left, right, op })) } +fn parse_ternary<'a>( + ir_text: &'a str, + context: &RefCell<Context<'a>>, + op: TernaryOperator, +) -> nom::IResult<&'a str, Node> { + let (ir_text, (first, second, third)) = + parse_tuple3(parse_identifier, parse_identifier, parse_identifier)(ir_text)?; + let first = context.borrow_mut().get_node_id(first); + let second = context.borrow_mut().get_node_id(second); + let third = context.borrow_mut().get_node_id(third); + Ok(( + ir_text, + Node::Ternary { + first, + second, + third, + op, + }, + )) +} + fn parse_call<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { // Call nodes are a bit complicated because they 1. optionally take dynamic // constants as "arguments" (though these are specified between <>), 2. diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index eabe45d73395f8ad2cbe2971449ca0f468a40455..ff631bd7d64828cb0519fd709f513736cc41b08b 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -701,6 +701,37 @@ fn typeflow( input_ty.clone() } + Node::Ternary { + first: _, + second: _, + third: _, + op, + } => { + if inputs.len() != 3 { + return Error(String::from("Ternary node must have exactly three inputs.")); + } + + if let Concrete(id) = inputs[0] { + match op { + TernaryOperator::Select => { + if !types[id.idx()].is_bool() { + return Error(String::from( + "Select ternary node input cannot have non-bool condition input.", + )); + } + + let data_ty = TypeSemilattice::meet(inputs[1], inputs[2]); + if let Concrete(data_id) = data_ty { + return Concrete(data_id); + } else { + return data_ty; + } + } + } + } + + Error(String::from("Unhandled ternary types.")) + } Node::Call { function: callee_id, dynamic_constants: dc_args, diff --git a/hercules_opt/src/ccp.rs b/hercules_opt/src/ccp.rs index 1381a2f176f275339fa76557891e38289f3676e4..8999730fcb78065433519e1ed219065ccba53ff9 100644 --- a/hercules_opt/src/ccp.rs +++ b/hercules_opt/src/ccp.rs @@ -665,6 +665,62 @@ fn ccp_flow_function( constant: new_constant, } } + Node::Ternary { + first, + second, + third, + op, + } => { + let CCPLattice { + reachability: ref first_reachability, + constant: ref first_constant, + } = inputs[first.idx()]; + let CCPLattice { + reachability: ref second_reachability, + constant: ref second_constant, + } = inputs[second.idx()]; + let CCPLattice { + reachability: ref third_reachability, + constant: ref third_constant, + } = inputs[third.idx()]; + + let new_constant = if let ( + ConstantLattice::Constant(first_cons), + ConstantLattice::Constant(second_cons), + ConstantLattice::Constant(third_cons), + ) = (first_constant, second_constant, third_constant) + { + let new_cons = match(op, first_cons, second_cons, third_cons) { + (TernaryOperator::Select, Constant::Boolean(first_val), second_val, third_val) => if *first_val {second_val.clone()} else {third_val.clone()}, + _ => panic!("Unsupported combination of ternary operation and constant values. Did typechecking succeed?") + }; + ConstantLattice::Constant(new_cons) + } else if (first_constant.is_top() + && !second_constant.is_bottom() + && !third_constant.is_bottom()) + || (!first_constant.is_bottom() + && second_constant.is_top() + && !first_constant.is_bottom()) + || (!first_constant.is_bottom() + && !second_constant.is_bottom() + && third_constant.is_top()) + { + ConstantLattice::top() + } else { + ConstantLattice::meet( + first_constant, + &ConstantLattice::meet(second_constant, third_constant), + ) + }; + + CCPLattice { + reachability: ReachabilityLattice::meet( + first_reachability, + &ReachabilityLattice::meet(second_reachability, third_reachability), + ), + constant: new_constant, + } + } // Call nodes are uninterpretable. Node::Call { function: _,