diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index af2420d83a550e7a50ff22962302612f21627995..446231de1677a34e6e6e0be43b380b6737905f5e 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -23,7 +23,7 @@ pub const LARGEST_ALIGNMENT: usize = 32; */ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize { match types[ty.idx()] { - Type::Control => panic!(), + Type::Control | Type::MultiReturn(_) => panic!(), Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => 1, Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => 2, Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => 4, @@ -46,7 +46,7 @@ pub fn get_type_alignment(types: &Vec<Type>, ty: TypeID) -> usize { pub type FunctionNodeColors = ( BTreeMap<NodeID, Device>, Vec<Option<Device>>, - Option<Device>, + Vec<Option<Device>>, ); pub type NodeColors = BTreeMap<FunctionID, FunctionNodeColors>; diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index fb3e6bbda7e50397112078f87d834b76a1ff5695..cc0703abbe6f95ff250d75ed0cad63f4019eaa08 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -96,6 +96,10 @@ impl FunctionCollectionObjects { &self.returned[selection] } + pub fn all_returned_objects(&self) -> impl Iterator<Item = CollectionObjectID> + '_ { + self.returned.iter().flat_map(|colls| colls.iter().map(|c| *c)) + } + pub fn is_mutated(&self, object: CollectionObjectID) -> bool { !self.mutators(object).is_empty() } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 7a0158fbc3d4dcef22f1f6231ed67d39b84346a8..3d625a3931b5125f2259b11f7833cc281518e881 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -863,6 +863,14 @@ impl Type { } } + pub fn is_multireturn(&self) -> bool { + if let Type::MultiReturn(_) = self { + true + } else { + false + } + } + pub fn is_bool(&self) -> bool { self == &Type::Boolean } diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index a019f4d37d2ebfa645d8c73729e93e49290c13d3..d61ff6e76716dd4ff4243a41deab0930e6a3dfdb 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -248,11 +248,11 @@ fn parse_function<'a>( let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::bytes::complete::tag("->")(ir_text)?.0; let (ir_text, return_types) = nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), |text| parse_type_id(text, context), ).parse(ir_text)?; let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context)).parse(ir_text)?; @@ -506,13 +506,13 @@ fn parse_return<'a>( let ir_text = nom::character::complete::char(',')(ir_text)?.0; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let (ir_text, data) = nom::multi::separated_list1( - nom::sequence::tuple(( + ( nom::character::complete::multispace0, nom::character::complete::char(','), nom::character::complete::multispace0, - )), + ), parse_identifier, - )(ir_text)?; + ).parse(ir_text)?; let control = context.borrow_mut().get_node_id(control); let data = data .into_iter() diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index f240589300520d680c1e126ba93728977f1d17e8..c2ec4e9498dabab01e3041a4e502ca426f36a70f 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -127,7 +127,7 @@ pub fn gcm( let mut alignments = vec![]; Ref::map(editor.get_types(), |types| { for idx in 0..types.len() { - if types[idx].is_control() { + if types[idx].is_control() || types[idx].is_multireturn() { alignments.push(0); } else { alignments.push(get_type_alignment(types, TypeID::new(idx))); @@ -255,6 +255,15 @@ fn basic_blocks( dynamic_constants: _, args: _, } => bbs[idx] = Some(control), + Node::DataProjection { + data, + selection: _, + } => { + let Node::Call { control, .. } = function.nodes[data.idx()] else { + panic!(); + }; + bbs[idx] = Some(control); + } Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)), _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)), _ => {} @@ -508,7 +517,7 @@ fn basic_blocks( && objects[&func_id] .objects(id) .into_iter() - .any(|obj| objects[&func_id].returned_objects().contains(obj)); + .any(|obj| objects[&func_id].all_returned_objects().any(|ret| ret == *obj)); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); @@ -646,9 +655,9 @@ fn terminating_reads<'a>( ref args, } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| { let objects = &objects[&callee]; - let returns = objects.returned_objects(); + let mut returns = objects.all_returned_objects(); let param_obj = objects.param_to_object(idx)?; - if !objects.is_mutated(param_obj) && !returns.contains(¶m_obj) { + if !objects.is_mutated(param_obj) && !returns.any(|ret| ret == param_obj) { Some(*arg) } else { None @@ -692,9 +701,9 @@ fn forwarding_reads<'a>( ref args, } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| { let objects = &objects[&callee]; - let returns = objects.returned_objects(); + let mut returns = objects.all_returned_objects(); let param_obj = objects.param_to_object(idx)?; - if !objects.is_mutated(param_obj) && returns.contains(¶m_obj) { + if !objects.is_mutated(param_obj) && returns.any(|ret| ret == param_obj) { Some(*arg) } else { None @@ -1218,7 +1227,7 @@ fn color_nodes( let mut func_colors = ( BTreeMap::new(), vec![None; editor.func().param_types.len()], - None, + vec![None; editor.func().return_types.len()], ); // Assigning nodes to devices is tricky due to function calls. Technically, @@ -1320,17 +1329,31 @@ fn color_nodes( equations.push((UTerm::Node(*arg), UTerm::Device(device))); } } + } + Node::DataProjection { + data, + selection, + } => { + let Node::Call { + control: _, + function: callee, + dynamic_constants: _, + ref args, + } = &nodes[data.idx()] else { + panic!() + }; - // If the callee has a definite device for the returned value, - // add an equation for the call node itself. - if let Some(device) = node_colors[&callee].2 { + // If the callee has a definite device for this returned value, + // add an equation for the data projection node itself. + if let Some(device) = node_colors[&callee].2[selection] { equations.push((UTerm::Node(id), UTerm::Device(device))); } - // For any object that may be returned by the callee that - // originates as a parameter in the callee, the device of the - // corresponding argument and call node itself must be equal. - for ret in objects[&callee].returned_objects() { + // For any object that may be returned in this position by the + // callee that originates as a parameter in the callee, the + // device of the corresponding argument and the data projection + // must be equal. + for ret in objects[&callee].returned_objects(selection) { if let Some(idx) = objects[&callee].origin(*ret).try_parameter() { equations.push((UTerm::Node(args[idx]), UTerm::Node(id))); } @@ -1365,11 +1388,13 @@ fn color_nodes( { assert!(func_colors.1[index].is_none(), "PANIC: Found multiple parameter nodes for the same index in GCM. Please just run GVN first."); func_colors.1[index] = Some(*device); - } else if let Node::Return { control: _, data } = nodes[id.idx()] - && let Some(device) = func_colors.0.get(&data) - { - assert!(func_colors.2.is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix."); - func_colors.2 = Some(*device); + } else if let Node::Return { control: _, ref data } = nodes[id.idx()] { + for (idx, val) in data.iter().enumerate() { + if let Some(device) = func_colors.0.get(val) { + assert!(func_colors.2[idx].is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix."); + func_colors.2[idx] = Some(*device); + } + } } } Some(func_colors) @@ -1420,6 +1445,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> let ty = edit.get_type(ty_id).clone(); let size = match ty { Type::Control => panic!(), + Type::MultiReturn(_) => panic!(), Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { edit.add_dynamic_constant(DynamicConstant::Constant(1)) } diff --git a/juno_samples/multi_return/src/cpu.sch b/juno_samples/multi_return/src/cpu.sch index 03fb2585cff07dff23d26bbd66296611e5555f46..972405f50bb0f2a6a0574807ff5aeb64decef4fd 100644 --- a/juno_samples/multi_return/src/cpu.sch +++ b/juno_samples/multi_return/src/cpu.sch @@ -4,6 +4,10 @@ dce(*); ip-sroa(*); sroa(*); + +ip-sroa[true](rolling_sum); +sroa[true](rolling_sum, rolling_sum_prod); + dce(*); forkify(*); @@ -29,3 +33,4 @@ ccp(*); dce(*); gcm(*); +xdot[true](*); diff --git a/juno_samples/multi_return/src/multi_return.jn b/juno_samples/multi_return/src/multi_return.jn index a49df91c2179d76242965f088cda93c261ebec28..84bab01542bd63db8220bb8c0133bb609b106d10 100644 --- a/juno_samples/multi_return/src/multi_return.jn +++ b/juno_samples/multi_return/src/multi_return.jn @@ -1,4 +1,4 @@ -fn rolling_sum<t: number, n: usize>(x: t[n]) -> t, t[n + 1] { +fn rolling_sum<t: number, n: usize>(x: t[n]) -> (t, t[n + 1]) { let rolling_sum: t[n + 1]; let sum = 0;