diff --git a/Cargo.lock b/Cargo.lock index 48b95114f63276542a350dfef6058abbbc083a1b..e85f4f67dc04e3dc5a367a1750112e16af8b8ef0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -259,12 +259,6 @@ dependencies = [ "arrayvec", ] -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - [[package]] name = "bincode" version = "1.3.3" @@ -291,9 +285,6 @@ name = "bitflags" version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" -dependencies = [ - "serde", -] [[package]] name = "bitstream-io" @@ -367,6 +358,7 @@ name = "call" version = "0.1.0" dependencies = [ "async-std", + "hercules_rt", "juno_build", "rand", "with_builtin_macros", @@ -388,6 +380,7 @@ name = "ccp" version = "0.1.0" dependencies = [ "async-std", + "hercules_rt", "juno_build", "rand", "with_builtin_macros", @@ -639,6 +632,7 @@ version = "0.1.0" dependencies = [ "async-std", "clap", + "hercules_rt", "juno_build", "rand", "with_builtin_macros", @@ -825,17 +819,6 @@ dependencies = [ "serde", ] -[[package]] -name = "hercules_driver" -version = "0.1.0" -dependencies = [ - "clap", - "hercules_ir", - "hercules_opt", - "postcard", - "ron", -] - [[package]] name = "hercules_ir" version = "0.1.0" @@ -1013,6 +996,7 @@ name = "juno_casts_and_intrinsics" version = "0.1.0" dependencies = [ "async-std", + "hercules_rt", "juno_build", "with_builtin_macros", ] @@ -1048,6 +1032,7 @@ dependencies = [ "hercules_ir", "hercules_opt", "juno_scheduler", + "juno_utils", "lrlex", "lrpar", "num-rational", @@ -1087,14 +1072,29 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_schedule_test" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "rand", + "with_builtin_macros", +] + [[package]] name = "juno_scheduler" version = "0.0.1" dependencies = [ "cfgrammar", + "hercules_cg", "hercules_ir", + "hercules_opt", + "juno_utils", "lrlex", "lrpar", + "tempfile", ] [[package]] @@ -1107,6 +1107,13 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_utils" +version = "0.1.0" +dependencies = [ + "serde", +] + [[package]] name = "kv-log-macro" version = "1.0.7" @@ -1744,18 +1751,6 @@ version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" -[[package]] -name = "ron" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" -dependencies = [ - "base64", - "bitflags 2.7.0", - "serde", - "serde_derive", -] - [[package]] name = "rustc_version" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index 4e5826caf264974b1e50e13f0bf91083ad43e097..c7e005fef9dcf76395dbff484ac6fb081503d59b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,7 @@ members = [ "hercules_opt", "hercules_rt", - "hercules_tools/hercules_driver", - + "juno_utils", "juno_frontend", "juno_scheduler", "juno_build", @@ -27,7 +26,7 @@ members = [ "juno_samples/nested_ccp", "juno_samples/antideps", "juno_samples/implicit_clone", + "juno_samples/cava", "juno_samples/concat", - - "juno_samples/cava", + "juno_samples/schedule_test", ] diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index 85139b4c573c2287169715599336bd5b2146e8be..3750c4f6abbac3a774269c729eaded8afcc204c3 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -638,7 +638,7 @@ impl<'a> CPUContext<'a> { fn codegen_index_math( &self, collect_name: &str, - collect_ty: TypeID, + mut collect_ty: TypeID, indices: &[Index], body: &mut String, ) -> Result<String, Error> { @@ -665,11 +665,16 @@ impl<'a> CPUContext<'a> { get_type_alignment(&self.types, fields[*idx]), body, )?; - acc_ptr = Self::gep(collect_name, &acc_offset, body)?; + acc_ptr = Self::gep(&acc_ptr, &acc_offset, body)?; + collect_ty = fields[*idx]; } - Index::Variant(_) => { + Index::Variant(idx) => { // The tag of a summation is at the end of the summation, so // the variant pointer is just the base pointer. Do nothing. + let Type::Summation(ref variants) = self.types[collect_ty.idx()] else { + panic!() + }; + collect_ty = variants[*idx]; } Index::Position(ref pos) => { let Type::Array(elem, ref dims) = self.types[collect_ty.idx()] else { @@ -690,7 +695,8 @@ impl<'a> CPUContext<'a> { // Convert offset in # elements -> # bytes. acc_offset = Self::multiply(&acc_offset, &elem_size, body)?; - acc_ptr = Self::gep(collect_name, &acc_offset, body)?; + acc_ptr = Self::gep(&acc_ptr, &acc_offset, body)?; + collect_ty = elem; } } } diff --git a/hercules_cg/src/gpu.rs b/hercules_cg/src/gpu.rs index d960de89a6622beb5c9cbd8f44b66a8c1a52fd0e..1a9b6869324f724c58dd4438a1d7f856d91fb6eb 100644 --- a/hercules_cg/src/gpu.rs +++ b/hercules_cg/src/gpu.rs @@ -614,11 +614,15 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; * maximum over its descendants (leafs have base 1). We traverse up (details * in helper) and pass the factor and a map from fork node to a tuple of * (max quota of its siblings (including itself), its quota, its fork factor) - * from each node to its parents. The parent then compares - * - all three are needed for codegen. A node is in the map IFF it will be parallelized. - * If not, the fork will use the parent's quota and serialize over the Fork's - * ThreadIDs. Nodes may be removed from the map when traversing up the tree - * due to an ancestor having a larger factor that conflicts. + * from each node to its parents. The parent then compares the received quota + * of its subtree vs just it's own. If it's associative, it chooses the larger + * of the two, if not it can parallelize both if applicable and if they fit. + * + * Finally, the map is returned such that a node is in the map IFF it will + * be parallelized. If not, the fork will use the parent's quota and serialize + * over the Fork's ThreadIDs. Nodes may be removed from the map when traversing + * up the tree due to conflicting (due to associative or limit) ancestor of + * larger factor. */ fn get_thread_quotas( &self, @@ -629,7 +633,6 @@ extern \"C\" {} {}(", ret_type.clone(), self.function.name)?; (tree_map, tree_quota) } - // Helper function for post-order traversal of fork tree fn recurse_thread_quotas( &self, curr_fork: NodeID, diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index fbab6dbcc7d04329cc5fe341cd2b7c7292dd91bf..ad44ecd0e718cfa633c35dfa6b13645563238c29 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -16,18 +16,6 @@ pub use crate::fork_tree::*; use hercules_ir::*; -/* - * Basic block info consists of two things: - * - * 1. A map from node to block (named by control nodes). - * 2. For each node, which nodes are in its own block. - * - * Note that for #2, the structure is Vec<NodeID>, meaning the nodes are ordered - * inside the block. This order corresponds to the traversal order of the nodes - * in the block needed by the backend code generators. - */ -pub type BasicBlocks = (Vec<NodeID>, Vec<Vec<NodeID>>); - /* * The alignment of a type does not depend on dynamic constants. */ diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index 6e56ebc48d381108f4483411b4e497bed6aec8ed..d093b2b0f937d796e8d43c707693ee976e673081 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -53,7 +53,7 @@ impl<'a> RTContext<'a> { // Dump the function signature. write!( w, - "#[allow(unused_variables,unused_mut,unused_parens)]\nasync fn {}<'a>(", + "#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe)]\nasync fn {}<'a>(", func.name )?; let mut first_param = true; @@ -149,7 +149,7 @@ impl<'a> RTContext<'a> { // blocks to drive execution. write!( w, - " let mut control_token: i8 = 0;\n loop {{\n match control_token {{\n", + " let mut control_token: i8 = 0;\n let return_value = loop {{\n match control_token {{\n", )?; let mut blocks: BTreeMap<_, _> = (0..func.nodes.len()) @@ -182,8 +182,41 @@ impl<'a> RTContext<'a> { )?; } - // Close the match, loop, and function. - write!(w, " _ => panic!()\n }}\n }}\n}}\n")?; + // Close the match and loop. + write!(w, " _ => panic!()\n }}\n }};\n")?; + + // Emit the epilogue of the function. + write!(w, " unsafe {{\n")?; + for idx in 0..func.param_types.len() { + if !self.module.types[func.param_types[idx].idx()].is_primitive() { + write!(w, " p{}.__forget();\n", idx)?; + } + } + if !self.module.types[func.return_type.idx()].is_primitive() { + for object in self.collection_objects[&self.func_id].iter_objects() { + if let CollectionObjectOrigin::Constant(_) = + self.collection_objects[&self.func_id].origin(object) + { + write!( + w, + " if obj{}.__cmp_ids(&return_value) {{\n", + object.idx() + )?; + write!(w, " obj{}.__forget();\n", object.idx())?; + write!(w, " }}\n")?; + } + } + } + for idx in 0..func.nodes.len() { + if !func.nodes[idx].is_control() + && !self.module.types[self.typing[idx].idx()].is_primitive() + { + write!(w, " node_{}.__forget();\n", idx)?; + } + } + write!(w, " }}\n")?; + write!(w, " return_value\n")?; + write!(w, "}}\n")?; Ok(()) } @@ -230,7 +263,15 @@ impl<'a> RTContext<'a> { } Node::Return { control: _, data } => { let block = &mut blocks.get_mut(&id).unwrap(); - write!(block, " return {};\n", self.get_value(data))? + if self.module.types[self.typing[data.idx()].idx()].is_primitive() { + write!(block, " break {};\n", self.get_value(data))? + } else { + write!( + block, + " break unsafe {{ {}.__clone() }};\n", + self.get_value(data) + )? + } } _ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]), } @@ -259,7 +300,7 @@ impl<'a> RTContext<'a> { } else { write!( block, - " {} = unsafe {{ p{}.__take() }};\n", + " {} = unsafe {{ p{}.__clone() }};\n", self.get_value(id), index )? @@ -284,7 +325,7 @@ impl<'a> RTContext<'a> { let objects = self.collection_objects[&self.func_id].objects(id); assert_eq!(objects.len(), 1); let object = objects[0]; - write!(block, "unsafe {{ obj{}.__take() }}", object.idx())? + write!(block, "unsafe {{ obj{}.__clone() }}", object.idx())? } } write!(block, ";\n")? @@ -374,7 +415,7 @@ impl<'a> RTContext<'a> { )?; write!( block, - " {} = unsafe {{ {}.__take() }};\n", + " {} = unsafe {{ {}.__clone() }};\n", self.get_value(id), self.get_value(*arg) )?; @@ -407,13 +448,84 @@ impl<'a> RTContext<'a> { if self.module.types[self.typing[arg.idx()].idx()].is_primitive() { write!(block, "{}, ", self.get_value(*arg))?; } else { - write!(block, "unsafe {{ {}.__take() }}, ", self.get_value(*arg))?; + write!(block, "unsafe {{ {}.__clone() }}, ", self.get_value(*arg))?; } } write!(block, ").await;\n")?; } } } + Node::Read { + collect, + ref indices, + } => { + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let collect_ty = self.typing[collect.idx()]; + let out_size = self.codegen_type_size(self.typing[id.idx()]); + let offset = self.codegen_index_math(collect_ty, indices)?; + write!( + block, + " let mut read_offset_obj = unsafe {{ {}.__clone() }};\n", + self.get_value(collect) + )?; + write!( + block, + " unsafe {{ read_offset_obj.__offset({}, {}) }};\n", + offset, out_size, + )?; + if self.module.types[self.typing[id.idx()].idx()].is_primitive() { + write!( + block, + " {} = unsafe {{ *(read_offset_obj.__cpu_ptr() as *const _) }};\n", + self.get_value(id) + )?; + write!( + block, + " unsafe {{ read_offset_obj.__forget() }};\n", + )?; + } else { + write!( + block, + " {} = read_offset_obj;\n", + self.get_value(id) + )?; + } + } + Node::Write { + collect, + data, + ref indices, + } => { + let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap(); + let collect_ty = self.typing[collect.idx()]; + let data_size = self.codegen_type_size(self.typing[data.idx()]); + let offset = self.codegen_index_math(collect_ty, indices)?; + write!( + block, + " let mut write_offset_obj = unsafe {{ {}.__clone() }};\n", + self.get_value(collect) + )?; + write!(block, " let write_offset_ptr = unsafe {{ write_offset_obj.__cpu_ptr_mut().byte_add({}) }};\n", offset)?; + if self.module.types[self.typing[data.idx()].idx()].is_primitive() { + write!( + block, + " unsafe {{ *(write_offset_ptr as *mut _) = {} }};\n", + self.get_value(data) + )?; + } else { + write!( + block, + " unsafe {{ ::core::ptr::copy_nonoverlapping({}.__cpu_ptr(), write_offset_ptr as *mut _, {} as usize) }};\n", + self.get_value(data), + data_size, + )?; + } + write!( + block, + " {} = write_offset_obj;\n", + self.get_value(id), + )?; + } _ => panic!( "PANIC: Can't lower {:?} in {}.", func.nodes[id.idx()], @@ -487,6 +599,78 @@ impl<'a> RTContext<'a> { Ok(()) } + /* + * Emit logic to index into an collection. + */ + fn codegen_index_math( + &self, + mut collect_ty: TypeID, + indices: &[Index], + ) -> Result<String, Error> { + let mut acc_offset = "0".to_string(); + for index in indices { + match index { + Index::Field(idx) => { + let Type::Product(ref fields) = self.module.types[collect_ty.idx()] else { + panic!() + }; + + // Get the offset of the field at index `idx` by calculating + // the product's size up to field `idx`, then offseting the + // base pointer by that amount. + for field in &fields[..*idx] { + let field_align = get_type_alignment(&self.module.types, *field); + let field = self.codegen_type_size(*field); + acc_offset = format!( + "((({} + {}) & !{}) + {})", + acc_offset, + field_align - 1, + field_align - 1, + field + ); + } + let last_align = get_type_alignment(&self.module.types, fields[*idx]); + acc_offset = format!( + "(({} + {}) & !{})", + acc_offset, + last_align - 1, + last_align - 1 + ); + collect_ty = fields[*idx]; + } + Index::Variant(idx) => { + // The tag of a summation is at the end of the summation, so + // the variant pointer is just the base pointer. Do nothing. + let Type::Summation(ref variants) = self.module.types[collect_ty.idx()] else { + panic!() + }; + collect_ty = variants[*idx]; + } + Index::Position(ref pos) => { + let Type::Array(elem, ref dims) = self.module.types[collect_ty.idx()] else { + panic!() + }; + + // The offset of the position into an array is: + // + // ((0 * s1 + p1) * s2 + p2) * s3 + p3 ... + let elem_size = self.codegen_type_size(elem); + for (p, s) in zip(pos, dims) { + let p = self.get_value(*p); + acc_offset = format!("{} * ", acc_offset); + self.codegen_dynamic_constant(*s, &mut acc_offset)?; + acc_offset = format!("({} + {})", acc_offset, p); + } + + // Convert offset in # elements -> # bytes. + acc_offset = format!("({} * {})", acc_offset, elem_size); + collect_ty = elem; + } + } + } + Ok(acc_offset) + } + /* * Lower the size of a type into a Rust expression. */ diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 78c4eca4c319690254facec7f73893673cc39e0a..1dd326c3ad1abf24e7bfa4aa1f28dfb8255af0e9 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use crate::*; @@ -11,10 +11,11 @@ pub struct Builder<'a> { // Intern function names. function_ids: HashMap<&'a str, FunctionID>, - // Intern types, constants, and dynamic constants on a per-module basis. + // Intern types, constants, dynamic constants, and labels on a per-module basis. interned_types: HashMap<Type, TypeID>, interned_constants: HashMap<Constant, ConstantID>, interned_dynamic_constants: HashMap<DynamicConstant, DynamicConstantID>, + interned_labels: HashMap<String, LabelID>, // For product, summation, and array constant creation, it's useful to know // the type of each constant. @@ -37,6 +38,7 @@ pub struct NodeBuilder { function_id: FunctionID, node: Node, schedules: Vec<Schedule>, + labels: Vec<LabelID>, } /* @@ -79,6 +81,17 @@ impl<'a> Builder<'a> { } } + pub fn add_label(&mut self, label: &String) -> LabelID { + if let Some(id) = self.interned_labels.get(label) { + *id + } else { + let id = LabelID::new(self.interned_labels.len()); + self.interned_labels.insert(label.clone(), id); + self.module.labels.push(label.clone()); + id + } + } + pub fn create() -> Self { Self::default() } @@ -452,6 +465,10 @@ impl<'a> Builder<'a> { Index::Position(idx) } + pub fn get_labels(&self, func: FunctionID, node: NodeID) -> &HashSet<LabelID> { + &self.module.functions[func.idx()].labels[node.idx()] + } + pub fn create_function( &mut self, name: &str, @@ -473,6 +490,7 @@ impl<'a> Builder<'a> { entry, nodes: vec![Node::Start], schedules: vec![vec![]], + labels: vec![HashSet::new()], device: None, }); Ok((id, NodeID::new(0))) @@ -484,11 +502,15 @@ impl<'a> Builder<'a> { .nodes .push(Node::Start); self.module.functions[function.idx()].schedules.push(vec![]); + self.module.functions[function.idx()] + .labels + .push(HashSet::new()); NodeBuilder { id, function_id: function, node: Node::Start, schedules: vec![], + labels: vec![], } } @@ -499,6 +521,8 @@ impl<'a> Builder<'a> { self.module.functions[builder.function_id.idx()].nodes[builder.id.idx()] = builder.node; self.module.functions[builder.function_id.idx()].schedules[builder.id.idx()] = builder.schedules; + self.module.functions[builder.function_id.idx()].labels[builder.id.idx()] = + builder.labels.into_iter().collect(); Ok(()) } } @@ -617,4 +641,15 @@ impl NodeBuilder { pub fn add_schedule(&mut self, schedule: Schedule) { self.schedules.push(schedule); } + + pub fn add_label(&mut self, label: LabelID) { + self.labels.push(label); + } + + pub fn add_labels<I>(&mut self, labels: I) + where + I: Iterator<Item = LabelID>, + { + self.labels.extend(labels); + } } diff --git a/hercules_ir/src/callgraph.rs b/hercules_ir/src/callgraph.rs index 3a8e6316f8213b2c665e2ad34034d070a495c0cb..834cbbf811193b516a7a4134b8f9aca6d75dcc26 100644 --- a/hercules_ir/src/callgraph.rs +++ b/hercules_ir/src/callgraph.rs @@ -79,10 +79,9 @@ impl CallGraph { /* * Top level function to calculate the call graph of a Hercules module. */ -pub fn callgraph(module: &Module) -> CallGraph { +pub fn callgraph(functions: &Vec<Function>) -> CallGraph { // Step 1: collect the functions called in each function. - let callee_functions: Vec<Vec<FunctionID>> = module - .functions + let callee_functions: Vec<Vec<FunctionID>> = functions .iter() .map(|func| { let mut called: Vec<_> = func diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 8bb1b359fdbf27c44baaec5ac129419abb066331..9f421221fce6387e2a65cd11d9f5043120b88bfe 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -135,7 +135,8 @@ impl Semilattice for CollectionObjectLattice { * Top level function to analyze collection objects in a Hercules module. */ pub fn collection_objects( - module: &Module, + functions: &Vec<Function>, + types: &Vec<Type>, reverse_postorders: &Vec<Vec<NodeID>>, typing: &ModuleTyping, callgraph: &CallGraph, @@ -146,7 +147,7 @@ pub fn collection_objects( let topo = callgraph.topo(); for func_id in topo { - let func = &module.functions[func_id.idx()]; + let func = &functions[func_id.idx()]; let typing = &typing[func_id.idx()]; let reverse_postorder = &reverse_postorders[func_id.idx()]; @@ -156,14 +157,14 @@ pub fn collection_objects( .param_types .iter() .enumerate() - .filter(|(_, ty_id)| !module.types[ty_id.idx()].is_primitive()) + .filter(|(_, ty_id)| !types[ty_id.idx()].is_primitive()) .map(|(idx, _)| CollectionObjectOrigin::Parameter(idx)); let other_origins = func .nodes .iter() .enumerate() .filter_map(|(idx, node)| match node { - Node::Constant { id: _ } if !module.types[typing[idx].idx()].is_primitive() => { + Node::Constant { id: _ } if !types[typing[idx].idx()].is_primitive() => { Some(CollectionObjectOrigin::Constant(NodeID::new(idx))) } Node::Call { @@ -185,7 +186,7 @@ pub fn collection_objects( // this is determined later. Some(CollectionObjectOrigin::Call(NodeID::new(idx))) } - Node::Undef { ty: _ } if !module.types[typing[idx].idx()].is_primitive() => { + Node::Undef { ty: _ } if !types[typing[idx].idx()].is_primitive() => { Some(CollectionObjectOrigin::Undef(NodeID::new(idx))) } _ => None, @@ -255,7 +256,7 @@ pub fn collection_objects( function: callee, dynamic_constants: _, args: _, - } if !module.types[typing[id.idx()].idx()].is_primitive() => { + } if !types[typing[id.idx()].idx()].is_primitive() => { let new_obj = origins .iter() .position(|origin| *origin == CollectionObjectOrigin::Call(id)) @@ -285,7 +286,7 @@ pub fn collection_objects( Node::Read { collect: _, indices: _, - } if !module.types[typing[id.idx()].idx()].is_primitive() => inputs[0].clone(), + } if !types[typing[id.idx()].idx()].is_primitive() => inputs[0].clone(), Node::Write { collect: _, data: _, diff --git a/hercules_ir/src/dot.rs b/hercules_ir/src/dot.rs index 4d526366808e8b2aea39fecd81f6c00269ffb154..22cd0beb7e5b6946c1116422441a0777f21f064b 100644 --- a/hercules_ir/src/dot.rs +++ b/hercules_ir/src/dot.rs @@ -18,6 +18,7 @@ pub fn xdot_module( reverse_postorders: &Vec<Vec<NodeID>>, doms: Option<&Vec<DomTree>>, fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>, + bbs: Option<&Vec<BasicBlocks>>, ) { let mut tmp_path = temp_dir(); let mut rng = rand::thread_rng(); @@ -30,6 +31,7 @@ pub fn xdot_module( &reverse_postorders, doms, fork_join_maps, + bbs, &mut contents, ) .expect("PANIC: Unable to generate output file contents."); @@ -51,6 +53,7 @@ pub fn write_dot<W: Write>( reverse_postorders: &Vec<Vec<NodeID>>, doms: Option<&Vec<DomTree>>, fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>, + bbs: Option<&Vec<BasicBlocks>>, w: &mut W, ) -> std::fmt::Result { write_digraph_header(w)?; @@ -165,6 +168,26 @@ pub fn write_dot<W: Write>( } } + // Step 4: draw basic block edges in indigo. + if let Some(bbs) = bbs { + let bbs = &bbs[function_id.idx()].0; + for (idx, bb) in bbs.into_iter().enumerate() { + if idx != bb.idx() { + write_edge( + NodeID::new(idx), + function_id, + *bb, + function_id, + true, + "indigo", + "dotted", + &module, + w, + )?; + } + } + } + write_graph_footer(w)?; } @@ -196,7 +219,13 @@ fn write_subgraph_header<W: Write>( } else { write!(w, "label=\"{}\"\n", function.name)?; } - write!(w, "bgcolor=ivory4\n")?; + let color = match function.device { + Some(Device::LLVM) => "paleturquoise1", + Some(Device::CUDA) => "darkseagreen1", + Some(Device::AsyncRust) => "peachpuff1", + None => "ivory2", + }; + write!(w, "bgcolor={}\n", color)?; write!(w, "cluster=true\n")?; Ok(()) } diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index c16251f7804dce961cade3bc3f416744bfa0307f..bf468de69a4314c348ff8ccd1d960cb6359c7d2e 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1,4 +1,5 @@ use std::cmp::{max, min}; +use std::collections::HashSet; use std::fmt::Write; use std::ops::Coroutine; use std::ops::CoroutineState; @@ -23,6 +24,7 @@ pub struct Module { pub types: Vec<Type>, pub constants: Vec<Constant>, pub dynamic_constants: Vec<DynamicConstant>, + pub labels: Vec<String>, } /* @@ -43,7 +45,8 @@ pub struct Function { pub nodes: Vec<Node>, - pub schedules: FunctionSchedule, + pub schedules: FunctionSchedules, + pub labels: FunctionLabels, pub device: Option<Device>, } @@ -341,7 +344,24 @@ pub enum Device { /* * A single node may have multiple schedules. */ -pub type FunctionSchedule = Vec<Vec<Schedule>>; +pub type FunctionSchedules = Vec<Vec<Schedule>>; + +/* + * A single node may have multiple labels. + */ +pub type FunctionLabels = Vec<HashSet<LabelID>>; + +/* + * Basic block info consists of two things: + * + * 1. A map from node to block (named by control nodes). + * 2. For each node, which nodes are in its own block. + * + * Note that for #2, the structure is Vec<NodeID>, meaning the nodes are ordered + * inside the block. This order corresponds to the traversal order of the nodes + * in the block needed by the backend code generators. + */ +pub type BasicBlocks = (Vec<NodeID>, Vec<Vec<NodeID>>); impl Module { /* @@ -734,6 +754,7 @@ impl Function { // Step 4: update the schedules. self.schedules.fix_gravestones(&node_mapping); + self.labels.fix_gravestones(&node_mapping); node_mapping } @@ -1775,3 +1796,4 @@ define_id_type!(NodeID); define_id_type!(TypeID); define_id_type!(ConstantID); define_id_type!(DynamicConstantID); +define_id_type!(LabelID); diff --git a/hercules_ir/src/loops.rs b/hercules_ir/src/loops.rs index 13e935e0dd151ba3a29c4d07c9f9ee50341d5091..1d706c7834cf30fa3bf5e556d812917942a48d8b 100644 --- a/hercules_ir/src/loops.rs +++ b/hercules_ir/src/loops.rs @@ -152,16 +152,7 @@ pub fn loops( }) .collect(); - // Step 6: compute the inverse loop map - this maps control nodes to which - // loop they are in (keyed by header), if they are in one. - let mut inverse_loops = HashMap::new(); - for (header, (contents, _)) in loops.iter() { - for idx in contents.iter_ones() { - inverse_loops.insert(NodeID::new(idx), *header); - } - } - - // Step 7: compute loop tree nesting. + // Step 6: compute loop tree nesting. let mut nesting = HashMap::new(); let mut worklist: VecDeque<NodeID> = loops.keys().map(|id| *id).collect(); while let Some(header) = worklist.pop_front() { @@ -175,6 +166,24 @@ pub fn loops( } } + // Step 7: compute the inverse loop map - this maps control nodes to which + // loop they are in (identified by header), if they are in one. Pick the + // most nested loop as the loop they are in. + let mut inverse_loops = HashMap::new(); + for (header, (contents, _)) in loops.iter() { + for idx in contents.iter_ones() { + let id = NodeID::new(idx); + if let Some(old_header) = inverse_loops.get(&id) + && nesting[old_header] > nesting[header] + { + // If the inserted header is more deeply nested, don't do anything. + assert!(nesting[old_header] != nesting[header] || old_header == header); + } else { + inverse_loops.insert(id, *header); + } + } + } + LoopTree { root, loops, diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 21eb325a530907bec0aa1f34788708da76453a54..cdad54f935afd7b79eaf258b8ec0a83415ecb5ef 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -1,5 +1,5 @@ use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::str::FromStr; use crate::*; @@ -37,6 +37,7 @@ struct Context<'a> { interned_constants: HashMap<Constant, ConstantID>, interned_dynamic_constants: HashMap<DynamicConstant, DynamicConstantID>, reverse_dynamic_constant_map: HashMap<DynamicConstantID, DynamicConstant>, + interned_labels: HashMap<String, LabelID>, } /* @@ -97,6 +98,16 @@ impl<'a> Context<'a> { id } } + + fn get_label_id(&mut self, label: String) -> LabelID { + if let Some(id) = self.interned_labels.get(&label) { + *id + } else { + let id = LabelID::new(self.interned_labels.len()); + self.interned_labels.insert(label, id); + id + } + } } /* @@ -124,6 +135,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a entry: true, nodes: vec![], schedules: vec![], + labels: vec![], device: None, }; context.function_ids.len() @@ -152,11 +164,16 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a for (dynamic_constant, id) in context.interned_dynamic_constants { dynamic_constants[id.idx()] = dynamic_constant; } + let mut labels = vec!["".to_string(); context.interned_labels.len()]; + for (label, id) in context.interned_labels { + labels[id.idx()] = label; + } let module = Module { functions: fixed_functions, types, constants, dynamic_constants, + labels, }; Ok((rest, module)) } @@ -225,10 +242,13 @@ fn parse_function<'a>( // `nodes`, as returned by parsing, is in parse order, which may differ from // the order dictated by NodeIDs in the node name intern map. let mut fixed_nodes = vec![Node::Start; context.borrow().node_ids.len()]; + let mut labels = vec![HashSet::new(); fixed_nodes.len()]; for (name, node) in nodes { // We can remove items from the node name intern map now, as the map // will be cleared during the next iteration of parse_function. - fixed_nodes[context.borrow_mut().node_ids.remove(name).unwrap().idx()] = node; + let id = context.borrow_mut().node_ids.remove(name).unwrap(); + fixed_nodes[id.idx()] = node; + labels[id.idx()].insert(context.borrow_mut().get_label_id(name.to_string())); } // The nodes removed from node_ids in the previous step are nodes that are @@ -262,6 +282,7 @@ fn parse_function<'a>( entry: true, nodes: fixed_nodes, schedules: vec![vec![]; num_nodes], + labels, device: None, }, )) @@ -1229,7 +1250,7 @@ mod tests { #[test] fn parse_ir1() { - parse( + let module = parse( " fn myfunc(x: i32) -> i32 y = call<0>(add, x, x) @@ -1243,5 +1264,14 @@ fn add<1>(x: i32, y: i32) -> i32 ", ) .unwrap(); + assert_eq!(module.labels.len(), 5); + assert_eq!( + module.functions[0].labels.len(), + module.functions[0].nodes.len() + ); + assert_eq!( + module.functions[1].labels.len(), + module.functions[1].nodes.len() + ); } } diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 79cbd40399921eceecf4399c31e6dfc2aa1af297..a80dd422128bd3ba2ab6436272943ff1b2deb82f 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -82,19 +82,16 @@ pub type ModuleTyping = Vec<Vec<TypeID>>; * Returns a type for every node in every function. */ pub fn typecheck( - module: &mut Module, + functions: &Vec<Function>, + types: &mut Vec<Type>, + constants: &Vec<Constant>, + dynamic_constants: &mut Vec<DynamicConstant>, reverse_postorders: &Vec<Vec<NodeID>>, ) -> Result<ModuleTyping, String> { // Step 1: assemble a reverse type map. This is needed to get or create the // ID of potentially new types. Break down module into references to // individual elements at this point, so that borrows don't overlap each // other. - let Module { - ref functions, - ref mut types, - ref constants, - ref mut dynamic_constants, - } = module; let mut reverse_type_map: HashMap<Type, TypeID> = types .iter() .enumerate() diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs index 572bb9d11d3aca8efb5bd70b6b18781da83bc0e7..5ee5f1d26850c32e7bd292602185dbb8327167ce 100644 --- a/hercules_ir/src/verify.rs +++ b/hercules_ir/src/verify.rs @@ -29,7 +29,13 @@ pub fn verify( let reverse_postorders: Vec<_> = def_uses.iter().map(reverse_postorder).collect(); // Typecheck the module. - let typing = typecheck(module, &reverse_postorders)?; + let typing = typecheck( + &module.functions, + &mut module.types, + &module.constants, + &mut module.dynamic_constants, + &reverse_postorders, + )?; // Assemble fork join maps for module. let subgraphs: Vec<_> = zip(module.functions.iter(), def_uses.iter()) diff --git a/hercules_opt/src/crc.rs b/hercules_opt/src/crc.rs new file mode 100644 index 0000000000000000000000000000000000000000..be80e61bea8d6522a40b35542bb6b3c39edeb4db --- /dev/null +++ b/hercules_opt/src/crc.rs @@ -0,0 +1,39 @@ +use hercules_ir::*; + +use crate::*; + +/* + * Top level function to collapse read chains in a function. + */ +pub fn crc(editor: &mut FunctionEditor) { + let mut changed = true; + while changed { + changed = false; + for id in editor.node_ids() { + if let Node::Read { + collect: lower_collect, + indices: ref lower_indices, + } = editor.func().nodes[id.idx()] + && let Node::Read { + collect: upper_collect, + indices: ref upper_indices, + } = editor.func().nodes[lower_collect.idx()] + { + let collapsed_read = Node::Read { + collect: upper_collect, + indices: upper_indices + .iter() + .chain(lower_indices.iter()) + .map(|idx| idx.clone()) + .collect(), + }; + let success = editor.edit(|mut edit| { + let new_id = edit.add_node(collapsed_read); + let edit = edit.replace_all_uses(id, new_id)?; + edit.delete_node(id) + }); + changed = changed || success; + } + } + } +} diff --git a/hercules_opt/src/dce.rs b/hercules_opt/src/dce.rs index 026672a395d783c0abd5257894c4c32335654371..6eec42e59ea102c44e4464ee981b683f988b00d8 100644 --- a/hercules_opt/src/dce.rs +++ b/hercules_opt/src/dce.rs @@ -8,7 +8,7 @@ use crate::*; */ pub fn dce(editor: &mut FunctionEditor) { // Create worklist (starts as all nodes). - let mut worklist: Vec<NodeID> = (0..editor.func().nodes.len()).map(NodeID::new).collect(); + let mut worklist: Vec<NodeID> = editor.node_ids().collect(); while let Some(work) = worklist.pop() { // If a node on the worklist is a start node, it is either *the* start diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 1318f032e0ac8a02c0375ae2de56311f0d97306a..60745f214d182393fcddac163006532426103199 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -26,6 +26,9 @@ pub struct FunctionEditor<'a> { constants: &'a RefCell<Vec<Constant>>, dynamic_constants: &'a RefCell<Vec<DynamicConstant>>, types: &'a RefCell<Vec<Type>>, + // Keep a RefCell to the string table that tracks labels, so that new labels + // can be added as needed + labels: &'a RefCell<Vec<String>>, // Most optimizations need def use info, so provide an iteratively updated // mutable version that's automatically updated based on recorded edits. mut_def_use: Vec<HashSet<NodeID>>, @@ -34,6 +37,9 @@ pub struct FunctionEditor<'a> { // are off limits for deletion (equivalently modification) or being replaced // as a use. mutable_nodes: BitVec<u8, Lsb0>, + // Tracks whether this editor has been used to make any edits to the IR of + // this function + modified: bool, } /* @@ -51,10 +57,13 @@ pub struct FunctionEdit<'a: 'b, 'b> { added_and_updated_nodes: BTreeMap<NodeID, Node>, // Keep track of added and updated schedules. added_and_updated_schedules: BTreeMap<NodeID, Vec<Schedule>>, - // Keep track of added (dynamic) constants and types + // Keep track of added and updated labels. + added_and_updated_labels: BTreeMap<NodeID, HashSet<LabelID>>, + // Keep track of added (dynamic) constants, types, and labels added_constants: Vec<Constant>, added_dynamic_constants: Vec<DynamicConstant>, added_types: Vec<Type>, + added_labels: Vec<String>, // Compute a def-use map entries iteratively. updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>, updated_param_types: Option<Vec<TypeID>>, @@ -70,6 +79,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { constants: &'a RefCell<Vec<Constant>>, dynamic_constants: &'a RefCell<Vec<DynamicConstant>>, types: &'a RefCell<Vec<Type>>, + labels: &'a RefCell<Vec<String>>, def_use: &ImmutableDefUseMap, ) -> Self { let mut_def_use = (0..function.nodes.len()) @@ -89,11 +99,60 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { constants, dynamic_constants, types, + labels, mut_def_use, mutable_nodes, + modified: false, } } + // Constructs an editor but only makes the nodes with at least one of the set of labels as + // mutable + pub fn new_labeled( + function: &'a mut Function, + function_id: FunctionID, + constants: &'a RefCell<Vec<Constant>>, + dynamic_constants: &'a RefCell<Vec<DynamicConstant>>, + types: &'a RefCell<Vec<Type>>, + labels: &'a RefCell<Vec<String>>, + def_use: &ImmutableDefUseMap, + with_labels: &HashSet<LabelID>, + ) -> Self { + let mut_def_use = (0..function.nodes.len()) + .map(|idx| { + def_use + .get_users(NodeID::new(idx)) + .into_iter() + .map(|x| *x) + .collect() + }) + .collect(); + + let mut mutable_nodes = bitvec![u8, Lsb0; 0; function.nodes.len()]; + // Add all nodes which have some label which is in the with_labels set + for (idx, labels) in function.labels.iter().enumerate() { + if !labels.is_disjoint(with_labels) { + mutable_nodes.set(idx, true); + } + } + + FunctionEditor { + function, + function_id, + constants, + dynamic_constants, + types, + labels, + mut_def_use, + mutable_nodes, + modified: false, + } + } + + pub fn modified(&self) -> bool { + self.modified + } + pub fn edit<F>(&'b mut self, edit: F) -> bool where F: FnOnce(FunctionEdit<'a, 'b>) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>>, @@ -105,9 +164,11 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { added_nodeids: HashSet::new(), added_and_updated_nodes: BTreeMap::new(), added_and_updated_schedules: BTreeMap::new(), + added_and_updated_labels: BTreeMap::new(), added_constants: Vec::new().into(), added_dynamic_constants: Vec::new().into(), added_types: Vec::new().into(), + added_labels: Vec::new().into(), updated_def_use: BTreeMap::new(), updated_param_types: None, updated_return_type: None, @@ -120,17 +181,28 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { let FunctionEdit { editor, deleted_nodeids, - added_nodeids: _, + added_nodeids, added_and_updated_nodes, added_and_updated_schedules, + added_and_updated_labels, added_constants, added_dynamic_constants, added_types, + added_labels, updated_def_use, updated_param_types, updated_return_type, sub_edits, } = populated_edit; + + // Step 0: determine whether the edit changed the IR by checking if + // any nodes were deleted, added, or updated in any way + editor.modified |= !deleted_nodeids.is_empty() + || !added_nodeids.is_empty() + || !added_and_updated_nodes.is_empty() + || !added_and_updated_schedules.is_empty() + || !added_and_updated_labels.is_empty(); + // Step 1: update the mutable def use map. for (u, new_users) in updated_def_use { // Go through new def-use entries in order. These are either @@ -160,7 +232,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { } } - // Step 3: add and update schedules. + // Step 3.0: add and update schedules. editor .function .schedules @@ -169,6 +241,15 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor.function.schedules[id.idx()] = schedule; } + // Step 3.1: add and update labels. + editor + .function + .labels + .resize(editor.function.nodes.len(), HashSet::new()); + for (id, label) in added_and_updated_labels { + editor.function.labels[id.idx()] = label; + } + // Step 4: delete nodes. This is done using "gravestones", where a // node other than node ID 0 being a start node is considered a // gravestone. @@ -178,8 +259,8 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor.function.nodes[id.idx()] = Node::Start; } - // Step 5: propagate schedules along sub-edit edges. - for (src, dst) in sub_edits { + // Step 5.0: propagate schedules along sub-edit edges. + for (src, dst) in sub_edits.iter() { let mut dst_schedules = take(&mut editor.function.schedules[dst.idx()]); for src_schedule in editor.function.schedules[src.idx()].iter() { if !dst_schedules.contains(src_schedule) { @@ -189,6 +270,32 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor.function.schedules[dst.idx()] = dst_schedules; } + // Step 5.1: update and propagate labels + editor.labels.borrow_mut().extend(added_labels); + + // We propagate labels in two steps, first along sub-edits and then + // all the labels on any deleted node not used in any sub-edit to all + // added nodes not in any sub-edit + let mut sources = deleted_nodeids.clone(); + let mut dests = added_nodeids.clone(); + + for (src, dst) in sub_edits { + let mut dst_labels = take(&mut editor.function.labels[dst.idx()]); + dst_labels.extend(editor.function.labels[src.idx()].iter()); + editor.function.labels[dst.idx()] = dst_labels; + + sources.remove(&src); + dests.remove(&dst); + } + + let mut src_labels = HashSet::new(); + for src in sources { + src_labels.extend(editor.function.labels[src.idx()].clone()); + } + for dst in dests { + editor.function.labels[dst.idx()].extend(src_labels.clone()); + } + // Step 6: update the length of mutable_nodes. All added nodes are // mutable. editor @@ -425,7 +532,13 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { if let Some(schedules) = self.added_and_updated_schedules.get_mut(&id) { schedules.push(schedule); } else { - let mut schedules = self.editor.function.schedules[id.idx()].clone(); + let mut schedules = self + .editor + .function + .schedules + .get(id.idx()) + .unwrap_or(&vec![]) + .clone(); if !schedules.contains(&schedule) { schedules.push(schedule); } @@ -446,6 +559,63 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } } + pub fn get_label(&self, id: NodeID) -> &HashSet<LabelID> { + // The user may get the labels of a to-be deleted node. + if let Some(label) = self.added_and_updated_labels.get(&id) { + // Refer to added or updated label. + label + } else { + // Refer to the origin label of this code. + &self.editor.function.labels[id.idx()] + } + } + + pub fn add_label(mut self, id: NodeID, label: LabelID) -> Result<Self, Self> { + if self.is_mutable(id) { + if let Some(labels) = self.added_and_updated_labels.get_mut(&id) { + labels.insert(label); + } else { + let mut labels = self + .editor + .function + .labels + .get(id.idx()) + .unwrap_or(&HashSet::new()) + .clone(); + labels.insert(label); + self.added_and_updated_labels.insert(id, labels); + } + Ok(self) + } else { + Err(self) + } + } + + // Creates or returns the LabelID for a given label name + pub fn new_label(&mut self, name: String) -> LabelID { + let pos = self + .editor + .labels + .borrow() + .iter() + .chain(self.added_labels.iter()) + .position(|l| *l == name); + if let Some(idx) = pos { + LabelID::new(idx) + } else { + let idx = self.editor.labels.borrow().len() + self.added_labels.len(); + self.added_labels.push(name); + LabelID::new(idx) + } + } + + // Creates an entirely fresh label and returns its LabelID + pub fn fresh_label(&mut self) -> LabelID { + let idx = self.editor.labels.borrow().len() + self.added_labels.len(); + self.added_labels.push(format!("#fresh_{}", idx)); + LabelID::new(idx) + } + pub fn get_users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ { assert!(!self.deleted_nodeids.contains(&id)); if let Some(users) = self.updated_def_use.get(&id) { @@ -671,6 +841,7 @@ fn func(x: i32) -> i32 let constants_ref = RefCell::new(src_module.constants); let dynamic_constants_ref = RefCell::new(src_module.dynamic_constants); let types_ref = RefCell::new(src_module.types); + let labels_ref = RefCell::new(src_module.labels); // Edit the function by replacing the add with a multiply. let mut editor = FunctionEditor::new( func, @@ -678,6 +849,7 @@ fn func(x: i32) -> i32 &constants_ref, &dynamic_constants_ref, &types_ref, + &labels_ref, &def_use(func), ); let success = editor.edit(|mut edit| { diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 5ea9485d108ea6454d856bf164d990ea5d7895f8..1323d5a05a784e76d4d3b040f014acd216c710c0 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -5,7 +5,6 @@ use bitvec::prelude::*; use either::Either; use union_find::{QuickFindUf, UnionBySize, UnionFind}; -use hercules_cg::*; use hercules_ir::*; use crate::*; @@ -837,19 +836,16 @@ fn liveness_dataflow( liveness.insert(NodeID::new(bb_idx), vec![BTreeSet::new(); insts.len() + 1]); } let mut num_phis_reduces = vec![0; function.nodes.len()]; - let mut reducing = vec![false; function.nodes.len()]; + let mut has_phi = vec![false; function.nodes.len()]; + let mut has_reduce = vec![false; function.nodes.len()]; for (node_idx, bb) in bbs.0.iter().enumerate() { let node = &function.nodes[node_idx]; if node.is_phi() || node.is_reduce() { num_phis_reduces[bb.idx()] += 1; - // Phis and reduces can't be in the same basic block. - if node.is_reduce() { - assert!(num_phis_reduces[bb.idx()] == 0 || reducing[bb.idx()]); - reducing[bb.idx()] = true; - } else { - assert!(!reducing[bb.idx()]); - } } + has_phi[bb.idx()] = node.is_phi(); + has_reduce[bb.idx()] = node.is_reduce(); + assert!(!has_phi[bb.idx()] || !has_reduce[bb.idx()]); } let is_obj = |id: NodeID| !objects[&func_id].objects(id).is_empty(); @@ -861,7 +857,7 @@ fn liveness_dataflow( let last_pt = bbs.1[bb.idx()].len(); let old_value = &liveness[&bb][last_pt]; let mut new_value = BTreeSet::new(); - for succ in control_subgraph.succs(*bb).chain(if reducing[bb.idx()] { + for succ in control_subgraph.succs(*bb).chain(if has_reduce[bb.idx()] { Either::Left(once(*bb)) } else { Either::Right(empty()) diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index 54af85828d5e8cc9f0921d2954ce6137dd50d0a3..064e3d73a1d9604ca5b284fe52b8c2e8c5a0339e 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -209,6 +209,11 @@ fn inline_func( for schedule in callee_schedule { edit = edit.add_schedule(add_id, schedule.clone())?; } + // Copy the labels from the callee. + let callee_labels = &called_func.labels[idx]; + for label in callee_labels { + edit = edit.add_label(add_id, *label)?; + } } // Stitch the control use of the inlined start node with the diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index 9935703e839b5a4dd705af03f40a456008fe0f12..0b10bdaef74168d99471632ffdfe08068fd9fb24 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -1,6 +1,7 @@ #![feature(let_chains)] pub mod ccp; +pub mod crc; pub mod dce; pub mod delete_uncalled; pub mod editor; @@ -12,8 +13,8 @@ pub mod gcm; pub mod gvn; pub mod inline; pub mod interprocedural_sroa; +pub mod lift_dc_math; pub mod outline; -pub mod pass; pub mod phi_elim; pub mod pred; pub mod schedule; @@ -23,6 +24,7 @@ pub mod unforkify; pub mod utils; pub use crate::ccp::*; +pub use crate::crc::*; pub use crate::dce::*; pub use crate::delete_uncalled::*; pub use crate::editor::*; @@ -34,8 +36,8 @@ pub use crate::gcm::*; pub use crate::gvn::*; pub use crate::inline::*; pub use crate::interprocedural_sroa::*; +pub use crate::lift_dc_math::*; pub use crate::outline::*; -pub use crate::pass::*; pub use crate::phi_elim::*; pub use crate::pred::*; pub use crate::schedule::*; diff --git a/hercules_opt/src/lift_dc_math.rs b/hercules_opt/src/lift_dc_math.rs new file mode 100644 index 0000000000000000000000000000000000000000..afdb212064d84a0191f87ce366d67b7ea6728fa8 --- /dev/null +++ b/hercules_opt/src/lift_dc_math.rs @@ -0,0 +1,90 @@ +use hercules_ir::ir::*; + +use crate::*; + +/* + * Lift math in IR nodes into dynamic constants. + */ +pub fn lift_dc_math(editor: &mut FunctionEditor) { + // Create worklist (starts as all nodes). + let mut worklist: Vec<NodeID> = editor.node_ids().collect(); + while let Some(work) = worklist.pop() { + // Look for single nodes that can be converted to dynamic constants. + let users: Vec<_> = editor.get_users(work).collect(); + let nodes = &editor.func().nodes; + let dc = match nodes[work.idx()] { + Node::Constant { id } => { + // Why do we need this weird crap? This is due to a limitation + // in Rust's lifetime rules w/ let guards. + let cons = if let Constant::UnsignedInteger64(cons) = *editor.get_constant(id) { + cons + } else { + continue; + }; + DynamicConstant::Constant(cons as usize) + } + Node::DynamicConstant { id } => { + let Some(cons) = evaluate_dynamic_constant(id, &*editor.get_dynamic_constants()) + else { + continue; + }; + DynamicConstant::Constant(cons) + } + Node::Binary { op, left, right } => { + let (left, right) = if let ( + Node::DynamicConstant { id: left }, + Node::DynamicConstant { id: right }, + ) = (&nodes[left.idx()], &nodes[right.idx()]) + { + (*left, *right) + } else { + continue; + }; + match op { + BinaryOperator::Add => DynamicConstant::Add(left, right), + BinaryOperator::Sub => DynamicConstant::Sub(left, right), + BinaryOperator::Mul => DynamicConstant::Mul(left, right), + BinaryOperator::Div => DynamicConstant::Div(left, right), + BinaryOperator::Rem => DynamicConstant::Rem(left, right), + _ => { + continue; + } + } + } + Node::IntrinsicCall { + intrinsic, + ref args, + } => { + let (left, right) = if args.len() == 2 + && let (Node::DynamicConstant { id: left }, Node::DynamicConstant { id: right }) = + (&nodes[args[0].idx()], &nodes[args[1].idx()]) + { + (*left, *right) + } else { + continue; + }; + match intrinsic { + Intrinsic::Min => DynamicConstant::Min(left, right), + Intrinsic::Max => DynamicConstant::Max(left, right), + _ => { + continue; + } + } + } + _ => { + continue; + } + }; + + // Replace the node with the computed dynamic constant. + let success = editor.edit(|mut edit| { + let dc = edit.add_dynamic_constant(dc); + let node = edit.add_node(Node::DynamicConstant { id: dc }); + edit = edit.replace_all_uses(work, node)?; + edit.delete_node(work) + }); + if success { + worklist.extend(users); + } + } +} diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs index 80f97c7f079ad57bb1edb0929dfe56578eae338c..e59c815da12b505cadc807c4d87e6a2ef913d3fa 100644 --- a/hercules_opt/src/outline.rs +++ b/hercules_opt/src/outline.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeMap, BTreeSet, HashSet}; use std::iter::zip; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -203,6 +203,7 @@ pub fn outline( entry: false, nodes: vec![], schedules: vec![], + labels: vec![], device: None, }; @@ -420,6 +421,13 @@ pub fn outline( outlined.schedules[callee_id.idx()] = edit.get_schedule(*id).clone(); } + // Copy the labels into the new callee. + outlined.labels.resize(outlined.nodes.len(), HashSet::new()); + for id in partition.iter() { + let callee_id = convert_id(*id); + outlined.labels[callee_id.idx()] = edit.get_label(*id).clone(); + } + // Step 3: edit the original function to call the outlined function. let dynamic_constants = (0..edit.get_num_dynamic_constant_params()) .map(|idx| edit.add_dynamic_constant(DynamicConstant::Parameter(idx as usize))) diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs deleted file mode 100644 index 295b8bcbdc96c623a4ef8c2c7e899636e8d3a570..0000000000000000000000000000000000000000 --- a/hercules_opt/src/pass.rs +++ /dev/null @@ -1,1206 +0,0 @@ -use std::cell::RefCell; -use std::collections::{HashMap, HashSet}; -use std::fs::File; -use std::io::Write; -use std::iter::zip; -use std::process::{Command, Stdio}; - -use serde::Deserialize; - -use tempfile::TempDir; - -use hercules_cg::*; -use hercules_ir::*; - -use crate::*; - -/* - * Passes that can be run on a module. - */ -#[derive(Debug, Clone, Deserialize)] -pub enum Pass { - DCE, - CCP, - GVN, - PhiElim, - Forkify, - ForkGuardElim, - SLF, - WritePredication, - Predication, - SROA, - Inline, - Outline, - InterproceduralSROA, - DeleteUncalled, - ForkSplit, - Unforkify, - InferSchedules, - GCM, - FloatCollections, - Verify, - // Parameterized over whether analyses that aid visualization are necessary. - // Useful to set to false if displaying a potentially broken module. - Xdot(bool), - // Parameterized over output directory and module name. - Codegen(String, String), - // Parameterized over where to serialize module to. - Serialize(String), -} - -/* - * Manages passes to be run on an IR module. Transparently handles analysis - * requirements for optimizations. - */ -#[derive(Debug, Clone)] -pub struct PassManager { - module: Module, - - // Passes to run. - passes: Vec<Pass>, - - // Cached analysis results. - pub def_uses: Option<Vec<ImmutableDefUseMap>>, - pub reverse_postorders: Option<Vec<Vec<NodeID>>>, - pub typing: Option<ModuleTyping>, - pub control_subgraphs: Option<Vec<Subgraph>>, - pub doms: Option<Vec<DomTree>>, - pub postdoms: Option<Vec<DomTree>>, - pub fork_join_maps: Option<Vec<HashMap<NodeID, NodeID>>>, - pub fork_join_nests: Option<Vec<HashMap<NodeID, Vec<NodeID>>>>, - pub fork_control_maps: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, - pub fork_trees: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, - pub loops: Option<Vec<LoopTree>>, - pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, - pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, - pub bbs: Option<Vec<BasicBlocks>>, - pub collection_objects: Option<CollectionObjects>, - pub callgraph: Option<CallGraph>, -} - -impl PassManager { - pub fn new(module: Module) -> Self { - PassManager { - module, - passes: vec![], - def_uses: None, - reverse_postorders: None, - typing: None, - control_subgraphs: None, - doms: None, - postdoms: None, - fork_join_maps: None, - fork_join_nests: None, - fork_control_maps: None, - fork_trees: None, - loops: None, - reduce_cycles: None, - data_nodes_in_fork_joins: None, - bbs: None, - collection_objects: None, - callgraph: None, - } - } - - pub fn add_pass(&mut self, pass: Pass) { - self.passes.push(pass); - } - - pub fn make_def_uses(&mut self) { - if self.def_uses.is_none() { - self.def_uses = Some(self.module.functions.iter().map(def_use).collect()); - } - } - - pub fn make_reverse_postorders(&mut self) { - if self.reverse_postorders.is_none() { - self.make_def_uses(); - self.reverse_postorders = Some( - self.def_uses - .as_ref() - .unwrap() - .iter() - .map(reverse_postorder) - .collect(), - ); - } - } - - pub fn make_typing(&mut self) { - if self.typing.is_none() { - self.make_reverse_postorders(); - self.typing = Some( - typecheck(&mut self.module, self.reverse_postorders.as_ref().unwrap()).unwrap(), - ); - } - } - - pub fn make_control_subgraphs(&mut self) { - if self.control_subgraphs.is_none() { - self.make_def_uses(); - self.control_subgraphs = Some( - zip(&self.module.functions, self.def_uses.as_ref().unwrap()) - .map(|(function, def_use)| control_subgraph(function, def_use)) - .collect(), - ); - } - } - - pub fn make_doms(&mut self) { - if self.doms.is_none() { - self.make_control_subgraphs(); - self.doms = Some( - self.control_subgraphs - .as_ref() - .unwrap() - .iter() - .map(|subgraph| dominator(subgraph, NodeID::new(0))) - .collect(), - ); - } - } - - pub fn make_postdoms(&mut self) { - if self.postdoms.is_none() { - self.make_control_subgraphs(); - self.postdoms = Some( - zip( - self.control_subgraphs.as_ref().unwrap().iter(), - self.module.functions.iter(), - ) - .map(|(subgraph, function)| dominator(subgraph, NodeID::new(function.nodes.len()))) - .collect(), - ); - } - } - - pub fn make_fork_join_maps(&mut self) { - if self.fork_join_maps.is_none() { - self.make_control_subgraphs(); - self.fork_join_maps = Some( - zip( - self.module.functions.iter(), - self.control_subgraphs.as_ref().unwrap().iter(), - ) - .map(|(function, subgraph)| fork_join_map(function, subgraph)) - .collect(), - ); - } - } - - pub fn make_fork_join_nests(&mut self) { - if self.fork_join_nests.is_none() { - self.make_doms(); - self.make_fork_join_maps(); - self.fork_join_nests = Some( - zip( - self.module.functions.iter(), - zip( - self.doms.as_ref().unwrap().iter(), - self.fork_join_maps.as_ref().unwrap().iter(), - ), - ) - .map(|(function, (dom, fork_join_map))| { - compute_fork_join_nesting(function, dom, fork_join_map) - }) - .collect(), - ); - } - } - - pub fn make_fork_control_maps(&mut self) { - if self.fork_control_maps.is_none() { - self.make_fork_join_nests(); - self.fork_control_maps = Some( - self.fork_join_nests.as_ref().unwrap().iter().map(fork_control_map).collect(), - ); - } - } - - pub fn make_fork_trees(&mut self) { - if self.fork_trees.is_none() { - self.make_fork_join_nests(); - self.fork_trees = Some( - zip( - self.module.functions.iter(), - self.fork_join_nests.as_ref().unwrap().iter(), - ) - .map(|(function, fork_join_nesting)| { - fork_tree(function, fork_join_nesting) - }) - .collect(), - ); - } - } - - pub fn make_loops(&mut self) { - if self.loops.is_none() { - self.make_control_subgraphs(); - self.make_doms(); - self.make_fork_join_maps(); - let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter(); - let doms = self.doms.as_ref().unwrap().iter(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); - self.loops = Some( - zip(control_subgraphs, zip(doms, fork_join_maps)) - .map(|(control_subgraph, (dom, fork_join_map))| { - loops(control_subgraph, NodeID::new(0), dom, fork_join_map) - }) - .collect(), - ); - } - } - - pub fn make_reduce_cycles(&mut self) { - if self.reduce_cycles.is_none() { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap().iter(); - self.reduce_cycles = Some( - zip(self.module.functions.iter(), def_uses) - .map(|(function, def_use)| reduce_cycles(function, def_use)) - .collect(), - ); - } - } - - pub fn make_data_nodes_in_fork_joins(&mut self) { - if self.data_nodes_in_fork_joins.is_none() { - self.make_def_uses(); - self.make_fork_join_maps(); - self.data_nodes_in_fork_joins = Some( - zip( - self.module.functions.iter(), - zip( - self.def_uses.as_ref().unwrap().iter(), - self.fork_join_maps.as_ref().unwrap().iter(), - ), - ) - .map(|(function, (def_use, fork_join_map))| { - data_nodes_in_fork_joins(function, def_use, fork_join_map) - }) - .collect(), - ); - } - } - - pub fn make_collection_objects(&mut self) { - if self.collection_objects.is_none() { - self.make_reverse_postorders(); - self.make_typing(); - self.make_callgraph(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - self.collection_objects = Some(collection_objects( - &self.module, - reverse_postorders, - typing, - callgraph, - )); - } - } - - pub fn make_callgraph(&mut self) { - if self.callgraph.is_none() { - self.callgraph = Some(callgraph(&self.module)); - } - } - - pub fn run_passes(&mut self) { - for pass in self.passes.clone().iter() { - match pass { - Pass::DCE => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - dce(&mut editor); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::InterproceduralSROA => { - self.make_def_uses(); - self.make_typing(); - - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - - let def_uses = self.def_uses.as_ref().unwrap(); - - let mut editors: Vec<_> = self - .module - .functions - .iter_mut() - .enumerate() - .map(|(i, f)| { - FunctionEditor::new( - f, - FunctionID::new(i), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[i], - ) - }) - .collect(); - - interprocedural_sroa(&mut editors); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - - self.clear_analyses(); - } - Pass::CCP => { - self.make_def_uses(); - self.make_reverse_postorders(); - let def_uses = self.def_uses.as_ref().unwrap(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - ccp(&mut editor, &reverse_postorders[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::GVN => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - gvn(&mut editor, false); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Forkify => { - self.make_def_uses(); - self.make_loops(); - let def_uses = self.def_uses.as_ref().unwrap(); - let loops = self.loops.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - forkify( - &mut self.module.functions[idx], - &self.module.constants, - &mut self.module.dynamic_constants, - &def_uses[idx], - &loops[idx], - ); - let num_nodes = self.module.functions[idx].nodes.len(); - self.module.functions[idx] - .schedules - .resize(num_nodes, vec![]); - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::PhiElim => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - phi_elim(&mut editor); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::ForkGuardElim => { - self.make_def_uses(); - self.make_fork_join_maps(); - let def_uses = self.def_uses.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - fork_guard_elim( - &mut self.module.functions[idx], - &self.module.constants, - &fork_join_maps[idx], - &def_uses[idx], - ); - let num_nodes = self.module.functions[idx].nodes.len(); - self.module.functions[idx] - .schedules - .resize(num_nodes, vec![]); - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::SLF => { - self.make_def_uses(); - self.make_reverse_postorders(); - self.make_typing(); - let def_uses = self.def_uses.as_ref().unwrap(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - slf(&mut editor, &reverse_postorders[idx], &typing[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - println!("{}", self.module.functions[idx].name); - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::WritePredication => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - write_predication(&mut editor); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Predication => { - self.make_def_uses(); - self.make_typing(); - let def_uses = self.def_uses.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - predication(&mut editor, &typing[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::SROA => { - self.make_def_uses(); - self.make_reverse_postorders(); - self.make_typing(); - let def_uses = self.def_uses.as_ref().unwrap(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - sroa(&mut editor, &reverse_postorders[idx], &typing[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Inline => { - self.make_def_uses(); - self.make_callgraph(); - let def_uses = self.def_uses.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - inline(&mut editors, callgraph); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Outline => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let old_num_funcs = self.module.functions.len(); - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - for editor in editors.iter_mut() { - collapse_returns(editor); - ensure_between_control_flow(editor); - } - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - self.clear_analyses(); - - self.make_def_uses(); - self.make_typing(); - self.make_control_subgraphs(); - self.make_doms(); - let def_uses = self.def_uses.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - let control_subgraphs = self.control_subgraphs.as_ref().unwrap(); - let doms = self.doms.as_ref().unwrap(); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - let mut new_funcs = vec![]; - for (idx, editor) in editors.iter_mut().enumerate() { - let new_func_id = FunctionID::new(old_num_funcs + new_funcs.len()); - let new_func = dumb_outline( - editor, - &typing[idx], - &control_subgraphs[idx], - &doms[idx], - new_func_id, - ); - if let Some(new_func) = new_func { - new_funcs.push(new_func); - } - } - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - self.module.functions.extend(new_funcs); - self.clear_analyses(); - } - Pass::DeleteUncalled => { - self.make_def_uses(); - self.make_callgraph(); - let def_uses = self.def_uses.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - - // By default in an editor all nodes are mutable, which is desired in this case - // since we are only modifying the IDs of functions that we call. - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - - let new_idx = delete_uncalled(&mut editors, callgraph); - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - - self.fix_deleted_functions(&new_idx); - self.clear_analyses(); - - assert!(self.module.functions.len() > 0, "PANIC: There are no entry functions in the Hercules module being compiled, and they all got deleted by DeleteUncalled. Please mark at least one function as an entry!"); - } - Pass::ForkSplit => { - self.make_def_uses(); - self.make_fork_join_maps(); - self.make_reduce_cycles(); - let def_uses = self.def_uses.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - let reduce_cycles = self.reduce_cycles.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - fork_split(&mut editor, &fork_join_maps[idx], &reduce_cycles[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Unforkify => { - self.make_def_uses(); - self.make_fork_join_maps(); - let def_uses = self.def_uses.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - unforkify(&mut editor, &fork_join_maps[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::GCM => loop { - self.make_def_uses(); - self.make_reverse_postorders(); - self.make_typing(); - self.make_control_subgraphs(); - self.make_doms(); - self.make_fork_join_maps(); - self.make_loops(); - self.make_collection_objects(); - let def_uses = self.def_uses.as_ref().unwrap(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - let doms = self.doms.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - let loops = self.loops.as_ref().unwrap(); - let control_subgraphs = self.control_subgraphs.as_ref().unwrap(); - let collection_objects = self.collection_objects.as_ref().unwrap(); - let mut bbs = vec![]; - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - if let Some(bb) = gcm( - &mut editor, - &def_uses[idx], - &reverse_postorders[idx], - &typing[idx], - &control_subgraphs[idx], - &doms[idx], - &fork_join_maps[idx], - &loops[idx], - collection_objects, - ) { - bbs.push(bb); - } - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - if bbs.len() == self.module.functions.len() { - self.bbs = Some(bbs); - break; - } - }, - Pass::FloatCollections => { - self.make_def_uses(); - self.make_typing(); - self.make_callgraph(); - let def_uses = self.def_uses.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - let devices = device_placement(&self.module.functions, &callgraph); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - float_collections(&mut editors, typing, callgraph, &devices); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - self.clear_analyses(); - } - Pass::InferSchedules => { - self.make_def_uses(); - self.make_fork_join_maps(); - self.make_reduce_cycles(); - let def_uses = self.def_uses.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - let reduce_cycles = self.reduce_cycles.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - infer_parallel_reduce( - &mut editor, - &fork_join_maps[idx], - &reduce_cycles[idx], - ); - infer_parallel_fork(&mut editor, &fork_join_maps[idx]); - infer_vectorizable(&mut editor, &fork_join_maps[idx]); - infer_tight_associative(&mut editor, &reduce_cycles[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Verify => { - let ( - def_uses, - reverse_postorders, - typing, - subgraphs, - doms, - postdoms, - fork_join_maps, - ) = verify(&mut self.module) - .expect("PANIC: Failed to verify Hercules IR module."); - - // Verification produces a bunch of analysis results that - // may be useful for later passes. - self.def_uses = Some(def_uses); - self.reverse_postorders = Some(reverse_postorders); - self.typing = Some(typing); - self.control_subgraphs = Some(subgraphs); - self.doms = Some(doms); - self.postdoms = Some(postdoms); - self.fork_join_maps = Some(fork_join_maps); - } - Pass::Xdot(force_analyses) => { - self.make_reverse_postorders(); - if *force_analyses { - self.make_doms(); - self.make_fork_join_maps(); - } - xdot_module( - &self.module, - self.reverse_postorders.as_ref().unwrap(), - self.doms.as_ref(), - self.fork_join_maps.as_ref(), - ); - } - Pass::Codegen(output_dir, module_name) => { - self.make_typing(); - self.make_control_subgraphs(); - self.make_collection_objects(); - self.make_callgraph(); - self.make_fork_join_maps(); - self.make_fork_control_maps(); - self.make_fork_trees(); - self.make_def_uses(); - let typing = self.typing.as_ref().unwrap(); - let control_subgraphs = self.control_subgraphs.as_ref().unwrap(); - let bbs = self.bbs.as_ref().unwrap(); - let collection_objects = self.collection_objects.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - let def_uses = self.def_uses.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - let fork_control_maps = self.fork_control_maps.as_ref().unwrap(); - let fork_trees = self.fork_trees.as_ref().unwrap(); - - let devices = device_placement(&self.module.functions, callgraph); - - let mut rust_rt = String::new(); - let mut llvm_ir = String::new(); - let mut cuda_ir = String::new(); - for idx in 0..self.module.functions.len() { - match devices[idx] { - Device::LLVM => cpu_codegen( - &self.module.functions[idx], - &self.module.types, - &self.module.constants, - &self.module.dynamic_constants, - &typing[idx], - &control_subgraphs[idx], - &bbs[idx], - &mut llvm_ir, - ) - .unwrap(), - Device::AsyncRust => rt_codegen( - FunctionID::new(idx), - &self.module, - &typing[idx], - &control_subgraphs[idx], - &bbs[idx], - &collection_objects, - &callgraph, - &devices, - &mut rust_rt, - ) - .unwrap(), - Device::CUDA => gpu_codegen( - &self.module.functions[idx], - &self.module.types, - &self.module.constants, - &self.module.dynamic_constants, - &typing[idx], - &control_subgraphs[idx], - &bbs[idx], - &collection_objects[&FunctionID::new(idx)], - &def_uses[idx], - &fork_join_maps[idx], - &fork_control_maps[idx], - &fork_trees[idx], - &mut cuda_ir, - ) - .unwrap(), - _ => todo!(), - } - } - println!("{}", llvm_ir); - println!("{}", cuda_ir); - println!("{}", rust_rt); - - let output_archive = format!("{}/lib{}.a", output_dir, module_name); - println!("{}", output_archive); - - // Write the LLVM IR into a temporary file. - let tmp_dir = TempDir::new().unwrap(); - let mut llvm_path = tmp_dir.path().to_path_buf(); - llvm_path.push(format!("{}.ll", module_name)); - println!("{}", llvm_path.display()); - let mut file = File::create(&llvm_path) - .expect("PANIC: Unable to open output LLVM IR file."); - file.write_all(llvm_ir.as_bytes()) - .expect("PANIC: Unable to write output LLVM IR file contents."); - - // Compile LLVM IR into an ELF object file. - let llvm_object = format!("{}/{}_cpu.o", tmp_dir.path().to_str().unwrap(), module_name); - let mut clang_process = Command::new("clang") - .arg(&llvm_path) - .arg("-c") - .arg("-O3") - .arg("-march=native") - .arg("-o") - .arg(&llvm_object) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .spawn() - .expect("Error running clang. Is it installed?"); - assert!(clang_process.wait().unwrap().success()); - - let mut ar_args = vec!["crus", &output_archive, &llvm_object]; - - let cuda_object = format!("{}/{}_cuda.o", tmp_dir.path().to_str().unwrap(), module_name); - if cfg!(feature = "cuda") { - // Write the CUDA IR into a temporary file. - let mut cuda_path = tmp_dir.path().to_path_buf(); - cuda_path.push(format!("{}.cu", module_name)); - let mut file = File::create(&cuda_path) - .expect("PANIC: Unable to open output CUDA IR file."); - file.write_all(cuda_ir.as_bytes()) - .expect("PANIC: Unable to write output CUDA IR file contents."); - - let cuda_text_path = format!("{}.cu", module_name); - let mut cuda_text_file = File::create(&cuda_text_path) - .expect("PANIC: Unable to open CUDA IR text file."); - cuda_text_file.write_all(cuda_ir.as_bytes()) - .expect("PANIC: Unable to write CUDA IR text file contents."); - - let mut nvcc_process = Command::new("nvcc") - .arg("-c") - .arg("-O3") - .arg("-o") - .arg(&cuda_object) - .arg(&cuda_path) - .spawn() - .expect("Error running nvcc. Is it installed?"); - assert!(nvcc_process.wait().unwrap().success()); - - ar_args.push(&cuda_object); - } - - let mut ar_process = Command::new("ar") - .args(&ar_args) - .spawn() - .expect("Error running ar. Is it installed?"); - assert!(ar_process.wait().unwrap().success()); - - // Write the Rust runtime into a file. - let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); - println!("{}", output_rt); - let mut file = File::create(&output_rt) - .expect("PANIC: Unable to open output Rust runtime file."); - file.write_all(rust_rt.as_bytes()) - .expect("PANIC: Unable to write output Rust runtime file contents."); - - let rt_text_path = format!("{}.hrt", module_name); - let mut rt_text_file = File::create(&rt_text_path) - .expect("PANIC: Unable to open Rust runtime text file."); - rt_text_file.write_all(rust_rt.as_bytes()) - .expect("PANIC: Unable to write Rust runtime text file contents."); - - } - Pass::Serialize(output_file) => { - let module_contents: Vec<u8> = postcard::to_allocvec(&self.module).unwrap(); - let mut file = File::create(&output_file) - .expect("PANIC: Unable to open output module file."); - file.write_all(&module_contents) - .expect("PANIC: Unable to write output module file contents."); - } - } - eprintln!("Ran pass: {:?}", pass); - } - } - - fn clear_analyses(&mut self) { - self.def_uses = None; - self.reverse_postorders = None; - self.typing = None; - self.control_subgraphs = None; - self.doms = None; - self.postdoms = None; - self.fork_join_maps = None; - self.fork_join_nests = None; - self.loops = None; - self.reduce_cycles = None; - self.data_nodes_in_fork_joins = None; - self.bbs = None; - self.collection_objects = None; - self.callgraph = None; - } - - pub fn get_module(self) -> Module { - self.module - } - - fn fix_deleted_functions(&mut self, id_mapping: &[Option<usize>]) { - let mut idx = 0; - - // Rust does not like enumerate here, so use - // idx outside as a hack to make it happy. - self.module.functions.retain(|_| { - idx += 1; - id_mapping[idx - 1].is_some() - }); - } -} diff --git a/hercules_opt/src/sroa.rs b/hercules_opt/src/sroa.rs index 6461ad717d5b7cbeac0e237916d11d4a3a4ae6d6..66d11d69c33d1a77ce5a54bfd13ad88618916bfd 100644 --- a/hercules_opt/src/sroa.rs +++ b/hercules_opt/src/sroa.rs @@ -33,6 +33,8 @@ use crate::*; * * - Read: the read node reads primitive fields from product values - these get * replaced by a direct use of the field value + * A read can also extract a product from an array or sum; the value read out + * will be broken into individual fields (by individual reads from the array) * * - Write: the write node writes primitive fields in product values - these get * replaced by a direct def of the field value @@ -54,15 +56,12 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: // for the call's arguments or the return's value let mut call_return_nodes: Vec<NodeID> = vec![]; - let func = editor.func(); - for node in reverse_postorder { - match func.nodes[node.idx()] { + match &editor.func().nodes[node.idx()] { Node::Phi { .. } | Node::Reduce { .. } | Node::Parameter { .. } | Node::Constant { .. } - | Node::Write { .. } | Node::Ternary { first: _, second: _, @@ -70,8 +69,211 @@ pub fn sroa(editor: &mut FunctionEditor, reverse_postorder: &Vec<NodeID>, types: op: TernaryOperator::Select, } if editor.get_type(types[&node]).is_product() => product_nodes.push(*node), - Node::Read { collect, .. } if editor.get_type(types[&collect]).is_product() => { - product_nodes.push(*node) + Node::Write { + collect, + data, + indices, + } => { + let data = *data; + let collect = *collect; + + // For a write, we may need to split it into two pieces if the it contains a mix of + // field and non-field indices + let (fields_write, write_prod_into_non) = { + let mut fields = vec![]; + let mut remainder = vec![]; + + let mut indices = indices.iter(); + while let Some(idx) = indices.next() { + if idx.is_field() { + fields.push(idx.clone()); + } else { + remainder.push(idx.clone()); + remainder.extend(indices.cloned()); + break; + } + } + + if fields.is_empty() { + if editor.get_type(types[&data]).is_product() { + (None, Some((*node, collect, remainder))) + } else { + (None, None) + } + } else if remainder.is_empty() { + (Some(*node), None) + } else { + // Here we perform the split into two writes + // We need to find the type of the collection that will be extracted from + // the collection being modified when we read it at the fields index + let after_fields_type = type_at_index(editor, types[&collect], &fields); + + let mut inner_collection = None; + let mut fields_write = None; + let mut remainder_write = None; + editor.edit(|mut edit| { + let read_inner = edit.add_node(Node::Read { + collect, + indices: fields.clone().into(), + }); + types.insert(read_inner, after_fields_type); + product_nodes.push(read_inner); + inner_collection = Some(read_inner); + + let rem_write = edit.add_node(Node::Write { + collect: read_inner, + data, + indices: remainder.clone().into(), + }); + types.insert(rem_write, after_fields_type); + remainder_write = Some(rem_write); + + let complete_write = edit.add_node(Node::Write { + collect, + data: rem_write, + indices: fields.into(), + }); + types.insert(complete_write, types[&collect]); + fields_write = Some(complete_write); + + edit = edit.replace_all_uses(*node, complete_write)?; + edit.delete_node(*node) + }); + let inner_collection = inner_collection.unwrap(); + let fields_write = fields_write.unwrap(); + let remainder_write = remainder_write.unwrap(); + + if editor.get_type(types[&data]).is_product() { + ( + Some(fields_write), + Some((remainder_write, inner_collection, remainder)), + ) + } else { + (Some(fields_write), None) + } + } + }; + + if let Some(node) = fields_write { + product_nodes.push(node); + } + + if let Some((write_node, collection, index)) = write_prod_into_non { + let node = write_node; + // If we're writing a product into a non-product we need to replace the write + // by a sequence of writes that read each field of the product and write them + // into the collection, then those write nodes can be ignored for SROA but the + // reads will be handled by SROA + + // The value being written must be the data and so must be a product + assert!(editor.get_type(types[&data]).is_product()); + let fields = generate_reads(editor, types[&data], data); + + let mut collection = collection; + let collection_type = types[&collection]; + + fields.for_each(|field: &Vec<Index>, val: &NodeID| { + product_nodes.push(*val); + editor.edit(|mut edit| { + collection = edit.add_node(Node::Write { + collect: collection, + data: *val, + indices: index + .iter() + .chain(field) + .cloned() + .collect::<Vec<_>>() + .into(), + }); + types.insert(collection, collection_type); + Ok(edit) + }); + }); + + editor.edit(|mut edit| { + edit = edit.replace_all_uses(node, collection)?; + edit.delete_node(node) + }); + } + } + Node::Read { collect, indices } => { + // For a read, we split the read into a series of reads where each piece has either + // only field reads or no field reads. Those with fields are the only ones + // considered during SROA but any read whose collection is not a product but + // produces a product (i.e. if there's an array of products) then following the + // read we replace the read that produces a product by reads of each field and add + // that information to the node map for the rest of SROA (this produces some reads + // that mix types of indices, since we only read leaves but that's okay since those + // reads are not handled by SROA) + let indices = indices + .chunk_by(|i, j| i.is_field() && j.is_field()) + .collect::<Vec<_>>(); + + let (field_reads, non_fields_produce_prod) = { + if indices.len() == 0 { + // If there are no indices then there were no indices originally, this is + // only used with clones of arrays + (vec![], vec![]) + } else if indices.len() == 1 { + // If once we perform chunking there's only one set of indices, we can just + // use the original node + if indices[0][0].is_field() { + (vec![*node], vec![]) + } else if editor.get_type(types[node]).is_product() { + (vec![], vec![*node]) + } else { + (vec![], vec![]) + } + } else { + let mut field_reads = vec![]; + let mut non_field = vec![]; + + // To construct the multiple reads we need to track the current collection + // and the type of that collection + let mut collect = *collect; + let mut typ = types[&collect]; + + let indices = indices + .into_iter() + .map(|i| i.into_iter().cloned().collect::<Vec<_>>()) + .collect::<Vec<_>>(); + for index in indices { + let is_field_read = index[0].is_field(); + let field_type = type_at_index(editor, typ, &index); + + editor.edit(|mut edit| { + collect = edit.add_node(Node::Read { + collect, + indices: index.into(), + }); + types.insert(collect, field_type); + typ = field_type; + Ok(edit) + }); + + if is_field_read { + field_reads.push(collect); + } else if editor.get_type(typ).is_product() { + non_field.push(collect); + } + } + + // Replace all uses of the original read (with mixed indices) with the + // newly constructed reads + editor.edit(|mut edit| { + edit = edit.replace_all_uses(*node, collect)?; + edit.delete_node(*node) + }); + + (field_reads, non_field) + } + }; + + product_nodes.extend(field_reads); + + for node in non_fields_produce_prod { + field_map.insert(node, generate_reads(editor, types[&node], node)); + } } // We add all calls to the call/return list and check their arguments later @@ -516,8 +718,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { IndexTree::Node(ts) => ts[i].lookup_idx(idx, n + 1), } } else { - // TODO: This could be hit because of an array inside of a product - panic!("Error handling lookup of field"); + panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this"); } } else { self @@ -548,7 +749,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } } else { - panic!("Error handling set of field"); + panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this"); } } else { IndexTree::Leaf(val) @@ -579,7 +780,7 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } } else { - panic!("Error handling set of field"); + panic!("Cannot process a mix of fields and other read indices; pre-processing should prevent this"); } } else { val @@ -658,6 +859,38 @@ impl<T: std::fmt::Debug> IndexTree<T> { } } +// Given the editor, type of some collection, and a list of indices to access that type at, returns +// the TypeID of accessing the collection at the given indices +fn type_at_index(editor: &FunctionEditor, typ: TypeID, idx: &[Index]) -> TypeID { + let mut typ = typ; + for index in idx { + match index { + Index::Field(i) => { + let Type::Product(ref ts) = *editor.get_type(typ) else { + panic!("Accessing a field of a non-product type; did typechecking succeed?"); + }; + typ = ts[*i]; + } + Index::Variant(i) => { + let Type::Summation(ref ts) = *editor.get_type(typ) else { + panic!( + "Accessing a variant of a non-summation type; did typechecking succeed?" + ); + }; + typ = ts[*i]; + } + Index::Position(pos) => { + let Type::Array(elem, ref dims) = *editor.get_type(typ) else { + panic!("Accessing an array position of a non-array type; did typechecking succeed?"); + }; + assert!(pos.len() == dims.len(), "Read mismatch array dimensions"); + typ = elem; + } + } + } + return typ; +} + // Given a product value val of type typ, constructs a copy of that value by extracting all fields // from that value and then writing them into a new constant // This process also adds all the read nodes that are generated into the read_list so that the @@ -696,7 +929,7 @@ fn reconstruct_product( } // Given a node val of type typ, adds nodes to the function which read all (leaf) fields of val and -// returns a list of pairs of the indices and the node that reads that index +// returns an IndexTree that tracks the nodes reading each leaf field fn generate_reads(editor: &mut FunctionEditor, typ: TypeID, val: NodeID) -> IndexTree<NodeID> { let res = generate_reads_at_index(editor, typ, val, vec![]); res diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 6239a644c5f63a37fe2e7b48450b40388288c088..aa0d53fe32f855d5a1f9ad689fab44ef330e7a76 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -243,7 +243,7 @@ pub(crate) fn substitute_dynamic_constants_in_node( /* * Top level function to make a function have only a single return. */ -pub(crate) fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { +pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { let returns: Vec<NodeID> = (0..editor.func().nodes.len()) .filter(|idx| editor.func().nodes[*idx].is_return()) .map(NodeID::new) @@ -293,7 +293,7 @@ pub(crate) fn contains_between_control_flow(func: &Function) -> bool { * Top level function to ensure a Hercules function contains at least one * control node that isn't the start or return nodes. */ -pub(crate) fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> { +pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> { if !contains_between_control_flow(editor.func()) { let ret = editor .node_ids() diff --git a/hercules_rt/build.rs b/hercules_rt/build.rs index 51fdfa2330b613e3c6304aad60108341472430c3..8f82e86cab73b28f44222a8a3d8eaa4c49f04218 100644 --- a/hercules_rt/build.rs +++ b/hercules_rt/build.rs @@ -19,6 +19,7 @@ fn main() { println!("cargo::rustc-link-search=native={}", out_dir); println!("cargo::rustc-link-search=native=/usr/lib/x86_64-linux-gnu/"); println!("cargo::rustc-link-search=native=/usr/local/cuda/lib64"); + println!("cargo::rustc-link-search=native=/opt/cuda/lib/"); println!("cargo::rustc-link-lib=static=rtdefs"); println!("cargo::rustc-link-lib=cudart"); println!("cargo::rerun-if-changed=src/rtdefs.cu"); diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index 2a96970e88e2b432c8c9c2ed896dc1c623781cbd..60d3470edc8cf42dd105cb4a2835a2a86d00fe3f 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -1,8 +1,8 @@ use std::alloc::{alloc, alloc_zeroed, dealloc, Layout}; use std::marker::PhantomData; -use std::mem::swap; use std::ptr::{copy_nonoverlapping, NonNull}; use std::slice::from_raw_parts; +use std::sync::atomic::{AtomicUsize, Ordering}; #[cfg(feature = "cuda")] extern "C" { @@ -14,48 +14,115 @@ extern "C" { fn copy_cuda_to_cuda(dst: *mut u8, src: *mut u8, size: usize); } +/* + * Each object needs to get assigned a unique ID. + */ +static NUM_OBJECTS: AtomicUsize = AtomicUsize::new(1); + /* * An in-memory collection object that can be used by functions compiled by the - * Hercules compiler. + * Hercules compiler. Memory objects can be in these states: + * + * 1. Shared CPU - the object has a shared reference to some CPU memory, usually + * from the programmer using the Hercules RT API. + * 2. Exclusive CPU - the object has an exclusive reference to some CPU memory, + * usually from the programmer using the Hercules RT API. + * 3. Owned CPU - the object owns some allocated CPU memory. + * 4. Owned GPU - the object owns some allocated GPU memory. + * + * A single object can be in some combination of these objects at the same time. + * Only some combinations are valid, because only some combinations are + * reachable. Under this assumption, we can model an object's placement as a + * state machine, where states are combinations of the aforementioned states, + * and actions are requests on the CPU or GPU, immutably or mutably. Here's the + * state transition table: + * + * Shared CPU = CS + * Exclusive CPU = CE + * Owned CPU = CO + * Owned GPU = GO + * + * CPU Mut CPU GPU Mut GPU + * *--------------------------------------- + * CS | CS CO CS,GO GO + * CE | CE CE CE,GO GO + * CO | CO CO CO,GO GO + * GO | CO CO GO GO + * CS,GO | CS,GO CO CS,GO GO + * CE,GO | CE,GO CE CE,GO GO + * CO,GO | CO,GO CO CO,GO GO + * | + * + * A HerculesBox cannot be cloned, because it may have be a mutable reference to + * some CPU memory. */ +#[derive(Debug)] pub struct HerculesBox<'a> { - cpu_shared: Option<NonNull<u8>>, - cpu_exclusive: Option<NonNull<u8>>, - cpu_owned: Option<NonNull<u8>>, + cpu_shared: Option<NonOwned<'a>>, + cpu_exclusive: Option<NonOwned<'a>>, + cpu_owned: Option<Owned>, #[cfg(feature = "cuda")] - cuda_owned: Option<NonNull<u8>>, + cuda_owned: Option<Owned>, size: usize, + id: usize, +} + +#[derive(Clone, Debug)] +struct NonOwned<'a> { + ptr: NonNull<u8>, + offset: usize, _phantom: PhantomData<&'a u8>, } +#[derive(Clone, Debug)] +struct Owned { + ptr: NonNull<u8>, + alloc_size: usize, + offset: usize, +} + impl<'b, 'a: 'b> HerculesBox<'a> { pub fn from_slice<T>(slice: &'a [T]) -> Self { + let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }; + let size = slice.len() * size_of::<T>(); + let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed); HerculesBox { - cpu_shared: Some(unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }), + cpu_shared: Some(NonOwned { + ptr, + offset: 0, + _phantom: PhantomData, + }), cpu_exclusive: None, cpu_owned: None, #[cfg(feature = "cuda")] cuda_owned: None, - size: slice.len() * size_of::<T>(), - _phantom: PhantomData, + size, + id, } } pub fn from_slice_mut<T>(slice: &'a mut [T]) -> Self { + let ptr = unsafe { NonNull::new_unchecked(slice.as_ptr() as *mut u8) }; + let size = slice.len() * size_of::<T>(); + let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed); HerculesBox { cpu_shared: None, - cpu_exclusive: Some(unsafe { NonNull::new_unchecked(slice.as_mut_ptr() as *mut u8) }), + cpu_exclusive: Some(NonOwned { + ptr, + offset: 0, + _phantom: PhantomData, + }), cpu_owned: None, #[cfg(feature = "cuda")] cuda_owned: None, - size: slice.len() * size_of::<T>(), - _phantom: PhantomData, + size, + id, } } @@ -65,69 +132,90 @@ impl<'b, 'a: 'b> HerculesBox<'a> { } unsafe fn get_cpu_ptr(&self) -> Option<NonNull<u8>> { - self.cpu_owned.or(self.cpu_exclusive).or(self.cpu_shared) + self.cpu_owned + .as_ref() + .map(|obj| obj.ptr.byte_add(obj.offset)) + .or(self + .cpu_exclusive + .as_ref() + .map(|obj| obj.ptr.byte_add(obj.offset))) + .or(self + .cpu_shared + .as_ref() + .map(|obj| obj.ptr.byte_add(obj.offset))) } #[cfg(feature = "cuda")] unsafe fn get_cuda_ptr(&self) -> Option<NonNull<u8>> { self.cuda_owned + .as_ref() + .map(|obj| obj.ptr.byte_add(obj.offset)) } unsafe fn allocate_cpu(&mut self) -> NonNull<u8> { - if let Some(ptr) = self.cpu_owned { - ptr + if let Some(obj) = self.cpu_owned.as_ref() { + obj.ptr } else { let ptr = NonNull::new(alloc(Layout::from_size_align_unchecked(self.size, 16))).unwrap(); - self.cpu_owned = Some(ptr); + self.cpu_owned = Some(Owned { + ptr, + alloc_size: self.size, + offset: 0, + }); ptr } } #[cfg(feature = "cuda")] unsafe fn allocate_cuda(&mut self) -> NonNull<u8> { - if let Some(ptr) = self.cuda_owned { - ptr + if let Some(obj) = self.cuda_owned.as_ref() { + obj.ptr } else { - let ptr = cuda_alloc(self.size); - self.cuda_owned = Some(NonNull::new(ptr).unwrap()); - self.cuda_owned.unwrap() + let ptr = NonNull::new(cuda_alloc(self.size)).unwrap(); + self.cuda_owned = Some(Owned { + ptr, + alloc_size: self.size, + offset: 0, + }); + ptr } } unsafe fn deallocate_cpu(&mut self) { - if let Some(ptr) = self.cpu_owned { + if let Some(obj) = self.cpu_owned.take() { dealloc( - ptr.as_ptr(), - Layout::from_size_align_unchecked(self.size, 16), + obj.ptr.as_ptr(), + Layout::from_size_align_unchecked(obj.alloc_size, 16), ); - self.cpu_owned = None; } } #[cfg(feature = "cuda")] unsafe fn deallocate_cuda(&mut self) { - if let Some(ptr) = self.cuda_owned { - cuda_dealloc(ptr.as_ptr()); - self.cuda_owned = None; + if let Some(obj) = self.cuda_owned.take() { + cuda_dealloc(obj.ptr.as_ptr()); } } pub unsafe fn __zeros(size: u64) -> Self { - assert_ne!(size, 0); let size = size as usize; + let id = NUM_OBJECTS.fetch_add(1, Ordering::Relaxed); HerculesBox { cpu_shared: None, cpu_exclusive: None, - cpu_owned: Some( - NonNull::new(alloc_zeroed(Layout::from_size_align_unchecked(size, 16))).unwrap(), - ), + cpu_owned: Some(Owned { + ptr: NonNull::new(alloc_zeroed(Layout::from_size_align_unchecked(size, 16))) + .unwrap(), + alloc_size: size, + offset: 0, + }), #[cfg(feature = "cuda")] cuda_owned: None, - size: size, - _phantom: PhantomData, + size, + id, } } @@ -141,16 +229,10 @@ impl<'b, 'a: 'b> HerculesBox<'a> { cuda_owned: None, size: 0, - _phantom: PhantomData, + id: 0, } } - pub unsafe fn __take(&mut self) -> Self { - let mut ret = Self::__null(); - swap(&mut ret, self); - ret - } - pub unsafe fn __cpu_ptr(&mut self) -> *mut u8 { if let Some(ptr) = self.get_cpu_ptr() { return ptr.as_ptr(); @@ -167,12 +249,15 @@ impl<'b, 'a: 'b> HerculesBox<'a> { pub unsafe fn __cpu_ptr_mut(&mut self) -> *mut u8 { let cpu_ptr = self.__cpu_ptr(); - if Some(cpu_ptr) == self.cpu_shared.map(|nn| nn.as_ptr()) { + if Some(cpu_ptr) == self.cpu_shared.as_ref().map(|obj| obj.ptr.as_ptr()) { self.allocate_cpu(); - copy_nonoverlapping(cpu_ptr, self.cpu_owned.unwrap().as_ptr(), self.size); + copy_nonoverlapping( + cpu_ptr, + self.cpu_owned.as_ref().unwrap().ptr.as_ptr(), + self.size, + ); } self.cpu_shared = None; - self.cpu_exclusive = None; #[cfg(feature = "cuda")] self.deallocate_cuda(); cpu_ptr @@ -198,6 +283,47 @@ impl<'b, 'a: 'b> HerculesBox<'a> { self.deallocate_cpu(); cuda_ptr } + + pub unsafe fn __clone(&self) -> Self { + Self { + cpu_shared: self.cpu_shared.clone(), + cpu_exclusive: self.cpu_exclusive.clone(), + cpu_owned: self.cpu_owned.clone(), + #[cfg(feature = "cuda")] + cuda_owned: self.cuda_owned.clone(), + size: self.size, + id: self.id, + } + } + + pub unsafe fn __forget(&mut self) { + self.cpu_owned = None; + #[cfg(feature = "cuda")] + { + self.cuda_owned = None; + } + } + + pub unsafe fn __offset(&mut self, offset: u64, size: u64) { + if let Some(obj) = self.cpu_shared.as_mut() { + obj.offset += offset as usize; + } + if let Some(obj) = self.cpu_exclusive.as_mut() { + obj.offset += offset as usize; + } + if let Some(obj) = self.cpu_owned.as_mut() { + obj.offset += offset as usize; + } + #[cfg(feature = "cuda")] + if let Some(obj) = self.cuda_owned.as_mut() { + obj.offset += offset as usize; + } + self.size = size as usize; + } + + pub unsafe fn __cmp_ids(&self, other: &HerculesBox<'_>) -> bool { + self.id == other.id + } } impl<'a> Drop for HerculesBox<'a> { diff --git a/hercules_samples/call/Cargo.toml b/hercules_samples/call/Cargo.toml index a5a44c2e061043101e377bff1492a4eaba995927..38a3d1b6cf4a247b4078e62d380ea942d7fb6e09 100644 --- a/hercules_samples/call/Cargo.toml +++ b/hercules_samples/call/Cargo.toml @@ -12,6 +12,7 @@ juno_build = { path = "../../juno_build" } [dependencies] juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } rand = "*" async-std = "*" with_builtin_macros = "0.1.0" diff --git a/hercules_samples/ccp/Cargo.toml b/hercules_samples/ccp/Cargo.toml index 313fd179ed25292c51942ff8ddb9da77da0dee79..c10aced18c8a2261852e49f5f569bf13e689f72b 100644 --- a/hercules_samples/ccp/Cargo.toml +++ b/hercules_samples/ccp/Cargo.toml @@ -9,6 +9,7 @@ cuda = ["juno_build/cuda"] [dependencies] juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } rand = "*" async-std = "*" with_builtin_macros = "0.1.0" diff --git a/hercules_samples/dot/build.rs b/hercules_samples/dot/build.rs index 2a239bc6c3ebd3780cb15358375c59bdfb2e25ae..4cfd2a87fba14d3c542bb54806a65da2d1a9b8f5 100644 --- a/hercules_samples/dot/build.rs +++ b/hercules_samples/dot/build.rs @@ -4,6 +4,9 @@ fn main() { JunoCompiler::new() .ir_in_src("dot.hir") .unwrap() + //.schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/hercules_samples/dot/src/cpu.sch b/hercules_samples/dot/src/cpu.sch new file mode 100644 index 0000000000000000000000000000000000000000..58a7266df5c71232aae41a969dcf286ec3a98385 --- /dev/null +++ b/hercules_samples/dot/src/cpu.sch @@ -0,0 +1,12 @@ +gvn(*); +phi-elim(*); +dce(*); + +auto-outline(*); + +ip-sroa(*); +sroa(*); +unforkify(*); +dce(*); + +gcm(*); diff --git a/hercules_samples/dot/src/gpu.sch b/hercules_samples/dot/src/gpu.sch new file mode 100644 index 0000000000000000000000000000000000000000..956eb99628a03a3efb3d77e97d93a8cb677bbd6a --- /dev/null +++ b/hercules_samples/dot/src/gpu.sch @@ -0,0 +1,13 @@ +gvn(*); +phi-elim(*); +dce(*); + +auto-outline(*); +gpu(*); +host(dot); + +ip-sroa(*); +sroa(*); +dce(*); + +gcm(*); diff --git a/hercules_samples/fac/Cargo.toml b/hercules_samples/fac/Cargo.toml index 350e365890ecb3ef2e4337b97d22839c9e8051b9..72f82672b18e1bf9529075c233a3efab742fda81 100644 --- a/hercules_samples/fac/Cargo.toml +++ b/hercules_samples/fac/Cargo.toml @@ -13,6 +13,7 @@ juno_build = { path = "../../juno_build" } [dependencies] clap = { version = "*", features = ["derive"] } juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } rand = "*" async-std = "*" with_builtin_macros = "0.1.0" diff --git a/hercules_samples/matmul/build.rs b/hercules_samples/matmul/build.rs index 08478deaac459d9a94f79fdabce37da9a1205f89..f895af867a019dfd23381a4df2d9a02f80a032f8 100644 --- a/hercules_samples/matmul/build.rs +++ b/hercules_samples/matmul/build.rs @@ -4,6 +4,9 @@ fn main() { JunoCompiler::new() .ir_in_src("matmul.hir") .unwrap() + //.schedule_in_src(if cfg!(feature = "cuda") { "gpu.sch" } else { "cpu.sch" }) + .schedule_in_src("cpu.sch") + .unwrap() .build() .unwrap(); } diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch new file mode 100644 index 0000000000000000000000000000000000000000..42dda6e3fc02b23e72ca31ef89a83f020bc9bebc --- /dev/null +++ b/hercules_samples/matmul/src/cpu.sch @@ -0,0 +1,14 @@ +gvn(*); +phi-elim(*); +dce(*); + +auto-outline(*); + +ip-sroa(*); +sroa(*); +fork-split(*); +unforkify(*); +dce(*); +float-collections(*); + +gcm(*); diff --git a/hercules_samples/matmul/src/gpu.sch b/hercules_samples/matmul/src/gpu.sch new file mode 100644 index 0000000000000000000000000000000000000000..9067a1908c6615a56f917cb4eb435ace93e9ba3a --- /dev/null +++ b/hercules_samples/matmul/src/gpu.sch @@ -0,0 +1,15 @@ +gvn(*); +phi-elim(*); +dce(*); + +auto-outline(*); +gpu(*); +host(matmul); + +ip-sroa(*); +sroa(*); +dce(*); +float-collections(*); + +gcm(*); +xdot[true](*); diff --git a/hercules_tools/hercules_driver/Cargo.toml b/hercules_tools/hercules_driver/Cargo.toml deleted file mode 100644 index ad9397b140052539a341084646d5f7fde1cbafff..0000000000000000000000000000000000000000 --- a/hercules_tools/hercules_driver/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "hercules_driver" -version = "0.1.0" -authors = ["Russel Arbore <rarbore2@illinois.edu>"] -edition = "2021" - -[dependencies] -clap = { version = "*", features = ["derive"] } -ron = "*" -postcard = { version = "*", features = ["alloc"] } -hercules_ir = { path = "../../hercules_ir" } -hercules_opt = { path = "../../hercules_opt" } diff --git a/hercules_tools/hercules_driver/src/main.rs b/hercules_tools/hercules_driver/src/main.rs deleted file mode 100644 index a2550022129029387664fb4327528d09078c2e02..0000000000000000000000000000000000000000 --- a/hercules_tools/hercules_driver/src/main.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::fs::File; -use std::io::prelude::*; -use std::path::Path; - -use clap::Parser; - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - file: String, - passes: String, -} - -fn parse_file_from_hir(path: &Path) -> hercules_ir::ir::Module { - let mut file = File::open(path).expect("PANIC: Unable to open input file."); - let mut contents = String::new(); - file.read_to_string(&mut contents) - .expect("PANIC: Unable to read input file contents."); - hercules_ir::parse::parse(&contents).expect("PANIC: Failed to parse Hercules IR file.") -} - -fn parse_file_from_hbin(path: &Path) -> hercules_ir::ir::Module { - let mut file = File::open(path).expect("PANIC: Unable to open input file."); - let mut buffer = vec![]; - file.read_to_end(&mut buffer).unwrap(); - postcard::from_bytes(&buffer).unwrap() -} - -fn main() { - let args = Args::parse(); - assert!( - args.file.ends_with(".hir") || args.file.ends_with(".hbin"), - "PANIC: Running hercules_driver on a file without a .hir or .hbin extension." - ); - let path = Path::new(&args.file); - let module = if args.file.ends_with(".hir") { - parse_file_from_hir(path) - } else { - parse_file_from_hbin(path) - }; - - let mut pm = hercules_opt::pass::PassManager::new(module); - let passes: Vec<hercules_opt::pass::Pass> = args - .passes - .split(char::is_whitespace) - .map(|pass_str| { - assert_ne!( - pass_str, "", - "PANIC: Can't interpret empty pass name. Try giving a list of pass names." - ); - ron::from_str(pass_str).expect("PANIC: Couldn't parse list of passes.") - }) - .collect(); - for pass in passes { - pm.add_pass(pass); - } - pm.run_passes(); -} diff --git a/juno_build/src/lib.rs b/juno_build/src/lib.rs index 0c676e4c1b203c53e98a8430e0f2354104540e07..b30a5d250d1fd42c365d6d3da9481afe91c1a1e1 100644 --- a/juno_build/src/lib.rs +++ b/juno_build/src/lib.rs @@ -1,12 +1,9 @@ use juno_compiler::*; use std::env::{current_dir, var}; -use std::fmt::Write; use std::fs::{create_dir_all, read_to_string}; use std::path::{Path, PathBuf}; -use with_builtin_macros::with_builtin; - // JunoCompiler is used to compile juno files into a library and manifest file appropriately to // import the definitions into a rust project via the juno! macro defined below // You can also specify a Hercules IR file instead of a Juno file and this will compile that IR @@ -15,9 +12,7 @@ pub struct JunoCompiler { ir_src_path: Option<PathBuf>, src_path: Option<PathBuf>, out_path: Option<PathBuf>, - verify: JunoVerify, - x_dot: bool, - schedule: JunoSchedule, + schedule: Option<String>, } impl JunoCompiler { @@ -26,9 +21,7 @@ impl JunoCompiler { ir_src_path: None, src_path: None, out_path: None, - verify: JunoVerify::None, - x_dot: false, - schedule: JunoSchedule::None, + schedule: None, } } @@ -119,35 +112,6 @@ impl JunoCompiler { ); } - pub fn verify(mut self, enabled: bool) -> Self { - if enabled && !self.verify.verify() { - self.verify = JunoVerify::JunoOpts; - } else if !enabled && self.verify.verify() { - self.verify = JunoVerify::None; - } - self - } - - pub fn verify_all(mut self, enabled: bool) -> Self { - if enabled { - self.verify = JunoVerify::AllPasses; - } else if !enabled && self.verify.verify_all() { - self.verify = JunoVerify::JunoOpts; - } - self - } - - pub fn x_dot(mut self, enabled: bool) -> Self { - self.x_dot = enabled; - self - } - - // Sets the schedule to be the default schedule - pub fn default_schedule(mut self) -> Self { - self.schedule = JunoSchedule::DefaultSchedule; - self - } - // Set the schedule as a schedule file in the src directory pub fn schedule_in_src<P>(mut self, file: P) -> Result<Self, String> where @@ -166,7 +130,7 @@ impl JunoCompiler { }; path.push("src"); path.push(file.as_ref()); - self.schedule = JunoSchedule::Schedule(path.to_str().unwrap().to_string()); + self.schedule = Some(path.to_str().unwrap().to_string()); // Tell cargo to rerun if the schedule changes println!( @@ -180,11 +144,9 @@ impl JunoCompiler { // Builds the juno file into a libary and a manifest file. pub fn build(self) -> Result<(), String> { let JunoCompiler { - ir_src_path: ir_src_path, - src_path: src_path, + ir_src_path, + src_path, out_path: Some(out_path), - verify, - x_dot, schedule, } = self else { @@ -196,7 +158,7 @@ impl JunoCompiler { if let Some(src_path) = src_path { let src_file = src_path.to_str().unwrap().to_string(); - match compile(src_file, verify, x_dot, schedule, out_dir) { + match compile(src_file, schedule, out_dir) { Ok(()) => Ok(()), Err(errs) => Err(format!("{}", errs)), } @@ -216,7 +178,7 @@ impl JunoCompiler { return Err("Unable to parse Hercules IR file.".to_string()); }; - match compile_ir(ir_mod, None, verify, x_dot, schedule, out_dir, module_name) { + match compile_ir(ir_mod, schedule, out_dir, module_name) { Ok(()) => Ok(()), Err(errs) => Err(format!("{}", errs)), } diff --git a/juno_frontend/Cargo.toml b/juno_frontend/Cargo.toml index dff81db22ac19bd481726ea73f4156b839256ff0..b6d9a71d70786d389649a80ea865c9b4153eef7f 100644 --- a/juno_frontend/Cargo.toml +++ b/juno_frontend/Cargo.toml @@ -33,3 +33,4 @@ ordered-float = "*" phf = { version = "0.11", features = ["macros"] } hercules_ir = { path = "../hercules_ir" } juno_scheduler = { path = "../juno_scheduler" } +juno_utils = { path = "../juno_utils" } diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs index 31c85ae9a03c0668b5ff78875eab8c44e6e34211..2ff9fa9f6b5dd07d30b71eb130aa3b4f019f28e9 100644 --- a/juno_frontend/src/codegen.rs +++ b/juno_frontend/src/codegen.rs @@ -1,8 +1,7 @@ -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{HashMap, VecDeque}; use hercules_ir::ir; use hercules_ir::ir::*; -use juno_scheduler::{FunctionMap, LabeledStructure}; use crate::labeled_builder::LabeledBuilder; use crate::semant; @@ -10,32 +9,12 @@ use crate::semant::{BinaryOp, Expr, Function, Literal, Prg, Stmt, UnaryOp}; use crate::ssa::SSA; use crate::types::{Either, Primitive, TypeSolver, TypeSolverInst}; +use juno_scheduler::labels::*; + // Loop info is a stack of the loop levels, recording the latch and exit block of each type LoopInfo = Vec<(NodeID, NodeID)>; -fn merge_function_maps( - mut functions: HashMap<(usize, Vec<TypeID>), FunctionID>, - funcs: &Vec<Function>, - mut tree: HashMap<FunctionID, Vec<(LabeledStructure, HashSet<usize>)>>, - mut labels: HashMap<FunctionID, HashMap<NodeID, usize>>, -) -> FunctionMap { - let mut res = HashMap::new(); - for ((func_num, type_vars), func_id) in functions.drain() { - let func_labels = tree.remove(&func_id).unwrap(); - let node_labels = labels.remove(&func_id).unwrap(); - let func_info = res.entry(func_num).or_insert(( - funcs[func_num].label_map.clone(), - funcs[func_num].name.clone(), - vec![], - )); - func_info - .2 - .push((type_vars, func_id, func_labels, node_labels)); - } - res -} - -pub fn codegen_program(prg: Prg) -> (Module, FunctionMap) { +pub fn codegen_program(prg: Prg) -> (Module, JunoInfo) { CodeGenerator::build(prg) } @@ -51,10 +30,20 @@ struct CodeGenerator<'a> { // type-solving instantiation (account for the type parameters), the function id, and the entry // block id worklist: VecDeque<(usize, TypeSolverInst<'a>, FunctionID, NodeID)>, + + // The JunoInfo needed for scheduling, which tracks the Juno function names and their + // associated FunctionIDs. + juno_info: JunoInfo, } impl CodeGenerator<'_> { - fn build((types, funcs): Prg) -> (Module, FunctionMap) { + fn build( + Prg { + types, + funcs, + labels, + }: Prg, + ) -> (Module, JunoInfo) { // Identify the functions (by index) which have no type arguments, these are the ones we // ask for code to be generated for let func_idx = @@ -63,13 +52,16 @@ impl CodeGenerator<'_> { .enumerate() .filter_map(|(i, f)| if f.num_type_args == 0 { Some(i) } else { None }); + let juno_info = JunoInfo::new(funcs.iter().map(|f| f.name.clone())); + let mut codegen = CodeGenerator { - builder: LabeledBuilder::create(), + builder: LabeledBuilder::create(labels), types: &types, funcs: &funcs, uid: 0, functions: HashMap::new(), worklist: VecDeque::new(), + juno_info, }; // Add the identifed functions to the list to code-gen @@ -80,7 +72,7 @@ impl CodeGenerator<'_> { codegen.finish() } - fn finish(mut self) -> (Module, FunctionMap) { + fn finish(mut self) -> (Module, JunoInfo) { while !self.worklist.is_empty() { let (idx, mut type_inst, func, entry) = self.worklist.pop_front().unwrap(); self.builder.set_function(func); @@ -94,12 +86,10 @@ impl CodeGenerator<'_> { uid: _, functions, worklist: _, + juno_info, } = self; - let (module, label_tree, label_map) = builder.finish(); - ( - module, - merge_function_maps(functions, funcs, label_tree, label_map), - ) + + (builder.finish(), juno_info) } fn get_function(&mut self, func_idx: usize, ty_args: Vec<TypeID>) -> FunctionID { @@ -138,11 +128,11 @@ impl CodeGenerator<'_> { param_types, return_type, func.num_dyn_consts as u32, - func.num_labels, func.entry, ) .unwrap(); + self.juno_info.func_info.func_ids[func_idx].push(func_id); self.functions.insert((func_idx, ty_args), func_id); self.worklist .push_back((func_idx, solver_inst, func_id, entry)); @@ -164,7 +154,7 @@ impl CodeGenerator<'_> { } // Generate code for the body - let (_, None) = self.codegen_stmt(&func.body, types, &mut ssa, entry, &mut vec![]) else { + let None = self.codegen_stmt(&func.body, types, &mut ssa, entry, &mut vec![]) else { panic!("Generated code for a function missing a return") }; } @@ -176,42 +166,39 @@ impl CodeGenerator<'_> { ssa: &mut SSA, cur_block: NodeID, loops: &mut LoopInfo, - ) -> (LabeledStructure, Option<NodeID>) { + ) -> Option<NodeID> { match stmt { Stmt::AssignStmt { var, val } => { let (val, block) = self.codegen_expr(val, types, ssa, cur_block); ssa.write_variable(*var, block, val); - (LabeledStructure::Expression(val), Some(block)) + Some(block) } Stmt::IfStmt { cond, thn, els } => { let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, cur_block); let (mut if_node, block_then, block_else) = ssa.create_cond(&mut self.builder, block_cond); - let (_, then_end) = self.codegen_stmt(thn, types, ssa, block_then, loops); + let then_end = self.codegen_stmt(thn, types, ssa, block_then, loops); let else_end = match els { None => Some(block_else), - Some(els_stmt) => self.codegen_stmt(els_stmt, types, ssa, block_else, loops).1, + Some(els_stmt) => self.codegen_stmt(els_stmt, types, ssa, block_else, loops), }; let if_id = if_node.id(); if_node.build_if(block_cond, val_cond); self.builder.add_node(if_node); - ( - LabeledStructure::Branch(if_id), - match (then_end, else_end) { - (None, els) => els, - (thn, None) => thn, - (Some(then_term), Some(else_term)) => { - let block_join = ssa.create_block(&mut self.builder); - ssa.add_pred(block_join, then_term); - ssa.add_pred(block_join, else_term); - ssa.seal_block(block_join, &mut self.builder); - Some(block_join) - } - }, - ) + match (then_end, else_end) { + (None, els) => els, + (thn, None) => thn, + (Some(then_term), Some(else_term)) => { + let block_join = ssa.create_block(&mut self.builder); + ssa.add_pred(block_join, then_term); + ssa.add_pred(block_join, else_term); + ssa.seal_block(block_join, &mut self.builder); + Some(block_join) + } + } } Stmt::LoopStmt { cond, update, body } => { // We generate guarded loops, so the first step is to create @@ -236,7 +223,6 @@ impl CodeGenerator<'_> { None => block_latch, Some(stmt) => self .codegen_stmt(stmt, types, ssa, block_latch, loops) - .1 .expect("Loop update should return control"), }; let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, block_updated); @@ -258,7 +244,7 @@ impl CodeGenerator<'_> { // Generate code for the body loops.push((block_latch, block_exit)); - let (_, body_res) = self.codegen_stmt(body, types, ssa, body_block, loops); + let body_res = self.codegen_stmt(body, types, ssa, body_block, loops); loops.pop(); // If the body of the loop can reach some block, we add that block as a predecessor @@ -276,51 +262,46 @@ impl CodeGenerator<'_> { // It is always assumed a loop may be skipped and so control can reach after the // loop - (LabeledStructure::Loop(body_block), Some(block_exit)) + Some(block_exit) } Stmt::ReturnStmt { expr } => { let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, cur_block); let mut return_node = self.builder.allocate_node(); return_node.build_return(block_ret, val_ret); self.builder.add_node(return_node); - (LabeledStructure::Expression(val_ret), None) + None } Stmt::BreakStmt {} => { let last_loop = loops.len() - 1; let (_latch, exit) = loops[last_loop]; ssa.add_pred(exit, cur_block); // The block that contains this break now leads to // the exit - (LabeledStructure::Nothing(), None) + None } Stmt::ContinueStmt {} => { let last_loop = loops.len() - 1; let (latch, _exit) = loops[last_loop]; ssa.add_pred(latch, cur_block); // The block that contains this continue now leads // to the latch - (LabeledStructure::Nothing(), None) + None } - Stmt::BlockStmt { body, label_last } => { - let mut label = None; + Stmt::BlockStmt { body } => { let mut block = Some(cur_block); for stmt in body.iter() { - let (new_label, new_block) = - self.codegen_stmt(stmt, types, ssa, block.unwrap(), loops); + let new_block = self.codegen_stmt(stmt, types, ssa, block.unwrap(), loops); block = new_block; - if label.is_none() || *label_last { - label = Some(new_label); - } } - (label.unwrap_or(LabeledStructure::Nothing()), block) + block } Stmt::ExprStmt { expr } => { - let (val, block) = self.codegen_expr(expr, types, ssa, cur_block); - (LabeledStructure::Expression(val), Some(block)) + let (_val, block) = self.codegen_expr(expr, types, ssa, cur_block); + Some(block) } Stmt::LabeledStmt { label, stmt } => { self.builder.push_label(*label); - let (labeled, res) = self.codegen_stmt(&*stmt, types, ssa, cur_block, loops); - self.builder.pop_label(labeled); - (labeled, res) + let res = self.codegen_stmt(&*stmt, types, ssa, cur_block, loops); + self.builder.pop_label(); + res } } } diff --git a/juno_frontend/src/labeled_builder.rs b/juno_frontend/src/labeled_builder.rs index 3cbf36da93f7a903e838a6ce4445851f22258701..15bed6c26d29485b3133c1a89bea1540994693d2 100644 --- a/juno_frontend/src/labeled_builder.rs +++ b/juno_frontend/src/labeled_builder.rs @@ -1,7 +1,7 @@ use hercules_ir::build::*; use hercules_ir::ir::*; -use juno_scheduler::LabeledStructure; -use std::collections::{HashMap, HashSet}; + +use juno_utils::stringtab::StringTable; // A label-tracking code generator which tracks the current function that we're // generating code for and the what labels apply to the nodes being created @@ -12,39 +12,21 @@ use std::collections::{HashMap, HashSet}; pub struct LabeledBuilder<'a> { pub builder: Builder<'a>, function: Option<FunctionID>, - label: usize, - label_stack: Vec<usize>, - label_tree: HashMap<FunctionID, Vec<(LabeledStructure, HashSet<usize>)>>, - label_map: HashMap<FunctionID, HashMap<NodeID, usize>>, + label_stack: Vec<LabelID>, + label_tab: StringTable, } impl<'a> LabeledBuilder<'a> { - pub fn create() -> LabeledBuilder<'a> { + pub fn create(labels: StringTable) -> LabeledBuilder<'a> { LabeledBuilder { builder: Builder::create(), function: None, - label: 0, // 0 is always the root label label_stack: vec![], - label_tree: HashMap::new(), - label_map: HashMap::new(), + label_tab: labels, } } - pub fn finish( - self, - ) -> ( - Module, - HashMap<FunctionID, Vec<(LabeledStructure, HashSet<usize>)>>, - HashMap<FunctionID, HashMap<NodeID, usize>>, - ) { - let LabeledBuilder { - builder, - function: _, - label: _, - label_stack: _, - label_tree, - label_map, - } = self; - (builder.finish(), label_tree, label_map) + pub fn finish(self) -> Module { + self.builder.finish() } pub fn create_function( @@ -53,7 +35,6 @@ impl<'a> LabeledBuilder<'a> { param_types: Vec<TypeID>, return_type: TypeID, num_dynamic_constants: u32, - num_labels: usize, entry: bool, ) -> Result<(FunctionID, NodeID), String> { let (func, entry) = self.builder.create_function( @@ -64,78 +45,53 @@ impl<'a> LabeledBuilder<'a> { entry, )?; - self.label_tree.insert( - func, - vec![(LabeledStructure::Nothing(), HashSet::new()); num_labels], - ); - self.label_map.insert(func, HashMap::new()); - Ok((func, entry)) } pub fn set_function(&mut self, func: FunctionID) { self.function = Some(func); - self.label = 0; + assert!(self.label_stack.is_empty()); } pub fn push_label(&mut self, label: usize) { - let Some(cur_func) = self.function else { - panic!("Setting label without function") + let Some(label_str) = self.label_tab.lookup_id(label) else { + panic!("Label missing from string table") }; - - let cur_label = self.label; - self.label_stack.push(cur_label); - - for ancestor in self.label_stack.iter() { - self.label_tree.get_mut(&cur_func).unwrap()[*ancestor] - .1 - .insert(label); - } - - self.label = label; + let label_id = self.builder.add_label(&label_str); + self.label_stack.push(label_id); } - pub fn pop_label(&mut self, structure: LabeledStructure) { - let Some(cur_func) = self.function else { - panic!("Setting label without function") - }; - self.label_tree.get_mut(&cur_func).unwrap()[self.label].0 = structure; - let Some(label) = self.label_stack.pop() else { - panic!("Cannot pop label not pushed first") + pub fn pop_label(&mut self) { + let Some(_) = self.label_stack.pop() else { + panic!("No label to pop") }; - self.label = label; } - fn allocate_node_labeled(&mut self, label: usize) -> NodeBuilder { + fn allocate_node_labeled(&mut self, label: Vec<LabelID>) -> NodeBuilder { let Some(func) = self.function else { panic!("Cannot allocate node without function") }; - let builder = self.builder.allocate_node(func); - - self.label_map - .get_mut(&func) - .unwrap() - .insert(builder.id(), label); + let mut builder = self.builder.allocate_node(func); + builder.add_labels(label.into_iter()); builder } pub fn allocate_node(&mut self) -> NodeBuilder { - self.allocate_node_labeled(self.label) + self.allocate_node_labeled(self.label_stack.clone()) } pub fn allocate_node_labeled_with(&mut self, other: NodeID) -> NodeBuilder { let Some(func) = self.function else { panic!("Cannot allocate node without function") }; - let label = self - .label_map - .get(&func) - .unwrap() - .get(&other) - .expect("Other node not labeled"); - - self.allocate_node_labeled(*label) + let labels = self + .builder + .get_labels(func, other) + .iter() + .cloned() + .collect::<Vec<_>>(); + self.allocate_node_labeled(labels) } pub fn add_node(&mut self, builder: NodeBuilder) { diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index 46d3489156d34f5152388d08e6e66a600600b4f8..c85c9d71cf50dcb81de569bf148a3460c4e89f03 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -1,6 +1,5 @@ mod codegen; mod dynconst; -mod env; mod intrinsics; mod labeled_builder; mod locs; @@ -9,11 +8,11 @@ mod semant; mod ssa; mod types; +use juno_scheduler::{schedule_hercules, schedule_juno}; + use std::fmt; use std::path::Path; -use juno_scheduler::FunctionMap; - pub enum JunoVerify { None, JunoOpts, @@ -80,9 +79,7 @@ impl fmt::Display for ErrorMessage { pub fn compile( src_file: String, - verify: JunoVerify, - x_dot: bool, - schedule: JunoSchedule, + schedule: Option<String>, output_dir: String, ) -> Result<(), ErrorMessage> { let src_file_path = Path::new(&src_file); @@ -94,128 +91,18 @@ pub fn compile( return Err(ErrorMessage::SemanticError(msg)); } }; - let (module, func_info) = codegen::codegen_program(prg); + let (module, juno_info) = codegen::codegen_program(prg); - compile_ir( - module, - Some(func_info), - verify, - x_dot, - schedule, - output_dir, - module_name, - ) + schedule_juno(module, juno_info, schedule, output_dir, module_name) + .map_err(|s| ErrorMessage::SchedulingError(s)) } pub fn compile_ir( module: hercules_ir::ir::Module, - func_info: Option<FunctionMap>, - verify: JunoVerify, - x_dot: bool, - schedule: JunoSchedule, + schedule: Option<String>, output_dir: String, module_name: String, ) -> Result<(), ErrorMessage> { - let mut pm = match schedule { - JunoSchedule::None => hercules_opt::pass::PassManager::new(module), - _ => todo!(), - /* - JunoSchedule::DefaultSchedule => { - let mut pm = hercules_opt::pass::PassManager::new(module); - pm.make_plans(); - pm - } - JunoSchedule::Schedule(file) => { - let Some(func_info) = func_info else { - return Err(ErrorMessage::SchedulingError( - "Cannot schedule, no function information provided".to_string(), - )); - }; - - match juno_scheduler::schedule(&module, func_info, file) { - Ok(plans) => { - let mut pm = hercules_opt::pass::PassManager::new(module); - pm.set_plans(plans); - pm - } - Err(msg) => { - return Err(ErrorMessage::SchedulingError(msg)); - } - } - } - */ - }; - if verify.verify() || verify.verify_all() { - pm.add_pass(hercules_opt::pass::Pass::Verify); - } - add_verified_pass!(pm, verify, GVN); - add_verified_pass!(pm, verify, PhiElim); - add_pass!(pm, verify, DCE); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - add_pass!(pm, verify, Inline); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - // Inlining may make some functions uncalled, so run this pass. - // In general, this should always be run after inlining. - add_pass!(pm, verify, DeleteUncalled); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - // Run SROA pretty early (though after inlining which can make SROA more effective) so that - // CCP, GVN, etc. can work on the result of SROA - add_pass!(pm, verify, InterproceduralSROA); - add_pass!(pm, verify, SROA); - // We run phi-elim again because SROA can introduce new phis that might be able to be - // simplified - add_verified_pass!(pm, verify, PhiElim); - add_pass!(pm, verify, DCE); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - add_pass!(pm, verify, CCP); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, GVN); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, WritePredication); - add_pass!(pm, verify, PhiElim); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, SLF); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, Predication); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, CCP); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, GVN); - add_pass!(pm, verify, DCE); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - //add_pass!(pm, verify, Forkify); - //add_pass!(pm, verify, ForkGuardElim); - add_verified_pass!(pm, verify, DCE); - add_pass!(pm, verify, ForkSplit); - add_pass!(pm, verify, Unforkify); - add_pass!(pm, verify, GVN); - add_verified_pass!(pm, verify, DCE); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, Outline); - add_pass!(pm, verify, InterproceduralSROA); - add_pass!(pm, verify, SROA); - add_pass!(pm, verify, InferSchedules); - add_verified_pass!(pm, verify, DCE); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - - add_pass!(pm, verify, GCM); - add_verified_pass!(pm, verify, DCE); - add_pass!(pm, verify, FloatCollections); - add_pass!(pm, verify, GCM); - pm.add_pass(hercules_opt::pass::Pass::Codegen(output_dir, module_name)); - pm.run_passes(); - - Ok(()) + schedule_hercules(module, schedule, output_dir, module_name) + .map_err(|s| ErrorMessage::SchedulingError(s)) } diff --git a/juno_frontend/src/main.rs b/juno_frontend/src/main.rs index d98c1e29e5aedeac5ce2a4e791c669a7af34c320..6c2722fc799fc09677ef8815b2983a46c01c3e09 100644 --- a/juno_frontend/src/main.rs +++ b/juno_frontend/src/main.rs @@ -1,52 +1,21 @@ use juno_compiler::*; -use clap::{ArgGroup, Parser}; +use clap::Parser; use std::path::PathBuf; #[derive(Parser)] #[clap(author, version, about, long_about = None)] -#[clap(group( - ArgGroup::new("scheduling") - .required(false) - .args(&["schedule", "default_schedule", "no_schedule"])))] struct Cli { src_file: String, - #[clap(short, long)] - verify: bool, - #[clap(long = "verify-all")] - verify_all: bool, - #[arg(short, long = "x-dot")] - x_dot: bool, #[clap(short, long, value_name = "SCHEDULE")] schedule: Option<String>, - #[clap(short, long = "default-schedule")] - default_schedule: bool, - #[clap(short, long)] - no_schedule: bool, #[arg(short, long = "output-dir", value_name = "OUTPUT DIR")] output_dir: Option<String>, } fn main() { let args = Cli::parse(); - let verify = if args.verify_all { - JunoVerify::AllPasses - } else if args.verify { - JunoVerify::JunoOpts - } else { - JunoVerify::None - }; - let schedule = match args.schedule { - Some(file) => JunoSchedule::Schedule(file), - None => { - if args.default_schedule { - JunoSchedule::DefaultSchedule - } else { - JunoSchedule::None - } - } - }; let output_dir = match args.output_dir { Some(dir) => dir, None => PathBuf::from(args.src_file.clone()) @@ -56,7 +25,7 @@ fn main() { .unwrap() .to_string(), }; - match compile(args.src_file, verify, args.x_dot, schedule, output_dir) { + match compile(args.src_file, args.schedule, output_dir) { Ok(()) => {} Err(errs) => { eprintln!("{}", errs); diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index 660d8afe35a803e7395d67871236e0f596b09a87..2fe4bf88278c2478749a0ed67451d5371c32d847 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -10,7 +10,6 @@ use lrpar::NonStreamingLexer; use ordered_float::OrderedFloat; use crate::dynconst::DynConst; -use crate::env::Env; use crate::intrinsics; use crate::locs::{span_to_loc, Location}; use crate::parser; @@ -18,6 +17,9 @@ use crate::parser::*; use crate::types; use crate::types::{Either, Type, TypeSolver}; +use juno_utils::env::Env; +use juno_utils::stringtab::StringTable; + // Definitions and data structures for semantic analysis // Entities in the environment @@ -76,72 +78,6 @@ impl PartialEq for Literal { impl Eq for Literal {} -// Map strings to unique identifiers and counts uids -struct StringTable { - count: usize, - string_to_index: HashMap<String, usize>, - index_to_string: HashMap<usize, String>, -} -impl StringTable { - fn new() -> StringTable { - StringTable { - count: 0, - string_to_index: HashMap::new(), - index_to_string: HashMap::new(), - } - } - - // Produce the UID for a string - fn lookup_string(&mut self, s: String) -> usize { - match self.string_to_index.get(&s) { - Some(n) => *n, - None => { - let n = self.count; - self.count += 1; - self.string_to_index.insert(s.clone(), n); - self.index_to_string.insert(n, s); - n - } - } - } - - // Identify the string corresponding to a UID - fn lookup_id(&self, n: usize) -> Option<String> { - self.index_to_string.get(&n).cloned() - } -} - -// Maps label names to unique identifiers (numbered 0..n for each function) -// Also tracks the map from function names to their numbers -struct LabelSet { - count: usize, - string_to_index: HashMap<String, usize>, -} -impl LabelSet { - fn new() -> LabelSet { - // Label number 0 is reserved to be the "root" label in code generation - LabelSet { - count: 1, - string_to_index: HashMap::from([("<root>".to_string(), 0)]), - } - } - - // Inserts a string if it is not already contained in this set, if it is - // contained does nothing and returns the label back wrapped in an error, - // otherwise inserts and returns the new label's id wrapped in Ok - fn insert_new(&mut self, label: String) -> Result<usize, String> { - match self.string_to_index.get(&label) { - Some(_) => Err(label), - None => { - let uid = self.count; - self.count += 1; - self.string_to_index.insert(label, uid); - Ok(uid) - } - } - } -} - // Convert spans into uids in the String Table fn intern_id( n: &Span, @@ -252,7 +188,11 @@ fn append_errors3<A, B, C>( // Normalized AST forms after semantic analysis // These include type information at all expression nodes, and remove names and locations -pub type Prg = (TypeSolver, Vec<Function>); +pub struct Prg { + pub types: TypeSolver, + pub funcs: Vec<Function>, + pub labels: StringTable, +} // The function stores information for code-generation. The type information therefore is not the // type information that is needed for type checking code that uses this function. @@ -263,8 +203,6 @@ pub struct Function { pub num_type_args: usize, pub arguments: Vec<(usize, Type)>, pub return_type: Type, - pub num_labels: usize, - pub label_map: HashMap<String, usize>, pub body: Stmt, pub entry: bool, } @@ -304,7 +242,6 @@ pub enum Stmt { ContinueStmt {}, BlockStmt { body: Vec<Stmt>, - label_last: bool, }, ExprStmt { expr: Expr, @@ -583,6 +520,7 @@ fn analyze_program( lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, ) -> Result<Prg, ErrorMessages> { let mut stringtab = StringTable::new(); + let mut labels = StringTable::new(); let mut env: Env<usize, Entity> = Env::new(); let mut types = TypeSolver::new(); @@ -889,9 +827,6 @@ fn analyze_program( } } - // Create a set of the labels in this function - let mut labels = LabelSet::new(); - // Finally, we have a properly built environment and we can // start processing the body let (mut body, end_reachable) = process_stmt( @@ -927,7 +862,6 @@ fn analyze_program( &mut types, ), ], - label_last: false, }; } else { Err(singleton_error(ErrorMessage::SemanticError( @@ -965,8 +899,6 @@ fn analyze_program( .map(|(v, n)| (*n, v.1)) .collect::<Vec<_>>(), return_type: pure_return_type, - num_labels: labels.count, - label_map: labels.string_to_index, body: body, entry: entry, }); @@ -998,7 +930,11 @@ fn analyze_program( } } - Ok((types, res)) + Ok(Prg { + types, + funcs: res, + labels, + }) } fn process_type_def( @@ -1686,7 +1622,7 @@ fn process_stmt( return_type: Type, inout_vars: &Vec<usize>, inout_types: &Vec<Type>, - labels: &mut LabelSet, + labels: &mut StringTable, ) -> Result<(Stmt, bool), ErrorMessages> { match stmt { parser::Stmt::LetStmt { @@ -2286,9 +2222,6 @@ fn process_stmt( body: Box::new(body), }, ], - // A label applied to this loop should be applied to the - // loop, not the initialization - label_last: true, }, true, )) @@ -2438,13 +2371,7 @@ fn process_stmt( if !errors.is_empty() { Err(errors) } else { - Ok(( - Stmt::BlockStmt { - body: res, - label_last: false, - }, - reachable, - )) + Ok((Stmt::BlockStmt { body: res }, reachable)) } } parser::Stmt::CallStmt { @@ -2480,15 +2407,8 @@ fn process_stmt( label, stmt, } => { - let label_str = lexer.span_str(label).to_string(); - - let label_id = match labels.insert_new(label_str) { - Err(label_str) => Err(singleton_error(ErrorMessage::SemanticError( - span_to_loc(label, lexer), - format!("Label {} already exists", label_str), - )))?, - Ok(id) => id, - }; + let label_str = lexer.span_str(label)[1..].to_string(); + let label_id = labels.lookup_string(label_str); let (body, reach_end) = process_stmt( *stmt, diff --git a/juno_frontend/src/ssa.rs b/juno_frontend/src/ssa.rs index 578f7a9af120644d265d5c5b679fb97b35806d36..7076d62259e04cf17f2e629825bd9fd7b7f05747 100644 --- a/juno_frontend/src/ssa.rs +++ b/juno_frontend/src/ssa.rs @@ -7,9 +7,9 @@ use std::collections::{HashMap, HashSet}; +use crate::labeled_builder::LabeledBuilder; use hercules_ir::build::*; use hercules_ir::ir::*; -use crate::labeled_builder::LabeledBuilder; pub struct SSA { // Map from variable (usize) to build (NodeID) to definition (NodeID) diff --git a/juno_samples/antideps/src/antideps.jn b/juno_samples/antideps/src/antideps.jn index 6886741b6de9cde99d0372b5b9e885e08f9b95e0..f40640d2d7fdec15f3b4ec18d753b01ae145cb23 100644 --- a/juno_samples/antideps/src/antideps.jn +++ b/juno_samples/antideps/src/antideps.jn @@ -110,3 +110,28 @@ fn very_complex_antideps(x: usize) -> usize { } return arr4[w] + w; } + +#[entry] +fn read_chains(input : i32) -> i32 { + let arrs : (i32[2], i32[2]); + let sub = arrs.0; + sub[1] = input + 7; + arrs.0[1] = input + 3; + let result = sub[1] + arrs.0[1]; + sub[1] = 99; + arrs.0[1] = 99; + return result + sub[1] - arrs.0[1]; +} + +#[entry] +fn array_of_structs(input: i32) -> i32 { + let arr : (i32, i32)[2]; + let sub = arr[0]; + sub.1 = input + 7; + arr[0] = sub; + arr[0].1 = input + 3; + let result = sub.1 + arr[0].1; + sub.1 = 99; + arr[0].1 = 99; + return result + sub.1 - arr[0].1; +} diff --git a/juno_samples/antideps/src/main.rs b/juno_samples/antideps/src/main.rs index 0b065cbaa6e6cffcaf9ff7b3fbc5a2c882dc7248..2f1e8efc1de92ebaa9e95a3d08c32742ad2c76e1 100644 --- a/juno_samples/antideps/src/main.rs +++ b/juno_samples/antideps/src/main.rs @@ -23,6 +23,14 @@ fn main() { let output = very_complex_antideps(3).await; println!("{}", output); assert_eq!(output, 144); + + let output = read_chains(2).await; + println!("{}", output); + assert_eq!(output, 14); + + let output = array_of_structs(2).await; + println!("{}", output); + assert_eq!(output, 14); }); } diff --git a/juno_samples/casts_and_intrinsics/Cargo.toml b/juno_samples/casts_and_intrinsics/Cargo.toml index 9fac18b77db9e01bb55ad491a93f1558a5339c54..b2e7c815b0e93ba1741503cb29fe7528b78b5996 100644 --- a/juno_samples/casts_and_intrinsics/Cargo.toml +++ b/juno_samples/casts_and_intrinsics/Cargo.toml @@ -16,5 +16,6 @@ juno_build = { path = "../../juno_build" } [dependencies] juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } with_builtin_macros = "0.1.0" async-std = "*" diff --git a/juno_samples/schedule_test/Cargo.toml b/juno_samples/schedule_test/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..be5d949bf1959d48c3463717f28b9fd186e05170 --- /dev/null +++ b/juno_samples/schedule_test/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "juno_schedule_test" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_schedule_test" +path = "src/main.rs" + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" +rand = "*" diff --git a/juno_samples/schedule_test/build.rs b/juno_samples/schedule_test/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..4a4282473e87d6c24e12b5e3d59521ee8c99141e --- /dev/null +++ b/juno_samples/schedule_test/build.rs @@ -0,0 +1,11 @@ +use juno_build::JunoCompiler; + +fn main() { + JunoCompiler::new() + .file_in_src("code.jn") + .unwrap() + .schedule_in_src("sched.sch") + .unwrap() + .build() + .unwrap(); +} diff --git a/juno_samples/schedule_test/src/code.jn b/juno_samples/schedule_test/src/code.jn new file mode 100644 index 0000000000000000000000000000000000000000..5bb923bf6c68c92fcdb8869333945fcbd401e545 --- /dev/null +++ b/juno_samples/schedule_test/src/code.jn @@ -0,0 +1,30 @@ +#[entry] +fn test<n, m, k: usize>(a: i32[n, m], b: i32[m, k], c: i32[k]) -> i32[n] { + let prod: i32[n, k]; + + @outer for i = 0 to n { + @middle for j = 0 to k { + let val = 0; + + @inner for k = 0 to m { + val += a[i, k] * b[k, j]; + } + + prod[i, j] = val; + } + } + + let res: i32[n]; + + @row for i = 0 to n { + let val = 0; + + @col for j = 0 to k { + val += prod[i, j] * c[j]; + } + + res[i] = val; + } + + return res; +} diff --git a/juno_samples/schedule_test/src/main.rs b/juno_samples/schedule_test/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..a64cd16f8e8b7b50ae11199b128ebe12b5d08412 --- /dev/null +++ b/juno_samples/schedule_test/src/main.rs @@ -0,0 +1,42 @@ +#![feature(box_as_ptr, let_chains)] + +use rand::random; + +use hercules_rt::HerculesBox; + +juno_build::juno!("code"); + +fn main() { + async_std::task::block_on(async { + const N: usize = 256; + const M: usize = 64; + const K: usize = 128; + let a: Box<[i32]> = (0..N * M).map(|_| random::<i32>() % 100).collect(); + let b: Box<[i32]> = (0..M * K).map(|_| random::<i32>() % 100).collect(); + let c: Box<[i32]> = (0..K).map(|_| random::<i32>() % 100).collect(); + + let mut correct_res: Box<[i32]> = (0..N).map(|_| 0).collect(); + for i in 0..N { + for j in 0..K { + let mut res = 0; + for k in 0..M { + res += a[i * M + k] * b[k * K + j]; + } + correct_res[i] += c[j] * res; + } + } + + let mut res = { + let a = HerculesBox::from_slice(&a); + let b = HerculesBox::from_slice(&b); + let c = HerculesBox::from_slice(&c); + test(N as u64, M as u64, K as u64, a, b, c).await + }; + assert_eq!(res.as_slice::<i32>(), &*correct_res); + }); +} + +#[test] +fn schedule_test() { + main(); +} diff --git a/juno_samples/schedule_test/src/sched.sch b/juno_samples/schedule_test/src/sched.sch new file mode 100644 index 0000000000000000000000000000000000000000..960dca22481e06f7aaf3875477bc6ed5595159e7 --- /dev/null +++ b/juno_samples/schedule_test/src/sched.sch @@ -0,0 +1,47 @@ +macro juno-setup!(X) { + //gvn(X); + phi-elim(X); + dce(X); + lift-dc-math(X); +} +macro codegen-prep!(X) { + infer-schedules(X); + dce(X); + gcm(X); + dce(X); + phi-elim(X); + float-collections(X); + gcm(X); +} + + +juno-setup!(*); + +let first = outline(test@outer); +let second = outline(test@row); + +// We can use the functions produced by outlining in our schedules +gvn(first, second, test); + +ip-sroa(*); +sroa(*); + +// We can evaluate expressions using labels and save them for later use +let inner = first@inner; + +// A fixpoint can run a (series) of passes until no more changes are made +// (though some passes seem to make edits even if there are no real changes, +// so this is fragile). +// We could just let it run until it converges but can also tell it to panic +// if it hasn't converged after a number of iterations (like here) tell it to +// just stop after a certain number of iterations (stop after #) or to print +// the iteration number (print iter) +fixpoint panic after 2 { + phi-elim(*); +} + +host(*); +cpu(first, second); + +codegen-prep!(*); +//xdot[true](*); diff --git a/juno_scheduler/Cargo.toml b/juno_scheduler/Cargo.toml index 49e5f4a31b94834489aaf551a2a24568292aa305..1c837d4a32764abb179b6a05fd5225b808ea764a 100644 --- a/juno_scheduler/Cargo.toml +++ b/juno_scheduler/Cargo.toml @@ -13,4 +13,8 @@ lrpar = "0.13" cfgrammar = "0.13" lrlex = "0.13" lrpar = "0.13" +tempfile = "*" +hercules_cg = { path = "../hercules_cg" } hercules_ir = { path = "../hercules_ir" } +hercules_opt = { path = "../hercules_opt" } +juno_utils = { path = "../juno_utils" } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs new file mode 100644 index 0000000000000000000000000000000000000000..fc3c1279606dd33c4fe79e1b03c9ec345cbfc61b --- /dev/null +++ b/juno_scheduler/src/compile.rs @@ -0,0 +1,496 @@ +use crate::ir; +use crate::parser; + +use juno_utils::env::Env; +use juno_utils::stringtab::StringTable; + +extern crate hercules_ir; +use self::hercules_ir::ir::{Device, Schedule}; + +use lrlex::DefaultLexerTypes; +use lrpar::NonStreamingLexer; +use lrpar::Span; + +use std::fmt; +use std::str::FromStr; + +type Location = ((usize, usize), (usize, usize)); + +pub enum ScheduleCompilerError { + UndefinedMacro(String, Location), + NoSuchPass(String, Location), + IncorrectArguments { + expected: usize, + actual: usize, + loc: Location, + }, +} + +impl fmt::Display for ScheduleCompilerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ScheduleCompilerError::UndefinedMacro(name, loc) => write!( + f, + "({}, {}) -- ({}, {}): Undefined macro '{}'", + loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name + ), + ScheduleCompilerError::NoSuchPass(name, loc) => write!( + f, + "({}, {}) -- ({}, {}): Undefined pass '{}'", + loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name + ), + ScheduleCompilerError::IncorrectArguments { + expected, + actual, + loc, + } => write!( + f, + "({}, {}) -- ({}, {}): Expected {} arguments, found {}", + loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, expected, actual + ), + } + } +} + +pub fn compile_schedule( + sched: parser::OperationList, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, +) -> Result<ir::ScheduleStmt, ScheduleCompilerError> { + let mut macrostab = StringTable::new(); + let mut macros = Env::new(); + + macros.open_scope(); + + Ok(ir::ScheduleStmt::Block { + body: compile_ops_as_block(sched, lexer, &mut macrostab, &mut macros)?, + }) +} + +#[derive(Debug, Clone)] +struct MacroInfo { + params: Vec<String>, + selection_name: String, + def: ir::ScheduleExp, +} + +enum Appliable { + Pass(ir::Pass), + Schedule(Schedule), + Device(Device), +} + +impl Appliable { + fn num_args(&self) -> usize { + match self { + Appliable::Pass(pass) => pass.num_args(), + // Schedules and devices do not arguments (at the moment) + _ => 0, + } + } +} + +impl FromStr for Appliable { + type Err = String; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s { + "auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)), + "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)), + "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)), + "dce" => Ok(Appliable::Pass(ir::Pass::DCE)), + "delete-uncalled" => Ok(Appliable::Pass(ir::Pass::DeleteUncalled)), + "float-collections" | "collections" => Ok(Appliable::Pass(ir::Pass::FloatCollections)), + "fork-guard-elim" => Ok(Appliable::Pass(ir::Pass::ForkGuardElim)), + "fork-split" => Ok(Appliable::Pass(ir::Pass::ForkSplit)), + "forkify" => Ok(Appliable::Pass(ir::Pass::Forkify)), + "gcm" | "bbs" => Ok(Appliable::Pass(ir::Pass::GCM)), + "gvn" => Ok(Appliable::Pass(ir::Pass::GVN)), + "infer-schedules" => Ok(Appliable::Pass(ir::Pass::InferSchedules)), + "inline" => Ok(Appliable::Pass(ir::Pass::Inline)), + "ip-sroa" | "interprocedural-sroa" => { + Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) + } + "lift-dc-math" => Ok(Appliable::Pass(ir::Pass::LiftDCMath)), + "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), + "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), + "predication" => Ok(Appliable::Pass(ir::Pass::Predication)), + "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), + "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), + "unforkify" => Ok(Appliable::Pass(ir::Pass::Unforkify)), + "verify" => Ok(Appliable::Pass(ir::Pass::Verify)), + "xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)), + + "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)), + "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)), + "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)), + + "associative" => Ok(Appliable::Schedule(Schedule::TightAssociative)), + "parallel-fork" => Ok(Appliable::Schedule(Schedule::ParallelFork)), + "parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)), + "vectorize" => Ok(Appliable::Schedule(Schedule::Vectorizable)), + + _ => Err(s.to_string()), + } + } +} + +fn compile_ops_as_block( + sched: parser::OperationList, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> { + match sched { + parser::OperationList::NilStmt() => Ok(vec![]), + parser::OperationList::FinalExpr(expr) => { + Ok(vec![compile_exp_as_stmt(expr, lexer, macrostab, macros)?]) + } + parser::OperationList::ConsStmt(stmt, ops) => { + let mut res = compile_stmt(stmt, lexer, macrostab, macros)?; + res.extend(compile_ops_as_block(*ops, lexer, macrostab, macros)?); + Ok(res) + } + } +} + +fn compile_stmt( + stmt: parser::Stmt, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> { + match stmt { + parser::Stmt::LetStmt { span: _, var, expr } => { + let var = lexer.span_str(var).to_string(); + Ok(vec![ir::ScheduleStmt::Let { + var, + exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, + }]) + } + parser::Stmt::AssignStmt { span: _, var, rhs } => { + let var = lexer.span_str(var).to_string(); + Ok(vec![ir::ScheduleStmt::Assign { + var, + exp: compile_exp_as_expr(rhs, lexer, macrostab, macros)?, + }]) + } + parser::Stmt::ExprStmt { span: _, exp } => { + Ok(vec![compile_exp_as_stmt(exp, lexer, macrostab, macros)?]) + } + parser::Stmt::Fixpoint { + span: _, + limit, + body, + } => { + let limit = match limit { + parser::FixpointLimit::NoLimit { .. } => ir::FixpointLimit::NoLimit(), + parser::FixpointLimit::StopAfter { span: _, limit } => { + ir::FixpointLimit::StopAfter( + lexer + .span_str(limit) + .parse() + .expect("Parsing ensures integer"), + ) + } + parser::FixpointLimit::PanicAfter { span: _, limit } => { + ir::FixpointLimit::PanicAfter( + lexer + .span_str(limit) + .parse() + .expect("Parsing ensures integer"), + ) + } + parser::FixpointLimit::PrintIter { .. } => ir::FixpointLimit::PrintIter(), + }; + + macros.open_scope(); + let body = compile_ops_as_block(*body, lexer, macrostab, macros); + macros.close_scope(); + + Ok(vec![ir::ScheduleStmt::Fixpoint { + body: Box::new(ir::ScheduleStmt::Block { body: body? }), + limit, + }]) + } + parser::Stmt::MacroDecl { span: _, def } => { + let parser::MacroDecl { + name, + params, + selection_name, + def, + } = def; + let name = lexer.span_str(name).to_string(); + let macro_id = macrostab.lookup_string(name); + + let selection_name = lexer.span_str(selection_name).to_string(); + + let params = params + .into_iter() + .map(|s| lexer.span_str(s).to_string()) + .collect(); + + let def = compile_macro_def(*def, params, selection_name, lexer, macrostab, macros)?; + macros.insert(macro_id, def); + + Ok(vec![]) + } + } +} + +fn compile_exp_as_stmt( + expr: parser::Expr, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ir::ScheduleStmt, ScheduleCompilerError> { + match compile_expr(expr, lexer, macrostab, macros)? { + ExprResult::Expr(exp) => Ok(ir::ScheduleStmt::Let { + var: "_".to_string(), + exp, + }), + ExprResult::Stmt(stm) => Ok(stm), + } +} + +fn compile_exp_as_expr( + expr: parser::Expr, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ir::ScheduleExp, ScheduleCompilerError> { + match compile_expr(expr, lexer, macrostab, macros)? { + ExprResult::Expr(exp) => Ok(exp), + ExprResult::Stmt(stm) => Ok(ir::ScheduleExp::Block { + body: vec![stm], + res: Box::new(ir::ScheduleExp::Record { fields: vec![] }), + }), + } +} + +enum ExprResult { + Expr(ir::ScheduleExp), + Stmt(ir::ScheduleStmt), +} + +fn compile_expr( + expr: parser::Expr, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ExprResult, ScheduleCompilerError> { + match expr { + parser::Expr::Function { + span, + name, + args, + selection, + } => { + let func: Appliable = lexer + .span_str(name) + .to_lowercase() + .parse() + .map_err(|s| ScheduleCompilerError::NoSuchPass(s, lexer.line_col(name)))?; + + if args.len() != func.num_args() { + return Err(ScheduleCompilerError::IncorrectArguments { + expected: func.num_args(), + actual: args.len(), + loc: lexer.line_col(span), + }); + } + + let mut arg_vals = vec![]; + for arg in args { + arg_vals.push(compile_exp_as_expr(arg, lexer, macrostab, macros)?); + } + + let selection = compile_selector(selection, lexer, macrostab, macros)?; + + match func { + Appliable::Pass(pass) => Ok(ExprResult::Expr(ir::ScheduleExp::RunPass { + pass, + args: arg_vals, + on: selection, + })), + Appliable::Schedule(sched) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddSchedule { + sched, + on: selection, + })), + Appliable::Device(device) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddDevice { + device, + on: selection, + })), + } + } + parser::Expr::Macro { + span, + name, + args, + selection, + } => { + let name_str = lexer.span_str(name).to_string(); + let macro_id = macrostab.lookup_string(name_str.clone()); + let Some(macro_def) = macros.lookup(¯o_id) else { + return Err(ScheduleCompilerError::UndefinedMacro( + name_str, + lexer.line_col(name), + )); + }; + let macro_def: MacroInfo = macro_def.clone(); + let MacroInfo { + params, + selection_name, + def, + } = macro_def; + + if args.len() != params.len() { + return Err(ScheduleCompilerError::IncorrectArguments { + expected: params.len(), + actual: args.len(), + loc: lexer.line_col(span), + }); + } + + // To initialize the macro's arguments, we have to do this in two steps, we first + // evaluate all of the arguments and store them into new variables, using names that + // cannot conflict with other values in the program and then we assign those variables + // to the macro's parameters; this avoids any shadowing issues, for instance: + // macro![3, x] where macro!'s arguments are named x and y becomes + // let #0 = 3; let #1 = x; let x = #0; let y = #1; + // which has the desired semantics, as opposed to + // let x = 3; let y = x; + let mut arg_eval = vec![]; + let mut arg_setters = vec![]; + + for (i, (exp, var)) in args.into_iter().zip(params.into_iter()).enumerate() { + let tmp = format!("#{}", i); + arg_eval.push(ir::ScheduleStmt::Let { + var: tmp.clone(), + exp: compile_exp_as_expr(exp, lexer, macrostab, macros)?, + }); + arg_setters.push(ir::ScheduleStmt::Let { + var, + exp: ir::ScheduleExp::Variable { var: tmp }, + }); + } + + // Set the selection + arg_eval.push(ir::ScheduleStmt::Let { + var: selection_name, + exp: ir::ScheduleExp::Selection { + selection: compile_selector(selection, lexer, macrostab, macros)?, + }, + }); + + // Combine the evaluation and initialization code + arg_eval.extend(arg_setters); + + Ok(ExprResult::Expr(ir::ScheduleExp::Block { + body: arg_eval, + res: Box::new(def), + })) + } + parser::Expr::Variable { span } => { + let var = lexer.span_str(span).to_string(); + Ok(ExprResult::Expr(ir::ScheduleExp::Variable { var })) + } + parser::Expr::Integer { span } => { + let val: usize = lexer.span_str(span).parse().expect("Parsing"); + Ok(ExprResult::Expr(ir::ScheduleExp::Integer { val })) + } + parser::Expr::Boolean { span: _, val } => { + Ok(ExprResult::Expr(ir::ScheduleExp::Boolean { val })) + } + parser::Expr::Field { + span: _, + lhs, + field, + } => { + let field = lexer.span_str(field).to_string(); + let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; + Ok(ExprResult::Expr(ir::ScheduleExp::Field { + collect: Box::new(lhs), + field, + })) + } + parser::Expr::BlockExpr { span: _, body } => { + compile_ops_as_expr(*body, lexer, macrostab, macros) + } + parser::Expr::Record { span: _, fields } => { + let mut result = vec![]; + for (name, expr) in fields { + let name = lexer.span_str(name).to_string(); + let expr = compile_exp_as_expr(expr, lexer, macrostab, macros)?; + result.push((name, expr)); + } + Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result })) + } + } +} + +fn compile_ops_as_expr( + mut sched: parser::OperationList, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ExprResult, ScheduleCompilerError> { + let mut body = vec![]; + loop { + match sched { + parser::OperationList::NilStmt() => { + return Ok(ExprResult::Stmt(ir::ScheduleStmt::Block { body })); + } + parser::OperationList::FinalExpr(expr) => { + return Ok(ExprResult::Expr(ir::ScheduleExp::Block { + body, + res: Box::new(compile_exp_as_expr(expr, lexer, macrostab, macros)?), + })); + } + parser::OperationList::ConsStmt(stmt, ops) => { + body.extend(compile_stmt(stmt, lexer, macrostab, macros)?); + sched = *ops; + } + } + } +} + +fn compile_selector( + sel: parser::Selector, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ir::Selector, ScheduleCompilerError> { + match sel { + parser::Selector::SelectAll { span: _ } => Ok(ir::Selector::Everything()), + parser::Selector::SelectExprs { span: _, exprs } => { + let mut res = vec![]; + for exp in exprs { + res.push(compile_exp_as_expr(exp, lexer, macrostab, macros)?); + } + Ok(ir::Selector::Selection(res)) + } + } +} + +fn compile_macro_def( + body: parser::OperationList, + params: Vec<String>, + selection_name: String, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<MacroInfo, ScheduleCompilerError> { + // FIXME: The body should be checked in an environment that prohibits running anything on + // everything (*) and check that only local variables/parameters are used + Ok(MacroInfo { + params, + selection_name, + def: match compile_ops_as_expr(body, lexer, macrostab, macros)? { + ExprResult::Expr(expr) => expr, + ExprResult::Stmt(stmt) => ir::ScheduleExp::Block { + body: vec![stmt], + res: Box::new(ir::ScheduleExp::Record { fields: vec![] }), + }, + }, + }) +} diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs new file mode 100644 index 0000000000000000000000000000000000000000..46b51b43b8040105f92d452df8f939522df40d2c --- /dev/null +++ b/juno_scheduler/src/default.rs @@ -0,0 +1,85 @@ +use crate::ir::*; + +macro_rules! pass { + ($p:ident) => { + ScheduleStmt::Let { + var: String::from("_"), + exp: ScheduleExp::RunPass { + pass: Pass::$p, + args: vec![], + on: Selector::Everything(), + }, + } + }; +} + +macro_rules! default_schedule { + () => { + ScheduleStmt::Block { + body: vec![], + } + }; + ($($p:ident),+ $(,)?) => { + ScheduleStmt::Block { + body: vec![$(pass!($p)),+], + } + }; +} + +// Defualt schedule, which is used if no schedule is provided +pub fn default_schedule() -> ScheduleStmt { + default_schedule![ + GVN, + DCE, + PhiElim, + DCE, + CRC, + DCE, + SLF, + DCE, + Inline, + /*DeleteUncalled,*/ + InterproceduralSROA, + SROA, + PhiElim, + DCE, + CCP, + DCE, + GVN, + DCE, + WritePredication, + PhiElim, + DCE, + CRC, + DCE, + SLF, + DCE, + Predication, + DCE, + CCP, + DCE, + GVN, + DCE, + LiftDCMath, + DCE, + GVN, + DCE, + /*Forkify,*/ + /*ForkGuardElim,*/ + DCE, + ForkSplit, + Unforkify, + GVN, + DCE, + DCE, + AutoOutline, + InterproceduralSROA, + SROA, + InferSchedules, + DCE, + GCM, + DCE, + FloatCollections, + GCM, + ] +} diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs new file mode 100644 index 0000000000000000000000000000000000000000..381c3475e0b4f52285bf333f59cbe175d60fce60 --- /dev/null +++ b/juno_scheduler/src/ir.rs @@ -0,0 +1,115 @@ +extern crate hercules_ir; + +use self::hercules_ir::ir::{Device, Schedule}; + +#[derive(Debug, Copy, Clone)] +pub enum Pass { + AutoOutline, + CCP, + CRC, + DCE, + DeleteUncalled, + FloatCollections, + ForkGuardElim, + ForkSplit, + Forkify, + GCM, + GVN, + InferSchedules, + Inline, + InterproceduralSROA, + LiftDCMath, + Outline, + PhiElim, + Predication, + SLF, + SROA, + Unforkify, + WritePredication, + Verify, + Xdot, +} + +impl Pass { + pub fn num_args(&self) -> usize { + match self { + Pass::Xdot => 1, + _ => 0, + } + } +} + +#[derive(Debug, Clone)] +pub enum Selector { + Everything(), + Selection(Vec<ScheduleExp>), +} + +#[derive(Debug, Clone)] +pub enum ScheduleExp { + Variable { + var: String, + }, + Integer { + val: usize, + }, + Boolean { + val: bool, + }, + Field { + collect: Box<ScheduleExp>, + field: String, + }, + RunPass { + pass: Pass, + args: Vec<ScheduleExp>, + on: Selector, + }, + Record { + fields: Vec<(String, ScheduleExp)>, + }, + Block { + body: Vec<ScheduleStmt>, + res: Box<ScheduleExp>, + }, + // This is used to "box" a selection by evaluating it at one point and then + // allowing it to be used as a selector later on + Selection { + selection: Selector, + }, +} + +#[derive(Debug, Copy, Clone)] +pub enum FixpointLimit { + NoLimit(), + PrintIter(), + StopAfter(usize), + PanicAfter(usize), +} + +#[derive(Debug, Clone)] +pub enum ScheduleStmt { + Fixpoint { + body: Box<ScheduleStmt>, + limit: FixpointLimit, + }, + Block { + body: Vec<ScheduleStmt>, + }, + Let { + var: String, + exp: ScheduleExp, + }, + Assign { + var: String, + exp: ScheduleExp, + }, + AddSchedule { + sched: Schedule, + on: Selector, + }, + AddDevice { + device: Device, + on: Selector, + }, +} diff --git a/juno_scheduler/src/labels.rs b/juno_scheduler/src/labels.rs new file mode 100644 index 0000000000000000000000000000000000000000..6690e17a50a3bed1bbf63ae11f5d05be43c0239f --- /dev/null +++ b/juno_scheduler/src/labels.rs @@ -0,0 +1,61 @@ +use hercules_ir::ir::*; + +use std::collections::HashMap; + +#[derive(Debug, Copy, Clone)] +pub struct LabelInfo { + pub func: FunctionID, + pub label: LabelID, +} + +#[derive(Debug, Copy, Clone)] +pub struct JunoFunctionID { + pub idx: usize, +} + +impl JunoFunctionID { + pub fn new(idx: usize) -> Self { + JunoFunctionID { idx } + } +} + +// From the Juno frontend we collect certain information we need for scheduling it, in particular a +// map from the function names to a "JunoFunctionID" which can be used to lookup the Hercules +// FunctionIDs that are definitions of that function (it may be multiple since the same Juno +// function may be instantiated with multiple difference type variables). +#[derive(Debug, Clone)] +pub struct JunoInfo { + pub func_names: HashMap<String, JunoFunctionID>, + pub func_info: JunoFunctions, +} + +#[derive(Debug, Clone)] +pub struct JunoFunctions { + pub func_ids: Vec<Vec<FunctionID>>, +} + +impl JunoInfo { + pub fn new<I>(funcs: I) -> Self + where + I: Iterator<Item = String>, + { + let mut func_names = HashMap::new(); + let mut func_ids = vec![]; + + for (idx, name) in funcs.enumerate() { + func_names.insert(name, JunoFunctionID::new(idx)); + func_ids.push(vec![]); + } + + JunoInfo { + func_names, + func_info: JunoFunctions { func_ids }, + } + } +} + +impl JunoFunctions { + pub fn get_function(&self, id: JunoFunctionID) -> &Vec<FunctionID> { + &self.func_ids[id.idx] + } +} diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l index e6526c74ffdbcd2db03ec1b27f9674ff9fb70cc3..9d4c34bf8f7e50aaef048c7eba1a22c6f5872a96 100644 --- a/juno_scheduler/src/lang.l +++ b/juno_scheduler/src/lang.l @@ -13,15 +13,38 @@ [\n\r] ; , "," +; ";" += "=" +@ "@" +\* "*" +\. "." +apply "apply" +fixpoint "fixpoint" +let "let" +macro "macro_keyword" +on "on" +set "set" +target "target" + +true "true" +false "false" + +\( "(" +\) ")" +\< "<" +\> ">" +\[ "[" +\] "]" \{ "{" \} "}" -function "function" -on "on" -partition "partition" +panic[\t \n\r]+after "panic_after" +print[\t \n\r]+iter "print_iter" +stop[\t \n\r]+after "stop_after" -[a-zA-Z][a-zA-Z0-9_]* "ID" -@[a-zA-Z0-9_]+ "LABEL" +[a-zA-Z][a-zA-Z0-9_\-]*! "MACRO" +[a-zA-Z][a-zA-Z0-9_\-]* "ID" +[0-9]+ "INT" . "UNMATCHED" diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index e7d98dbaf025a7db36cc54df666598d18b92233a..9cb728428e08b0f77b77be526e2f3e0aa9c86a37 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -1,61 +1,112 @@ %start Schedule -%avoid_insert "ID" "LABEL" +%avoid_insert "ID" "INT" %expect-unused Unmatched 'UNMATCHED' %% -Schedule -> Vec<FuncDirectives> : FunctionList { $1 }; +Schedule -> OperationList + : { OperationList::NilStmt() } + | Expr { OperationList::FinalExpr($1) } + | Stmt Schedule { OperationList::ConsStmt($1, Box::new($2)) } + ; -FunctionList -> Vec<FuncDirectives> - : { vec![] } - | FunctionList FunctionDef { snoc($1, $2) } +Stmt -> Stmt + : 'let' 'ID' '=' Expr ';' + { Stmt::LetStmt { span: $span, var: span_of_tok($2), expr: $4 } } + | 'ID' '=' Expr ';' + { Stmt::AssignStmt { span: $span, var: span_of_tok($1), rhs: $3 } } + | Expr ';' + { Stmt::ExprStmt { span: $span, exp: $1 } } + | 'fixpoint' FixpointLimit '{' Schedule '}' + { Stmt::Fixpoint { span: $span, limit: $2, body: Box::new($4) } } + | MacroDecl + { Stmt::MacroDecl { span: $span, def: $1 } } ; -FunctionDef -> FuncDirectives - : 'function' Func '{' DirectiveList '}' - { FuncDirectives { span : $span, func : $2, directives : $4 }}; +FixpointLimit -> FixpointLimit + : { FixpointLimit::NoLimit { span: $span } } + | 'stop_after' 'INT' + { FixpointLimit::StopAfter { span: $span, limit: span_of_tok($2) } } + | 'panic_after' 'INT' + { FixpointLimit::PanicAfter { span: $span, limit: span_of_tok($2) } } + | 'print_iter' + { FixpointLimit::PrintIter { span: $span } } + ; -DirectiveList -> Vec<Directive> - : { vec![] } - | DirectiveList Directive { snoc($1, $2) } +Expr -> Expr + : 'ID' Args Selector + { Expr::Function { span: $span, name: span_of_tok($1), args: $2, selection: $3 } } + | 'MACRO' Args Selector + { Expr::Macro { span: $span, name: span_of_tok($1), args: $2, selection: $3 } } + | 'ID' + { Expr::Variable { span: span_of_tok($1) } } + | 'INT' + { Expr::Integer { span: span_of_tok($1) } } + | 'true' + { Expr::Boolean { span: $span, val: true } } + | 'false' + { Expr::Boolean { span: $span, val: false } } + | Expr '.' 'ID' + { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } + | Expr '@' 'ID' + { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } + | '(' Expr ')' + { $2 } + | '{' Schedule '}' + { Expr::BlockExpr { span: $span, body: Box::new($2) } } + | '<' Fields '>' + { Expr::Record { span: $span, fields: rev($2) } } ; -Directive -> Directive - : 'partition' Labels 'on' Devices - { Directive::Partition { span : $span, labels : $2, devices : $4 } } - | 'ID' Labels - { Directive::Schedule { span : $span, command : span_of_tok($1), args : $2 } } +Args -> Vec<Expr> + : { vec![] } + | '[' Exprs ']' { rev($2) } ; -Func -> Func - : 'ID' { Func { span : $span, name : $span, }} +Exprs -> Vec<Expr> + : { vec![] } + | Expr { vec![$1] } + | Expr ',' Exprs { snoc($1, $3) } ; -Labels -> Vec<Span> - : 'LABEL' { vec![span_of_tok($1)] } - | '{' LabelsRev '}' { rev($2) } +Fields -> Vec<(Span, Expr)> + : { vec![] } + | 'ID' '=' Expr { vec![(span_of_tok($1), $3)] } + | 'ID' '=' Expr ',' Fields { snoc((span_of_tok($1), $3), $5) } ; -LabelsRev -> Vec<Span> - : { vec![] } - | 'LABEL' { vec![span_of_tok($1)] } - | 'LABEL' ',' LabelsRev { cons(span_of_tok($1), $3) } + +Selector -> Selector + : '(' '*' ')' + { Selector::SelectAll { span: $span } } + | '(' Exprs ')' + { Selector::SelectExprs { span: $span, exprs: $2 } } ; -Devices -> Vec<Device> - : Device { vec![$1] } - | '{' SomeDevices '}' { $2 } +MacroDecl -> MacroDecl + : 'macro_keyword' 'MACRO' Params '(' 'ID' ')' MacroDef + { MacroDecl { + name: span_of_tok($2), + params: rev($3), + selection_name: span_of_tok($5), + def: Box::new($7), + } + } ; -SomeDevices -> Vec<Device> - : Device { vec![$1] } - | SomeDevices ',' Device { snoc($1, $3) } + +Params -> Vec<Span> + : { vec![] } + | '[' Ids ']' { $2 } ; -Device -> Device - : 'ID' - { Device { span : $span, name : span_of_tok($1), } } +Ids -> Vec<Span> + : { vec![] } + | 'ID' { vec![span_of_tok($1)] } + | 'ID' ',' Ids { snoc(span_of_tok($1), $3) } ; +MacroDef -> OperationList : '{' Schedule '}' { $2 }; + Unmatched -> () : 'UNMATCHED' {}; %% @@ -63,32 +114,60 @@ Unmatched -> () : 'UNMATCHED' {}; use cfgrammar::Span; use lrlex::DefaultLexeme; +fn snoc<T>(x: T, mut xs: Vec<T>) -> Vec<T> { + xs.push(x); + xs +} + +fn rev<T>(mut xs: Vec<T>) -> Vec<T> { + xs.reverse(); + xs +} + fn span_of_tok(t : Result<DefaultLexeme, DefaultLexeme>) -> Span { t.map_err(|_| ()).map(|l| l.span()).unwrap() } -fn cons<A>(hd : A, mut tl : Vec<A>) -> Vec<A> { - tl.push(hd); - tl +pub enum OperationList { + NilStmt(), + FinalExpr(Expr), + ConsStmt(Stmt, Box<OperationList>), } -fn snoc<A>(mut hd : Vec<A>, tl : A) -> Vec<A> { - hd.push(tl); - hd +pub enum Stmt { + LetStmt { span: Span, var: Span, expr: Expr }, + AssignStmt { span: Span, var: Span, rhs: Expr }, + ExprStmt { span: Span, exp: Expr }, + Fixpoint { span: Span, limit: FixpointLimit, body: Box<OperationList> }, + MacroDecl { span: Span, def: MacroDecl }, } -fn rev<A>(mut lst : Vec<A>) -> Vec<A> { - lst.reverse(); - lst +pub enum FixpointLimit { + NoLimit { span: Span }, + StopAfter { span: Span, limit: Span }, + PanicAfter { span: Span, limit: Span }, + PrintIter { span: Span }, } -pub struct Func { pub span : Span, pub name : Span, } -pub struct Device { pub span : Span, pub name : Span, } +pub enum Expr { + Function { span: Span, name: Span, args: Vec<Expr>, selection: Selector }, + Macro { span: Span, name: Span, args: Vec<Expr>, selection: Selector }, + Variable { span: Span }, + Integer { span: Span }, + Boolean { span: Span, val: bool }, + Field { span: Span, lhs: Box<Expr>, field: Span }, + BlockExpr { span: Span, body: Box<OperationList> }, + Record { span: Span, fields: Vec<(Span, Expr)> }, +} -pub struct FuncDirectives { pub span : Span, pub func : Func, - pub directives : Vec<Directive> } +pub enum Selector { + SelectAll { span: Span }, + SelectExprs { span: Span, exprs: Vec<Expr> }, +} -pub enum Directive { - Schedule { span : Span, command : Span, args : Vec<Span> }, - Partition { span : Span, labels : Vec<Span>, devices : Vec<Device> }, +pub struct MacroDecl { + pub name: Span, + pub params: Vec<Span>, + pub selection_name: Span, + pub def: Box<OperationList>, } diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index d515633eec468012ceee7305544b1859c1ebd621..1caafe4f0f3a0275130e47bb230f8a2047865cd6 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -6,47 +6,27 @@ use lrlex::DefaultLexerTypes; use lrpar::NonStreamingLexer; use hercules_ir::ir::*; +use juno_utils::env::Env; +use juno_utils::stringtab::StringTable; mod parser; use crate::parser::lexer; -// FunctionMap tracks a map from function numbers (produced by semantic analysis) to a tuple of -// - The map from label names to their numbers -// - The name of the function -// - A list of the instances of the function tracking -// + The instantiated type variables -// + The resulting FunctionID -// + A list of each label, tracking the structure at the label and a set of -// the labels which are its descendants -// + A map from NodeID to the innermost label containing it -// This is the core data structure provided from code generation, along with the -// module -pub type FunctionMap = HashMap< - usize, - ( - HashMap<String, usize>, - String, - Vec<( - Vec<TypeID>, - FunctionID, - Vec<(LabeledStructure, HashSet<usize>)>, - HashMap<NodeID, usize>, - )>, - ), ->; -// LabeledStructure represents structures from the source code and where they -// exist in the IR -#[derive(Copy, Clone)] -pub enum LabeledStructure { - Nothing(), - Expression(NodeID), - Loop(NodeID), // Header - Branch(NodeID), // If node -} - -/* -pub fn schedule(module: &Module, info: FunctionMap, schedule: String) -> Result<Vec<Plan>, String> { - if let Ok(mut file) = File::open(schedule) { +mod compile; +mod default; +mod ir; +pub mod labels; +mod pm; + +use crate::compile::*; +use crate::default::*; +use crate::ir::*; +use crate::labels::*; +use crate::pm::*; + +// Given a schedule's filename parse and process the schedule +fn build_schedule(sched_filename: String) -> Result<ScheduleStmt, String> { + if let Ok(mut file) = File::open(sched_filename) { let mut contents = String::new(); if let Ok(_) = file.read_to_string(&mut contents) { let lexerdef = lexer::lexerdef(); @@ -55,284 +35,113 @@ pub fn schedule(module: &Module, info: FunctionMap, schedule: String) -> Result< if errs.is_empty() { match res { - None => Err(format!("No parse errors, but no parsing failed")), + None => Err(format!("No parse errors, but parsing the schedule failed")), Some(schd) => { - let mut sched = generate_schedule(module, info, schd, &lexer)?; - let mut schedules = vec![]; - for i in 0..sched.len() { - schedules.push(sched.remove(&FunctionID::new(i)).unwrap()); - } - Ok(schedules) + compile_schedule(schd, &lexer).map_err(|e| format!("Schedule Error: {}", e)) } } } else { Err(errs .iter() - .map(|e| format!("Syntax Error: {}", e.pp(&lexer, &parser::token_epp))) + .map(|e| { + format!( + "Schedule Syntax Error: {}", + e.pp(&lexer, &parser::token_epp) + ) + }) .collect::<Vec<_>>() .join("\n")) } } else { - Err(format!("Unable to read input file")) + Err(format!("Unable to read schedule")) } } else { - Err(format!("Unable to open input file")) + Err(format!("Unable to open schedule")) } } -// a plan that tracks additional information useful while we construct the -// schedule -struct TempPlan { - schedules: Vec<Vec<Schedule>>, - // we track both the partition each node is in and what labeled caused us - // to assign that partition - partitions: Vec<(usize, PartitionNumber)>, - partition_devices: Vec<Vec<Device>>, -} -type PartitionNumber = usize; - -impl Into<Plan> for TempPlan { - fn into(self) -> Plan { - let num_partitions = self.partition_devices.len(); - Plan { - schedules: self.schedules, - partitions: self - .partitions - .into_iter() - .map(|(_, n)| PartitionID::new(n)) - .collect::<Vec<_>>(), - partition_devices: self - .partition_devices - .into_iter() - .map(|mut d| { - if d.len() != 1 { - panic!("Partition with multiple devices") - } else { - d.pop().unwrap() - } - }) - .collect::<Vec<_>>(), - num_partitions: num_partitions, - } +fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> { + if let Some(name) = sched_filename { + build_schedule(name) + } else { + Ok(default_schedule()) } } -fn generate_schedule( - module: &Module, - info: FunctionMap, - schedule: Vec<parser::FuncDirectives>, - lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, -) -> Result<HashMap<FunctionID, Plan>, String> { - let mut res: HashMap<FunctionID, TempPlan> = HashMap::new(); - - // We initialize every node in every function as not having any schedule - // and being in the default partition which is a CPU-only partition - // (a result of label 0) - for (_, (_, _, func_insts)) in info.iter() { - for (_, func_id, _, _) in func_insts.iter() { - let num_nodes = module.functions[func_id.idx()].nodes.len(); - res.insert( - *func_id, - TempPlan { - schedules: vec![vec![]; num_nodes], - partitions: vec![(0, 0); num_nodes], - partition_devices: vec![vec![Device::CPU]], - }, - ); - } - } - - // Construct a map from function names to function numbers - let mut function_names: HashMap<String, usize> = HashMap::new(); - for (num, (_, nm, _)) in info.iter() { - function_names.insert(nm.clone(), *num); +pub fn schedule_juno( + module: Module, + juno_info: JunoInfo, + sched_filename: Option<String>, + output_dir: String, + module_name: String, +) -> Result<(), String> { + let sched = process_schedule(sched_filename)?; + + // Prepare the scheduler's string table and environment + // For this, we need to put all of the Juno functions into the environment + // and string table + let mut strings = StringTable::new(); + let mut env = Env::new(); + + env.open_scope(); + + let JunoInfo { + func_names, + func_info, + } = juno_info; + for (func_name, func_id) in func_names { + let func_name = strings.lookup_string(func_name); + env.insert(func_name, Value::JunoFunction { func: func_id }); } - // Make the map immutable - let function_names = function_names; - - for parser::FuncDirectives { - span: _, - func, - directives, - } in schedule - { - // Identify the function - let parser::Func { - span: _, - name: func_name, - } = func; - let name = lexer.span_str(func_name).to_string(); - let func_num = match function_names.get(&name) { - Some(num) => num, - None => { - return Err(format!("Function {} is undefined", name)); - } - }; - - // Identify label information - let (label_map, _, func_inst) = info.get(func_num).unwrap(); - let get_label_num = |label_span| { - let label_name = lexer.span_str(label_span).to_string(); - match label_map.get(&label_name) { - Some(num) => Ok(*num), - None => Err(format!("Label {} undefined in {}", label_name, name)), - } - }; - - // Process the partitioning and scheduling directives for each instance - // of the function - for (_, func_id, label_info, node_labels) in func_inst { - let func_info = res.get_mut(func_id).unwrap(); - for directive in &directives { - match directive { - parser::Directive::Partition { - span: _, - labels, - devices, - } => { - // Setup the new partition - let partition_num = func_info.partition_devices.len(); - let mut partition_devices = vec![]; - - for parser::Device { span: _, name } in devices { - let device_name = lexer.span_str(*name).to_string(); - if device_name == "cpu" { - partition_devices.push(Device::CPU); - } else if device_name == "gpu" { - partition_devices.push(Device::GPU); - } else { - return Err(format!("Invalid device {}", device_name)); - } - } - - func_info.partition_devices.push(partition_devices); - - for label in labels { - let label_num = get_label_num(*label)?; - let descendants = &label_info[label_num].1; - - node_labels - .iter() - .filter_map(|(node, label)| { - if *label == label_num || descendants.contains(label) { - Some(node.idx()) - } else { - None - } - }) - .for_each(|node| { - let node_part: &mut (usize, PartitionNumber) = - &mut func_info.partitions[node]; - if !descendants.contains(&node_part.0) { - *node_part = (label_num, partition_num); - } - }); - } - } - parser::Directive::Schedule { - span: _, - command, - args, - } => { - let command = lexer.span_str(*command).to_string(); - if command == "parallelize" { - for label in args { - let label_num = get_label_num(*label)?; - match label_info[label_num].0 { - LabeledStructure::Loop(header) => { - func_info.schedules[header.idx()] - .push(Schedule::ParallelReduce); - } - _ => { - return Err(format!( - "Cannot parallelize {}, not a loop", - lexer.span_str(*label) - )); - } - } - } - } else if command == "vectorize" { - for label in args { - let label_num = get_label_num(*label)?; - match label_info[label_num].0 { - LabeledStructure::Loop(header) => { - // FIXME: Take the factor as part of schedule - func_info.schedules[header.idx()] - .push(Schedule::Vectorizable(8)); - } - _ => { - return Err(format!( - "Cannot vectorize {}, not a loop", - lexer.span_str(*label) - )); - } - } - } - } else { - return Err(format!("Command {} undefined", command)); - } - } - } - } - /* - - /* - for parser::Command { span : _, name : command_name, - args : command_args } in commands.iter() { - if command_args.len() != 0 { todo!("Command arguments not supported") } - - let command = lexer.span_str(*command_name).to_string(); - if command == "cpu" || command == "gpu" { - let partition = res.get(func_id).unwrap() - .partition_devices.len(); - res.get_mut(func_id).unwrap().partition_devices.push( - if command == "cpu" { Device::CPU } - else { Device::GPU }); - - node_labels.iter() - .filter_map(|(node, label)| - if label_num == *label - || label_info[label_num].1.contains(&label) { - Some(node.idx()) - } else { - None - }) - .for_each(|node| { - let node_part : &mut (usize, PartitionNumber) = - &mut res.get_mut(func_id).unwrap().partitions[node]; - if !label_info[label_num].1.contains(&node_part.0) { - *node_part = (label_num, partition); - }}); - } else if command == "parallel" || command == "vectorize" { - match label_info[label_num].0 { - LabeledStructure::Loop(header) => { - res.get_mut(func_id).unwrap() - .schedules[header.idx()] - .push(if command == "parallel" { - Schedule::ParallelReduce - } else { - Schedule::Vectorize - }); - }, - _ => { - return Err(format!("Cannot parallelize, not a loop")); - }, - } - } else { - return Err(format!("Command {} undefined", command)); - } - } - */ + env.open_scope(); + schedule_codegen( + module, + sched, + strings, + env, + func_info, + output_dir, + module_name, + ) + .map_err(|e| format!("Scheduling Error: {}", e)) +} - func_info.partition_devices.push(partition_devices); - */ - } +pub fn schedule_hercules( + module: Module, + sched_filename: Option<String>, + output_dir: String, + module_name: String, +) -> Result<(), String> { + let sched = process_schedule(sched_filename)?; + + // Prepare the scheduler's string table and environment + // For this, we put all of the Hercules function names into the environment + // and string table + let mut strings = StringTable::new(); + let mut env = Env::new(); + + env.open_scope(); + + for (idx, func) in module.functions.iter().enumerate() { + let func_name = strings.lookup_string(func.name.clone()); + env.insert( + func_name, + Value::HerculesFunction { + func: FunctionID::new(idx), + }, + ); } - Ok(res - .into_iter() - .map(|(f, p)| (f, p.into())) - .collect::<HashMap<_, _>>()) + env.open_scope(); + schedule_codegen( + module, + sched, + strings, + env, + JunoFunctions { func_ids: vec![] }, + output_dir, + module_name, + ) + .map_err(|e| format!("Scheduling Error: {}", e)) } -*/ diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs new file mode 100644 index 0000000000000000000000000000000000000000..452c1995ee313e77d787ba7b769788c6e67b1271 --- /dev/null +++ b/juno_scheduler/src/pm.rs @@ -0,0 +1,1610 @@ +use crate::ir::*; +use crate::labels::*; +use hercules_cg::*; +use hercules_ir::*; +use hercules_opt::FunctionEditor; +use hercules_opt::{ + ccp, collapse_returns, crc, dce, dumb_outline, ensure_between_control_flow, float_collections, + fork_split, gcm, gvn, infer_parallel_fork, infer_parallel_reduce, infer_tight_associative, + infer_vectorizable, inline, interprocedural_sroa, lift_dc_math, outline, phi_elim, predication, + slf, sroa, unforkify, write_predication, +}; + +use tempfile::TempDir; + +use juno_utils::env::Env; +use juno_utils::stringtab::StringTable; + +use std::cell::RefCell; +use std::collections::{BTreeSet, HashMap, HashSet}; +use std::env::temp_dir; +use std::fmt; +use std::fs::File; +use std::io::Write; +use std::iter::zip; +use std::process::{Command, Stdio}; + +#[derive(Debug, Clone)] +pub enum Value { + Label { labels: Vec<LabelInfo> }, + JunoFunction { func: JunoFunctionID }, + HerculesFunction { func: FunctionID }, + Record { fields: HashMap<String, Value> }, + Everything {}, + Selection { selection: Vec<Value> }, + Integer { val: usize }, + Boolean { val: bool }, +} + +#[derive(Debug, Copy, Clone)] +enum CodeLocation { + Label(LabelInfo), + Function(FunctionID), +} + +impl Value { + fn is_everything(&self) -> bool { + match self { + Value::Everything {} => true, + _ => false, + } + } + + fn as_labels(&self) -> Result<Vec<LabelInfo>, SchedulerError> { + match self { + Value::Label { labels } => Ok(labels.clone()), + Value::Selection { selection } => { + let mut result = vec![]; + for val in selection { + result.extend(val.as_labels()?); + } + Ok(result) + } + Value::JunoFunction { .. } | Value::HerculesFunction { .. } => Err( + SchedulerError::SemanticError("Expected labels, found function".to_string()), + ), + Value::Record { .. } => Err(SchedulerError::SemanticError( + "Expected labels, found record".to_string(), + )), + Value::Everything {} => Err(SchedulerError::SemanticError( + "Expected labels, found everything".to_string(), + )), + Value::Integer { .. } => Err(SchedulerError::SemanticError( + "Expected labels, found integer".to_string(), + )), + Value::Boolean { .. } => Err(SchedulerError::SemanticError( + "Expected labels, found boolean".to_string(), + )), + } + } + + fn as_functions(&self, funcs: &JunoFunctions) -> Result<Vec<FunctionID>, SchedulerError> { + match self { + Value::JunoFunction { func } => Ok(funcs.get_function(*func).clone()), + Value::HerculesFunction { func } => Ok(vec![*func]), + Value::Selection { selection } => { + let mut result = vec![]; + for val in selection { + result.extend(val.as_functions(funcs)?); + } + Ok(result) + } + Value::Label { .. } => Err(SchedulerError::SemanticError( + "Expected functions, found label".to_string(), + )), + Value::Record { .. } => Err(SchedulerError::SemanticError( + "Expected functions, found record".to_string(), + )), + Value::Everything {} => Err(SchedulerError::SemanticError( + "Expected functions, found everything".to_string(), + )), + Value::Integer { .. } => Err(SchedulerError::SemanticError( + "Expected functions, found integer".to_string(), + )), + Value::Boolean { .. } => Err(SchedulerError::SemanticError( + "Expected functions, found boolean".to_string(), + )), + } + } + + fn as_locations(&self, funcs: &JunoFunctions) -> Result<Vec<CodeLocation>, SchedulerError> { + match self { + Value::Label { labels } => Ok(labels.iter().map(|l| CodeLocation::Label(*l)).collect()), + Value::JunoFunction { func } => Ok(funcs + .get_function(*func) + .iter() + .map(|f| CodeLocation::Function(*f)) + .collect()), + Value::HerculesFunction { func } => Ok(vec![CodeLocation::Function(*func)]), + Value::Selection { selection } => { + let mut result = vec![]; + for val in selection { + result.extend(val.as_locations(funcs)?); + } + Ok(result) + } + Value::Record { .. } => Err(SchedulerError::SemanticError( + "Expected code locations, found record".to_string(), + )), + Value::Everything {} => { + panic!("Internal error, check is_everything() before using as_functions()") + } + Value::Integer { .. } => Err(SchedulerError::SemanticError( + "Expected code locations, found integer".to_string(), + )), + Value::Boolean { .. } => Err(SchedulerError::SemanticError( + "Expected code locations, found boolean".to_string(), + )), + } + } +} + +#[derive(Debug, Clone)] +pub enum SchedulerError { + UndefinedVariable(String), + UndefinedField(String), + UndefinedLabel(String), + SemanticError(String), + PassError { pass: String, error: String }, + FixpointFailure(), +} + +impl fmt::Display for SchedulerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SchedulerError::UndefinedVariable(nm) => write!(f, "Undefined variable '{}'", nm), + SchedulerError::UndefinedField(nm) => write!(f, "No field '{}'", nm), + SchedulerError::UndefinedLabel(nm) => write!(f, "No label '{}'", nm), + SchedulerError::SemanticError(msg) => write!(f, "{}", msg), + SchedulerError::PassError { pass, error } => { + write!(f, "Error in pass {}: {}", pass, error) + } + SchedulerError::FixpointFailure() => { + write!(f, "Fixpoint did not converge within limit") + } + } + } +} + +#[derive(Debug)] +struct PassManager { + functions: Vec<Function>, + types: RefCell<Vec<Type>>, + constants: RefCell<Vec<Constant>>, + dynamic_constants: RefCell<Vec<DynamicConstant>>, + labels: RefCell<Vec<String>>, + + // Cached analysis results. + pub def_uses: Option<Vec<ImmutableDefUseMap>>, + pub reverse_postorders: Option<Vec<Vec<NodeID>>>, + pub typing: Option<ModuleTyping>, + pub control_subgraphs: Option<Vec<Subgraph>>, + pub doms: Option<Vec<DomTree>>, + pub postdoms: Option<Vec<DomTree>>, + pub fork_join_maps: Option<Vec<HashMap<NodeID, NodeID>>>, + pub fork_join_nests: Option<Vec<HashMap<NodeID, Vec<NodeID>>>>, + pub loops: Option<Vec<LoopTree>>, + pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, + pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, + pub bbs: Option<Vec<BasicBlocks>>, + pub collection_objects: Option<CollectionObjects>, + pub callgraph: Option<CallGraph>, +} + +impl PassManager { + fn new(module: Module) -> Self { + let Module { + functions, + types, + constants, + dynamic_constants, + labels, + } = module; + PassManager { + functions, + types: RefCell::new(types), + constants: RefCell::new(constants), + dynamic_constants: RefCell::new(dynamic_constants), + labels: RefCell::new(labels), + def_uses: None, + reverse_postorders: None, + typing: None, + control_subgraphs: None, + doms: None, + postdoms: None, + fork_join_maps: None, + fork_join_nests: None, + loops: None, + reduce_cycles: None, + data_nodes_in_fork_joins: None, + bbs: None, + collection_objects: None, + callgraph: None, + } + } + + pub fn make_def_uses(&mut self) { + if self.def_uses.is_none() { + self.def_uses = Some(self.functions.iter().map(def_use).collect()); + } + } + + pub fn make_reverse_postorders(&mut self) { + if self.reverse_postorders.is_none() { + self.make_def_uses(); + self.reverse_postorders = Some( + self.def_uses + .as_ref() + .unwrap() + .iter() + .map(reverse_postorder) + .collect(), + ); + } + } + + pub fn make_typing(&mut self) { + if self.typing.is_none() { + self.make_reverse_postorders(); + self.typing = Some( + typecheck( + &self.functions, + &mut self.types.borrow_mut(), + &self.constants.borrow(), + &mut self.dynamic_constants.borrow_mut(), + self.reverse_postorders.as_ref().unwrap(), + ) + .unwrap(), + ); + } + } + + pub fn make_control_subgraphs(&mut self) { + if self.control_subgraphs.is_none() { + self.make_def_uses(); + self.control_subgraphs = Some( + zip(&self.functions, self.def_uses.as_ref().unwrap()) + .map(|(function, def_use)| control_subgraph(function, def_use)) + .collect(), + ); + } + } + + pub fn make_doms(&mut self) { + if self.doms.is_none() { + self.make_control_subgraphs(); + self.doms = Some( + self.control_subgraphs + .as_ref() + .unwrap() + .iter() + .map(|subgraph| dominator(subgraph, NodeID::new(0))) + .collect(), + ); + } + } + + pub fn make_postdoms(&mut self) { + if self.postdoms.is_none() { + self.make_control_subgraphs(); + self.postdoms = Some( + zip( + self.control_subgraphs.as_ref().unwrap().iter(), + self.functions.iter(), + ) + .map(|(subgraph, function)| dominator(subgraph, NodeID::new(function.nodes.len()))) + .collect(), + ); + } + } + + pub fn make_fork_join_maps(&mut self) { + if self.fork_join_maps.is_none() { + self.make_control_subgraphs(); + self.fork_join_maps = Some( + zip( + self.functions.iter(), + self.control_subgraphs.as_ref().unwrap().iter(), + ) + .map(|(function, subgraph)| fork_join_map(function, subgraph)) + .collect(), + ); + } + } + + pub fn make_fork_join_nests(&mut self) { + if self.fork_join_nests.is_none() { + self.make_doms(); + self.make_fork_join_maps(); + self.fork_join_nests = Some( + zip( + self.functions.iter(), + zip( + self.doms.as_ref().unwrap().iter(), + self.fork_join_maps.as_ref().unwrap().iter(), + ), + ) + .map(|(function, (dom, fork_join_map))| { + compute_fork_join_nesting(function, dom, fork_join_map) + }) + .collect(), + ); + } + } + + pub fn make_loops(&mut self) { + if self.loops.is_none() { + self.make_control_subgraphs(); + self.make_doms(); + self.make_fork_join_maps(); + let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter(); + let doms = self.doms.as_ref().unwrap().iter(); + let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); + self.loops = Some( + zip(control_subgraphs, zip(doms, fork_join_maps)) + .map(|(control_subgraph, (dom, fork_join_map))| { + loops(control_subgraph, NodeID::new(0), dom, fork_join_map) + }) + .collect(), + ); + } + } + + pub fn make_reduce_cycles(&mut self) { + if self.reduce_cycles.is_none() { + self.make_def_uses(); + let def_uses = self.def_uses.as_ref().unwrap().iter(); + self.reduce_cycles = Some( + zip(self.functions.iter(), def_uses) + .map(|(function, def_use)| reduce_cycles(function, def_use)) + .collect(), + ); + } + } + + pub fn make_data_nodes_in_fork_joins(&mut self) { + if self.data_nodes_in_fork_joins.is_none() { + self.make_def_uses(); + self.make_fork_join_maps(); + self.data_nodes_in_fork_joins = Some( + zip( + self.functions.iter(), + zip( + self.def_uses.as_ref().unwrap().iter(), + self.fork_join_maps.as_ref().unwrap().iter(), + ), + ) + .map(|(function, (def_use, fork_join_map))| { + data_nodes_in_fork_joins(function, def_use, fork_join_map) + }) + .collect(), + ); + } + } + + pub fn make_collection_objects(&mut self) { + if self.collection_objects.is_none() { + self.make_reverse_postorders(); + self.make_typing(); + self.make_callgraph(); + let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); + let typing = self.typing.as_ref().unwrap(); + let callgraph = self.callgraph.as_ref().unwrap(); + self.collection_objects = Some(collection_objects( + &self.functions, + &self.types.borrow(), + reverse_postorders, + typing, + callgraph, + )); + } + } + + pub fn make_callgraph(&mut self) { + if self.callgraph.is_none() { + self.callgraph = Some(callgraph(&self.functions)); + } + } + + pub fn delete_gravestones(&mut self) { + for func in self.functions.iter_mut() { + func.delete_gravestones(); + } + } + + fn clear_analyses(&mut self) { + self.def_uses = None; + self.reverse_postorders = None; + self.typing = None; + self.control_subgraphs = None; + self.doms = None; + self.postdoms = None; + self.fork_join_maps = None; + self.fork_join_nests = None; + self.loops = None; + self.reduce_cycles = None; + self.data_nodes_in_fork_joins = None; + self.bbs = None; + self.collection_objects = None; + self.callgraph = None; + } + + fn with_mod<B, F>(&mut self, mut f: F) -> B + where + F: FnMut(&mut Module) -> B, + { + let mut module = Module { + functions: std::mem::take(&mut self.functions), + types: self.types.take(), + constants: self.constants.take(), + dynamic_constants: self.dynamic_constants.take(), + labels: self.labels.take(), + }; + + let res = f(&mut module); + + let Module { + functions, + types, + constants, + dynamic_constants, + labels, + } = module; + self.functions = functions; + self.types.replace(types); + self.constants.replace(constants); + self.dynamic_constants.replace(dynamic_constants); + self.labels.replace(labels); + + res + } + + fn codegen(mut self, output_dir: String, module_name: String) -> Result<(), SchedulerError> { + self.make_typing(); + self.make_control_subgraphs(); + self.make_collection_objects(); + self.make_callgraph(); + + let PassManager { + functions, + types, + constants, + dynamic_constants, + labels, + typing: Some(typing), + control_subgraphs: Some(control_subgraphs), + bbs: Some(bbs), + collection_objects: Some(collection_objects), + callgraph: Some(callgraph), + .. + } = self + else { + return Err(SchedulerError::PassError { + pass: "codegen".to_string(), + error: "Missing basic blocks".to_string(), + }); + }; + + let module = Module { + functions, + types: types.into_inner(), + constants: constants.into_inner(), + dynamic_constants: dynamic_constants.into_inner(), + labels: labels.into_inner(), + }; + + let devices = device_placement(&module.functions, &callgraph); + + let mut rust_rt = String::new(); + let mut llvm_ir = String::new(); + for idx in 0..module.functions.len() { + match devices[idx] { + Device::LLVM => cpu_codegen( + &module.functions[idx], + &module.types, + &module.constants, + &module.dynamic_constants, + &typing[idx], + &control_subgraphs[idx], + &bbs[idx], + &mut llvm_ir, + ) + .map_err(|e| SchedulerError::PassError { + pass: "cpu codegen".to_string(), + error: format!("{}", e), + })?, + Device::AsyncRust => rt_codegen( + FunctionID::new(idx), + &module, + &typing[idx], + &control_subgraphs[idx], + &bbs[idx], + &collection_objects, + &callgraph, + &devices, + &mut rust_rt, + ) + .map_err(|e| SchedulerError::PassError { + pass: "rust codegen".to_string(), + error: format!("{}", e), + })?, + _ => todo!(), + } + } + println!("{}", llvm_ir); + println!("{}", rust_rt); + + // Write the LLVM IR into a temporary file. + let tmp_dir = TempDir::new().unwrap(); + let mut tmp_path = tmp_dir.path().to_path_buf(); + tmp_path.push(format!("{}.ll", module_name)); + println!("{}", tmp_path.display()); + let mut file = File::create(&tmp_path).expect("PANIC: Unable to open output LLVM IR file."); + file.write_all(llvm_ir.as_bytes()) + .expect("PANIC: Unable to write output LLVM IR file contents."); + + // Compile LLVM IR into an ELF object file. + let output_archive = format!("{}/lib{}.a", output_dir, module_name); + let mut clang_process = Command::new("clang") + .arg(&tmp_path) + .arg("--emit-static-lib") + .arg("-O3") + .arg("-march=native") + .arg("-o") + .arg(&output_archive) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .expect("Error running clang. Is it installed?"); + assert!(clang_process.wait().unwrap().success()); + + // Write the Rust runtime into a file. + let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); + println!("{}", output_rt); + let mut file = + File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file."); + file.write_all(rust_rt.as_bytes()) + .expect("PANIC: Unable to write output Rust runtime file contents."); + + Ok(()) + } +} + +pub fn schedule_codegen( + mut module: Module, + schedule: ScheduleStmt, + mut stringtab: StringTable, + mut env: Env<usize, Value>, + functions: JunoFunctions, + output_dir: String, + module_name: String, +) -> Result<(), SchedulerError> { + let mut pm = PassManager::new(module); + let _ = schedule_interpret(&mut pm, &schedule, &mut stringtab, &mut env, &functions)?; + pm.codegen(output_dir, module_name) +} + +// Interpreter for statements and expressions returns a bool indicating whether +// any optimization ran and changed the IR. This is used for implementing +// the fixpoint +fn schedule_interpret( + pm: &mut PassManager, + schedule: &ScheduleStmt, + stringtab: &mut StringTable, + env: &mut Env<usize, Value>, + functions: &JunoFunctions, +) -> Result<bool, SchedulerError> { + match schedule { + ScheduleStmt::Fixpoint { body, limit } => { + let mut i = 0; + loop { + // If no change was made, we've reached the fixpoint and are done + if !schedule_interpret(pm, body, stringtab, env, functions)? { + break; + } + // Otherwise, increase the iteration count and check the limit + i += 1; + match limit { + FixpointLimit::NoLimit() => {} + FixpointLimit::PrintIter() => { + println!("Finished Iteration {}", i - 1) + } + FixpointLimit::StopAfter(n) => { + if i >= *n { + break; + } + } + FixpointLimit::PanicAfter(n) => { + if i >= *n { + return Err(SchedulerError::FixpointFailure()); + } + } + } + } + // If we ran just 1 iteration, then no changes were made and otherwise some changes + // were made + Ok(i > 1) + } + ScheduleStmt::Block { body } => { + let mut modified = false; + env.open_scope(); + for command in body { + modified |= schedule_interpret(pm, command, stringtab, env, functions)?; + } + env.close_scope(); + Ok(modified) + } + ScheduleStmt::Let { var, exp } => { + let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?; + let var_id = stringtab.lookup_string(var.clone()); + env.insert(var_id, res); + Ok(modified) + } + ScheduleStmt::Assign { var, exp } => { + let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?; + let var_id = stringtab.lookup_string(var.clone()); + match env.lookup_mut(&var_id) { + None => { + return Err(SchedulerError::UndefinedVariable(var.clone())); + } + Some(val) => { + *val = res; + } + } + Ok(modified) + } + ScheduleStmt::AddSchedule { sched, on } => match on { + Selector::Everything() => Err(SchedulerError::SemanticError( + "Cannot apply schedule to everything".to_string(), + )), + Selector::Selection(selection) => { + let mut changed = false; + for label in selection { + let (label, modified) = interp_expr(pm, label, stringtab, env, functions)?; + changed |= modified; + add_schedule(pm, sched.clone(), label.as_labels()?); + } + Ok(changed) + } + }, + ScheduleStmt::AddDevice { device, on } => match on { + Selector::Everything() => { + for func in pm.functions.iter_mut() { + func.device = Some(device.clone()); + } + Ok(false) + } + Selector::Selection(selection) => { + let mut changed = false; + for func in selection { + let (func, modified) = interp_expr(pm, func, stringtab, env, functions)?; + changed |= modified; + add_device(pm, device.clone(), func.as_functions(functions)?); + } + Ok(changed) + } + }, + } +} + +fn interp_expr( + pm: &mut PassManager, + expr: &ScheduleExp, + stringtab: &mut StringTable, + env: &mut Env<usize, Value>, + functions: &JunoFunctions, +) -> Result<(Value, bool), SchedulerError> { + match expr { + ScheduleExp::Variable { var } => { + let var_id = stringtab.lookup_string(var.clone()); + match env.lookup(&var_id) { + None => Err(SchedulerError::UndefinedVariable(var.clone())), + Some(v) => Ok((v.clone(), false)), + } + } + ScheduleExp::Integer { val } => Ok((Value::Integer { val: *val }, false)), + ScheduleExp::Boolean { val } => Ok((Value::Boolean { val: *val }, false)), + ScheduleExp::Field { collect, field } => { + let (lhs, changed) = interp_expr(pm, collect, stringtab, env, functions)?; + match lhs { + Value::Label { .. } + | Value::Selection { .. } + | Value::Everything { .. } + | Value::Integer { .. } + | Value::Boolean { .. } => Err(SchedulerError::UndefinedField(field.clone())), + Value::JunoFunction { func } => { + match pm.labels.borrow().iter().position(|s| s == field) { + None => Err(SchedulerError::UndefinedLabel(field.clone())), + Some(label_idx) => Ok(( + Value::Label { + labels: functions + .get_function(func) + .iter() + .map(|f| LabelInfo { + func: *f, + label: LabelID::new(label_idx), + }) + .collect(), + }, + changed, + )), + } + } + Value::HerculesFunction { func } => { + match pm.labels.borrow().iter().position(|s| s == field) { + None => Err(SchedulerError::UndefinedLabel(field.clone())), + Some(label_idx) => Ok(( + Value::Label { + labels: vec![LabelInfo { + func: func, + label: LabelID::new(label_idx), + }], + }, + changed, + )), + } + } + Value::Record { fields } => match fields.get(field) { + None => Err(SchedulerError::UndefinedField(field.clone())), + Some(v) => Ok((v.clone(), changed)), + }, + } + } + ScheduleExp::RunPass { pass, args, on } => { + let mut changed = false; + let mut arg_vals = vec![]; + for arg in args { + let (val, modified) = interp_expr(pm, arg, stringtab, env, functions)?; + arg_vals.push(val); + changed |= modified; + } + + let selection = match on { + Selector::Everything() => None, + Selector::Selection(selection) => { + let mut locs = vec![]; + let mut everything = false; + for loc in selection { + let (val, modified) = interp_expr(pm, loc, stringtab, env, functions)?; + changed |= modified; + if val.is_everything() { + everything = true; + break; + } + locs.extend(val.as_locations(functions)?); + } + if everything { + None + } else { + Some(locs) + } + } + }; + + let (res, modified) = run_pass(pm, *pass, arg_vals, selection)?; + changed |= modified; + Ok((res, changed)) + } + ScheduleExp::Record { fields } => { + let mut result = HashMap::new(); + let mut changed = false; + for (field, val) in fields { + let (val, modified) = interp_expr(pm, val, stringtab, env, functions)?; + result.insert(field.clone(), val); + changed |= modified; + } + Ok((Value::Record { fields: result }, changed)) + } + ScheduleExp::Block { body, res } => { + let mut changed = false; + + env.open_scope(); + for command in body { + changed |= schedule_interpret(pm, command, stringtab, env, functions)?; + } + let (res, modified) = interp_expr(pm, res, stringtab, env, functions)?; + env.close_scope(); + + Ok((res, changed || modified)) + } + ScheduleExp::Selection { selection } => match selection { + Selector::Everything() => Ok((Value::Everything {}, false)), + Selector::Selection(selection) => { + let mut values = vec![]; + let mut changed = false; + for e in selection { + let (val, modified) = interp_expr(pm, e, stringtab, env, functions)?; + values.push(val); + changed |= modified; + } + Ok((Value::Selection { selection: values }, changed)) + } + }, + } +} + +fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo>) { + for LabelInfo { func, label } in label_ids { + let nodes = pm.functions[func.idx()] + .labels + .iter() + .enumerate() + .filter(|(i, ls)| ls.contains(&label)) + .map(|(i, ls)| i) + .collect::<Vec<_>>(); + for node in nodes { + pm.functions[func.idx()].schedules[node].push(sched.clone()); + } + } +} + +fn add_device(pm: &mut PassManager, device: Device, funcs: Vec<FunctionID>) { + for func in funcs { + pm.functions[func.idx()].device = Some(device.clone()); + } +} + +#[derive(Debug, Clone)] +enum FunctionSelection { + Nothing(), + Everything(), + Labels(HashSet<LabelID>), +} + +impl FunctionSelection { + fn add_label(&mut self, label: LabelID) { + match self { + FunctionSelection::Nothing() => { + *self = FunctionSelection::Labels(HashSet::from([label])); + } + FunctionSelection::Everything() => {} + FunctionSelection::Labels(set) => { + set.insert(label); + } + } + } + + fn add_everything(&mut self) { + *self = FunctionSelection::Everything(); + } +} + +fn build_editors<'a>(pm: &'a mut PassManager) -> Vec<FunctionEditor<'a>> { + pm.make_def_uses(); + let def_uses = pm.def_uses.take().unwrap(); + pm.functions + .iter_mut() + .zip(def_uses.iter()) + .enumerate() + .map(|(idx, (func, def_use))| { + FunctionEditor::new( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ) + }) + .collect() +} + +// With a selection, we process it to identify which labels in which functions are to be selected +fn construct_selection(pm: &PassManager, selection: Vec<CodeLocation>) -> Vec<FunctionSelection> { + let mut selected = vec![FunctionSelection::Nothing(); pm.functions.len()]; + for loc in selection { + match loc { + CodeLocation::Label(label) => selected[label.func.idx()].add_label(label.label), + CodeLocation::Function(func) => selected[func.idx()].add_everything(), + } + } + selected +} + +// Given a selection, constructs the set of functions which are selected (and each must be selected +// fully) +fn selection_of_functions( + pm: &PassManager, + selection: Option<Vec<CodeLocation>>, +) -> Option<Vec<FunctionID>> { + if let Some(selection) = selection { + let selection = construct_selection(pm, selection); + + let mut result = vec![]; + + for (idx, selected) in selection.into_iter().enumerate() { + match selected { + FunctionSelection::Nothing() => {} + FunctionSelection::Everything() => result.push(FunctionID::new(idx)), + FunctionSelection::Labels(_) => { + return None; + } + } + } + + Some(result) + } else { + Some( + pm.functions + .iter() + .enumerate() + .map(|(i, _)| FunctionID::new(i)) + .collect(), + ) + } +} + +// Given a selection, constructs the set of the nodes selected for a single function, returning the +// function's id +fn selection_as_set( + pm: &PassManager, + selection: Option<Vec<CodeLocation>>, +) -> Option<(BTreeSet<NodeID>, FunctionID)> { + if let Some(selection) = selection { + let selection = construct_selection(pm, selection); + let mut result = None; + + for (idx, (selected, func)) in selection.into_iter().zip(pm.functions.iter()).enumerate() { + match selected { + FunctionSelection::Nothing() => {} + FunctionSelection::Everything() => match result { + Some(_) => { + return None; + } + None => { + result = Some(( + (0..func.nodes.len()).map(|i| NodeID::new(i)).collect(), + FunctionID::new(idx), + )); + } + }, + FunctionSelection::Labels(labels) => match result { + Some(_) => { + return None; + } + None => { + result = Some(( + (0..func.nodes.len()) + .filter(|i| !func.labels[*i].is_disjoint(&labels)) + .map(|i| NodeID::new(i)) + .collect(), + FunctionID::new(idx), + )); + } + }, + } + } + + result + } else { + None + } +} + +fn build_selection<'a>( + pm: &'a mut PassManager, + selection: Option<Vec<CodeLocation>>, +) -> Vec<Option<FunctionEditor<'a>>> { + // Build def uses, which are needed for the editors we'll construct + pm.make_def_uses(); + let def_uses = pm.def_uses.take().unwrap(); + + if let Some(selection) = selection { + let selected = construct_selection(pm, selection); + + pm.functions + .iter_mut() + .zip(selected.iter()) + .zip(def_uses.iter()) + .enumerate() + .map(|(idx, ((func, selected), def_use))| match selected { + FunctionSelection::Nothing() => None, + FunctionSelection::Everything() => Some(FunctionEditor::new( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + )), + FunctionSelection::Labels(labels) => Some(FunctionEditor::new_labeled( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + labels, + )), + }) + .collect() + } else { + build_editors(pm) + .into_iter() + .map(|func| Some(func)) + .collect() + } +} + +fn run_pass( + pm: &mut PassManager, + pass: Pass, + args: Vec<Value>, + selection: Option<Vec<CodeLocation>>, +) -> Result<(Value, bool), SchedulerError> { + let mut result = Value::Record { + fields: HashMap::new(), + }; + let mut changed = false; + + match pass { + Pass::AutoOutline => { + let Some(funcs) = selection_of_functions(pm, selection) else { + return Err(SchedulerError::PassError { + pass: "autoOutline".to_string(), + error: "must be applied to whole functions".to_string(), + }); + }; + + pm.make_def_uses(); + let def_uses = pm.def_uses.take().unwrap(); + + for func in funcs.iter() { + let mut editor = FunctionEditor::new( + &mut pm.functions[func.idx()], + *func, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + &def_uses[func.idx()], + ); + collapse_returns(&mut editor); + ensure_between_control_flow(&mut editor); + changed |= editor.modified(); + } + pm.clear_analyses(); + + pm.make_def_uses(); + pm.make_typing(); + pm.make_control_subgraphs(); + pm.make_doms(); + + let def_uses = pm.def_uses.take().unwrap(); + let typing = pm.typing.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + let doms = pm.doms.take().unwrap(); + let old_num_funcs = pm.functions.len(); + + let mut new_funcs = vec![]; + // Track the names of the old functions and the new function IDs for returning + let mut new_func_ids = HashMap::new(); + + for func in funcs { + let mut editor = FunctionEditor::new( + &mut pm.functions[func.idx()], + func, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + &def_uses[func.idx()], + ); + + let new_func_id = FunctionID::new(old_num_funcs + new_funcs.len()); + + let new_func = dumb_outline( + &mut editor, + &typing[func.idx()], + &control_subgraphs[func.idx()], + &doms[func.idx()], + new_func_id, + ); + changed |= editor.modified(); + + if let Some(new_func) = new_func { + new_func_ids.insert( + editor.func().name.clone(), + Value::HerculesFunction { func: new_func_id }, + ); + new_funcs.push(new_func); + } + + pm.functions[func.idx()].delete_gravestones(); + } + + pm.functions.extend(new_funcs); + pm.clear_analyses(); + + result = Value::Record { + fields: new_func_ids, + }; + } + Pass::CCP => { + assert!(args.is_empty()); + pm.make_reverse_postorders(); + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + for (func, reverse_postorder) in build_selection(pm, selection) + .into_iter() + .zip(reverse_postorders.iter()) + { + let Some(mut func) = func else { + continue; + }; + ccp(&mut func, reverse_postorder); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::CRC => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + crc(&mut func); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::DCE => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + dce(&mut func); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::DeleteUncalled => { + todo!("Delete Uncalled changes FunctionIDs, a bunch of bookkeeping is needed for the pass manager to address this") + } + Pass::FloatCollections => { + assert!(args.is_empty()); + if let Some(_) = selection { + return Err(SchedulerError::PassError { + pass: "floatCollections".to_string(), + error: "must be applied to the entire module".to_string(), + }); + } + + pm.make_typing(); + pm.make_callgraph(); + let typing = pm.typing.take().unwrap(); + let callgraph = pm.callgraph.take().unwrap(); + + let devices = device_placement(&pm.functions, &callgraph); + + let mut editors = build_editors(pm); + float_collections(&mut editors, &typing, &callgraph, &devices); + + for func in editors { + changed |= func.modified(); + } + + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::ForkGuardElim => { + todo!("Fork Guard Elim doesn't use editor") + } + Pass::ForkSplit => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_reduce_cycles(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) + { + let Some(mut func) = func else { + continue; + }; + fork_split(&mut func, fork_join_map, reduce_cycles); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Forkify => { + todo!("Forkify doesn't use editor") + } + Pass::GCM => { + assert!(args.is_empty()); + if let Some(_) = selection { + return Err(SchedulerError::PassError { + pass: "gcm".to_string(), + error: "must be applied to the entire module".to_string(), + }); + } + + loop { + pm.make_def_uses(); + pm.make_reverse_postorders(); + pm.make_typing(); + pm.make_control_subgraphs(); + pm.make_doms(); + pm.make_fork_join_maps(); + pm.make_loops(); + pm.make_collection_objects(); + + let def_uses = pm.def_uses.take().unwrap(); + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + let typing = pm.typing.take().unwrap(); + let doms = pm.doms.take().unwrap(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + let collection_objects = pm.collection_objects.take().unwrap(); + + let mut bbs = vec![]; + + for ( + ( + ( + ((((mut func, def_use), reverse_postorder), typing), control_subgraph), + doms, + ), + fork_join_map, + ), + loops, + ) in build_editors(pm) + .into_iter() + .zip(def_uses.iter()) + .zip(reverse_postorders.iter()) + .zip(typing.iter()) + .zip(control_subgraphs.iter()) + .zip(doms.iter()) + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + { + if let Some(bb) = gcm( + &mut func, + def_use, + reverse_postorder, + typing, + control_subgraph, + doms, + fork_join_map, + loops, + &collection_objects, + ) { + bbs.push(bb); + } + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + if bbs.len() == pm.functions.len() { + pm.bbs = Some(bbs); + break; + } + } + } + Pass::GVN => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + gvn(&mut func, false); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::InferSchedules => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_reduce_cycles(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) + { + let Some(mut func) = func else { + continue; + }; + infer_parallel_reduce(&mut func, fork_join_map, reduce_cycles); + infer_parallel_fork(&mut func, fork_join_map); + infer_vectorizable(&mut func, fork_join_map); + infer_tight_associative(&mut func, reduce_cycles); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Inline => { + assert!(args.is_empty()); + if let Some(_) = selection { + return Err(SchedulerError::PassError { + pass: "inline".to_string(), + error: "must be applied to the entire module (currently)".to_string(), + }); + } + + pm.make_callgraph(); + let callgraph = pm.callgraph.take().unwrap(); + + let mut editors = build_editors(pm); + inline(&mut editors, &callgraph); + + for func in editors { + changed |= func.modified(); + } + + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::InterproceduralSROA => { + assert!(args.is_empty()); + if let Some(_) = selection { + return Err(SchedulerError::PassError { + pass: "interproceduralSROA".to_string(), + error: "must be applied to the entire module".to_string(), + }); + } + + let mut editors = build_editors(pm); + interprocedural_sroa(&mut editors); + + for func in editors { + changed |= func.modified(); + } + + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::LiftDCMath => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + lift_dc_math(&mut func); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Outline => { + let Some((nodes, func)) = selection_as_set(pm, selection) else { + return Err(SchedulerError::PassError { + pass: "outline".to_string(), + error: "must be applied to nodes in a single function".to_string(), + }); + }; + + pm.make_def_uses(); + let def_uses = pm.def_uses.take().unwrap(); + + let mut editor = FunctionEditor::new( + &mut pm.functions[func.idx()], + func, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + &def_uses[func.idx()], + ); + + collapse_returns(&mut editor); + ensure_between_control_flow(&mut editor); + pm.clear_analyses(); + + pm.make_def_uses(); + pm.make_typing(); + pm.make_control_subgraphs(); + pm.make_doms(); + + let def_uses = pm.def_uses.take().unwrap(); + let typing = pm.typing.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + let doms = pm.doms.take().unwrap(); + let new_func_id = FunctionID::new(pm.functions.len()); + + let mut editor = FunctionEditor::new( + &mut pm.functions[func.idx()], + func, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + &def_uses[func.idx()], + ); + + let new_func = outline( + &mut editor, + &typing[func.idx()], + &control_subgraphs[func.idx()], + &doms[func.idx()], + &nodes, + new_func_id, + ); + let Some(new_func) = new_func else { + return Err(SchedulerError::PassError { + pass: "outlining".to_string(), + error: "failed to outline".to_string(), + }); + }; + + pm.functions.push(new_func); + changed = true; + pm.functions[func.idx()].delete_gravestones(); + pm.clear_analyses(); + + result = Value::HerculesFunction { func: new_func_id }; + } + Pass::PhiElim => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + phi_elim(&mut func); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Predication => { + assert!(args.is_empty()); + pm.make_typing(); + let typing = pm.typing.take().unwrap(); + + for (func, types) in build_selection(pm, selection) + .into_iter() + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + predication(&mut func, types); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::SLF => { + assert!(args.is_empty()); + pm.make_reverse_postorders(); + pm.make_typing(); + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + let typing = pm.typing.take().unwrap(); + + for ((func, reverse_postorder), types) in build_selection(pm, selection) + .into_iter() + .zip(reverse_postorders.iter()) + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + slf(&mut func, reverse_postorder, types); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::SROA => { + assert!(args.is_empty()); + pm.make_reverse_postorders(); + pm.make_typing(); + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + let typing = pm.typing.take().unwrap(); + + for ((func, reverse_postorder), types) in build_selection(pm, selection) + .into_iter() + .zip(reverse_postorders.iter()) + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + sroa(&mut func, reverse_postorder, types); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Unforkify => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + + for (func, fork_join_map) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + unforkify(&mut func, fork_join_map); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::WritePredication => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + write_predication(&mut func); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Verify => { + assert!(args.is_empty()); + let (def_uses, reverse_postorders, typing, subgraphs, doms, postdoms, fork_join_maps) = + pm.with_mod(|module| verify(module)) + .map_err(|msg| SchedulerError::PassError { + pass: "verify".to_string(), + error: format!("failed: {}", msg), + })?; + + // Verification produces a bunch of analysis results that + // may be useful for later passes. + pm.def_uses = Some(def_uses); + pm.reverse_postorders = Some(reverse_postorders); + pm.typing = Some(typing); + pm.control_subgraphs = Some(subgraphs); + pm.doms = Some(doms); + pm.postdoms = Some(postdoms); + pm.fork_join_maps = Some(fork_join_maps); + } + Pass::Xdot => { + let force_analyses = match args.get(0) { + Some(Value::Boolean { val }) => *val, + Some(_) => { + return Err(SchedulerError::PassError { + pass: "xdot".to_string(), + error: "expected boolean argument".to_string(), + }); + } + None => true, + }; + + pm.make_reverse_postorders(); + if force_analyses { + pm.make_doms(); + pm.make_fork_join_maps(); + } + + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + let doms = pm.doms.take(); + let fork_join_maps = pm.fork_join_maps.take(); + let bbs = pm.bbs.take(); + pm.with_mod(|module| { + xdot_module( + module, + &reverse_postorders, + doms.as_ref(), + fork_join_maps.as_ref(), + bbs.as_ref(), + ) + }); + + // Put BasicBlocks back, since it's needed for Codegen. + pm.bbs = bbs; + } + } + println!("Ran Pass: {:?}", pass); + + Ok((result, changed)) +} diff --git a/juno_utils/.gitignore b/juno_utils/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ef5f7e557e275015cb02004a1f69e6b98ae8f2f0 --- /dev/null +++ b/juno_utils/.gitignore @@ -0,0 +1,4 @@ +*.aux +*.log +*.out +*.pdf diff --git a/juno_utils/Cargo.toml b/juno_utils/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..8de3b6518f334b1e75b5a7388a9840531e37a938 --- /dev/null +++ b/juno_utils/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "juno_utils" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[lib] +name = "juno_utils" +path = "src/lib.rs" + +[dependencies] +serde = { version = "*", features = ["derive"] } diff --git a/juno_frontend/src/env.rs b/juno_utils/src/env.rs similarity index 91% rename from juno_frontend/src/env.rs rename to juno_utils/src/env.rs index fb746045c075ac3bb4b89a1880759c584a4014a5..cfa84b7875be3f5154cf4051d136ed234d75cd39 100644 --- a/juno_frontend/src/env.rs +++ b/juno_utils/src/env.rs @@ -24,6 +24,13 @@ impl<K: Eq + Hash + Copy, V> Env<K, V> { } } + pub fn lookup_mut(&mut self, k: &K) -> Option<&mut V> { + match self.table.get_mut(k) { + None => None, + Some(l) => l.last_mut(), + } + } + pub fn insert(&mut self, k: K, v: V) { if self.scope[self.scope.len() - 1].contains(&k) { match self.table.get_mut(&k) { diff --git a/juno_utils/src/lib.rs b/juno_utils/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..56b404bec65dcc18058e40d2f211fcc2643d9785 --- /dev/null +++ b/juno_utils/src/lib.rs @@ -0,0 +1,2 @@ +pub mod env; +pub mod stringtab; diff --git a/juno_utils/src/stringtab.rs b/juno_utils/src/stringtab.rs new file mode 100644 index 0000000000000000000000000000000000000000..e151b830d7b51d3baa4f64ae32214afca81a9eab --- /dev/null +++ b/juno_utils/src/stringtab.rs @@ -0,0 +1,48 @@ +extern crate serde; + +use self::serde::{Deserialize, Serialize}; + +use std::collections::HashMap; + +// Map strings to unique identifiers and counts uids +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct StringTable { + count: usize, + string_to_index: HashMap<String, usize>, + index_to_string: HashMap<usize, String>, +} + +impl StringTable { + pub fn new() -> StringTable { + StringTable { + count: 0, + string_to_index: HashMap::new(), + index_to_string: HashMap::new(), + } + } + + // Produce the UID for a string + pub fn lookup_string(&mut self, s: String) -> usize { + match self.string_to_index.get(&s) { + Some(n) => *n, + None => { + let n = self.count; + self.count += 1; + self.string_to_index.insert(s.clone(), n); + self.index_to_string.insert(n, s); + n + } + } + } + + // Identify the string corresponding to a UID + pub fn lookup_id(&self, n: usize) -> Option<String> { + self.index_to_string.get(&n).cloned() + } +} + +impl Default for StringTable { + fn default() -> Self { + StringTable::new() + } +}