From 78c4ad8319cce45036c4ecf205adbd9f895f4e86 Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Fri, 21 Feb 2025 09:45:01 -0600 Subject: [PATCH] Formatting --- hercules_cg/src/cpu.rs | 20 +-- hercules_cg/src/gpu.rs | 93 +++++++---- hercules_cg/src/rt.rs | 150 ++++++++++++------ hercules_ir/src/collections.rs | 40 +++-- hercules_ir/src/parse.rs | 6 +- hercules_opt/src/gcm.rs | 28 ++-- hercules_opt/src/interprocedural_sroa.rs | 37 +++-- hercules_opt/src/sroa.rs | 2 +- .../hercules_interpreter/src/interpreter.rs | 3 +- juno_samples/multi_return/src/main.rs | 18 +-- juno_samples/rodinia/backprop/src/main.rs | 38 ++--- juno_scheduler/src/pm.rs | 30 ++-- 12 files changed, 284 insertions(+), 181 deletions(-) diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 1f3ab0a4..6ad38fc0 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -82,17 +82,14 @@ impl<'a> CPUContext<'a> { w, "%return.{} = type {{ {} }}\n", self.function.name, - self.function.return_types + self.function + .return_types .iter() .map(|t| self.get_type(*t)) .collect::<Vec<_>>() .join(", "), )?; - write!( - w, - "define dso_local void @{}(", - self.function.name, - )?; + write!(w, "define dso_local void @{}(", self.function.name,)?; } let mut first_param = true; // The first parameter is a pointer to CPU backing memory, if it's @@ -216,7 +213,9 @@ impl<'a> CPUContext<'a> { let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); - let succ1_is_true = self.function.nodes[succ1.idx()].try_control_projection(1).is_some(); + let succ1_is_true = self.function.nodes[succ1.idx()] + .try_control_projection(1) + .is_some(); write!( term, " br {}, label %{}, label %{}\n", @@ -225,7 +224,10 @@ impl<'a> CPUContext<'a> { self.get_block_name(if succ1_is_true { succ2 } else { succ1 }), )? } - Node::Return { control: _, ref data } => { + Node::Return { + control: _, + ref data, + } => { if data.len() == 1 { let ret_data = data[0]; let term = &mut blocks.get_mut(&id).unwrap().term; @@ -1027,7 +1029,7 @@ fn convert_intrinsic(intrinsic: &Intrinsic, ty: &Type) -> String { } else { panic!() } - }, + } Intrinsic::ACos => "acos", Intrinsic::ASin => "asin", Intrinsic::ATan => "atan", diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index 453d33d5..76aba7e0 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -153,12 +153,14 @@ pub fn gpu_codegen<W: Write>( // Tracks for each return value whether it is always the same parameter // collection let return_parameters = (0..function.return_types.len()) - .map(|idx| if collection_objects.returned_objects(idx).len() == 1 { - collection_objects - .origin(*collection_objects.returned_objects(idx).first().unwrap()) - .try_parameter() - } else { - None + .map(|idx| { + if collection_objects.returned_objects(idx).len() == 1 { + collection_objects + .origin(*collection_objects.returned_objects(idx).first().unwrap()) + .try_parameter() + } else { + None + } }) .collect::<Vec<_>>(); @@ -373,12 +375,17 @@ namespace cg = cooperative_groups; } fn codegen_return_struct<W: Write>(&self, w: &mut W) -> Result<(), Error> { - write!(w, "struct return_{} {{ {} }};\n", - self.function.name, - self.function.return_types.iter().enumerate() - .map(|(idx, typ)| format!("{} f{};", self.get_type(*typ, false), idx)) - .collect::<Vec<_>>() - .join(" "), + write!( + w, + "struct return_{} {{ {} }};\n", + self.function.name, + self.function + .return_types + .iter() + .enumerate() + .map(|(idx, typ)| format!("{} f{};", self.get_type(*typ, false), idx)) + .collect::<Vec<_>>() + .join(" "), ) } @@ -425,10 +432,12 @@ namespace cg = cooperative_groups; .return_parameters .iter() .enumerate() - .filter_map(|(idx, param)| if param.is_some() { - None - } else { - Some((idx, self.function.return_types[idx])) + .filter_map(|(idx, param)| { + if param.is_some() { + None + } else { + Some((idx, self.function.return_types[idx])) + } }) .collect::<Vec<(usize, TypeID)>>(); if !ret_fields.is_empty() { @@ -682,22 +691,26 @@ namespace cg = cooperative_groups; // For case of dynamic block count self.codegen_dynamic_constants(w)?; - let (kernel_returns, param_returns) = - self.return_parameters.iter().enumerate() - .fold((vec![], vec![]), - |(mut kernel_returns, mut param_returns), (idx, param)| { - if let Some(param_idx) = param { - param_returns.push((idx, param_idx)); - } else { - kernel_returns.push((idx, self.function.return_types[idx])); - } - (kernel_returns, param_returns) - }); + let (kernel_returns, param_returns) = self.return_parameters.iter().enumerate().fold( + (vec![], vec![]), + |(mut kernel_returns, mut param_returns), (idx, param)| { + if let Some(param_idx) = param { + param_returns.push((idx, param_idx)); + } else { + kernel_returns.push((idx, self.function.return_types[idx])); + } + (kernel_returns, param_returns) + }, + ); if !kernel_returns.is_empty() { // Allocate kernel return struct write!(w, "\treturn_{}* ret_cuda;\n", self.function.name)?; - write!(w, "\tcudaMalloc((void**)&ret_cuda, sizeof(return_{}));\n", self.function.name)?; + write!( + w, + "\tcudaMalloc((void**)&ret_cuda, sizeof(return_{}));\n", + self.function.name + )?; // Add the return pointer to the kernel arguments if !first_param { write!(pass_args, ", ")?; @@ -737,7 +750,8 @@ namespace cg = cooperative_groups; // the parameter returns if !kernel_returns.is_empty() { // Copy from the device directly into the output struct - write!(w, + write!( + w, "\tcudaMemcpy(ret_ptr, ret_cuda, sizeof(return_{}), cudaMemcpyDeviceToHost);\n", self.function.name, )?; @@ -1627,7 +1641,9 @@ namespace cg = cooperative_groups; let mut succs = self.control_subgraph.succs(id); let succ1 = succs.next().unwrap(); let succ2 = succs.next().unwrap(); - let succ1_is_true = self.function.nodes[succ1.idx()].try_control_projection(1).is_some(); + let succ1_is_true = self.function.nodes[succ1.idx()] + .try_control_projection(1) + .is_some(); let succ1_block_name = self.get_block_name(succ1, false); let succ2_block_name = self.get_block_name(succ2, false); write!( @@ -1780,14 +1796,23 @@ namespace cg = cooperative_groups; } tabs } - Node::Return { control: _, ref data } => { + Node::Return { + control: _, + ref data, + } => { write!(w_term, "\tif (grid.thread_rank() == 0) {{\n")?; - for (idx, (data, param)) in data.iter().zip(self.return_parameters.iter()).enumerate() { + for (idx, (data, param)) in + data.iter().zip(self.return_parameters.iter()).enumerate() + { // For return values that are not identical to some parameter, we write it into // the output struct if !param.is_some() { - write!(w_term, "\t\tret->f{} = {};\n", idx, - self.get_value(*data, false, false))?; + write!( + w_term, + "\t\tret->f{} = {};\n", + idx, + self.get_value(*data, false, false) + )?; } } write!(w_term, "\t}}\n")?; diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 19cf29f0..884129c7 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -218,13 +218,16 @@ impl<'a> RTContext<'a> { // Generate the wrapper function for multi-return device functions write!(w, " {{ ")?; // Define the return struct - write!(w, "#[repr(C)] struct ReturnStruct {{ {} }} ", - callee.return_types - .iter() - .enumerate() - .map(|(idx, t)| format!("f{}: {}", idx, self.get_type(*t))) - .collect::<Vec<_>>() - .join(", "), + write!( + w, + "#[repr(C)] struct ReturnStruct {{ {} }} ", + callee + .return_types + .iter() + .enumerate() + .map(|(idx, t)| format!("f{}: {}", idx, self.get_type(*t))) + .collect::<Vec<_>>() + .join(", "), )?; // Declare the extern function's signature write!(w, "extern \"C\" {{ ")?; @@ -234,7 +237,8 @@ impl<'a> RTContext<'a> { write!(w, "let mut ret_struct: ::std::mem::MaybeUninit<ReturnStruct> = ::std::mem::MaybeUninit::uninit();")?; // Call the device function write!(w, "{}(", callee.name)?; - if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) { + if self.backing_allocations[&callee_id].contains_key(&self.devices[callee_id.idx()]) + { write!(w, "backing, ")?; } for idx in 0..callee.num_dynamic_constants { @@ -246,11 +250,13 @@ impl<'a> RTContext<'a> { write!(w, "ret_struct.as_mut_ptr());")?; // Extract the result into a Rust product write!(w, "let ret_struct = ret_struct.assume_init();")?; - write!(w, "({})", + write!( + w, + "({})", (0..callee.return_types.len()) - .map(|idx| format!("ret_struct.f{}", idx)) - .collect::<Vec<_>>() - .join(", "), + .map(|idx| format!("ret_struct.f{}", idx)) + .collect::<Vec<_>>() + .join(", "), )?; write!(w, "}}")?; } @@ -358,14 +364,19 @@ impl<'a> RTContext<'a> { if succ1_is_true { succ2 } else { succ1 }.idx(), )?; } - Node::Return { control: _, ref data } => { + Node::Return { + control: _, + ref data, + } => { let prologue = &mut blocks.get_mut(&id).unwrap().prologue; write!(prologue, "{} => {{", id.idx())?; let epilogue = &mut blocks.get_mut(&id).unwrap().epilogue; if data.len() == 1 { write!(epilogue, "return {};}}", self.get_value(data[0], id, false))?; } else { - write!(epilogue, "return ({});}}", + write!( + epilogue, + "return ({});}}", data.iter() .map(|v| self.get_value(*v, id, false)) .collect::<Vec<_>>() @@ -684,20 +695,28 @@ impl<'a> RTContext<'a> { } Node::DataProjection { data, selection } => { let block = &mut blocks.get_mut(&bb).unwrap().data; - let Node::Call { function: callee_id, .. } = func.nodes[data.idx()] else { + let Node::Call { + function: callee_id, + .. + } = func.nodes[data.idx()] + else { panic!() }; if self.module.functions[callee_id.idx()].return_types.len() == 1 { assert!(selection == 0); - write!(block, "{} = {};", - self.get_value(id, bb, true), - self.get_value(data, bb, false), + write!( + block, + "{} = {};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), )?; } else { - write!(block, "{} = {}.{};", - self.get_value(id, bb, true), - self.get_value(data, bb, false), - selection, + write!( + block, + "{} = {}.{};", + self.get_value(id, bb, true), + self.get_value(data, bb, false), + selection, )?; } } @@ -1296,11 +1315,14 @@ impl<'a> RTContext<'a> { // its own lifetime. We use lifetime bounds to ensure that the runner // and parameters are borrowed for the lifetimes needed by the outputs let returned_origins: Vec<HashSet<_>> = (0..num_returns) - .map(|idx| objects.returned_objects(idx) - .iter() - .map(|obj| objects.origin(*obj)) - .collect() - ).collect(); + .map(|idx| { + objects + .returned_objects(idx) + .iter() + .map(|obj| objects.origin(*obj)) + .collect() + }) + .collect(); write!(w, "async fn run<'runner:")?; for (ret_idx, origins) in returned_origins.iter().enumerate() { @@ -1314,11 +1336,12 @@ impl<'a> RTContext<'a> { for idx in 0..func.param_types.len() { write!(w, ", 'p{}:", idx)?; for (ret_idx, origins) in returned_origins.iter().enumerate() { - if origins.iter().any(|origin| origin - .try_parameter() - .map(|oidx| idx == oidx) - .unwrap_or(false)) - { + if origins.iter().any(|origin| { + origin + .try_parameter() + .map(|oidx| idx == oidx) + .unwrap_or(false) + }) { write!(w, " 'r{} +", ret_idx)?; } } @@ -1340,18 +1363,19 @@ impl<'a> RTContext<'a> { write!( w, ", p{}: ::hercules_rt::Hercules{}Ref{}<'p{}>", - idx, - device, - mutability, - idx, + idx, device, mutability, idx, )?; } } - write!(w, ") -> {}{}{} {{", + write!( + w, + ") -> {}{}{} {{", if num_returns != 1 { "(" } else { "" }, - func.return_types.iter().enumerate() - .map(|(ret_idx, typ)| - if self.module.types[typ.idx()].is_primitive() { + func.return_types + .iter() + .enumerate() + .map( + |(ret_idx, typ)| if self.module.types[typ.idx()].is_primitive() { self.get_type(*typ) } else { let device = match return_devices[ret_idx] { @@ -1360,8 +1384,10 @@ impl<'a> RTContext<'a> { _ => panic!(), }; let mutability = if return_muts[ret_idx] { "Mut" } else { "" }; - format!("::hercules_rt::Hercules{}Ref{}<'r{}>", - device, mutability, ret_idx) + format!( + "::hercules_rt::Hercules{}Ref{}<'r{}>", + device, mutability, ret_idx + ) } ) .collect::<Vec<_>>() @@ -1535,7 +1561,13 @@ impl<'a> RTContext<'a> { } else if typ.is_float() { "0.0".to_string() } else if let Some(ts) = typ.try_multi_return() { - format!("({})", ts.iter().map(|t| self.get_default_value(*t)).collect::<Vec<_>>().join(", ")) + format!( + "({})", + ts.iter() + .map(|t| self.get_default_value(*t)) + .collect::<Vec<_>>() + .join(", ") + ) } else { "::hercules_rt::__RawPtrSendSync(::core::ptr::null_mut())".to_string() } @@ -1545,7 +1577,9 @@ impl<'a> RTContext<'a> { if tys.len() == 1 { write!(w, "{}", self.get_type(tys[0])) } else { - write!(w, "({})", + write!( + w, + "({})", tys.iter() .map(|t| self.get_type(*t)) .collect::<Vec<_>>() @@ -1558,9 +1592,19 @@ impl<'a> RTContext<'a> { // this means that if the function is multi-return it will return a product in the produced // Rust code // Writes from the "fn" keyword up to the end of the return type - fn write_device_signature_async<W: Write>(&self, w: &mut W, func_id: FunctionID, is_unsafe: bool) -> Result<(), Error> { + fn write_device_signature_async<W: Write>( + &self, + w: &mut W, + func_id: FunctionID, + is_unsafe: bool, + ) -> Result<(), Error> { let func = &self.module.functions[func_id.idx()]; - write!(w, "{}fn {}(", if is_unsafe { "unsafe " } else { "" }, func.name)?; + write!( + w, + "{}fn {}(", + if is_unsafe { "unsafe " } else { "" }, + func.name + )?; let mut first_param = true; if self.backing_allocations[&func_id].contains_key(&self.devices[func_id.idx()]) { first_param = false; @@ -1588,7 +1632,11 @@ impl<'a> RTContext<'a> { // Writes the true signature of a device function // Compared to the _async version this converts multi-return into a return struct - fn write_device_signature<W: Write>(&self, w: &mut W, func_id: FunctionID) -> Result<(), Error> { + fn write_device_signature<W: Write>( + &self, + w: &mut W, + func_id: FunctionID, + ) -> Result<(), Error> { let func = &self.module.functions[func_id.idx()]; write!(w, "fn {}(", func.name)?; let mut first_param = true; @@ -1656,7 +1704,13 @@ fn convert_type(ty: &Type, types: &[Type]) -> String { "::hercules_rt::__RawPtrSendSync".to_string() } Type::MultiReturn(ts) => { - format!("({})", ts.iter().map(|t| convert_type(&types[t.idx()], types)).collect::<Vec<_>>().join(", ")) + format!( + "({})", + ts.iter() + .map(|t| convert_type(&types[t.idx()], types)) + .collect::<Vec<_>>() + .join(", ") + ) } _ => panic!(), } diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index cc0703ab..60f4fb1c 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -97,7 +97,9 @@ impl FunctionCollectionObjects { } pub fn all_returned_objects(&self) -> impl Iterator<Item = CollectionObjectID> + '_ { - self.returned.iter().flat_map(|colls| colls.iter().map(|c| *c)) + self.returned + .iter() + .flat_map(|colls| colls.iter().map(|c| *c)) } pub fn is_mutated(&self, object: CollectionObjectID) -> bool { @@ -187,19 +189,21 @@ pub fn collection_objects( Some(CollectionObjectOrigin::Constant(NodeID::new(idx))) } Node::DataProjection { data, selection } => { - let Node::Call { + let Node::Call { control: _, function: callee, dynamic_constants: _, args: _, - } = func.nodes[data.idx()] else { + } = func.nodes[data.idx()] + else { panic!("Data-projection's data is not a call node"); }; let fco = &collection_objects[&callee]; if fco.returned[*selection] - .iter() - .any(|returned| fco.origins[returned.idx()].try_parameter().is_some()) { + .iter() + .any(|returned| fco.origins[returned.idx()].try_parameter().is_some()) + { // If the callee may return a new collection object, then // this data projection node originates a single collection object. The // node may output multiple collection objects, say if the @@ -283,14 +287,16 @@ pub fn collection_objects( objs: obj.into_iter().collect(), } } - Node::DataProjection { data, selection } - if !types[typing[id.idx()].idx()].is_primitive() => { + Node::DataProjection { data, selection } + if !types[typing[id.idx()].idx()].is_primitive() => + { let Node::Call { control: _, function: callee, dynamic_constants: _, ref args, - } = func.nodes[data.idx()] else { + } = func.nodes[data.idx()] + else { panic!(); }; @@ -299,8 +305,7 @@ pub fn collection_objects( .position(|origin| *origin == CollectionObjectOrigin::DataProjection(id)) .map(CollectionObjectID::new); let fco = &collection_objects[&callee]; - let param_objs = fco - .returned[selection] + let param_objs = fco.returned[selection] .iter() .filter_map(|returned| fco.origins[returned.idx()].try_parameter()) .map(|param_index| &global_input[args[param_index].idx()]); @@ -346,7 +351,8 @@ pub fn collection_objects( .collect(); // Look at the collection objects that each return value may take as input. - let mut returned: Vec<BTreeSet<CollectionObjectID>> = vec![BTreeSet::new(); func.return_types.len()]; + let mut returned: Vec<BTreeSet<CollectionObjectID>> = + vec![BTreeSet::new(); func.return_types.len()]; for node in func.nodes.iter() { if let Node::Return { control: _, data } = node { for (idx, node) in data.iter().enumerate() { @@ -354,7 +360,10 @@ pub fn collection_objects( } } } - let returned = returned.into_iter().map(|set| set.into_iter().collect()).collect(); + let returned = returned + .into_iter() + .map(|set| set.into_iter().collect()) + .collect(); // Determine which objects are potentially mutated. let mut mutated = vec![vec![]; origins.len()]; @@ -523,10 +532,11 @@ pub fn no_reset_constant_collections( collect: _, data, indices: _, - } => { - Either::Left(zip(once(&full_indices), once(data))) + } => Either::Left(zip(once(&full_indices), once(data))), + Node::Return { + control: _, + ref data, } - Node::Return { control: _, ref data } | Node::Call { control: _, function: _, diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index a5a05d0f..9462df4d 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -254,7 +254,8 @@ fn parse_function<'a>( nom::character::complete::multispace0, ), |text| parse_type_id(text, context), - ).parse(ir_text)?; + ) + .parse(ir_text)?; let (ir_text, nodes) = nom::multi::many1(|x| parse_node(x, context)).parse(ir_text)?; // `nodes`, as returned by parsing, is in parse order, which may differ from @@ -512,7 +513,8 @@ fn parse_return<'a>( nom::character::complete::multispace0, ), parse_identifier, - ).parse(ir_text)?; + ) + .parse(ir_text)?; let ir_text = nom::character::complete::multispace0(ir_text)?.0; let ir_text = nom::character::complete::char(')')(ir_text)?.0; let control = context.borrow_mut().get_node_id(control); diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index c2ec4e94..d3119705 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -255,10 +255,7 @@ fn basic_blocks( dynamic_constants: _, args: _, } => bbs[idx] = Some(control), - Node::DataProjection { - data, - selection: _, - } => { + Node::DataProjection { data, selection: _ } => { let Node::Call { control, .. } = function.nodes[data.idx()] else { panic!(); }; @@ -514,10 +511,11 @@ fn basic_blocks( || function.nodes[id.idx()].is_undef()) && !types[typing[id.idx()].idx()].is_primitive(); let is_gpu_returned = devices[func_id.idx()] == Device::CUDA - && objects[&func_id] - .objects(id) - .into_iter() - .any(|obj| objects[&func_id].all_returned_objects().any(|ret| ret == *obj)); + && objects[&func_id].objects(id).into_iter().any(|obj| { + objects[&func_id] + .all_returned_objects() + .any(|ret| ret == *obj) + }); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); @@ -1330,16 +1328,14 @@ fn color_nodes( } } } - Node::DataProjection { - data, - selection, - } => { + Node::DataProjection { data, selection } => { let Node::Call { control: _, function: callee, dynamic_constants: _, ref args, - } = &nodes[data.idx()] else { + } = &nodes[data.idx()] + else { panic!() }; @@ -1388,7 +1384,11 @@ fn color_nodes( { assert!(func_colors.1[index].is_none(), "PANIC: Found multiple parameter nodes for the same index in GCM. Please just run GVN first."); func_colors.1[index] = Some(*device); - } else if let Node::Return { control: _, ref data } = nodes[id.idx()] { + } else if let Node::Return { + control: _, + ref data, + } = nodes[id.idx()] + { for (idx, val) in data.iter().enumerate() { if let Some(device) = func_colors.0.get(val) { assert!(func_colors.2[idx].is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix."); diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index 32fa9cc8..ad4ce19e 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -29,7 +29,11 @@ pub fn interprocedural_sroa( let callsites = get_callsites(editors); - for ((func_id, apply), callsites) in (0..func_selection.len()).map(FunctionID::new).zip(func_selection.iter()).zip(callsites.into_iter()) { + for ((func_id, apply), callsites) in (0..func_selection.len()) + .map(FunctionID::new) + .zip(func_selection.iter()) + .zip(callsites.into_iter()) + { if !apply { continue; } @@ -62,13 +66,17 @@ pub fn interprocedural_sroa( } // Now, modify each return in the current function and the return type - let return_nodes = editor.func().nodes + let return_nodes = editor + .func() + .nodes .iter() .enumerate() - .filter_map(|(idx, node)| if node.try_return().is_some() { - Some(NodeID::new(idx)) - } else { - None + .filter_map(|(idx, node)| { + if node.try_return().is_some() { + Some(NodeID::new(idx)) + } else { + None + } }) .collect::<Vec<_>>(); let success = editor.edit(|mut edit| { @@ -80,7 +88,9 @@ pub fn interprocedural_sroa( let data = data.to_vec(); let mut new_data = vec![]; - for (idx, (data_id, update_info)) in data.into_iter().zip(old_return_type_map.iter()).enumerate() { + for (idx, (data_id, update_info)) in + data.into_iter().zip(old_return_type_map.iter()).enumerate() + { if let IndexTree::Leaf(new_idx) = update_info { // Unchanged return value assert!(new_data.len() == *new_idx); @@ -128,7 +138,11 @@ pub fn interprocedural_sroa( } } -fn sroa_type(editor: &FunctionEditor, typ: TypeID, type_index: usize) -> (Vec<TypeID>, IndexTree<usize>) { +fn sroa_type( + editor: &FunctionEditor, + typ: TypeID, + type_index: usize, +) -> (Vec<TypeID>, IndexTree<usize>) { match &*editor.get_type(typ) { Type::Product(ts) => { let mut res_types = vec![]; @@ -157,7 +171,8 @@ fn get_callsites(editors: &Vec<FunctionEditor>) -> Vec<Vec<(FunctionID, NodeID)> .nodes .iter() .enumerate() - .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c))) { + .filter_map(|(idx, node)| node.try_call().map(|c| (idx, c))) + { assert!(editor.is_mutable(NodeID::new(callsite)), "IP SROA expects to be able to edit everything, specify what functions to IP SROA via the func_selection argument"); callsites[callee.idx()].push((caller, NodeID::new(callsite))); } @@ -178,9 +193,7 @@ fn replace_returned_value( let constant = generate_constant(editor, proj_typ); let success = editor.edit(|mut edit| { - let mut new_val = edit.add_node(Node::Constant { - id: constant, - }); + let mut new_val = edit.add_node(Node::Constant { id: constant }); of_new_call.for_each(|idx, selection| { let new_proj = edit.add_node(Node::DataProjection { data: call_node, diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 68a1b25e..e658ff88 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -985,7 +985,7 @@ pub fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> result = Some(generate_reads_edit(&mut edit, typ, val)); Ok(edit) }); - + result.unwrap() } diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index f9d666a5..8a577839 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -868,7 +868,8 @@ impl<'a> FunctionExecutionState<'a> { } } Node::Return { control: _, data } => { - let results = data.iter() + let results = data + .iter() .map(|data| self.handle_data(&ctrl_token, *data)) .collect(); break 'outer InterpreterVal::MultiReturn(results); diff --git a/juno_samples/multi_return/src/main.rs b/juno_samples/multi_return/src/main.rs index 0b3508a7..b0fd169f 100644 --- a/juno_samples/multi_return/src/main.rs +++ b/juno_samples/multi_return/src/main.rs @@ -9,19 +9,19 @@ fn main() { let a: Box<[f32]> = (1..=N).map(|i| i as f32).collect(); let arg = HerculesImmBox::from(a.as_ref()); let mut r = runner!(rolling_sum_prod); - let (sums, sum, prods, prod) = async_std::task::block_on(async { r.run(N as u64, arg.to()).await }); + let (sums, sum, prods, prod) = + async_std::task::block_on(async { r.run(N as u64, arg.to()).await }); let mut sums = HerculesMutBox::<f32>::from(sums); let mut prods = HerculesMutBox::<f32>::from(prods); - let (expected_sums, expected_sum) = a.iter() - .fold((vec![0.0], 0.0), |(mut sums, sum), v| { - let new_sum = sum + v; - sums.push(new_sum); - (sums, new_sum) - }); - let (expected_prods, expected_prod) = a.iter() - .fold((vec![1.0], 1.0), |(mut prods, prod), v| { + let (expected_sums, expected_sum) = a.iter().fold((vec![0.0], 0.0), |(mut sums, sum), v| { + let new_sum = sum + v; + sums.push(new_sum); + (sums, new_sum) + }); + let (expected_prods, expected_prod) = + a.iter().fold((vec![1.0], 1.0), |(mut prods, prod), v| { let new_prod = prod * v; prods.push(new_prod); (prods, new_prod) diff --git a/juno_samples/rodinia/backprop/src/main.rs b/juno_samples/rodinia/backprop/src/main.rs index 23f78fe4..fa80a7a5 100644 --- a/juno_samples/rodinia/backprop/src/main.rs +++ b/juno_samples/rodinia/backprop/src/main.rs @@ -37,28 +37,22 @@ fn run_backprop( let mut hidden_prev_weights = HerculesMutBox::from(hidden_prev_weights.to_vec()); let mut runner = runner!(backprop); - let ( - out_err, - hid_err, - input_weights, - input_prev_weights, - hidden_weights, - hidden_prev_weights - ) = async_std::task::block_on(async { - runner - .run( - input_n, - hidden_n, - output_n, - input_vals.to(), - input_weights.to(), - hidden_weights.to(), - target.to(), - input_prev_weights.to(), - hidden_prev_weights.to(), - ) - .await - }); + let (out_err, hid_err, input_weights, input_prev_weights, hidden_weights, hidden_prev_weights) = + async_std::task::block_on(async { + runner + .run( + input_n, + hidden_n, + output_n, + input_vals.to(), + input_weights.to(), + hidden_weights.to(), + target.to(), + input_prev_weights.to(), + hidden_prev_weights.to(), + ) + .await + }); let mut input_weights = HerculesMutBox::from(input_weights); let mut hidden_weights = HerculesMutBox::from(hidden_weights); let mut input_prev_weights = HerculesMutBox::from(input_prev_weights); diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs index 2c288097..84b25811 100644 --- a/juno_scheduler/src/pm.rs +++ b/juno_scheduler/src/pm.rs @@ -2088,15 +2088,16 @@ fn run_pass( None => false, }; - let selection = selection_of_functions(pm, selection) - .ok_or_else(|| { - SchedulerError::PassError { - pass: "xdot".to_string(), - error: "expected coarse-grained selection (can't partially xdot a function)".to_string(), - } + let selection = + selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError { + pass: "xdot".to_string(), + error: "expected coarse-grained selection (can't partially xdot a function)" + .to_string(), })?; let mut bool_selection = vec![false; pm.functions.len()]; - selection.into_iter().for_each(|func| bool_selection[func.idx()] = true); + selection + .into_iter() + .for_each(|func| bool_selection[func.idx()] = true); pm.make_typing(); let typing = pm.typing.take().unwrap(); @@ -2733,15 +2734,16 @@ fn run_pass( None => true, }; - let selection = selection_of_functions(pm, selection) - .ok_or_else(|| { - SchedulerError::PassError { - pass: "xdot".to_string(), - error: "expected coarse-grained selection (can't partially xdot a function)".to_string(), - } + let selection = + selection_of_functions(pm, selection).ok_or_else(|| SchedulerError::PassError { + pass: "xdot".to_string(), + error: "expected coarse-grained selection (can't partially xdot a function)" + .to_string(), })?; let mut bool_selection = vec![false; pm.functions.len()]; - selection.into_iter().for_each(|func| bool_selection[func.idx()] = true); + selection + .into_iter() + .for_each(|func| bool_selection[func.idx()] = true); pm.make_reverse_postorders(); if force_analyses { -- GitLab