From 2b326bd3dc54df360ddd41fb4d6da183944eacf7 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Tue, 18 Feb 2025 16:48:04 -0600 Subject: [PATCH] IR changes for multi-return --- hercules_ir/src/build.rs | 15 +++++--- hercules_ir/src/collections.rs | 14 ++++---- hercules_ir/src/def_use.rs | 22 +++++++++--- hercules_ir/src/ir.rs | 63 +++++++++++++++++++++++++--------- hercules_ir/src/parse.rs | 55 ++++++++++++++++++++++++----- hercules_ir/src/typecheck.rs | 44 ++++++++++++++++++++---- hercules_ir/src/verify.rs | 14 +++++--- 7 files changed, 176 insertions(+), 51 deletions(-) diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 40538cef..3e966d53 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -392,6 +392,7 @@ impl<'a> Builder<'a> { pub fn create_constant_zero(&mut self, typ: TypeID) -> ConstantID { match &self.module.types[typ.idx()] { Type::Control => panic!("Cannot create constant for control types"), + Type::MultiReturn(..) => panic!("Cannot create constant for multi-return types"), Type::Boolean => self.create_constant_bool(false), Type::Integer8 => self.create_constant_i8(0), Type::Integer16 => self.create_constant_i16(0), @@ -503,7 +504,7 @@ impl<'a> Builder<'a> { &mut self, name: &str, param_types: Vec<TypeID>, - return_type: TypeID, + return_types: Vec<TypeID>, num_dynamic_constants: u32, entry: bool, ) -> BuilderResult<(FunctionID, NodeID)> { @@ -515,7 +516,7 @@ impl<'a> Builder<'a> { self.module.functions.push(Function { name: name.to_owned(), param_types, - return_type, + return_types, num_dynamic_constants, entry, nodes: vec![Node::Start], @@ -594,11 +595,15 @@ impl NodeBuilder { }; } - pub fn build_projection(&mut self, control: NodeID, selection: usize) { - self.node = Node::Projection { control, selection }; + pub fn build_control_projection(&mut self, control: NodeID, selection: usize) { + self.node = Node::ControlProjection { control, selection }; } - pub fn build_return(&mut self, control: NodeID, data: NodeID) { + pub fn build_data_projection(&mut self, data: NodeID, selection: usize) { + self.node = Node::DataProjection { data, selection }; + } + + pub fn build_return(&mut self, control: NodeID, data: Box<[NodeID]>) { self.node = Node::Return { control, data }; } diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 6b631519..c4e71f8b 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -328,7 +328,9 @@ pub fn collection_objects( let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new(); for node in func.nodes.iter() { if let Node::Return { control: _, data } = node { - returned.extend(&objects_per_node[data.idx()]); + for node in data { + returned.extend(&objects_per_node[node.idx()]); + } } } let returned = returned.into_iter().collect(); @@ -500,16 +502,16 @@ pub fn no_reset_constant_collections( collect: _, data, indices: _, - } - | Node::Return { control: _, data } => { + } => { Either::Left(zip(once(&full_indices), once(data))) } - Node::Call { + Node::Return { control: _, ref data } + | Node::Call { control: _, function: _, dynamic_constants: _, - ref args, - } => Either::Right(zip(repeat(&full_indices), args.into_iter().map(|id| *id))), + args: ref data, + } => Either::Right(zip(repeat(&full_indices), data.into_iter().map(|id| *id))), _ => return None, }; diff --git a/hercules_ir/src/def_use.rs b/hercules_ir/src/def_use.rs index ff0e08ed..a99c8a23 100644 --- a/hercules_ir/src/def_use.rs +++ b/hercules_ir/src/def_use.rs @@ -156,7 +156,11 @@ pub fn get_uses(node: &Node) -> NodeUses { init, reduct, } => NodeUses::Three([*control, *init, *reduct]), - Node::Return { control, data } => NodeUses::Two([*control, *data]), + Node::Return { control, data } => { + let mut uses: Vec<NodeID> = Vec::from(&data[..]); + uses.push(*control); + NodeUses::Variable(uses.into_boxed_slice()) + } Node::Parameter { index: _ } => NodeUses::One([NodeID::new(0)]), Node::Constant { id: _ } => NodeUses::One([NodeID::new(0)]), Node::DynamicConstant { id: _ } => NodeUses::One([NodeID::new(0)]), @@ -222,10 +226,14 @@ pub fn get_uses(node: &Node) -> NodeUses { NodeUses::Two([*collect, *data]) } } - Node::Projection { + Node::ControlProjection { control, selection: _, } => NodeUses::One([*control]), + Node::DataProjection { + data, + selection: _, + } => NodeUses::One([*data]), Node::Undef { ty: _ } => NodeUses::One([NodeID::new(0)]), } } @@ -260,7 +268,9 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { init, reduct, } => NodeUsesMut::Three([control, init, reduct]), - Node::Return { control, data } => NodeUsesMut::Two([control, data]), + Node::Return { control, data } => { + NodeUsesMut::Variable(std::iter::once(control).chain(data.iter_mut()).collect()) + } Node::Parameter { index: _ } => NodeUsesMut::Zero, Node::Constant { id: _ } => NodeUsesMut::Zero, Node::DynamicConstant { id: _ } => NodeUsesMut::Zero, @@ -326,10 +336,14 @@ pub fn get_uses_mut<'a>(node: &'a mut Node) -> NodeUsesMut<'a> { NodeUsesMut::Two([collect, data]) } } - Node::Projection { + Node::ControlProjection { control, selection: _, } => NodeUsesMut::One([control]), + Node::DataProjection { + data, + selection: _, + } => NodeUsesMut::One([data]), Node::Undef { ty: _ } => NodeUsesMut::Zero, } } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index bf9698b3..68fdc26c 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -39,7 +39,7 @@ pub struct Module { pub struct Function { pub name: String, pub param_types: Vec<TypeID>, - pub return_type: TypeID, + pub return_types: Vec<TypeID>, pub num_dynamic_constants: u32, pub entry: bool, @@ -77,6 +77,7 @@ pub enum Type { Product(Box<[TypeID]>), Summation(Box<[TypeID]>), Array(TypeID, Box<[DynamicConstantID]>), + MultiReturn(Box<[TypeID]>), } /* @@ -186,7 +187,7 @@ pub enum Node { }, Return { control: NodeID, - data: NodeID, + data: Box<[NodeID]>, }, Parameter { index: usize, @@ -237,10 +238,14 @@ pub enum Node { data: NodeID, indices: Box<[Index]>, }, - Projection { + ControlProjection { control: NodeID, selection: usize, }, + DataProjection { + data: NodeID, + selection: usize, + }, Undef { ty: TypeID, }, @@ -434,6 +439,17 @@ impl Module { } write!(w, ")") } + Type::MultiReturn(fields) => { + write!(w, "MultiReturn(")?; + for idx in 0..fields.len() { + let field_ty_id = fields[idx]; + self.write_type(field_ty_id, w)?; + if idx + 1 < fields.len() { + write!(w, ", ")?; + } + } + write!(w, ")") + } }?; Ok(()) @@ -1262,12 +1278,19 @@ impl Node { ); define_pattern_predicate!(is_match, Node::Match { control: _, sum: _ }); define_pattern_predicate!( - is_projection, - Node::Projection { + is_control_projection, + Node::ControlProjection { control: _, selection: _ } ); + define_pattern_predicate!( + is_data_projection, + Node::DataProjection { + data: _, + selection: _ + } + ); define_pattern_predicate!(is_undef, Node::Undef { ty: _ }); @@ -1287,8 +1310,8 @@ impl Node { } } - pub fn try_proj(&self) -> Option<(NodeID, usize)> { - if let Node::Projection { control, selection } = self { + pub fn try_control_proj(&self) -> Option<(NodeID, usize)> { + if let Node::ControlProjection { control, selection } = self { Some((*control, *selection)) } else { None @@ -1303,9 +1326,9 @@ impl Node { } } - pub fn try_return(&self) -> Option<(NodeID, NodeID)> { + pub fn try_return(&self) -> Option<(NodeID, &[NodeID])> { if let Node::Return { control, data } = self { - Some((*control, *data)) + Some((*control, data)) } else { None } @@ -1479,8 +1502,8 @@ impl Node { } } - pub fn try_projection(&self, branch: usize) -> Option<NodeID> { - if let Node::Projection { control, selection } = self + pub fn try_control_projection(&self, branch: usize) -> Option<NodeID> { + if let Node::ControlProjection { control, selection } = self && branch == *selection { Some(*control) @@ -1560,10 +1583,14 @@ impl Node { data: _, indices: _, } => "Write", - Node::Projection { + Node::ControlProjection { control: _, selection: _, - } => "Projection", + } => "ControlProjection", + Node::DataProjection { + data: _, + selection: _, + } => "DataProjection", Node::Undef { ty: _ } => "Undef", } } @@ -1639,10 +1666,14 @@ impl Node { data: _, indices: _, } => "write", - Node::Projection { + Node::ControlProjection { control: _, selection: _, - } => "projection", + } => "control_projection", + Node::DataProjection { + data: _, + selection: _, + } => "data_projection", Node::Undef { ty: _ } => "undef", } } @@ -1655,7 +1686,7 @@ impl Node { || self.is_fork() || self.is_join() || self.is_return() - || self.is_projection() + || self.is_control_projection() } } diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index f1f4153a..42730f77 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -139,7 +139,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a Function { name: String::from(""), param_types: vec![], - return_type: TypeID::new(0), + return_types: vec![], num_dynamic_constants: 0, entry: true, nodes: vec![], @@ -245,7 +245,14 @@ fn parse_function<'a>( let ir_text = nom::character::complete::char(')')(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0; - let (ir_text, return_type) = parse_type_id(ir_text, context)?; + let (ir_text, return_types) = nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + |text| parse_type_id(text, context), + )(ir_text)?; let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context))(ir_text)?; // `nodes`, as returned by parsing, is in parse order, which may differ from @@ -286,7 +293,7 @@ fn parse_function<'a>( Function { name: String::from(function_name), param_types: params.into_iter().map(|x| x.5).collect(), - return_type, + return_types, num_dynamic_constants, entry: true, nodes: fixed_nodes, @@ -334,7 +341,8 @@ fn parse_node<'a>( "return" => parse_return(ir_text, context)?, "constant" => parse_constant_node(ir_text, context)?, "dynamic_constant" => parse_dynamic_constant_node(ir_text, context)?, - "projection" => parse_projection(ir_text, context)?, + "control_projection" => parse_control_projection(ir_text, context)?, + "data_projection" => parse_data_projection(ir_text, context)?, // Unary and binary ops are spelled out in the textual format, but we // parse them into Unary or Binary node kinds. "not" => parse_unary(ir_text, context, UnaryOperator::Not)?, @@ -489,9 +497,21 @@ fn parse_return<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Node> { - let (ir_text, (control, data)) = parse_tuple2(parse_identifier, parse_identifier)(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char('(')(ir_text)?.0; + let (ir_text, control) = parse_identifier(ir_text)?; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let ir_text = nom::character::complete::char(',')(ir_text)?.0; + let ir_text = nom::character::complete::multispace0(ir_text)?.0; + let (ir_text, data) = nom::multi::separated_list1( + nom::sequence::tuple(( + nom::character::complete::multispace0, + nom::character::complete::char(','), + nom::character::complete::multispace0, + )), + parse_identifier)(ir_text)?; let control = context.borrow_mut().get_node_id(control); - let data = context.borrow_mut().get_node_id(data); + let data = data.into_iter().map(|d| context.borrow_mut().get_node_id(d)).collect(); Ok((ir_text, Node::Return { control, data })) } @@ -719,7 +739,7 @@ fn parse_index<'a>( Ok((ir_text, idx)) } -fn parse_projection<'a>( +fn parse_control_projection<'a>( ir_text: &'a str, context: &RefCell<Context<'a>>, ) -> nom::IResult<&'a str, Node> { @@ -728,13 +748,29 @@ fn parse_projection<'a>( let control = context.borrow_mut().get_node_id(control); Ok(( ir_text, - Node::Projection { + Node::ControlProjection { control, selection: index, }, )) } +fn parse_data_projection<'a>( + ir_text: &'a str, + context: &RefCell<Context<'a>>, +) -> nom::IResult<&'a str, Node> { + let parse_usize = |x| parse_prim::<usize>(x, "1234567890"); + let (ir_text, (data, index)) = parse_tuple2(parse_identifier, parse_usize)(ir_text)?; + let data = context.borrow_mut().get_node_id(data); + Ok(( + ir_text, + Node::DataProjection { + data, + selection: index, + }, + )) +} + fn parse_read<'a>(ir_text: &'a str, context: &RefCell<Context<'a>>) -> nom::IResult<&'a str, Node> { let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::char('(')(ir_text)?.0; @@ -991,7 +1027,8 @@ fn parse_constant<'a>( ) -> nom::IResult<&'a str, Constant> { let (ir_text, constant) = match ty { // There are not control constants. - Type::Control => Err(nom::Err::Error(nom::error::Error { + Type::Control + | Type::MultiReturn(_) => Err(nom::Err::Error(nom::error::Error { input: ir_text, code: nom::error::ErrorKind::IsNot, }))?, diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 1ff890db..dca11fe7 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -441,12 +441,14 @@ fn typeflow( return inputs[0].clone(); } - if let Concrete(id) = inputs[1] { - if *id != function.return_type { - return Error(String::from("Return node's data input type must be the same as the function's return type.")); + for (idx, (input, return_type)) in inputs[1..].iter().zip(function.return_types.iter()).enumerate() { + if let Concrete(id) = input { + if *id != *return_type { + return Error(format!("Return node's data input at index {} does not match function's return type.", idx)); + } + } else if input.is_error() { + return (*input).clone(); } - } else if inputs[1].is_error() { - return inputs[1].clone(); } Concrete(get_type_id( @@ -759,7 +761,7 @@ fn typeflow( } } - Concrete(subst.type_subst(callee.return_type)) + Concrete(subst.build_return_type(&callee.return_types)) } Node::IntrinsicCall { intrinsic, args: _ } => { let num_params = match intrinsic { @@ -1061,13 +1063,35 @@ fn typeflow( TypeSemilattice::Error(msg) => TypeSemilattice::Error(msg), } } - Node::Projection { + Node::ControlProjection { control: _, selection: _, } => { // Type is the type of the _if node inputs[0].clone() } + Node::DataProjection { + data: _, + selection, + } => { + if let Concrete(type_id) = inputs[0] { + match &types[type_id.idx()] { + Type::MultiReturn(types) => { + if *selection >= types.len() { + return Error(String::from("Data projection's selection must be in range of the multi-return being indexed")); + } + return Concrete(*type_id); + } + _ => { + return Error(String::from( + "Data projection node must read from multi-return value.", + )); + } + } + } + + inputs[0].clone() + } Node::Undef { ty } => TypeSemilattice::Concrete(*ty), } } @@ -1138,6 +1162,11 @@ impl<'a> DCSubst<'a> { } } + fn build_return_type(&mut self, tys: &[TypeID]) -> TypeID { + let tys = tys.iter().map(|t| self.type_subst(*t)).collect(); + self.intern_type(Type::MultiReturn(tys)) + } + fn type_subst(&mut self, typ: TypeID) -> TypeID { match &self.types[typ.idx()] { Type::Control @@ -1172,6 +1201,7 @@ impl<'a> DCSubst<'a> { let new_elem = self.type_subst(elem); self.intern_type(Type::Array(new_elem, new_dims)) } + Type::MultiReturn(..) => panic!("A multi-return type should never be substituted"), } } diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs index f188932e..b50ab0d2 100644 --- a/hercules_ir/src/verify.rs +++ b/hercules_ir/src/verify.rs @@ -251,11 +251,11 @@ fn verify_structure( Err(format!("If node must have 2 users, not {}.", users.len()))?; } if let ( - Node::Projection { + Node::ControlProjection { control: _, selection: result1, }, - Node::Projection { + Node::ControlProjection { control: _, selection: result2, }, @@ -290,7 +290,8 @@ fn verify_structure( Err("ThreadID node's control input must be a fork node.")?; } } - // Call nodes must depend on a region node. + // Call nodes must depend on a region node and its only users must + // be DataProjections. Node::Call { control, function: _, @@ -300,6 +301,11 @@ fn verify_structure( if !function.nodes[control.idx()].is_region() { Err("Call node's control input must be a region node.")?; } + for user in users { + if !function.nodes[user.idx()].is_data_projection() { + Err("Call node users must be DataProjection nodes.")?; + } + } } // Reduce nodes must depend on a join node. Node::Reduce { @@ -339,7 +345,7 @@ fn verify_structure( } let mut users_covered = bitvec![u8, Lsb0; 0; users.len()]; for user in users { - if let Node::Projection { + if let Node::ControlProjection { control: _, ref selection, } = function.nodes[user.idx()] -- GitLab