From 494d793e3e0f8df04c46a0ee9b3b8c5db72651aa Mon Sep 17 00:00:00 2001 From: Aaron Councilman <aaronjc4@illinois.edu> Date: Thu, 23 Jan 2025 10:12:11 -0600 Subject: [PATCH] Scheduler --- Cargo.lock | 55 +- Cargo.toml | 7 +- hercules_ir/src/build.rs | 39 +- hercules_ir/src/callgraph.rs | 5 +- hercules_ir/src/collections.rs | 15 +- hercules_ir/src/ir.rs | 14 +- hercules_ir/src/parse.rs | 5 +- hercules_ir/src/typecheck.rs | 11 +- hercules_ir/src/verify.rs | 8 +- hercules_opt/src/editor.rs | 170 ++- hercules_opt/src/inline.rs | 5 + hercules_opt/src/lib.rs | 2 - hercules_opt/src/outline.rs | 10 +- hercules_opt/src/pass.rs | 1132 -------------- hercules_opt/src/utils.rs | 4 +- hercules_tools/hercules_driver/Cargo.toml | 12 - hercules_tools/hercules_driver/src/main.rs | 58 - juno_build/src/lib.rs | 52 +- juno_frontend/Cargo.toml | 1 + juno_frontend/src/codegen.rs | 125 +- juno_frontend/src/labeled_builder.rs | 98 +- juno_frontend/src/lib.rs | 138 +- juno_frontend/src/main.rs | 35 +- juno_frontend/src/semant.rs | 116 +- juno_frontend/src/ssa.rs | 2 +- juno_samples/schedule_test/Cargo.toml | 19 + juno_samples/schedule_test/build.rs | 11 + juno_samples/schedule_test/src/code.jn | 30 + juno_samples/schedule_test/src/main.rs | 42 + juno_samples/schedule_test/src/sched.sch | 43 + juno_scheduler/Cargo.toml | 4 + juno_scheduler/src/compile.rs | 495 +++++++ juno_scheduler/src/default.rs | 81 + juno_scheduler/src/ir.rs | 114 ++ juno_scheduler/src/labels.rs | 61 + juno_scheduler/src/lang.l | 33 +- juno_scheduler/src/lang.y | 179 ++- juno_scheduler/src/lib.rs | 397 ++--- juno_scheduler/src/pm.rs | 1564 ++++++++++++++++++++ juno_utils/.gitignore | 4 + juno_utils/Cargo.toml | 12 + {juno_frontend => juno_utils}/src/env.rs | 7 + juno_utils/src/lib.rs | 2 + juno_utils/src/stringtab.rs | 48 + 44 files changed, 3196 insertions(+), 2069 deletions(-) delete mode 100644 hercules_opt/src/pass.rs delete mode 100644 hercules_tools/hercules_driver/Cargo.toml delete mode 100644 hercules_tools/hercules_driver/src/main.rs create mode 100644 juno_samples/schedule_test/Cargo.toml create mode 100644 juno_samples/schedule_test/build.rs create mode 100644 juno_samples/schedule_test/src/code.jn create mode 100644 juno_samples/schedule_test/src/main.rs create mode 100644 juno_samples/schedule_test/src/sched.sch create mode 100644 juno_scheduler/src/compile.rs create mode 100644 juno_scheduler/src/default.rs create mode 100644 juno_scheduler/src/ir.rs create mode 100644 juno_scheduler/src/labels.rs create mode 100644 juno_scheduler/src/pm.rs create mode 100644 juno_utils/.gitignore create mode 100644 juno_utils/Cargo.toml rename {juno_frontend => juno_utils}/src/env.rs (91%) create mode 100644 juno_utils/src/lib.rs create mode 100644 juno_utils/src/stringtab.rs diff --git a/Cargo.lock b/Cargo.lock index 3c89534c..8bb64bd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -259,12 +259,6 @@ dependencies = [ "arrayvec", ] -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - [[package]] name = "bincode" version = "1.3.3" @@ -291,9 +285,6 @@ name = "bitflags" version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" -dependencies = [ - "serde", -] [[package]] name = "bitstream-io" @@ -828,17 +819,6 @@ dependencies = [ "serde", ] -[[package]] -name = "hercules_driver" -version = "0.1.0" -dependencies = [ - "clap", - "hercules_ir", - "hercules_opt", - "postcard", - "ron", -] - [[package]] name = "hercules_ir" version = "0.1.0" @@ -1052,6 +1032,7 @@ dependencies = [ "hercules_ir", "hercules_opt", "juno_scheduler", + "juno_utils", "lrlex", "lrpar", "num-rational", @@ -1091,14 +1072,29 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_schedule_test" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "rand", + "with_builtin_macros", +] + [[package]] name = "juno_scheduler" version = "0.0.1" dependencies = [ "cfgrammar", + "hercules_cg", "hercules_ir", + "hercules_opt", + "juno_utils", "lrlex", "lrpar", + "tempfile", ] [[package]] @@ -1111,6 +1107,13 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_utils" +version = "0.1.0" +dependencies = [ + "serde", +] + [[package]] name = "kv-log-macro" version = "1.0.7" @@ -1748,18 +1751,6 @@ version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" -[[package]] -name = "ron" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" -dependencies = [ - "base64", - "bitflags 2.7.0", - "serde", - "serde_derive", -] - [[package]] name = "rustc_version" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index c57125f7..6adefdba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,7 @@ members = [ "hercules_opt", "hercules_rt", - "hercules_tools/hercules_driver", - + "juno_utils", "juno_frontend", "juno_scheduler", "juno_build", @@ -27,7 +26,7 @@ members = [ "juno_samples/nested_ccp", "juno_samples/antideps", "juno_samples/implicit_clone", + "juno_samples/cava", "juno_samples/concat", - - "juno_samples/cava", + "juno_samples/schedule_test", ] diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 78c4eca4..1dd326c3 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use crate::*; @@ -11,10 +11,11 @@ pub struct Builder<'a> { // Intern function names. function_ids: HashMap<&'a str, FunctionID>, - // Intern types, constants, and dynamic constants on a per-module basis. + // Intern types, constants, dynamic constants, and labels on a per-module basis. interned_types: HashMap<Type, TypeID>, interned_constants: HashMap<Constant, ConstantID>, interned_dynamic_constants: HashMap<DynamicConstant, DynamicConstantID>, + interned_labels: HashMap<String, LabelID>, // For product, summation, and array constant creation, it's useful to know // the type of each constant. @@ -37,6 +38,7 @@ pub struct NodeBuilder { function_id: FunctionID, node: Node, schedules: Vec<Schedule>, + labels: Vec<LabelID>, } /* @@ -79,6 +81,17 @@ impl<'a> Builder<'a> { } } + pub fn add_label(&mut self, label: &String) -> LabelID { + if let Some(id) = self.interned_labels.get(label) { + *id + } else { + let id = LabelID::new(self.interned_labels.len()); + self.interned_labels.insert(label.clone(), id); + self.module.labels.push(label.clone()); + id + } + } + pub fn create() -> Self { Self::default() } @@ -452,6 +465,10 @@ impl<'a> Builder<'a> { Index::Position(idx) } + pub fn get_labels(&self, func: FunctionID, node: NodeID) -> &HashSet<LabelID> { + &self.module.functions[func.idx()].labels[node.idx()] + } + pub fn create_function( &mut self, name: &str, @@ -473,6 +490,7 @@ impl<'a> Builder<'a> { entry, nodes: vec![Node::Start], schedules: vec![vec![]], + labels: vec![HashSet::new()], device: None, }); Ok((id, NodeID::new(0))) @@ -484,11 +502,15 @@ impl<'a> Builder<'a> { .nodes .push(Node::Start); self.module.functions[function.idx()].schedules.push(vec![]); + self.module.functions[function.idx()] + .labels + .push(HashSet::new()); NodeBuilder { id, function_id: function, node: Node::Start, schedules: vec![], + labels: vec![], } } @@ -499,6 +521,8 @@ impl<'a> Builder<'a> { self.module.functions[builder.function_id.idx()].nodes[builder.id.idx()] = builder.node; self.module.functions[builder.function_id.idx()].schedules[builder.id.idx()] = builder.schedules; + self.module.functions[builder.function_id.idx()].labels[builder.id.idx()] = + builder.labels.into_iter().collect(); Ok(()) } } @@ -617,4 +641,15 @@ impl NodeBuilder { pub fn add_schedule(&mut self, schedule: Schedule) { self.schedules.push(schedule); } + + pub fn add_label(&mut self, label: LabelID) { + self.labels.push(label); + } + + pub fn add_labels<I>(&mut self, labels: I) + where + I: Iterator<Item = LabelID>, + { + self.labels.extend(labels); + } } diff --git a/hercules_ir/src/callgraph.rs b/hercules_ir/src/callgraph.rs index 3a8e6316..834cbbf8 100644 --- a/hercules_ir/src/callgraph.rs +++ b/hercules_ir/src/callgraph.rs @@ -79,10 +79,9 @@ impl CallGraph { /* * Top level function to calculate the call graph of a Hercules module. */ -pub fn callgraph(module: &Module) -> CallGraph { +pub fn callgraph(functions: &Vec<Function>) -> CallGraph { // Step 1: collect the functions called in each function. - let callee_functions: Vec<Vec<FunctionID>> = module - .functions + let callee_functions: Vec<Vec<FunctionID>> = functions .iter() .map(|func| { let mut called: Vec<_> = func diff --git a/hercules_ir/src/collections.rs b/hercules_ir/src/collections.rs index 8bb1b359..9f421221 100644 --- a/hercules_ir/src/collections.rs +++ b/hercules_ir/src/collections.rs @@ -135,7 +135,8 @@ impl Semilattice for CollectionObjectLattice { * Top level function to analyze collection objects in a Hercules module. */ pub fn collection_objects( - module: &Module, + functions: &Vec<Function>, + types: &Vec<Type>, reverse_postorders: &Vec<Vec<NodeID>>, typing: &ModuleTyping, callgraph: &CallGraph, @@ -146,7 +147,7 @@ pub fn collection_objects( let topo = callgraph.topo(); for func_id in topo { - let func = &module.functions[func_id.idx()]; + let func = &functions[func_id.idx()]; let typing = &typing[func_id.idx()]; let reverse_postorder = &reverse_postorders[func_id.idx()]; @@ -156,14 +157,14 @@ pub fn collection_objects( .param_types .iter() .enumerate() - .filter(|(_, ty_id)| !module.types[ty_id.idx()].is_primitive()) + .filter(|(_, ty_id)| !types[ty_id.idx()].is_primitive()) .map(|(idx, _)| CollectionObjectOrigin::Parameter(idx)); let other_origins = func .nodes .iter() .enumerate() .filter_map(|(idx, node)| match node { - Node::Constant { id: _ } if !module.types[typing[idx].idx()].is_primitive() => { + Node::Constant { id: _ } if !types[typing[idx].idx()].is_primitive() => { Some(CollectionObjectOrigin::Constant(NodeID::new(idx))) } Node::Call { @@ -185,7 +186,7 @@ pub fn collection_objects( // this is determined later. Some(CollectionObjectOrigin::Call(NodeID::new(idx))) } - Node::Undef { ty: _ } if !module.types[typing[idx].idx()].is_primitive() => { + Node::Undef { ty: _ } if !types[typing[idx].idx()].is_primitive() => { Some(CollectionObjectOrigin::Undef(NodeID::new(idx))) } _ => None, @@ -255,7 +256,7 @@ pub fn collection_objects( function: callee, dynamic_constants: _, args: _, - } if !module.types[typing[id.idx()].idx()].is_primitive() => { + } if !types[typing[id.idx()].idx()].is_primitive() => { let new_obj = origins .iter() .position(|origin| *origin == CollectionObjectOrigin::Call(id)) @@ -285,7 +286,7 @@ pub fn collection_objects( Node::Read { collect: _, indices: _, - } if !module.types[typing[id.idx()].idx()].is_primitive() => inputs[0].clone(), + } if !types[typing[id.idx()].idx()].is_primitive() => inputs[0].clone(), Node::Write { collect: _, data: _, diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index cef94a2d..d1fdd225 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1,4 +1,5 @@ use std::cmp::{max, min}; +use std::collections::HashSet; use std::fmt::Write; use std::ops::Coroutine; use std::ops::CoroutineState; @@ -23,6 +24,7 @@ pub struct Module { pub types: Vec<Type>, pub constants: Vec<Constant>, pub dynamic_constants: Vec<DynamicConstant>, + pub labels: Vec<String>, } /* @@ -43,7 +45,8 @@ pub struct Function { pub nodes: Vec<Node>, - pub schedules: FunctionSchedule, + pub schedules: FunctionSchedules, + pub labels: FunctionLabels, pub device: Option<Device>, } @@ -341,7 +344,12 @@ pub enum Device { /* * A single node may have multiple schedules. */ -pub type FunctionSchedule = Vec<Vec<Schedule>>; +pub type FunctionSchedules = Vec<Vec<Schedule>>; + +/* + * A single node may have multiple labels. + */ +pub type FunctionLabels = Vec<HashSet<LabelID>>; impl Module { /* @@ -734,6 +742,7 @@ impl Function { // Step 4: update the schedules. self.schedules.fix_gravestones(&node_mapping); + self.labels.fix_gravestones(&node_mapping); node_mapping } @@ -1767,3 +1776,4 @@ define_id_type!(NodeID); define_id_type!(TypeID); define_id_type!(ConstantID); define_id_type!(DynamicConstantID); +define_id_type!(LabelID); diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index 21eb325a..6c533bef 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -1,5 +1,5 @@ use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::str::FromStr; use crate::*; @@ -124,6 +124,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a entry: true, nodes: vec![], schedules: vec![], + labels: vec![], device: None, }; context.function_ids.len() @@ -157,6 +158,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a types, constants, dynamic_constants, + labels: vec![], }; Ok((rest, module)) } @@ -262,6 +264,7 @@ fn parse_function<'a>( entry: true, nodes: fixed_nodes, schedules: vec![vec![]; num_nodes], + labels: vec![HashSet::new(); num_nodes], device: None, }, )) diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index 79cbd403..a80dd422 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -82,19 +82,16 @@ pub type ModuleTyping = Vec<Vec<TypeID>>; * Returns a type for every node in every function. */ pub fn typecheck( - module: &mut Module, + functions: &Vec<Function>, + types: &mut Vec<Type>, + constants: &Vec<Constant>, + dynamic_constants: &mut Vec<DynamicConstant>, reverse_postorders: &Vec<Vec<NodeID>>, ) -> Result<ModuleTyping, String> { // Step 1: assemble a reverse type map. This is needed to get or create the // ID of potentially new types. Break down module into references to // individual elements at this point, so that borrows don't overlap each // other. - let Module { - ref functions, - ref mut types, - ref constants, - ref mut dynamic_constants, - } = module; let mut reverse_type_map: HashMap<Type, TypeID> = types .iter() .enumerate() diff --git a/hercules_ir/src/verify.rs b/hercules_ir/src/verify.rs index 572bb9d1..5ee5f1d2 100644 --- a/hercules_ir/src/verify.rs +++ b/hercules_ir/src/verify.rs @@ -29,7 +29,13 @@ pub fn verify( let reverse_postorders: Vec<_> = def_uses.iter().map(reverse_postorder).collect(); // Typecheck the module. - let typing = typecheck(module, &reverse_postorders)?; + let typing = typecheck( + &module.functions, + &mut module.types, + &module.constants, + &mut module.dynamic_constants, + &reverse_postorders, + )?; // Assemble fork join maps for module. let subgraphs: Vec<_> = zip(module.functions.iter(), def_uses.iter()) diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 1318f032..6271a958 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -26,6 +26,9 @@ pub struct FunctionEditor<'a> { constants: &'a RefCell<Vec<Constant>>, dynamic_constants: &'a RefCell<Vec<DynamicConstant>>, types: &'a RefCell<Vec<Type>>, + // Keep a RefCell to the string table that tracks labels, so that new labels + // can be added as needed + labels: &'a RefCell<Vec<String>>, // Most optimizations need def use info, so provide an iteratively updated // mutable version that's automatically updated based on recorded edits. mut_def_use: Vec<HashSet<NodeID>>, @@ -34,6 +37,9 @@ pub struct FunctionEditor<'a> { // are off limits for deletion (equivalently modification) or being replaced // as a use. mutable_nodes: BitVec<u8, Lsb0>, + // Tracks whether this editor has been used to make any edits to the IR of + // this function + modified: bool, } /* @@ -51,10 +57,13 @@ pub struct FunctionEdit<'a: 'b, 'b> { added_and_updated_nodes: BTreeMap<NodeID, Node>, // Keep track of added and updated schedules. added_and_updated_schedules: BTreeMap<NodeID, Vec<Schedule>>, - // Keep track of added (dynamic) constants and types + // Keep track of added and updated labels. + added_and_updated_labels: BTreeMap<NodeID, HashSet<LabelID>>, + // Keep track of added (dynamic) constants, types, and labels added_constants: Vec<Constant>, added_dynamic_constants: Vec<DynamicConstant>, added_types: Vec<Type>, + added_labels: Vec<String>, // Compute a def-use map entries iteratively. updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>, updated_param_types: Option<Vec<TypeID>>, @@ -70,6 +79,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { constants: &'a RefCell<Vec<Constant>>, dynamic_constants: &'a RefCell<Vec<DynamicConstant>>, types: &'a RefCell<Vec<Type>>, + labels: &'a RefCell<Vec<String>>, def_use: &ImmutableDefUseMap, ) -> Self { let mut_def_use = (0..function.nodes.len()) @@ -89,11 +99,60 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { constants, dynamic_constants, types, + labels, mut_def_use, mutable_nodes, + modified: false, } } + // Constructs an editor but only makes the nodes with at least one of the set of labels as + // mutable + pub fn new_labeled( + function: &'a mut Function, + function_id: FunctionID, + constants: &'a RefCell<Vec<Constant>>, + dynamic_constants: &'a RefCell<Vec<DynamicConstant>>, + types: &'a RefCell<Vec<Type>>, + labels: &'a RefCell<Vec<String>>, + def_use: &ImmutableDefUseMap, + with_labels: &HashSet<LabelID>, + ) -> Self { + let mut_def_use = (0..function.nodes.len()) + .map(|idx| { + def_use + .get_users(NodeID::new(idx)) + .into_iter() + .map(|x| *x) + .collect() + }) + .collect(); + + let mut mutable_nodes = bitvec![u8, Lsb0; 0; function.nodes.len()]; + // Add all nodes which have some label which is in the with_labels set + for (idx, labels) in function.labels.iter().enumerate() { + if !labels.is_disjoint(with_labels) { + mutable_nodes.set(idx, true); + } + } + + FunctionEditor { + function, + function_id, + constants, + dynamic_constants, + types, + labels, + mut_def_use, + mutable_nodes, + modified: false, + } + } + + pub fn modified(&self) -> bool { + self.modified + } + pub fn edit<F>(&'b mut self, edit: F) -> bool where F: FnOnce(FunctionEdit<'a, 'b>) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>>, @@ -105,9 +164,11 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { added_nodeids: HashSet::new(), added_and_updated_nodes: BTreeMap::new(), added_and_updated_schedules: BTreeMap::new(), + added_and_updated_labels: BTreeMap::new(), added_constants: Vec::new().into(), added_dynamic_constants: Vec::new().into(), added_types: Vec::new().into(), + added_labels: Vec::new().into(), updated_def_use: BTreeMap::new(), updated_param_types: None, updated_return_type: None, @@ -120,17 +181,28 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { let FunctionEdit { editor, deleted_nodeids, - added_nodeids: _, + added_nodeids, added_and_updated_nodes, added_and_updated_schedules, + added_and_updated_labels, added_constants, added_dynamic_constants, added_types, + added_labels, updated_def_use, updated_param_types, updated_return_type, sub_edits, } = populated_edit; + + // Step 0: determine whether the edit changed the IR by checking if + // any nodes were deleted, added, or updated in any way + editor.modified |= !deleted_nodeids.is_empty() + || !added_nodeids.is_empty() + || !added_and_updated_nodes.is_empty() + || !added_and_updated_schedules.is_empty() + || !added_and_updated_labels.is_empty(); + // Step 1: update the mutable def use map. for (u, new_users) in updated_def_use { // Go through new def-use entries in order. These are either @@ -160,7 +232,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { } } - // Step 3: add and update schedules. + // Step 3.0: add and update schedules. editor .function .schedules @@ -169,6 +241,15 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor.function.schedules[id.idx()] = schedule; } + // Step 3.1: add and update labels. + editor + .function + .labels + .resize(editor.function.nodes.len(), HashSet::new()); + for (id, label) in added_and_updated_labels { + editor.function.labels[id.idx()] = label; + } + // Step 4: delete nodes. This is done using "gravestones", where a // node other than node ID 0 being a start node is considered a // gravestone. @@ -178,8 +259,8 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor.function.nodes[id.idx()] = Node::Start; } - // Step 5: propagate schedules along sub-edit edges. - for (src, dst) in sub_edits { + // Step 5.0: propagate schedules along sub-edit edges. + for (src, dst) in sub_edits.iter() { let mut dst_schedules = take(&mut editor.function.schedules[dst.idx()]); for src_schedule in editor.function.schedules[src.idx()].iter() { if !dst_schedules.contains(src_schedule) { @@ -189,6 +270,32 @@ impl<'a: 'b, 'b> FunctionEditor<'a> { editor.function.schedules[dst.idx()] = dst_schedules; } + // Step 5.1: update and propagate labels + editor.labels.borrow_mut().extend(added_labels); + + // We propagate labels in two steps, first along sub-edits and then + // all the labels on any deleted node not used in any sub-edit to all + // added nodes not in any sub-edit + let mut sources = deleted_nodeids.clone(); + let mut dests = added_nodeids.clone(); + + for (src, dst) in sub_edits { + let mut dst_labels = take(&mut editor.function.labels[dst.idx()]); + dst_labels.extend(editor.function.labels[src.idx()].iter()); + editor.function.labels[dst.idx()] = dst_labels; + + sources.remove(&src); + dests.remove(&dst); + } + + let mut src_labels = HashSet::new(); + for src in sources { + src_labels.extend(editor.function.labels[src.idx()].clone()); + } + for dst in dests { + editor.function.labels[dst.idx()].extend(src_labels.clone()); + } + // Step 6: update the length of mutable_nodes. All added nodes are // mutable. editor @@ -446,6 +553,57 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } } + pub fn get_label(&self, id: NodeID) -> &HashSet<LabelID> { + // The user may get the labels of a to-be deleted node. + if let Some(label) = self.added_and_updated_labels.get(&id) { + // Refer to added or updated label. + label + } else { + // Refer to the origin label of this code. + &self.editor.function.labels[id.idx()] + } + } + + pub fn add_label(mut self, id: NodeID, label: LabelID) -> Result<Self, Self> { + if self.is_mutable(id) { + if let Some(labels) = self.added_and_updated_labels.get_mut(&id) { + labels.insert(label); + } else { + let mut labels = self.editor.function.labels[id.idx()].clone(); + labels.insert(label); + self.added_and_updated_labels.insert(id, labels); + } + Ok(self) + } else { + Err(self) + } + } + + // Creates or returns the LabelID for a given label name + pub fn new_label(&mut self, name: String) -> LabelID { + let pos = self + .editor + .labels + .borrow() + .iter() + .chain(self.added_labels.iter()) + .position(|l| *l == name); + if let Some(idx) = pos { + LabelID::new(idx) + } else { + let idx = self.editor.labels.borrow().len() + self.added_labels.len(); + self.added_labels.push(name); + LabelID::new(idx) + } + } + + // Creates an entirely fresh label and returns its LabelID + pub fn fresh_label(&mut self) -> LabelID { + let idx = self.editor.labels.borrow().len() + self.added_labels.len(); + self.added_labels.push(format!("#fresh_{}", idx)); + LabelID::new(idx) + } + pub fn get_users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ { assert!(!self.deleted_nodeids.contains(&id)); if let Some(users) = self.updated_def_use.get(&id) { @@ -671,6 +829,7 @@ fn func(x: i32) -> i32 let constants_ref = RefCell::new(src_module.constants); let dynamic_constants_ref = RefCell::new(src_module.dynamic_constants); let types_ref = RefCell::new(src_module.types); + let labels_ref = RefCell::new(src_module.labels); // Edit the function by replacing the add with a multiply. let mut editor = FunctionEditor::new( func, @@ -678,6 +837,7 @@ fn func(x: i32) -> i32 &constants_ref, &dynamic_constants_ref, &types_ref, + &labels_ref, &def_use(func), ); let success = editor.edit(|mut edit| { diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index 54af8582..064e3d73 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -209,6 +209,11 @@ fn inline_func( for schedule in callee_schedule { edit = edit.add_schedule(add_id, schedule.clone())?; } + // Copy the labels from the callee. + let callee_labels = &called_func.labels[idx]; + for label in callee_labels { + edit = edit.add_label(add_id, *label)?; + } } // Stitch the control use of the inlined start node with the diff --git a/hercules_opt/src/lib.rs b/hercules_opt/src/lib.rs index e351deba..2c9d4372 100644 --- a/hercules_opt/src/lib.rs +++ b/hercules_opt/src/lib.rs @@ -14,7 +14,6 @@ pub mod gvn; pub mod inline; pub mod interprocedural_sroa; pub mod outline; -pub mod pass; pub mod phi_elim; pub mod pred; pub mod schedule; @@ -37,7 +36,6 @@ pub use crate::gvn::*; pub use crate::inline::*; pub use crate::interprocedural_sroa::*; pub use crate::outline::*; -pub use crate::pass::*; pub use crate::phi_elim::*; pub use crate::pred::*; pub use crate::schedule::*; diff --git a/hercules_opt/src/outline.rs b/hercules_opt/src/outline.rs index 80f97c7f..e59c815d 100644 --- a/hercules_opt/src/outline.rs +++ b/hercules_opt/src/outline.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeMap, BTreeSet, HashSet}; use std::iter::zip; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -203,6 +203,7 @@ pub fn outline( entry: false, nodes: vec![], schedules: vec![], + labels: vec![], device: None, }; @@ -420,6 +421,13 @@ pub fn outline( outlined.schedules[callee_id.idx()] = edit.get_schedule(*id).clone(); } + // Copy the labels into the new callee. + outlined.labels.resize(outlined.nodes.len(), HashSet::new()); + for id in partition.iter() { + let callee_id = convert_id(*id); + outlined.labels[callee_id.idx()] = edit.get_label(*id).clone(); + } + // Step 3: edit the original function to call the outlined function. let dynamic_constants = (0..edit.get_num_dynamic_constant_params()) .map(|idx| edit.add_dynamic_constant(DynamicConstant::Parameter(idx as usize))) diff --git a/hercules_opt/src/pass.rs b/hercules_opt/src/pass.rs deleted file mode 100644 index e528b35d..00000000 --- a/hercules_opt/src/pass.rs +++ /dev/null @@ -1,1132 +0,0 @@ -use std::cell::RefCell; -use std::collections::{HashMap, HashSet}; -use std::fs::File; -use std::io::Write; -use std::iter::zip; -use std::process::{Command, Stdio}; - -use serde::Deserialize; - -use tempfile::TempDir; - -use hercules_cg::*; -use hercules_ir::*; - -use crate::*; - -/* - * Passes that can be run on a module. - */ -#[derive(Debug, Clone, Deserialize)] -pub enum Pass { - DCE, - CCP, - GVN, - PhiElim, - Forkify, - ForkGuardElim, - CRC, - SLF, - WritePredication, - Predication, - SROA, - Inline, - Outline, - InterproceduralSROA, - DeleteUncalled, - ForkSplit, - Unforkify, - InferSchedules, - GCM, - FloatCollections, - Verify, - // Parameterized over whether analyses that aid visualization are necessary. - // Useful to set to false if displaying a potentially broken module. - Xdot(bool), - // Parameterized over output directory and module name. - Codegen(String, String), - // Parameterized over where to serialize module to. - Serialize(String), -} - -/* - * Manages passes to be run on an IR module. Transparently handles analysis - * requirements for optimizations. - */ -#[derive(Debug, Clone)] -pub struct PassManager { - module: Module, - - // Passes to run. - passes: Vec<Pass>, - - // Cached analysis results. - pub def_uses: Option<Vec<ImmutableDefUseMap>>, - pub reverse_postorders: Option<Vec<Vec<NodeID>>>, - pub typing: Option<ModuleTyping>, - pub control_subgraphs: Option<Vec<Subgraph>>, - pub doms: Option<Vec<DomTree>>, - pub postdoms: Option<Vec<DomTree>>, - pub fork_join_maps: Option<Vec<HashMap<NodeID, NodeID>>>, - pub fork_join_nests: Option<Vec<HashMap<NodeID, Vec<NodeID>>>>, - pub loops: Option<Vec<LoopTree>>, - pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, - pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, - pub bbs: Option<Vec<BasicBlocks>>, - pub collection_objects: Option<CollectionObjects>, - pub callgraph: Option<CallGraph>, -} - -impl PassManager { - pub fn new(module: Module) -> Self { - PassManager { - module, - passes: vec![], - def_uses: None, - reverse_postorders: None, - typing: None, - control_subgraphs: None, - doms: None, - postdoms: None, - fork_join_maps: None, - fork_join_nests: None, - loops: None, - reduce_cycles: None, - data_nodes_in_fork_joins: None, - bbs: None, - collection_objects: None, - callgraph: None, - } - } - - pub fn add_pass(&mut self, pass: Pass) { - self.passes.push(pass); - } - - pub fn make_def_uses(&mut self) { - if self.def_uses.is_none() { - self.def_uses = Some(self.module.functions.iter().map(def_use).collect()); - } - } - - pub fn make_reverse_postorders(&mut self) { - if self.reverse_postorders.is_none() { - self.make_def_uses(); - self.reverse_postorders = Some( - self.def_uses - .as_ref() - .unwrap() - .iter() - .map(reverse_postorder) - .collect(), - ); - } - } - - pub fn make_typing(&mut self) { - if self.typing.is_none() { - self.make_reverse_postorders(); - self.typing = Some( - typecheck(&mut self.module, self.reverse_postorders.as_ref().unwrap()).unwrap(), - ); - } - } - - pub fn make_control_subgraphs(&mut self) { - if self.control_subgraphs.is_none() { - self.make_def_uses(); - self.control_subgraphs = Some( - zip(&self.module.functions, self.def_uses.as_ref().unwrap()) - .map(|(function, def_use)| control_subgraph(function, def_use)) - .collect(), - ); - } - } - - pub fn make_doms(&mut self) { - if self.doms.is_none() { - self.make_control_subgraphs(); - self.doms = Some( - self.control_subgraphs - .as_ref() - .unwrap() - .iter() - .map(|subgraph| dominator(subgraph, NodeID::new(0))) - .collect(), - ); - } - } - - pub fn make_postdoms(&mut self) { - if self.postdoms.is_none() { - self.make_control_subgraphs(); - self.postdoms = Some( - zip( - self.control_subgraphs.as_ref().unwrap().iter(), - self.module.functions.iter(), - ) - .map(|(subgraph, function)| dominator(subgraph, NodeID::new(function.nodes.len()))) - .collect(), - ); - } - } - - pub fn make_fork_join_maps(&mut self) { - if self.fork_join_maps.is_none() { - self.make_control_subgraphs(); - self.fork_join_maps = Some( - zip( - self.module.functions.iter(), - self.control_subgraphs.as_ref().unwrap().iter(), - ) - .map(|(function, subgraph)| fork_join_map(function, subgraph)) - .collect(), - ); - } - } - - pub fn make_fork_join_nests(&mut self) { - if self.fork_join_nests.is_none() { - self.make_doms(); - self.make_fork_join_maps(); - self.fork_join_nests = Some( - zip( - self.module.functions.iter(), - zip( - self.doms.as_ref().unwrap().iter(), - self.fork_join_maps.as_ref().unwrap().iter(), - ), - ) - .map(|(function, (dom, fork_join_map))| { - compute_fork_join_nesting(function, dom, fork_join_map) - }) - .collect(), - ); - } - } - - pub fn make_loops(&mut self) { - if self.loops.is_none() { - self.make_control_subgraphs(); - self.make_doms(); - self.make_fork_join_maps(); - let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter(); - let doms = self.doms.as_ref().unwrap().iter(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); - self.loops = Some( - zip(control_subgraphs, zip(doms, fork_join_maps)) - .map(|(control_subgraph, (dom, fork_join_map))| { - loops(control_subgraph, NodeID::new(0), dom, fork_join_map) - }) - .collect(), - ); - } - } - - pub fn make_reduce_cycles(&mut self) { - if self.reduce_cycles.is_none() { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap().iter(); - self.reduce_cycles = Some( - zip(self.module.functions.iter(), def_uses) - .map(|(function, def_use)| reduce_cycles(function, def_use)) - .collect(), - ); - } - } - - pub fn make_data_nodes_in_fork_joins(&mut self) { - if self.data_nodes_in_fork_joins.is_none() { - self.make_def_uses(); - self.make_fork_join_maps(); - self.data_nodes_in_fork_joins = Some( - zip( - self.module.functions.iter(), - zip( - self.def_uses.as_ref().unwrap().iter(), - self.fork_join_maps.as_ref().unwrap().iter(), - ), - ) - .map(|(function, (def_use, fork_join_map))| { - data_nodes_in_fork_joins(function, def_use, fork_join_map) - }) - .collect(), - ); - } - } - - pub fn make_collection_objects(&mut self) { - if self.collection_objects.is_none() { - self.make_reverse_postorders(); - self.make_typing(); - self.make_callgraph(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - self.collection_objects = Some(collection_objects( - &self.module, - reverse_postorders, - typing, - callgraph, - )); - } - } - - pub fn make_callgraph(&mut self) { - if self.callgraph.is_none() { - self.callgraph = Some(callgraph(&self.module)); - } - } - - pub fn run_passes(&mut self) { - for pass in self.passes.clone().iter() { - match pass { - Pass::DCE => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - dce(&mut editor); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::InterproceduralSROA => { - self.make_def_uses(); - self.make_typing(); - - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - - let def_uses = self.def_uses.as_ref().unwrap(); - - let mut editors: Vec<_> = self - .module - .functions - .iter_mut() - .enumerate() - .map(|(i, f)| { - FunctionEditor::new( - f, - FunctionID::new(i), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[i], - ) - }) - .collect(); - - interprocedural_sroa(&mut editors); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - - self.clear_analyses(); - } - Pass::CCP => { - self.make_def_uses(); - self.make_reverse_postorders(); - let def_uses = self.def_uses.as_ref().unwrap(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - ccp(&mut editor, &reverse_postorders[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::GVN => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - gvn(&mut editor, false); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Forkify => { - self.make_def_uses(); - self.make_loops(); - let def_uses = self.def_uses.as_ref().unwrap(); - let loops = self.loops.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - forkify( - &mut self.module.functions[idx], - &self.module.constants, - &mut self.module.dynamic_constants, - &def_uses[idx], - &loops[idx], - ); - let num_nodes = self.module.functions[idx].nodes.len(); - self.module.functions[idx] - .schedules - .resize(num_nodes, vec![]); - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::PhiElim => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - phi_elim(&mut editor); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::ForkGuardElim => { - self.make_def_uses(); - self.make_fork_join_maps(); - let def_uses = self.def_uses.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - fork_guard_elim( - &mut self.module.functions[idx], - &self.module.constants, - &fork_join_maps[idx], - &def_uses[idx], - ); - let num_nodes = self.module.functions[idx].nodes.len(); - self.module.functions[idx] - .schedules - .resize(num_nodes, vec![]); - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::CRC => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - crc(&mut editor); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::SLF => { - self.make_def_uses(); - self.make_reverse_postorders(); - self.make_typing(); - let def_uses = self.def_uses.as_ref().unwrap(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - slf(&mut editor, &reverse_postorders[idx], &typing[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::WritePredication => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - write_predication(&mut editor); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Predication => { - self.make_def_uses(); - self.make_typing(); - let def_uses = self.def_uses.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - predication(&mut editor, &typing[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::SROA => { - self.make_def_uses(); - self.make_reverse_postorders(); - self.make_typing(); - let def_uses = self.def_uses.as_ref().unwrap(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - sroa(&mut editor, &reverse_postorders[idx], &typing[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Inline => { - self.make_def_uses(); - self.make_callgraph(); - let def_uses = self.def_uses.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - inline(&mut editors, callgraph); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Outline => { - self.make_def_uses(); - let def_uses = self.def_uses.as_ref().unwrap(); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let old_num_funcs = self.module.functions.len(); - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - for editor in editors.iter_mut() { - collapse_returns(editor); - ensure_between_control_flow(editor); - } - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - self.clear_analyses(); - - self.make_def_uses(); - self.make_typing(); - self.make_control_subgraphs(); - self.make_doms(); - let def_uses = self.def_uses.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - let control_subgraphs = self.control_subgraphs.as_ref().unwrap(); - let doms = self.doms.as_ref().unwrap(); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - let mut new_funcs = vec![]; - for (idx, editor) in editors.iter_mut().enumerate() { - let new_func_id = FunctionID::new(old_num_funcs + new_funcs.len()); - let new_func = dumb_outline( - editor, - &typing[idx], - &control_subgraphs[idx], - &doms[idx], - new_func_id, - ); - if let Some(new_func) = new_func { - new_funcs.push(new_func); - } - } - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - self.module.functions.extend(new_funcs); - self.clear_analyses(); - } - Pass::DeleteUncalled => { - self.make_def_uses(); - self.make_callgraph(); - let def_uses = self.def_uses.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - - // By default in an editor all nodes are mutable, which is desired in this case - // since we are only modifying the IDs of functions that we call. - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - - let new_idx = delete_uncalled(&mut editors, callgraph); - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - - self.fix_deleted_functions(&new_idx); - self.clear_analyses(); - - assert!(self.module.functions.len() > 0, "PANIC: There are no entry functions in the Hercules module being compiled, and they all got deleted by DeleteUncalled. Please mark at least one function as an entry!"); - } - Pass::ForkSplit => { - self.make_def_uses(); - self.make_fork_join_maps(); - self.make_reduce_cycles(); - let def_uses = self.def_uses.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - let reduce_cycles = self.reduce_cycles.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - fork_split(&mut editor, &fork_join_maps[idx], &reduce_cycles[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Unforkify => { - self.make_def_uses(); - self.make_fork_join_maps(); - let def_uses = self.def_uses.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - unforkify(&mut editor, &fork_join_maps[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::GCM => loop { - self.make_def_uses(); - self.make_reverse_postorders(); - self.make_typing(); - self.make_control_subgraphs(); - self.make_doms(); - self.make_fork_join_maps(); - self.make_loops(); - self.make_collection_objects(); - let def_uses = self.def_uses.as_ref().unwrap(); - let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - let doms = self.doms.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - let loops = self.loops.as_ref().unwrap(); - let control_subgraphs = self.control_subgraphs.as_ref().unwrap(); - let collection_objects = self.collection_objects.as_ref().unwrap(); - let mut bbs = vec![]; - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - if let Some(bb) = gcm( - &mut editor, - &def_uses[idx], - &reverse_postorders[idx], - &typing[idx], - &control_subgraphs[idx], - &doms[idx], - &fork_join_maps[idx], - &loops[idx], - collection_objects, - ) { - bbs.push(bb); - } - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - if bbs.len() == self.module.functions.len() { - self.bbs = Some(bbs); - break; - } - }, - Pass::FloatCollections => { - self.make_def_uses(); - self.make_typing(); - self.make_callgraph(); - let def_uses = self.def_uses.as_ref().unwrap(); - let typing = self.typing.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - let devices = device_placement(&self.module.functions, &callgraph); - let constants_ref = RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editors: Vec<_> = zip( - self.module.functions.iter_mut().enumerate(), - def_uses.iter(), - ) - .map(|((idx, func), def_use)| { - FunctionEditor::new( - func, - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - def_use, - ) - }) - .collect(); - float_collections(&mut editors, typing, callgraph, &devices); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - for func in self.module.functions.iter_mut() { - func.delete_gravestones(); - } - self.clear_analyses(); - } - Pass::InferSchedules => { - self.make_def_uses(); - self.make_fork_join_maps(); - self.make_reduce_cycles(); - let def_uses = self.def_uses.as_ref().unwrap(); - let fork_join_maps = self.fork_join_maps.as_ref().unwrap(); - let reduce_cycles = self.reduce_cycles.as_ref().unwrap(); - for idx in 0..self.module.functions.len() { - let constants_ref = - RefCell::new(std::mem::take(&mut self.module.constants)); - let dynamic_constants_ref = - RefCell::new(std::mem::take(&mut self.module.dynamic_constants)); - let types_ref = RefCell::new(std::mem::take(&mut self.module.types)); - let mut editor = FunctionEditor::new( - &mut self.module.functions[idx], - FunctionID::new(idx), - &constants_ref, - &dynamic_constants_ref, - &types_ref, - &def_uses[idx], - ); - infer_parallel_reduce( - &mut editor, - &fork_join_maps[idx], - &reduce_cycles[idx], - ); - infer_parallel_fork(&mut editor, &fork_join_maps[idx]); - infer_vectorizable(&mut editor, &fork_join_maps[idx]); - infer_tight_associative(&mut editor, &reduce_cycles[idx]); - - self.module.constants = constants_ref.take(); - self.module.dynamic_constants = dynamic_constants_ref.take(); - self.module.types = types_ref.take(); - - self.module.functions[idx].delete_gravestones(); - } - self.clear_analyses(); - } - Pass::Verify => { - let ( - def_uses, - reverse_postorders, - typing, - subgraphs, - doms, - postdoms, - fork_join_maps, - ) = verify(&mut self.module) - .expect("PANIC: Failed to verify Hercules IR module."); - - // Verification produces a bunch of analysis results that - // may be useful for later passes. - self.def_uses = Some(def_uses); - self.reverse_postorders = Some(reverse_postorders); - self.typing = Some(typing); - self.control_subgraphs = Some(subgraphs); - self.doms = Some(doms); - self.postdoms = Some(postdoms); - self.fork_join_maps = Some(fork_join_maps); - } - Pass::Xdot(force_analyses) => { - self.make_reverse_postorders(); - if *force_analyses { - self.make_doms(); - self.make_fork_join_maps(); - } - xdot_module( - &self.module, - self.reverse_postorders.as_ref().unwrap(), - self.doms.as_ref(), - self.fork_join_maps.as_ref(), - ); - } - Pass::Codegen(output_dir, module_name) => { - self.make_typing(); - self.make_control_subgraphs(); - self.make_collection_objects(); - self.make_callgraph(); - let typing = self.typing.as_ref().unwrap(); - let control_subgraphs = self.control_subgraphs.as_ref().unwrap(); - let bbs = self.bbs.as_ref().unwrap(); - let collection_objects = self.collection_objects.as_ref().unwrap(); - let callgraph = self.callgraph.as_ref().unwrap(); - - let devices = device_placement(&self.module.functions, &callgraph); - - let mut rust_rt = String::new(); - let mut llvm_ir = String::new(); - for idx in 0..self.module.functions.len() { - match devices[idx] { - Device::LLVM => cpu_codegen( - &self.module.functions[idx], - &self.module.types, - &self.module.constants, - &self.module.dynamic_constants, - &typing[idx], - &control_subgraphs[idx], - &bbs[idx], - &mut llvm_ir, - ) - .unwrap(), - Device::AsyncRust => rt_codegen( - FunctionID::new(idx), - &self.module, - &typing[idx], - &control_subgraphs[idx], - &bbs[idx], - &collection_objects, - &callgraph, - &devices, - &mut rust_rt, - ) - .unwrap(), - _ => todo!(), - } - } - println!("{}", llvm_ir); - println!("{}", rust_rt); - - // Write the LLVM IR into a temporary file. - let tmp_dir = TempDir::new().unwrap(); - let mut tmp_path = tmp_dir.path().to_path_buf(); - tmp_path.push(format!("{}.ll", module_name)); - println!("{}", tmp_path.display()); - let mut file = File::create(&tmp_path) - .expect("PANIC: Unable to open output LLVM IR file."); - file.write_all(llvm_ir.as_bytes()) - .expect("PANIC: Unable to write output LLVM IR file contents."); - - // Compile LLVM IR into an ELF object file. - let output_archive = format!("{}/lib{}.a", output_dir, module_name); - println!("{}", output_archive); - let mut clang_process = Command::new("clang") - .arg(&tmp_path) - .arg("--emit-static-lib") - .arg("-O3") - .arg("-march=native") - .arg("-o") - .arg(&output_archive) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .spawn() - .expect("Error running clang. Is it installed?"); - assert!(clang_process.wait().unwrap().success()); - - // Write the Rust runtime into a file. - let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); - println!("{}", output_rt); - let mut file = File::create(&output_rt) - .expect("PANIC: Unable to open output Rust runtime file."); - file.write_all(rust_rt.as_bytes()) - .expect("PANIC: Unable to write output Rust runtime file contents."); - } - Pass::Serialize(output_file) => { - let module_contents: Vec<u8> = postcard::to_allocvec(&self.module).unwrap(); - let mut file = File::create(&output_file) - .expect("PANIC: Unable to open output module file."); - file.write_all(&module_contents) - .expect("PANIC: Unable to write output module file contents."); - } - } - eprintln!("Ran pass: {:?}", pass); - } - } - - fn clear_analyses(&mut self) { - self.def_uses = None; - self.reverse_postorders = None; - self.typing = None; - self.control_subgraphs = None; - self.doms = None; - self.postdoms = None; - self.fork_join_maps = None; - self.fork_join_nests = None; - self.loops = None; - self.reduce_cycles = None; - self.data_nodes_in_fork_joins = None; - self.bbs = None; - self.collection_objects = None; - self.callgraph = None; - } - - pub fn get_module(self) -> Module { - self.module - } - - fn fix_deleted_functions(&mut self, id_mapping: &[Option<usize>]) { - let mut idx = 0; - - // Rust does not like enumerate here, so use - // idx outside as a hack to make it happy. - self.module.functions.retain(|_| { - idx += 1; - id_mapping[idx - 1].is_some() - }); - } -} diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 6239a644..aa0d53fe 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -243,7 +243,7 @@ pub(crate) fn substitute_dynamic_constants_in_node( /* * Top level function to make a function have only a single return. */ -pub(crate) fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { +pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> { let returns: Vec<NodeID> = (0..editor.func().nodes.len()) .filter(|idx| editor.func().nodes[*idx].is_return()) .map(NodeID::new) @@ -293,7 +293,7 @@ pub(crate) fn contains_between_control_flow(func: &Function) -> bool { * Top level function to ensure a Hercules function contains at least one * control node that isn't the start or return nodes. */ -pub(crate) fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> { +pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> { if !contains_between_control_flow(editor.func()) { let ret = editor .node_ids() diff --git a/hercules_tools/hercules_driver/Cargo.toml b/hercules_tools/hercules_driver/Cargo.toml deleted file mode 100644 index ad9397b1..00000000 --- a/hercules_tools/hercules_driver/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "hercules_driver" -version = "0.1.0" -authors = ["Russel Arbore <rarbore2@illinois.edu>"] -edition = "2021" - -[dependencies] -clap = { version = "*", features = ["derive"] } -ron = "*" -postcard = { version = "*", features = ["alloc"] } -hercules_ir = { path = "../../hercules_ir" } -hercules_opt = { path = "../../hercules_opt" } diff --git a/hercules_tools/hercules_driver/src/main.rs b/hercules_tools/hercules_driver/src/main.rs deleted file mode 100644 index a2550022..00000000 --- a/hercules_tools/hercules_driver/src/main.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::fs::File; -use std::io::prelude::*; -use std::path::Path; - -use clap::Parser; - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - file: String, - passes: String, -} - -fn parse_file_from_hir(path: &Path) -> hercules_ir::ir::Module { - let mut file = File::open(path).expect("PANIC: Unable to open input file."); - let mut contents = String::new(); - file.read_to_string(&mut contents) - .expect("PANIC: Unable to read input file contents."); - hercules_ir::parse::parse(&contents).expect("PANIC: Failed to parse Hercules IR file.") -} - -fn parse_file_from_hbin(path: &Path) -> hercules_ir::ir::Module { - let mut file = File::open(path).expect("PANIC: Unable to open input file."); - let mut buffer = vec![]; - file.read_to_end(&mut buffer).unwrap(); - postcard::from_bytes(&buffer).unwrap() -} - -fn main() { - let args = Args::parse(); - assert!( - args.file.ends_with(".hir") || args.file.ends_with(".hbin"), - "PANIC: Running hercules_driver on a file without a .hir or .hbin extension." - ); - let path = Path::new(&args.file); - let module = if args.file.ends_with(".hir") { - parse_file_from_hir(path) - } else { - parse_file_from_hbin(path) - }; - - let mut pm = hercules_opt::pass::PassManager::new(module); - let passes: Vec<hercules_opt::pass::Pass> = args - .passes - .split(char::is_whitespace) - .map(|pass_str| { - assert_ne!( - pass_str, "", - "PANIC: Can't interpret empty pass name. Try giving a list of pass names." - ); - ron::from_str(pass_str).expect("PANIC: Couldn't parse list of passes.") - }) - .collect(); - for pass in passes { - pm.add_pass(pass); - } - pm.run_passes(); -} diff --git a/juno_build/src/lib.rs b/juno_build/src/lib.rs index 0c676e4c..b30a5d25 100644 --- a/juno_build/src/lib.rs +++ b/juno_build/src/lib.rs @@ -1,12 +1,9 @@ use juno_compiler::*; use std::env::{current_dir, var}; -use std::fmt::Write; use std::fs::{create_dir_all, read_to_string}; use std::path::{Path, PathBuf}; -use with_builtin_macros::with_builtin; - // JunoCompiler is used to compile juno files into a library and manifest file appropriately to // import the definitions into a rust project via the juno! macro defined below // You can also specify a Hercules IR file instead of a Juno file and this will compile that IR @@ -15,9 +12,7 @@ pub struct JunoCompiler { ir_src_path: Option<PathBuf>, src_path: Option<PathBuf>, out_path: Option<PathBuf>, - verify: JunoVerify, - x_dot: bool, - schedule: JunoSchedule, + schedule: Option<String>, } impl JunoCompiler { @@ -26,9 +21,7 @@ impl JunoCompiler { ir_src_path: None, src_path: None, out_path: None, - verify: JunoVerify::None, - x_dot: false, - schedule: JunoSchedule::None, + schedule: None, } } @@ -119,35 +112,6 @@ impl JunoCompiler { ); } - pub fn verify(mut self, enabled: bool) -> Self { - if enabled && !self.verify.verify() { - self.verify = JunoVerify::JunoOpts; - } else if !enabled && self.verify.verify() { - self.verify = JunoVerify::None; - } - self - } - - pub fn verify_all(mut self, enabled: bool) -> Self { - if enabled { - self.verify = JunoVerify::AllPasses; - } else if !enabled && self.verify.verify_all() { - self.verify = JunoVerify::JunoOpts; - } - self - } - - pub fn x_dot(mut self, enabled: bool) -> Self { - self.x_dot = enabled; - self - } - - // Sets the schedule to be the default schedule - pub fn default_schedule(mut self) -> Self { - self.schedule = JunoSchedule::DefaultSchedule; - self - } - // Set the schedule as a schedule file in the src directory pub fn schedule_in_src<P>(mut self, file: P) -> Result<Self, String> where @@ -166,7 +130,7 @@ impl JunoCompiler { }; path.push("src"); path.push(file.as_ref()); - self.schedule = JunoSchedule::Schedule(path.to_str().unwrap().to_string()); + self.schedule = Some(path.to_str().unwrap().to_string()); // Tell cargo to rerun if the schedule changes println!( @@ -180,11 +144,9 @@ impl JunoCompiler { // Builds the juno file into a libary and a manifest file. pub fn build(self) -> Result<(), String> { let JunoCompiler { - ir_src_path: ir_src_path, - src_path: src_path, + ir_src_path, + src_path, out_path: Some(out_path), - verify, - x_dot, schedule, } = self else { @@ -196,7 +158,7 @@ impl JunoCompiler { if let Some(src_path) = src_path { let src_file = src_path.to_str().unwrap().to_string(); - match compile(src_file, verify, x_dot, schedule, out_dir) { + match compile(src_file, schedule, out_dir) { Ok(()) => Ok(()), Err(errs) => Err(format!("{}", errs)), } @@ -216,7 +178,7 @@ impl JunoCompiler { return Err("Unable to parse Hercules IR file.".to_string()); }; - match compile_ir(ir_mod, None, verify, x_dot, schedule, out_dir, module_name) { + match compile_ir(ir_mod, schedule, out_dir, module_name) { Ok(()) => Ok(()), Err(errs) => Err(format!("{}", errs)), } diff --git a/juno_frontend/Cargo.toml b/juno_frontend/Cargo.toml index 39e18baa..ad35b84c 100644 --- a/juno_frontend/Cargo.toml +++ b/juno_frontend/Cargo.toml @@ -29,3 +29,4 @@ phf = { version = "0.11", features = ["macros"] } hercules_ir = { path = "../hercules_ir" } hercules_opt = { path = "../hercules_opt" } juno_scheduler = { path = "../juno_scheduler" } +juno_utils = { path = "../juno_utils" } diff --git a/juno_frontend/src/codegen.rs b/juno_frontend/src/codegen.rs index 31c85ae9..2ff9fa9f 100644 --- a/juno_frontend/src/codegen.rs +++ b/juno_frontend/src/codegen.rs @@ -1,8 +1,7 @@ -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{HashMap, VecDeque}; use hercules_ir::ir; use hercules_ir::ir::*; -use juno_scheduler::{FunctionMap, LabeledStructure}; use crate::labeled_builder::LabeledBuilder; use crate::semant; @@ -10,32 +9,12 @@ use crate::semant::{BinaryOp, Expr, Function, Literal, Prg, Stmt, UnaryOp}; use crate::ssa::SSA; use crate::types::{Either, Primitive, TypeSolver, TypeSolverInst}; +use juno_scheduler::labels::*; + // Loop info is a stack of the loop levels, recording the latch and exit block of each type LoopInfo = Vec<(NodeID, NodeID)>; -fn merge_function_maps( - mut functions: HashMap<(usize, Vec<TypeID>), FunctionID>, - funcs: &Vec<Function>, - mut tree: HashMap<FunctionID, Vec<(LabeledStructure, HashSet<usize>)>>, - mut labels: HashMap<FunctionID, HashMap<NodeID, usize>>, -) -> FunctionMap { - let mut res = HashMap::new(); - for ((func_num, type_vars), func_id) in functions.drain() { - let func_labels = tree.remove(&func_id).unwrap(); - let node_labels = labels.remove(&func_id).unwrap(); - let func_info = res.entry(func_num).or_insert(( - funcs[func_num].label_map.clone(), - funcs[func_num].name.clone(), - vec![], - )); - func_info - .2 - .push((type_vars, func_id, func_labels, node_labels)); - } - res -} - -pub fn codegen_program(prg: Prg) -> (Module, FunctionMap) { +pub fn codegen_program(prg: Prg) -> (Module, JunoInfo) { CodeGenerator::build(prg) } @@ -51,10 +30,20 @@ struct CodeGenerator<'a> { // type-solving instantiation (account for the type parameters), the function id, and the entry // block id worklist: VecDeque<(usize, TypeSolverInst<'a>, FunctionID, NodeID)>, + + // The JunoInfo needed for scheduling, which tracks the Juno function names and their + // associated FunctionIDs. + juno_info: JunoInfo, } impl CodeGenerator<'_> { - fn build((types, funcs): Prg) -> (Module, FunctionMap) { + fn build( + Prg { + types, + funcs, + labels, + }: Prg, + ) -> (Module, JunoInfo) { // Identify the functions (by index) which have no type arguments, these are the ones we // ask for code to be generated for let func_idx = @@ -63,13 +52,16 @@ impl CodeGenerator<'_> { .enumerate() .filter_map(|(i, f)| if f.num_type_args == 0 { Some(i) } else { None }); + let juno_info = JunoInfo::new(funcs.iter().map(|f| f.name.clone())); + let mut codegen = CodeGenerator { - builder: LabeledBuilder::create(), + builder: LabeledBuilder::create(labels), types: &types, funcs: &funcs, uid: 0, functions: HashMap::new(), worklist: VecDeque::new(), + juno_info, }; // Add the identifed functions to the list to code-gen @@ -80,7 +72,7 @@ impl CodeGenerator<'_> { codegen.finish() } - fn finish(mut self) -> (Module, FunctionMap) { + fn finish(mut self) -> (Module, JunoInfo) { while !self.worklist.is_empty() { let (idx, mut type_inst, func, entry) = self.worklist.pop_front().unwrap(); self.builder.set_function(func); @@ -94,12 +86,10 @@ impl CodeGenerator<'_> { uid: _, functions, worklist: _, + juno_info, } = self; - let (module, label_tree, label_map) = builder.finish(); - ( - module, - merge_function_maps(functions, funcs, label_tree, label_map), - ) + + (builder.finish(), juno_info) } fn get_function(&mut self, func_idx: usize, ty_args: Vec<TypeID>) -> FunctionID { @@ -138,11 +128,11 @@ impl CodeGenerator<'_> { param_types, return_type, func.num_dyn_consts as u32, - func.num_labels, func.entry, ) .unwrap(); + self.juno_info.func_info.func_ids[func_idx].push(func_id); self.functions.insert((func_idx, ty_args), func_id); self.worklist .push_back((func_idx, solver_inst, func_id, entry)); @@ -164,7 +154,7 @@ impl CodeGenerator<'_> { } // Generate code for the body - let (_, None) = self.codegen_stmt(&func.body, types, &mut ssa, entry, &mut vec![]) else { + let None = self.codegen_stmt(&func.body, types, &mut ssa, entry, &mut vec![]) else { panic!("Generated code for a function missing a return") }; } @@ -176,42 +166,39 @@ impl CodeGenerator<'_> { ssa: &mut SSA, cur_block: NodeID, loops: &mut LoopInfo, - ) -> (LabeledStructure, Option<NodeID>) { + ) -> Option<NodeID> { match stmt { Stmt::AssignStmt { var, val } => { let (val, block) = self.codegen_expr(val, types, ssa, cur_block); ssa.write_variable(*var, block, val); - (LabeledStructure::Expression(val), Some(block)) + Some(block) } Stmt::IfStmt { cond, thn, els } => { let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, cur_block); let (mut if_node, block_then, block_else) = ssa.create_cond(&mut self.builder, block_cond); - let (_, then_end) = self.codegen_stmt(thn, types, ssa, block_then, loops); + let then_end = self.codegen_stmt(thn, types, ssa, block_then, loops); let else_end = match els { None => Some(block_else), - Some(els_stmt) => self.codegen_stmt(els_stmt, types, ssa, block_else, loops).1, + Some(els_stmt) => self.codegen_stmt(els_stmt, types, ssa, block_else, loops), }; let if_id = if_node.id(); if_node.build_if(block_cond, val_cond); self.builder.add_node(if_node); - ( - LabeledStructure::Branch(if_id), - match (then_end, else_end) { - (None, els) => els, - (thn, None) => thn, - (Some(then_term), Some(else_term)) => { - let block_join = ssa.create_block(&mut self.builder); - ssa.add_pred(block_join, then_term); - ssa.add_pred(block_join, else_term); - ssa.seal_block(block_join, &mut self.builder); - Some(block_join) - } - }, - ) + match (then_end, else_end) { + (None, els) => els, + (thn, None) => thn, + (Some(then_term), Some(else_term)) => { + let block_join = ssa.create_block(&mut self.builder); + ssa.add_pred(block_join, then_term); + ssa.add_pred(block_join, else_term); + ssa.seal_block(block_join, &mut self.builder); + Some(block_join) + } + } } Stmt::LoopStmt { cond, update, body } => { // We generate guarded loops, so the first step is to create @@ -236,7 +223,6 @@ impl CodeGenerator<'_> { None => block_latch, Some(stmt) => self .codegen_stmt(stmt, types, ssa, block_latch, loops) - .1 .expect("Loop update should return control"), }; let (val_cond, block_cond) = self.codegen_expr(cond, types, ssa, block_updated); @@ -258,7 +244,7 @@ impl CodeGenerator<'_> { // Generate code for the body loops.push((block_latch, block_exit)); - let (_, body_res) = self.codegen_stmt(body, types, ssa, body_block, loops); + let body_res = self.codegen_stmt(body, types, ssa, body_block, loops); loops.pop(); // If the body of the loop can reach some block, we add that block as a predecessor @@ -276,51 +262,46 @@ impl CodeGenerator<'_> { // It is always assumed a loop may be skipped and so control can reach after the // loop - (LabeledStructure::Loop(body_block), Some(block_exit)) + Some(block_exit) } Stmt::ReturnStmt { expr } => { let (val_ret, block_ret) = self.codegen_expr(expr, types, ssa, cur_block); let mut return_node = self.builder.allocate_node(); return_node.build_return(block_ret, val_ret); self.builder.add_node(return_node); - (LabeledStructure::Expression(val_ret), None) + None } Stmt::BreakStmt {} => { let last_loop = loops.len() - 1; let (_latch, exit) = loops[last_loop]; ssa.add_pred(exit, cur_block); // The block that contains this break now leads to // the exit - (LabeledStructure::Nothing(), None) + None } Stmt::ContinueStmt {} => { let last_loop = loops.len() - 1; let (latch, _exit) = loops[last_loop]; ssa.add_pred(latch, cur_block); // The block that contains this continue now leads // to the latch - (LabeledStructure::Nothing(), None) + None } - Stmt::BlockStmt { body, label_last } => { - let mut label = None; + Stmt::BlockStmt { body } => { let mut block = Some(cur_block); for stmt in body.iter() { - let (new_label, new_block) = - self.codegen_stmt(stmt, types, ssa, block.unwrap(), loops); + let new_block = self.codegen_stmt(stmt, types, ssa, block.unwrap(), loops); block = new_block; - if label.is_none() || *label_last { - label = Some(new_label); - } } - (label.unwrap_or(LabeledStructure::Nothing()), block) + block } Stmt::ExprStmt { expr } => { - let (val, block) = self.codegen_expr(expr, types, ssa, cur_block); - (LabeledStructure::Expression(val), Some(block)) + let (_val, block) = self.codegen_expr(expr, types, ssa, cur_block); + Some(block) } Stmt::LabeledStmt { label, stmt } => { self.builder.push_label(*label); - let (labeled, res) = self.codegen_stmt(&*stmt, types, ssa, cur_block, loops); - self.builder.pop_label(labeled); - (labeled, res) + let res = self.codegen_stmt(&*stmt, types, ssa, cur_block, loops); + self.builder.pop_label(); + res } } } diff --git a/juno_frontend/src/labeled_builder.rs b/juno_frontend/src/labeled_builder.rs index 3cbf36da..15bed6c2 100644 --- a/juno_frontend/src/labeled_builder.rs +++ b/juno_frontend/src/labeled_builder.rs @@ -1,7 +1,7 @@ use hercules_ir::build::*; use hercules_ir::ir::*; -use juno_scheduler::LabeledStructure; -use std::collections::{HashMap, HashSet}; + +use juno_utils::stringtab::StringTable; // A label-tracking code generator which tracks the current function that we're // generating code for and the what labels apply to the nodes being created @@ -12,39 +12,21 @@ use std::collections::{HashMap, HashSet}; pub struct LabeledBuilder<'a> { pub builder: Builder<'a>, function: Option<FunctionID>, - label: usize, - label_stack: Vec<usize>, - label_tree: HashMap<FunctionID, Vec<(LabeledStructure, HashSet<usize>)>>, - label_map: HashMap<FunctionID, HashMap<NodeID, usize>>, + label_stack: Vec<LabelID>, + label_tab: StringTable, } impl<'a> LabeledBuilder<'a> { - pub fn create() -> LabeledBuilder<'a> { + pub fn create(labels: StringTable) -> LabeledBuilder<'a> { LabeledBuilder { builder: Builder::create(), function: None, - label: 0, // 0 is always the root label label_stack: vec![], - label_tree: HashMap::new(), - label_map: HashMap::new(), + label_tab: labels, } } - pub fn finish( - self, - ) -> ( - Module, - HashMap<FunctionID, Vec<(LabeledStructure, HashSet<usize>)>>, - HashMap<FunctionID, HashMap<NodeID, usize>>, - ) { - let LabeledBuilder { - builder, - function: _, - label: _, - label_stack: _, - label_tree, - label_map, - } = self; - (builder.finish(), label_tree, label_map) + pub fn finish(self) -> Module { + self.builder.finish() } pub fn create_function( @@ -53,7 +35,6 @@ impl<'a> LabeledBuilder<'a> { param_types: Vec<TypeID>, return_type: TypeID, num_dynamic_constants: u32, - num_labels: usize, entry: bool, ) -> Result<(FunctionID, NodeID), String> { let (func, entry) = self.builder.create_function( @@ -64,78 +45,53 @@ impl<'a> LabeledBuilder<'a> { entry, )?; - self.label_tree.insert( - func, - vec![(LabeledStructure::Nothing(), HashSet::new()); num_labels], - ); - self.label_map.insert(func, HashMap::new()); - Ok((func, entry)) } pub fn set_function(&mut self, func: FunctionID) { self.function = Some(func); - self.label = 0; + assert!(self.label_stack.is_empty()); } pub fn push_label(&mut self, label: usize) { - let Some(cur_func) = self.function else { - panic!("Setting label without function") + let Some(label_str) = self.label_tab.lookup_id(label) else { + panic!("Label missing from string table") }; - - let cur_label = self.label; - self.label_stack.push(cur_label); - - for ancestor in self.label_stack.iter() { - self.label_tree.get_mut(&cur_func).unwrap()[*ancestor] - .1 - .insert(label); - } - - self.label = label; + let label_id = self.builder.add_label(&label_str); + self.label_stack.push(label_id); } - pub fn pop_label(&mut self, structure: LabeledStructure) { - let Some(cur_func) = self.function else { - panic!("Setting label without function") - }; - self.label_tree.get_mut(&cur_func).unwrap()[self.label].0 = structure; - let Some(label) = self.label_stack.pop() else { - panic!("Cannot pop label not pushed first") + pub fn pop_label(&mut self) { + let Some(_) = self.label_stack.pop() else { + panic!("No label to pop") }; - self.label = label; } - fn allocate_node_labeled(&mut self, label: usize) -> NodeBuilder { + fn allocate_node_labeled(&mut self, label: Vec<LabelID>) -> NodeBuilder { let Some(func) = self.function else { panic!("Cannot allocate node without function") }; - let builder = self.builder.allocate_node(func); - - self.label_map - .get_mut(&func) - .unwrap() - .insert(builder.id(), label); + let mut builder = self.builder.allocate_node(func); + builder.add_labels(label.into_iter()); builder } pub fn allocate_node(&mut self) -> NodeBuilder { - self.allocate_node_labeled(self.label) + self.allocate_node_labeled(self.label_stack.clone()) } pub fn allocate_node_labeled_with(&mut self, other: NodeID) -> NodeBuilder { let Some(func) = self.function else { panic!("Cannot allocate node without function") }; - let label = self - .label_map - .get(&func) - .unwrap() - .get(&other) - .expect("Other node not labeled"); - - self.allocate_node_labeled(*label) + let labels = self + .builder + .get_labels(func, other) + .iter() + .cloned() + .collect::<Vec<_>>(); + self.allocate_node_labeled(labels) } pub fn add_node(&mut self, builder: NodeBuilder) { diff --git a/juno_frontend/src/lib.rs b/juno_frontend/src/lib.rs index d9a59a38..c85c9d71 100644 --- a/juno_frontend/src/lib.rs +++ b/juno_frontend/src/lib.rs @@ -1,6 +1,5 @@ mod codegen; mod dynconst; -mod env; mod intrinsics; mod labeled_builder; mod locs; @@ -9,11 +8,11 @@ mod semant; mod ssa; mod types; +use juno_scheduler::{schedule_hercules, schedule_juno}; + use std::fmt; use std::path::Path; -use juno_scheduler::FunctionMap; - pub enum JunoVerify { None, JunoOpts, @@ -80,9 +79,7 @@ impl fmt::Display for ErrorMessage { pub fn compile( src_file: String, - verify: JunoVerify, - x_dot: bool, - schedule: JunoSchedule, + schedule: Option<String>, output_dir: String, ) -> Result<(), ErrorMessage> { let src_file_path = Path::new(&src_file); @@ -94,135 +91,18 @@ pub fn compile( return Err(ErrorMessage::SemanticError(msg)); } }; - let (module, func_info) = codegen::codegen_program(prg); + let (module, juno_info) = codegen::codegen_program(prg); - compile_ir( - module, - Some(func_info), - verify, - x_dot, - schedule, - output_dir, - module_name, - ) + schedule_juno(module, juno_info, schedule, output_dir, module_name) + .map_err(|s| ErrorMessage::SchedulingError(s)) } pub fn compile_ir( module: hercules_ir::ir::Module, - func_info: Option<FunctionMap>, - verify: JunoVerify, - x_dot: bool, - schedule: JunoSchedule, + schedule: Option<String>, output_dir: String, module_name: String, ) -> Result<(), ErrorMessage> { - let mut pm = match schedule { - JunoSchedule::None => hercules_opt::pass::PassManager::new(module), - _ => todo!(), - /* - JunoSchedule::DefaultSchedule => { - let mut pm = hercules_opt::pass::PassManager::new(module); - pm.make_plans(); - pm - } - JunoSchedule::Schedule(file) => { - let Some(func_info) = func_info else { - return Err(ErrorMessage::SchedulingError( - "Cannot schedule, no function information provided".to_string(), - )); - }; - - match juno_scheduler::schedule(&module, func_info, file) { - Ok(plans) => { - let mut pm = hercules_opt::pass::PassManager::new(module); - pm.set_plans(plans); - pm - } - Err(msg) => { - return Err(ErrorMessage::SchedulingError(msg)); - } - } - } - */ - }; - if verify.verify() || verify.verify_all() { - pm.add_pass(hercules_opt::pass::Pass::Verify); - } - add_verified_pass!(pm, verify, GVN); - add_pass!(pm, verify, DCE); - add_verified_pass!(pm, verify, PhiElim); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, CRC); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, SLF); - add_pass!(pm, verify, DCE); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - add_pass!(pm, verify, Inline); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - // Inlining may make some functions uncalled, so run this pass. - // In general, this should always be run after inlining. - add_pass!(pm, verify, DeleteUncalled); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - // Run SROA pretty early (though after inlining which can make SROA more effective) so that - // CCP, GVN, etc. can work on the result of SROA - add_pass!(pm, verify, InterproceduralSROA); - add_pass!(pm, verify, SROA); - // We run phi-elim again because SROA can introduce new phis that might be able to be - // simplified - add_verified_pass!(pm, verify, PhiElim); - add_pass!(pm, verify, DCE); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - add_pass!(pm, verify, CCP); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, GVN); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, WritePredication); - add_pass!(pm, verify, PhiElim); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, CRC); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, SLF); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, Predication); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, CCP); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, GVN); - add_pass!(pm, verify, DCE); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - //add_pass!(pm, verify, Forkify); - //add_pass!(pm, verify, ForkGuardElim); - add_verified_pass!(pm, verify, DCE); - add_pass!(pm, verify, ForkSplit); - add_pass!(pm, verify, Unforkify); - add_pass!(pm, verify, GVN); - add_verified_pass!(pm, verify, DCE); - add_pass!(pm, verify, DCE); - add_pass!(pm, verify, Outline); - add_pass!(pm, verify, InterproceduralSROA); - add_pass!(pm, verify, SROA); - add_pass!(pm, verify, InferSchedules); - add_verified_pass!(pm, verify, DCE); - if x_dot { - pm.add_pass(hercules_opt::pass::Pass::Xdot(true)); - } - - add_pass!(pm, verify, GCM); - add_verified_pass!(pm, verify, DCE); - add_pass!(pm, verify, FloatCollections); - add_pass!(pm, verify, GCM); - pm.add_pass(hercules_opt::pass::Pass::Codegen(output_dir, module_name)); - pm.run_passes(); - - Ok(()) + schedule_hercules(module, schedule, output_dir, module_name) + .map_err(|s| ErrorMessage::SchedulingError(s)) } diff --git a/juno_frontend/src/main.rs b/juno_frontend/src/main.rs index d98c1e29..6c2722fc 100644 --- a/juno_frontend/src/main.rs +++ b/juno_frontend/src/main.rs @@ -1,52 +1,21 @@ use juno_compiler::*; -use clap::{ArgGroup, Parser}; +use clap::Parser; use std::path::PathBuf; #[derive(Parser)] #[clap(author, version, about, long_about = None)] -#[clap(group( - ArgGroup::new("scheduling") - .required(false) - .args(&["schedule", "default_schedule", "no_schedule"])))] struct Cli { src_file: String, - #[clap(short, long)] - verify: bool, - #[clap(long = "verify-all")] - verify_all: bool, - #[arg(short, long = "x-dot")] - x_dot: bool, #[clap(short, long, value_name = "SCHEDULE")] schedule: Option<String>, - #[clap(short, long = "default-schedule")] - default_schedule: bool, - #[clap(short, long)] - no_schedule: bool, #[arg(short, long = "output-dir", value_name = "OUTPUT DIR")] output_dir: Option<String>, } fn main() { let args = Cli::parse(); - let verify = if args.verify_all { - JunoVerify::AllPasses - } else if args.verify { - JunoVerify::JunoOpts - } else { - JunoVerify::None - }; - let schedule = match args.schedule { - Some(file) => JunoSchedule::Schedule(file), - None => { - if args.default_schedule { - JunoSchedule::DefaultSchedule - } else { - JunoSchedule::None - } - } - }; let output_dir = match args.output_dir { Some(dir) => dir, None => PathBuf::from(args.src_file.clone()) @@ -56,7 +25,7 @@ fn main() { .unwrap() .to_string(), }; - match compile(args.src_file, verify, args.x_dot, schedule, output_dir) { + match compile(args.src_file, args.schedule, output_dir) { Ok(()) => {} Err(errs) => { eprintln!("{}", errs); diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index 660d8afe..2fe4bf88 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -10,7 +10,6 @@ use lrpar::NonStreamingLexer; use ordered_float::OrderedFloat; use crate::dynconst::DynConst; -use crate::env::Env; use crate::intrinsics; use crate::locs::{span_to_loc, Location}; use crate::parser; @@ -18,6 +17,9 @@ use crate::parser::*; use crate::types; use crate::types::{Either, Type, TypeSolver}; +use juno_utils::env::Env; +use juno_utils::stringtab::StringTable; + // Definitions and data structures for semantic analysis // Entities in the environment @@ -76,72 +78,6 @@ impl PartialEq for Literal { impl Eq for Literal {} -// Map strings to unique identifiers and counts uids -struct StringTable { - count: usize, - string_to_index: HashMap<String, usize>, - index_to_string: HashMap<usize, String>, -} -impl StringTable { - fn new() -> StringTable { - StringTable { - count: 0, - string_to_index: HashMap::new(), - index_to_string: HashMap::new(), - } - } - - // Produce the UID for a string - fn lookup_string(&mut self, s: String) -> usize { - match self.string_to_index.get(&s) { - Some(n) => *n, - None => { - let n = self.count; - self.count += 1; - self.string_to_index.insert(s.clone(), n); - self.index_to_string.insert(n, s); - n - } - } - } - - // Identify the string corresponding to a UID - fn lookup_id(&self, n: usize) -> Option<String> { - self.index_to_string.get(&n).cloned() - } -} - -// Maps label names to unique identifiers (numbered 0..n for each function) -// Also tracks the map from function names to their numbers -struct LabelSet { - count: usize, - string_to_index: HashMap<String, usize>, -} -impl LabelSet { - fn new() -> LabelSet { - // Label number 0 is reserved to be the "root" label in code generation - LabelSet { - count: 1, - string_to_index: HashMap::from([("<root>".to_string(), 0)]), - } - } - - // Inserts a string if it is not already contained in this set, if it is - // contained does nothing and returns the label back wrapped in an error, - // otherwise inserts and returns the new label's id wrapped in Ok - fn insert_new(&mut self, label: String) -> Result<usize, String> { - match self.string_to_index.get(&label) { - Some(_) => Err(label), - None => { - let uid = self.count; - self.count += 1; - self.string_to_index.insert(label, uid); - Ok(uid) - } - } - } -} - // Convert spans into uids in the String Table fn intern_id( n: &Span, @@ -252,7 +188,11 @@ fn append_errors3<A, B, C>( // Normalized AST forms after semantic analysis // These include type information at all expression nodes, and remove names and locations -pub type Prg = (TypeSolver, Vec<Function>); +pub struct Prg { + pub types: TypeSolver, + pub funcs: Vec<Function>, + pub labels: StringTable, +} // The function stores information for code-generation. The type information therefore is not the // type information that is needed for type checking code that uses this function. @@ -263,8 +203,6 @@ pub struct Function { pub num_type_args: usize, pub arguments: Vec<(usize, Type)>, pub return_type: Type, - pub num_labels: usize, - pub label_map: HashMap<String, usize>, pub body: Stmt, pub entry: bool, } @@ -304,7 +242,6 @@ pub enum Stmt { ContinueStmt {}, BlockStmt { body: Vec<Stmt>, - label_last: bool, }, ExprStmt { expr: Expr, @@ -583,6 +520,7 @@ fn analyze_program( lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, ) -> Result<Prg, ErrorMessages> { let mut stringtab = StringTable::new(); + let mut labels = StringTable::new(); let mut env: Env<usize, Entity> = Env::new(); let mut types = TypeSolver::new(); @@ -889,9 +827,6 @@ fn analyze_program( } } - // Create a set of the labels in this function - let mut labels = LabelSet::new(); - // Finally, we have a properly built environment and we can // start processing the body let (mut body, end_reachable) = process_stmt( @@ -927,7 +862,6 @@ fn analyze_program( &mut types, ), ], - label_last: false, }; } else { Err(singleton_error(ErrorMessage::SemanticError( @@ -965,8 +899,6 @@ fn analyze_program( .map(|(v, n)| (*n, v.1)) .collect::<Vec<_>>(), return_type: pure_return_type, - num_labels: labels.count, - label_map: labels.string_to_index, body: body, entry: entry, }); @@ -998,7 +930,11 @@ fn analyze_program( } } - Ok((types, res)) + Ok(Prg { + types, + funcs: res, + labels, + }) } fn process_type_def( @@ -1686,7 +1622,7 @@ fn process_stmt( return_type: Type, inout_vars: &Vec<usize>, inout_types: &Vec<Type>, - labels: &mut LabelSet, + labels: &mut StringTable, ) -> Result<(Stmt, bool), ErrorMessages> { match stmt { parser::Stmt::LetStmt { @@ -2286,9 +2222,6 @@ fn process_stmt( body: Box::new(body), }, ], - // A label applied to this loop should be applied to the - // loop, not the initialization - label_last: true, }, true, )) @@ -2438,13 +2371,7 @@ fn process_stmt( if !errors.is_empty() { Err(errors) } else { - Ok(( - Stmt::BlockStmt { - body: res, - label_last: false, - }, - reachable, - )) + Ok((Stmt::BlockStmt { body: res }, reachable)) } } parser::Stmt::CallStmt { @@ -2480,15 +2407,8 @@ fn process_stmt( label, stmt, } => { - let label_str = lexer.span_str(label).to_string(); - - let label_id = match labels.insert_new(label_str) { - Err(label_str) => Err(singleton_error(ErrorMessage::SemanticError( - span_to_loc(label, lexer), - format!("Label {} already exists", label_str), - )))?, - Ok(id) => id, - }; + let label_str = lexer.span_str(label)[1..].to_string(); + let label_id = labels.lookup_string(label_str); let (body, reach_end) = process_stmt( *stmt, diff --git a/juno_frontend/src/ssa.rs b/juno_frontend/src/ssa.rs index 578f7a9a..7076d622 100644 --- a/juno_frontend/src/ssa.rs +++ b/juno_frontend/src/ssa.rs @@ -7,9 +7,9 @@ use std::collections::{HashMap, HashSet}; +use crate::labeled_builder::LabeledBuilder; use hercules_ir::build::*; use hercules_ir::ir::*; -use crate::labeled_builder::LabeledBuilder; pub struct SSA { // Map from variable (usize) to build (NodeID) to definition (NodeID) diff --git a/juno_samples/schedule_test/Cargo.toml b/juno_samples/schedule_test/Cargo.toml new file mode 100644 index 00000000..be5d949b --- /dev/null +++ b/juno_samples/schedule_test/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "juno_schedule_test" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_schedule_test" +path = "src/main.rs" + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" +rand = "*" diff --git a/juno_samples/schedule_test/build.rs b/juno_samples/schedule_test/build.rs new file mode 100644 index 00000000..4a428247 --- /dev/null +++ b/juno_samples/schedule_test/build.rs @@ -0,0 +1,11 @@ +use juno_build::JunoCompiler; + +fn main() { + JunoCompiler::new() + .file_in_src("code.jn") + .unwrap() + .schedule_in_src("sched.sch") + .unwrap() + .build() + .unwrap(); +} diff --git a/juno_samples/schedule_test/src/code.jn b/juno_samples/schedule_test/src/code.jn new file mode 100644 index 00000000..5bb923bf --- /dev/null +++ b/juno_samples/schedule_test/src/code.jn @@ -0,0 +1,30 @@ +#[entry] +fn test<n, m, k: usize>(a: i32[n, m], b: i32[m, k], c: i32[k]) -> i32[n] { + let prod: i32[n, k]; + + @outer for i = 0 to n { + @middle for j = 0 to k { + let val = 0; + + @inner for k = 0 to m { + val += a[i, k] * b[k, j]; + } + + prod[i, j] = val; + } + } + + let res: i32[n]; + + @row for i = 0 to n { + let val = 0; + + @col for j = 0 to k { + val += prod[i, j] * c[j]; + } + + res[i] = val; + } + + return res; +} diff --git a/juno_samples/schedule_test/src/main.rs b/juno_samples/schedule_test/src/main.rs new file mode 100644 index 00000000..a64cd16f --- /dev/null +++ b/juno_samples/schedule_test/src/main.rs @@ -0,0 +1,42 @@ +#![feature(box_as_ptr, let_chains)] + +use rand::random; + +use hercules_rt::HerculesBox; + +juno_build::juno!("code"); + +fn main() { + async_std::task::block_on(async { + const N: usize = 256; + const M: usize = 64; + const K: usize = 128; + let a: Box<[i32]> = (0..N * M).map(|_| random::<i32>() % 100).collect(); + let b: Box<[i32]> = (0..M * K).map(|_| random::<i32>() % 100).collect(); + let c: Box<[i32]> = (0..K).map(|_| random::<i32>() % 100).collect(); + + let mut correct_res: Box<[i32]> = (0..N).map(|_| 0).collect(); + for i in 0..N { + for j in 0..K { + let mut res = 0; + for k in 0..M { + res += a[i * M + k] * b[k * K + j]; + } + correct_res[i] += c[j] * res; + } + } + + let mut res = { + let a = HerculesBox::from_slice(&a); + let b = HerculesBox::from_slice(&b); + let c = HerculesBox::from_slice(&c); + test(N as u64, M as u64, K as u64, a, b, c).await + }; + assert_eq!(res.as_slice::<i32>(), &*correct_res); + }); +} + +#[test] +fn schedule_test() { + main(); +} diff --git a/juno_samples/schedule_test/src/sched.sch b/juno_samples/schedule_test/src/sched.sch new file mode 100644 index 00000000..f73e3e70 --- /dev/null +++ b/juno_samples/schedule_test/src/sched.sch @@ -0,0 +1,43 @@ +macro juno-setup!(X) { + //gvn(X); + phi-elim(X); + dce(X); +} +macro codegen-prep!(X) { + infer-schedules(X); + dce(X); + gcm(X); + dce(X); + phi-elim(X); + float-collections(X); + gcm(X); +} + + +juno-setup!(*); + +let first = outline(test@outer); +let second = outline(test@row); + +// We can use the functions produced by outlining in our schedules +gvn(first, second, test); + +ip-sroa(*); +sroa(*); + +// We can evaluate expressions using labels and save them for later use +let inner = first@inner; + +// A fixpoint can run a (series) of passes until no more changes are made +// (though some passes seem to make edits even if there are no real changes, +// so this is fragile). +// We could just let it run until it converges but can also tell it to panic +// if it hasn't converged after a number of iterations (like here) tell it to +// just stop after a certain number of iterations (stop after #) or to print +// the iteration number (print iter) +fixpoint panic after 2 { + phi-elim(*); +} + +codegen-prep!(*); +//xdot[true](*); diff --git a/juno_scheduler/Cargo.toml b/juno_scheduler/Cargo.toml index 49e5f4a3..1c837d4a 100644 --- a/juno_scheduler/Cargo.toml +++ b/juno_scheduler/Cargo.toml @@ -13,4 +13,8 @@ lrpar = "0.13" cfgrammar = "0.13" lrlex = "0.13" lrpar = "0.13" +tempfile = "*" +hercules_cg = { path = "../hercules_cg" } hercules_ir = { path = "../hercules_ir" } +hercules_opt = { path = "../hercules_opt" } +juno_utils = { path = "../juno_utils" } diff --git a/juno_scheduler/src/compile.rs b/juno_scheduler/src/compile.rs new file mode 100644 index 00000000..5317eb86 --- /dev/null +++ b/juno_scheduler/src/compile.rs @@ -0,0 +1,495 @@ +use crate::ir; +use crate::parser; + +use juno_utils::env::Env; +use juno_utils::stringtab::StringTable; + +extern crate hercules_ir; +use self::hercules_ir::ir::{Device, Schedule}; + +use lrlex::DefaultLexerTypes; +use lrpar::NonStreamingLexer; +use lrpar::Span; + +use std::fmt; +use std::str::FromStr; + +type Location = ((usize, usize), (usize, usize)); + +pub enum ScheduleCompilerError { + UndefinedMacro(String, Location), + NoSuchPass(String, Location), + IncorrectArguments { + expected: usize, + actual: usize, + loc: Location, + }, +} + +impl fmt::Display for ScheduleCompilerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ScheduleCompilerError::UndefinedMacro(name, loc) => write!( + f, + "({}, {}) -- ({}, {}): Undefined macro '{}'", + loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name + ), + ScheduleCompilerError::NoSuchPass(name, loc) => write!( + f, + "({}, {}) -- ({}, {}): Undefined pass '{}'", + loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, name + ), + ScheduleCompilerError::IncorrectArguments { + expected, + actual, + loc, + } => write!( + f, + "({}, {}) -- ({}, {}): Expected {} arguments, found {}", + loc.0 .0, loc.0 .1, loc.1 .0, loc.1 .1, expected, actual + ), + } + } +} + +pub fn compile_schedule( + sched: parser::OperationList, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, +) -> Result<ir::ScheduleStmt, ScheduleCompilerError> { + let mut macrostab = StringTable::new(); + let mut macros = Env::new(); + + macros.open_scope(); + + Ok(ir::ScheduleStmt::Block { + body: compile_ops_as_block(sched, lexer, &mut macrostab, &mut macros)?, + }) +} + +#[derive(Debug, Clone)] +struct MacroInfo { + params: Vec<String>, + selection_name: String, + def: ir::ScheduleExp, +} + +enum Appliable { + Pass(ir::Pass), + Schedule(Schedule), + Device(Device), +} + +impl Appliable { + fn num_args(&self) -> usize { + match self { + Appliable::Pass(pass) => pass.num_args(), + // Schedules and devices do not arguments (at the moment) + _ => 0, + } + } +} + +impl FromStr for Appliable { + type Err = String; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + match s { + "auto-outline" => Ok(Appliable::Pass(ir::Pass::AutoOutline)), + "ccp" => Ok(Appliable::Pass(ir::Pass::CCP)), + "crc" | "collapse-read-chains" => Ok(Appliable::Pass(ir::Pass::CRC)), + "dce" => Ok(Appliable::Pass(ir::Pass::DCE)), + "delete-uncalled" => Ok(Appliable::Pass(ir::Pass::DeleteUncalled)), + "float-collections" | "collections" => Ok(Appliable::Pass(ir::Pass::FloatCollections)), + "fork-guard-elim" => Ok(Appliable::Pass(ir::Pass::ForkGuardElim)), + "fork-split" => Ok(Appliable::Pass(ir::Pass::ForkSplit)), + "forkify" => Ok(Appliable::Pass(ir::Pass::Forkify)), + "gcm" | "bbs" => Ok(Appliable::Pass(ir::Pass::GCM)), + "gvn" => Ok(Appliable::Pass(ir::Pass::GVN)), + "infer-schedules" => Ok(Appliable::Pass(ir::Pass::InferSchedules)), + "inline" => Ok(Appliable::Pass(ir::Pass::Inline)), + "ip-sroa" | "interprocedural-sroa" => { + Ok(Appliable::Pass(ir::Pass::InterproceduralSROA)) + } + "outline" => Ok(Appliable::Pass(ir::Pass::Outline)), + "phi-elim" => Ok(Appliable::Pass(ir::Pass::PhiElim)), + "predication" => Ok(Appliable::Pass(ir::Pass::Predication)), + "slf" | "store-load-forward" => Ok(Appliable::Pass(ir::Pass::SLF)), + "sroa" => Ok(Appliable::Pass(ir::Pass::SROA)), + "unforkify" => Ok(Appliable::Pass(ir::Pass::Unforkify)), + "verify" => Ok(Appliable::Pass(ir::Pass::Verify)), + "xdot" => Ok(Appliable::Pass(ir::Pass::Xdot)), + + "cpu" | "llvm" => Ok(Appliable::Device(Device::LLVM)), + "gpu" | "cuda" | "nvidia" => Ok(Appliable::Device(Device::CUDA)), + "host" | "rust" | "rust-async" => Ok(Appliable::Device(Device::AsyncRust)), + + "associative" => Ok(Appliable::Schedule(Schedule::TightAssociative)), + "parallel-fork" => Ok(Appliable::Schedule(Schedule::ParallelFork)), + "parallel-reduce" => Ok(Appliable::Schedule(Schedule::ParallelReduce)), + "vectorize" => Ok(Appliable::Schedule(Schedule::Vectorizable)), + + _ => Err(s.to_string()), + } + } +} + +fn compile_ops_as_block( + sched: parser::OperationList, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> { + match sched { + parser::OperationList::NilStmt() => Ok(vec![]), + parser::OperationList::FinalExpr(expr) => { + Ok(vec![compile_exp_as_stmt(expr, lexer, macrostab, macros)?]) + } + parser::OperationList::ConsStmt(stmt, ops) => { + let mut res = compile_stmt(stmt, lexer, macrostab, macros)?; + res.extend(compile_ops_as_block(*ops, lexer, macrostab, macros)?); + Ok(res) + } + } +} + +fn compile_stmt( + stmt: parser::Stmt, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<Vec<ir::ScheduleStmt>, ScheduleCompilerError> { + match stmt { + parser::Stmt::LetStmt { span: _, var, expr } => { + let var = lexer.span_str(var).to_string(); + Ok(vec![ir::ScheduleStmt::Let { + var, + exp: compile_exp_as_expr(expr, lexer, macrostab, macros)?, + }]) + } + parser::Stmt::AssignStmt { span: _, var, rhs } => { + let var = lexer.span_str(var).to_string(); + Ok(vec![ir::ScheduleStmt::Assign { + var, + exp: compile_exp_as_expr(rhs, lexer, macrostab, macros)?, + }]) + } + parser::Stmt::ExprStmt { span: _, exp } => { + Ok(vec![compile_exp_as_stmt(exp, lexer, macrostab, macros)?]) + } + parser::Stmt::Fixpoint { + span: _, + limit, + body, + } => { + let limit = match limit { + parser::FixpointLimit::NoLimit { .. } => ir::FixpointLimit::NoLimit(), + parser::FixpointLimit::StopAfter { span: _, limit } => { + ir::FixpointLimit::StopAfter( + lexer + .span_str(limit) + .parse() + .expect("Parsing ensures integer"), + ) + } + parser::FixpointLimit::PanicAfter { span: _, limit } => { + ir::FixpointLimit::PanicAfter( + lexer + .span_str(limit) + .parse() + .expect("Parsing ensures integer"), + ) + } + parser::FixpointLimit::PrintIter { .. } => ir::FixpointLimit::PrintIter(), + }; + + macros.open_scope(); + let body = compile_ops_as_block(*body, lexer, macrostab, macros); + macros.close_scope(); + + Ok(vec![ir::ScheduleStmt::Fixpoint { + body: Box::new(ir::ScheduleStmt::Block { body: body? }), + limit, + }]) + } + parser::Stmt::MacroDecl { span: _, def } => { + let parser::MacroDecl { + name, + params, + selection_name, + def, + } = def; + let name = lexer.span_str(name).to_string(); + let macro_id = macrostab.lookup_string(name); + + let selection_name = lexer.span_str(selection_name).to_string(); + + let params = params + .into_iter() + .map(|s| lexer.span_str(s).to_string()) + .collect(); + + let def = compile_macro_def(*def, params, selection_name, lexer, macrostab, macros)?; + macros.insert(macro_id, def); + + Ok(vec![]) + } + } +} + +fn compile_exp_as_stmt( + expr: parser::Expr, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ir::ScheduleStmt, ScheduleCompilerError> { + match compile_expr(expr, lexer, macrostab, macros)? { + ExprResult::Expr(exp) => Ok(ir::ScheduleStmt::Let { + var: "_".to_string(), + exp, + }), + ExprResult::Stmt(stm) => Ok(stm), + } +} + +fn compile_exp_as_expr( + expr: parser::Expr, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ir::ScheduleExp, ScheduleCompilerError> { + match compile_expr(expr, lexer, macrostab, macros)? { + ExprResult::Expr(exp) => Ok(exp), + ExprResult::Stmt(stm) => Ok(ir::ScheduleExp::Block { + body: vec![stm], + res: Box::new(ir::ScheduleExp::Record { fields: vec![] }), + }), + } +} + +enum ExprResult { + Expr(ir::ScheduleExp), + Stmt(ir::ScheduleStmt), +} + +fn compile_expr( + expr: parser::Expr, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ExprResult, ScheduleCompilerError> { + match expr { + parser::Expr::Function { + span, + name, + args, + selection, + } => { + let func: Appliable = lexer + .span_str(name) + .to_lowercase() + .parse() + .map_err(|s| ScheduleCompilerError::NoSuchPass(s, lexer.line_col(name)))?; + + if args.len() != func.num_args() { + return Err(ScheduleCompilerError::IncorrectArguments { + expected: func.num_args(), + actual: args.len(), + loc: lexer.line_col(span), + }); + } + + let mut arg_vals = vec![]; + for arg in args { + arg_vals.push(compile_exp_as_expr(arg, lexer, macrostab, macros)?); + } + + let selection = compile_selector(selection, lexer, macrostab, macros)?; + + match func { + Appliable::Pass(pass) => Ok(ExprResult::Expr(ir::ScheduleExp::RunPass { + pass, + args: arg_vals, + on: selection, + })), + Appliable::Schedule(sched) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddSchedule { + sched, + on: selection, + })), + Appliable::Device(device) => Ok(ExprResult::Stmt(ir::ScheduleStmt::AddDevice { + device, + on: selection, + })), + } + } + parser::Expr::Macro { + span, + name, + args, + selection, + } => { + let name_str = lexer.span_str(name).to_string(); + let macro_id = macrostab.lookup_string(name_str.clone()); + let Some(macro_def) = macros.lookup(¯o_id) else { + return Err(ScheduleCompilerError::UndefinedMacro( + name_str, + lexer.line_col(name), + )); + }; + let macro_def: MacroInfo = macro_def.clone(); + let MacroInfo { + params, + selection_name, + def, + } = macro_def; + + if args.len() != params.len() { + return Err(ScheduleCompilerError::IncorrectArguments { + expected: params.len(), + actual: args.len(), + loc: lexer.line_col(span), + }); + } + + // To initialize the macro's arguments, we have to do this in two steps, we first + // evaluate all of the arguments and store them into new variables, using names that + // cannot conflict with other values in the program and then we assign those variables + // to the macro's parameters; this avoids any shadowing issues, for instance: + // macro![3, x] where macro!'s arguments are named x and y becomes + // let #0 = 3; let #1 = x; let x = #0; let y = #1; + // which has the desired semantics, as opposed to + // let x = 3; let y = x; + let mut arg_eval = vec![]; + let mut arg_setters = vec![]; + + for (i, (exp, var)) in args.into_iter().zip(params.into_iter()).enumerate() { + let tmp = format!("#{}", i); + arg_eval.push(ir::ScheduleStmt::Let { + var: tmp.clone(), + exp: compile_exp_as_expr(exp, lexer, macrostab, macros)?, + }); + arg_setters.push(ir::ScheduleStmt::Let { + var, + exp: ir::ScheduleExp::Variable { var: tmp }, + }); + } + + // Set the selection + arg_eval.push(ir::ScheduleStmt::Let { + var: selection_name, + exp: ir::ScheduleExp::Selection { + selection: compile_selector(selection, lexer, macrostab, macros)?, + }, + }); + + // Combine the evaluation and initialization code + arg_eval.extend(arg_setters); + + Ok(ExprResult::Expr(ir::ScheduleExp::Block { + body: arg_eval, + res: Box::new(def), + })) + } + parser::Expr::Variable { span } => { + let var = lexer.span_str(span).to_string(); + Ok(ExprResult::Expr(ir::ScheduleExp::Variable { var })) + } + parser::Expr::Integer { span } => { + let val: usize = lexer.span_str(span).parse().expect("Parsing"); + Ok(ExprResult::Expr(ir::ScheduleExp::Integer { val })) + } + parser::Expr::Boolean { span: _, val } => { + Ok(ExprResult::Expr(ir::ScheduleExp::Boolean { val })) + } + parser::Expr::Field { + span: _, + lhs, + field, + } => { + let field = lexer.span_str(field).to_string(); + let lhs = compile_exp_as_expr(*lhs, lexer, macrostab, macros)?; + Ok(ExprResult::Expr(ir::ScheduleExp::Field { + collect: Box::new(lhs), + field, + })) + } + parser::Expr::BlockExpr { span: _, body } => { + compile_ops_as_expr(*body, lexer, macrostab, macros) + } + parser::Expr::Record { span: _, fields } => { + let mut result = vec![]; + for (name, expr) in fields { + let name = lexer.span_str(name).to_string(); + let expr = compile_exp_as_expr(expr, lexer, macrostab, macros)?; + result.push((name, expr)); + } + Ok(ExprResult::Expr(ir::ScheduleExp::Record { fields: result })) + } + } +} + +fn compile_ops_as_expr( + mut sched: parser::OperationList, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ExprResult, ScheduleCompilerError> { + let mut body = vec![]; + loop { + match sched { + parser::OperationList::NilStmt() => { + return Ok(ExprResult::Stmt(ir::ScheduleStmt::Block { body })); + } + parser::OperationList::FinalExpr(expr) => { + return Ok(ExprResult::Expr(ir::ScheduleExp::Block { + body, + res: Box::new(compile_exp_as_expr(expr, lexer, macrostab, macros)?), + })); + } + parser::OperationList::ConsStmt(stmt, ops) => { + body.extend(compile_stmt(stmt, lexer, macrostab, macros)?); + sched = *ops; + } + } + } +} + +fn compile_selector( + sel: parser::Selector, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<ir::Selector, ScheduleCompilerError> { + match sel { + parser::Selector::SelectAll { span: _ } => Ok(ir::Selector::Everything()), + parser::Selector::SelectExprs { span: _, exprs } => { + let mut res = vec![]; + for exp in exprs { + res.push(compile_exp_as_expr(exp, lexer, macrostab, macros)?); + } + Ok(ir::Selector::Selection(res)) + } + } +} + +fn compile_macro_def( + body: parser::OperationList, + params: Vec<String>, + selection_name: String, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + macrostab: &mut StringTable, + macros: &mut Env<usize, MacroInfo>, +) -> Result<MacroInfo, ScheduleCompilerError> { + // FIXME: The body should be checked in an environment that prohibits running anything on + // everything (*) and check that only local variables/parameters are used + Ok(MacroInfo { + params, + selection_name, + def: match compile_ops_as_expr(body, lexer, macrostab, macros)? { + ExprResult::Expr(expr) => expr, + ExprResult::Stmt(stmt) => ir::ScheduleExp::Block { + body: vec![stmt], + res: Box::new(ir::ScheduleExp::Record { fields: vec![] }), + }, + }, + }) +} diff --git a/juno_scheduler/src/default.rs b/juno_scheduler/src/default.rs new file mode 100644 index 00000000..8274b81a --- /dev/null +++ b/juno_scheduler/src/default.rs @@ -0,0 +1,81 @@ +use crate::ir::*; + +macro_rules! pass { + ($p:ident) => { + ScheduleStmt::Let { + var: String::from("_"), + exp: ScheduleExp::RunPass { + pass: Pass::$p, + args: vec![], + on: Selector::Everything(), + }, + } + }; +} + +macro_rules! default_schedule { + () => { + ScheduleStmt::Block { + body: vec![], + } + }; + ($($p:ident),+ $(,)?) => { + ScheduleStmt::Block { + body: vec![$(pass!($p)),+], + } + }; +} + +// Defualt schedule, which is used if no schedule is provided +pub fn default_schedule() -> ScheduleStmt { + default_schedule![ + GVN, + DCE, + PhiElim, + DCE, + CRC, + DCE, + SLF, + DCE, + Inline, + /*DeleteUncalled,*/ + InterproceduralSROA, + SROA, + PhiElim, + DCE, + CCP, + DCE, + GVN, + DCE, + WritePredication, + PhiElim, + DCE, + CRC, + DCE, + SLF, + DCE, + Predication, + DCE, + CCP, + DCE, + GVN, + DCE, + /*Forkify,*/ + /*ForkGuardElim,*/ + DCE, + ForkSplit, + Unforkify, + GVN, + DCE, + DCE, + AutoOutline, + InterproceduralSROA, + SROA, + InferSchedules, + DCE, + GCM, + DCE, + FloatCollections, + GCM, + ] +} diff --git a/juno_scheduler/src/ir.rs b/juno_scheduler/src/ir.rs new file mode 100644 index 00000000..16f2de9b --- /dev/null +++ b/juno_scheduler/src/ir.rs @@ -0,0 +1,114 @@ +extern crate hercules_ir; + +use self::hercules_ir::ir::{Device, Schedule}; + +#[derive(Debug, Copy, Clone)] +pub enum Pass { + AutoOutline, + CCP, + CRC, + DCE, + DeleteUncalled, + FloatCollections, + ForkGuardElim, + ForkSplit, + Forkify, + GCM, + GVN, + InferSchedules, + Inline, + InterproceduralSROA, + Outline, + PhiElim, + Predication, + SLF, + SROA, + Unforkify, + WritePredication, + Verify, + Xdot, +} + +impl Pass { + pub fn num_args(&self) -> usize { + match self { + Pass::Xdot => 1, + _ => 0, + } + } +} + +#[derive(Debug, Clone)] +pub enum Selector { + Everything(), + Selection(Vec<ScheduleExp>), +} + +#[derive(Debug, Clone)] +pub enum ScheduleExp { + Variable { + var: String, + }, + Integer { + val: usize, + }, + Boolean { + val: bool, + }, + Field { + collect: Box<ScheduleExp>, + field: String, + }, + RunPass { + pass: Pass, + args: Vec<ScheduleExp>, + on: Selector, + }, + Record { + fields: Vec<(String, ScheduleExp)>, + }, + Block { + body: Vec<ScheduleStmt>, + res: Box<ScheduleExp>, + }, + // This is used to "box" a selection by evaluating it at one point and then + // allowing it to be used as a selector later on + Selection { + selection: Selector, + }, +} + +#[derive(Debug, Copy, Clone)] +pub enum FixpointLimit { + NoLimit(), + PrintIter(), + StopAfter(usize), + PanicAfter(usize), +} + +#[derive(Debug, Clone)] +pub enum ScheduleStmt { + Fixpoint { + body: Box<ScheduleStmt>, + limit: FixpointLimit, + }, + Block { + body: Vec<ScheduleStmt>, + }, + Let { + var: String, + exp: ScheduleExp, + }, + Assign { + var: String, + exp: ScheduleExp, + }, + AddSchedule { + sched: Schedule, + on: Selector, + }, + AddDevice { + device: Device, + on: Selector, + }, +} diff --git a/juno_scheduler/src/labels.rs b/juno_scheduler/src/labels.rs new file mode 100644 index 00000000..6690e17a --- /dev/null +++ b/juno_scheduler/src/labels.rs @@ -0,0 +1,61 @@ +use hercules_ir::ir::*; + +use std::collections::HashMap; + +#[derive(Debug, Copy, Clone)] +pub struct LabelInfo { + pub func: FunctionID, + pub label: LabelID, +} + +#[derive(Debug, Copy, Clone)] +pub struct JunoFunctionID { + pub idx: usize, +} + +impl JunoFunctionID { + pub fn new(idx: usize) -> Self { + JunoFunctionID { idx } + } +} + +// From the Juno frontend we collect certain information we need for scheduling it, in particular a +// map from the function names to a "JunoFunctionID" which can be used to lookup the Hercules +// FunctionIDs that are definitions of that function (it may be multiple since the same Juno +// function may be instantiated with multiple difference type variables). +#[derive(Debug, Clone)] +pub struct JunoInfo { + pub func_names: HashMap<String, JunoFunctionID>, + pub func_info: JunoFunctions, +} + +#[derive(Debug, Clone)] +pub struct JunoFunctions { + pub func_ids: Vec<Vec<FunctionID>>, +} + +impl JunoInfo { + pub fn new<I>(funcs: I) -> Self + where + I: Iterator<Item = String>, + { + let mut func_names = HashMap::new(); + let mut func_ids = vec![]; + + for (idx, name) in funcs.enumerate() { + func_names.insert(name, JunoFunctionID::new(idx)); + func_ids.push(vec![]); + } + + JunoInfo { + func_names, + func_info: JunoFunctions { func_ids }, + } + } +} + +impl JunoFunctions { + pub fn get_function(&self, id: JunoFunctionID) -> &Vec<FunctionID> { + &self.func_ids[id.idx] + } +} diff --git a/juno_scheduler/src/lang.l b/juno_scheduler/src/lang.l index e6526c74..9d4c34bf 100644 --- a/juno_scheduler/src/lang.l +++ b/juno_scheduler/src/lang.l @@ -13,15 +13,38 @@ [\n\r] ; , "," +; ";" += "=" +@ "@" +\* "*" +\. "." +apply "apply" +fixpoint "fixpoint" +let "let" +macro "macro_keyword" +on "on" +set "set" +target "target" + +true "true" +false "false" + +\( "(" +\) ")" +\< "<" +\> ">" +\[ "[" +\] "]" \{ "{" \} "}" -function "function" -on "on" -partition "partition" +panic[\t \n\r]+after "panic_after" +print[\t \n\r]+iter "print_iter" +stop[\t \n\r]+after "stop_after" -[a-zA-Z][a-zA-Z0-9_]* "ID" -@[a-zA-Z0-9_]+ "LABEL" +[a-zA-Z][a-zA-Z0-9_\-]*! "MACRO" +[a-zA-Z][a-zA-Z0-9_\-]* "ID" +[0-9]+ "INT" . "UNMATCHED" diff --git a/juno_scheduler/src/lang.y b/juno_scheduler/src/lang.y index e7d98dba..9cb72842 100644 --- a/juno_scheduler/src/lang.y +++ b/juno_scheduler/src/lang.y @@ -1,61 +1,112 @@ %start Schedule -%avoid_insert "ID" "LABEL" +%avoid_insert "ID" "INT" %expect-unused Unmatched 'UNMATCHED' %% -Schedule -> Vec<FuncDirectives> : FunctionList { $1 }; +Schedule -> OperationList + : { OperationList::NilStmt() } + | Expr { OperationList::FinalExpr($1) } + | Stmt Schedule { OperationList::ConsStmt($1, Box::new($2)) } + ; -FunctionList -> Vec<FuncDirectives> - : { vec![] } - | FunctionList FunctionDef { snoc($1, $2) } +Stmt -> Stmt + : 'let' 'ID' '=' Expr ';' + { Stmt::LetStmt { span: $span, var: span_of_tok($2), expr: $4 } } + | 'ID' '=' Expr ';' + { Stmt::AssignStmt { span: $span, var: span_of_tok($1), rhs: $3 } } + | Expr ';' + { Stmt::ExprStmt { span: $span, exp: $1 } } + | 'fixpoint' FixpointLimit '{' Schedule '}' + { Stmt::Fixpoint { span: $span, limit: $2, body: Box::new($4) } } + | MacroDecl + { Stmt::MacroDecl { span: $span, def: $1 } } ; -FunctionDef -> FuncDirectives - : 'function' Func '{' DirectiveList '}' - { FuncDirectives { span : $span, func : $2, directives : $4 }}; +FixpointLimit -> FixpointLimit + : { FixpointLimit::NoLimit { span: $span } } + | 'stop_after' 'INT' + { FixpointLimit::StopAfter { span: $span, limit: span_of_tok($2) } } + | 'panic_after' 'INT' + { FixpointLimit::PanicAfter { span: $span, limit: span_of_tok($2) } } + | 'print_iter' + { FixpointLimit::PrintIter { span: $span } } + ; -DirectiveList -> Vec<Directive> - : { vec![] } - | DirectiveList Directive { snoc($1, $2) } +Expr -> Expr + : 'ID' Args Selector + { Expr::Function { span: $span, name: span_of_tok($1), args: $2, selection: $3 } } + | 'MACRO' Args Selector + { Expr::Macro { span: $span, name: span_of_tok($1), args: $2, selection: $3 } } + | 'ID' + { Expr::Variable { span: span_of_tok($1) } } + | 'INT' + { Expr::Integer { span: span_of_tok($1) } } + | 'true' + { Expr::Boolean { span: $span, val: true } } + | 'false' + { Expr::Boolean { span: $span, val: false } } + | Expr '.' 'ID' + { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } + | Expr '@' 'ID' + { Expr::Field { span: $span, lhs: Box::new($1), field: span_of_tok($3) } } + | '(' Expr ')' + { $2 } + | '{' Schedule '}' + { Expr::BlockExpr { span: $span, body: Box::new($2) } } + | '<' Fields '>' + { Expr::Record { span: $span, fields: rev($2) } } ; -Directive -> Directive - : 'partition' Labels 'on' Devices - { Directive::Partition { span : $span, labels : $2, devices : $4 } } - | 'ID' Labels - { Directive::Schedule { span : $span, command : span_of_tok($1), args : $2 } } +Args -> Vec<Expr> + : { vec![] } + | '[' Exprs ']' { rev($2) } ; -Func -> Func - : 'ID' { Func { span : $span, name : $span, }} +Exprs -> Vec<Expr> + : { vec![] } + | Expr { vec![$1] } + | Expr ',' Exprs { snoc($1, $3) } ; -Labels -> Vec<Span> - : 'LABEL' { vec![span_of_tok($1)] } - | '{' LabelsRev '}' { rev($2) } +Fields -> Vec<(Span, Expr)> + : { vec![] } + | 'ID' '=' Expr { vec![(span_of_tok($1), $3)] } + | 'ID' '=' Expr ',' Fields { snoc((span_of_tok($1), $3), $5) } ; -LabelsRev -> Vec<Span> - : { vec![] } - | 'LABEL' { vec![span_of_tok($1)] } - | 'LABEL' ',' LabelsRev { cons(span_of_tok($1), $3) } + +Selector -> Selector + : '(' '*' ')' + { Selector::SelectAll { span: $span } } + | '(' Exprs ')' + { Selector::SelectExprs { span: $span, exprs: $2 } } ; -Devices -> Vec<Device> - : Device { vec![$1] } - | '{' SomeDevices '}' { $2 } +MacroDecl -> MacroDecl + : 'macro_keyword' 'MACRO' Params '(' 'ID' ')' MacroDef + { MacroDecl { + name: span_of_tok($2), + params: rev($3), + selection_name: span_of_tok($5), + def: Box::new($7), + } + } ; -SomeDevices -> Vec<Device> - : Device { vec![$1] } - | SomeDevices ',' Device { snoc($1, $3) } + +Params -> Vec<Span> + : { vec![] } + | '[' Ids ']' { $2 } ; -Device -> Device - : 'ID' - { Device { span : $span, name : span_of_tok($1), } } +Ids -> Vec<Span> + : { vec![] } + | 'ID' { vec![span_of_tok($1)] } + | 'ID' ',' Ids { snoc(span_of_tok($1), $3) } ; +MacroDef -> OperationList : '{' Schedule '}' { $2 }; + Unmatched -> () : 'UNMATCHED' {}; %% @@ -63,32 +114,60 @@ Unmatched -> () : 'UNMATCHED' {}; use cfgrammar::Span; use lrlex::DefaultLexeme; +fn snoc<T>(x: T, mut xs: Vec<T>) -> Vec<T> { + xs.push(x); + xs +} + +fn rev<T>(mut xs: Vec<T>) -> Vec<T> { + xs.reverse(); + xs +} + fn span_of_tok(t : Result<DefaultLexeme, DefaultLexeme>) -> Span { t.map_err(|_| ()).map(|l| l.span()).unwrap() } -fn cons<A>(hd : A, mut tl : Vec<A>) -> Vec<A> { - tl.push(hd); - tl +pub enum OperationList { + NilStmt(), + FinalExpr(Expr), + ConsStmt(Stmt, Box<OperationList>), } -fn snoc<A>(mut hd : Vec<A>, tl : A) -> Vec<A> { - hd.push(tl); - hd +pub enum Stmt { + LetStmt { span: Span, var: Span, expr: Expr }, + AssignStmt { span: Span, var: Span, rhs: Expr }, + ExprStmt { span: Span, exp: Expr }, + Fixpoint { span: Span, limit: FixpointLimit, body: Box<OperationList> }, + MacroDecl { span: Span, def: MacroDecl }, } -fn rev<A>(mut lst : Vec<A>) -> Vec<A> { - lst.reverse(); - lst +pub enum FixpointLimit { + NoLimit { span: Span }, + StopAfter { span: Span, limit: Span }, + PanicAfter { span: Span, limit: Span }, + PrintIter { span: Span }, } -pub struct Func { pub span : Span, pub name : Span, } -pub struct Device { pub span : Span, pub name : Span, } +pub enum Expr { + Function { span: Span, name: Span, args: Vec<Expr>, selection: Selector }, + Macro { span: Span, name: Span, args: Vec<Expr>, selection: Selector }, + Variable { span: Span }, + Integer { span: Span }, + Boolean { span: Span, val: bool }, + Field { span: Span, lhs: Box<Expr>, field: Span }, + BlockExpr { span: Span, body: Box<OperationList> }, + Record { span: Span, fields: Vec<(Span, Expr)> }, +} -pub struct FuncDirectives { pub span : Span, pub func : Func, - pub directives : Vec<Directive> } +pub enum Selector { + SelectAll { span: Span }, + SelectExprs { span: Span, exprs: Vec<Expr> }, +} -pub enum Directive { - Schedule { span : Span, command : Span, args : Vec<Span> }, - Partition { span : Span, labels : Vec<Span>, devices : Vec<Device> }, +pub struct MacroDecl { + pub name: Span, + pub params: Vec<Span>, + pub selection_name: Span, + pub def: Box<OperationList>, } diff --git a/juno_scheduler/src/lib.rs b/juno_scheduler/src/lib.rs index d515633e..1caafe4f 100644 --- a/juno_scheduler/src/lib.rs +++ b/juno_scheduler/src/lib.rs @@ -6,47 +6,27 @@ use lrlex::DefaultLexerTypes; use lrpar::NonStreamingLexer; use hercules_ir::ir::*; +use juno_utils::env::Env; +use juno_utils::stringtab::StringTable; mod parser; use crate::parser::lexer; -// FunctionMap tracks a map from function numbers (produced by semantic analysis) to a tuple of -// - The map from label names to their numbers -// - The name of the function -// - A list of the instances of the function tracking -// + The instantiated type variables -// + The resulting FunctionID -// + A list of each label, tracking the structure at the label and a set of -// the labels which are its descendants -// + A map from NodeID to the innermost label containing it -// This is the core data structure provided from code generation, along with the -// module -pub type FunctionMap = HashMap< - usize, - ( - HashMap<String, usize>, - String, - Vec<( - Vec<TypeID>, - FunctionID, - Vec<(LabeledStructure, HashSet<usize>)>, - HashMap<NodeID, usize>, - )>, - ), ->; -// LabeledStructure represents structures from the source code and where they -// exist in the IR -#[derive(Copy, Clone)] -pub enum LabeledStructure { - Nothing(), - Expression(NodeID), - Loop(NodeID), // Header - Branch(NodeID), // If node -} - -/* -pub fn schedule(module: &Module, info: FunctionMap, schedule: String) -> Result<Vec<Plan>, String> { - if let Ok(mut file) = File::open(schedule) { +mod compile; +mod default; +mod ir; +pub mod labels; +mod pm; + +use crate::compile::*; +use crate::default::*; +use crate::ir::*; +use crate::labels::*; +use crate::pm::*; + +// Given a schedule's filename parse and process the schedule +fn build_schedule(sched_filename: String) -> Result<ScheduleStmt, String> { + if let Ok(mut file) = File::open(sched_filename) { let mut contents = String::new(); if let Ok(_) = file.read_to_string(&mut contents) { let lexerdef = lexer::lexerdef(); @@ -55,284 +35,113 @@ pub fn schedule(module: &Module, info: FunctionMap, schedule: String) -> Result< if errs.is_empty() { match res { - None => Err(format!("No parse errors, but no parsing failed")), + None => Err(format!("No parse errors, but parsing the schedule failed")), Some(schd) => { - let mut sched = generate_schedule(module, info, schd, &lexer)?; - let mut schedules = vec![]; - for i in 0..sched.len() { - schedules.push(sched.remove(&FunctionID::new(i)).unwrap()); - } - Ok(schedules) + compile_schedule(schd, &lexer).map_err(|e| format!("Schedule Error: {}", e)) } } } else { Err(errs .iter() - .map(|e| format!("Syntax Error: {}", e.pp(&lexer, &parser::token_epp))) + .map(|e| { + format!( + "Schedule Syntax Error: {}", + e.pp(&lexer, &parser::token_epp) + ) + }) .collect::<Vec<_>>() .join("\n")) } } else { - Err(format!("Unable to read input file")) + Err(format!("Unable to read schedule")) } } else { - Err(format!("Unable to open input file")) + Err(format!("Unable to open schedule")) } } -// a plan that tracks additional information useful while we construct the -// schedule -struct TempPlan { - schedules: Vec<Vec<Schedule>>, - // we track both the partition each node is in and what labeled caused us - // to assign that partition - partitions: Vec<(usize, PartitionNumber)>, - partition_devices: Vec<Vec<Device>>, -} -type PartitionNumber = usize; - -impl Into<Plan> for TempPlan { - fn into(self) -> Plan { - let num_partitions = self.partition_devices.len(); - Plan { - schedules: self.schedules, - partitions: self - .partitions - .into_iter() - .map(|(_, n)| PartitionID::new(n)) - .collect::<Vec<_>>(), - partition_devices: self - .partition_devices - .into_iter() - .map(|mut d| { - if d.len() != 1 { - panic!("Partition with multiple devices") - } else { - d.pop().unwrap() - } - }) - .collect::<Vec<_>>(), - num_partitions: num_partitions, - } +fn process_schedule(sched_filename: Option<String>) -> Result<ScheduleStmt, String> { + if let Some(name) = sched_filename { + build_schedule(name) + } else { + Ok(default_schedule()) } } -fn generate_schedule( - module: &Module, - info: FunctionMap, - schedule: Vec<parser::FuncDirectives>, - lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, -) -> Result<HashMap<FunctionID, Plan>, String> { - let mut res: HashMap<FunctionID, TempPlan> = HashMap::new(); - - // We initialize every node in every function as not having any schedule - // and being in the default partition which is a CPU-only partition - // (a result of label 0) - for (_, (_, _, func_insts)) in info.iter() { - for (_, func_id, _, _) in func_insts.iter() { - let num_nodes = module.functions[func_id.idx()].nodes.len(); - res.insert( - *func_id, - TempPlan { - schedules: vec![vec![]; num_nodes], - partitions: vec![(0, 0); num_nodes], - partition_devices: vec![vec![Device::CPU]], - }, - ); - } - } - - // Construct a map from function names to function numbers - let mut function_names: HashMap<String, usize> = HashMap::new(); - for (num, (_, nm, _)) in info.iter() { - function_names.insert(nm.clone(), *num); +pub fn schedule_juno( + module: Module, + juno_info: JunoInfo, + sched_filename: Option<String>, + output_dir: String, + module_name: String, +) -> Result<(), String> { + let sched = process_schedule(sched_filename)?; + + // Prepare the scheduler's string table and environment + // For this, we need to put all of the Juno functions into the environment + // and string table + let mut strings = StringTable::new(); + let mut env = Env::new(); + + env.open_scope(); + + let JunoInfo { + func_names, + func_info, + } = juno_info; + for (func_name, func_id) in func_names { + let func_name = strings.lookup_string(func_name); + env.insert(func_name, Value::JunoFunction { func: func_id }); } - // Make the map immutable - let function_names = function_names; - - for parser::FuncDirectives { - span: _, - func, - directives, - } in schedule - { - // Identify the function - let parser::Func { - span: _, - name: func_name, - } = func; - let name = lexer.span_str(func_name).to_string(); - let func_num = match function_names.get(&name) { - Some(num) => num, - None => { - return Err(format!("Function {} is undefined", name)); - } - }; - - // Identify label information - let (label_map, _, func_inst) = info.get(func_num).unwrap(); - let get_label_num = |label_span| { - let label_name = lexer.span_str(label_span).to_string(); - match label_map.get(&label_name) { - Some(num) => Ok(*num), - None => Err(format!("Label {} undefined in {}", label_name, name)), - } - }; - - // Process the partitioning and scheduling directives for each instance - // of the function - for (_, func_id, label_info, node_labels) in func_inst { - let func_info = res.get_mut(func_id).unwrap(); - for directive in &directives { - match directive { - parser::Directive::Partition { - span: _, - labels, - devices, - } => { - // Setup the new partition - let partition_num = func_info.partition_devices.len(); - let mut partition_devices = vec![]; - - for parser::Device { span: _, name } in devices { - let device_name = lexer.span_str(*name).to_string(); - if device_name == "cpu" { - partition_devices.push(Device::CPU); - } else if device_name == "gpu" { - partition_devices.push(Device::GPU); - } else { - return Err(format!("Invalid device {}", device_name)); - } - } - - func_info.partition_devices.push(partition_devices); - - for label in labels { - let label_num = get_label_num(*label)?; - let descendants = &label_info[label_num].1; - - node_labels - .iter() - .filter_map(|(node, label)| { - if *label == label_num || descendants.contains(label) { - Some(node.idx()) - } else { - None - } - }) - .for_each(|node| { - let node_part: &mut (usize, PartitionNumber) = - &mut func_info.partitions[node]; - if !descendants.contains(&node_part.0) { - *node_part = (label_num, partition_num); - } - }); - } - } - parser::Directive::Schedule { - span: _, - command, - args, - } => { - let command = lexer.span_str(*command).to_string(); - if command == "parallelize" { - for label in args { - let label_num = get_label_num(*label)?; - match label_info[label_num].0 { - LabeledStructure::Loop(header) => { - func_info.schedules[header.idx()] - .push(Schedule::ParallelReduce); - } - _ => { - return Err(format!( - "Cannot parallelize {}, not a loop", - lexer.span_str(*label) - )); - } - } - } - } else if command == "vectorize" { - for label in args { - let label_num = get_label_num(*label)?; - match label_info[label_num].0 { - LabeledStructure::Loop(header) => { - // FIXME: Take the factor as part of schedule - func_info.schedules[header.idx()] - .push(Schedule::Vectorizable(8)); - } - _ => { - return Err(format!( - "Cannot vectorize {}, not a loop", - lexer.span_str(*label) - )); - } - } - } - } else { - return Err(format!("Command {} undefined", command)); - } - } - } - } - /* - - /* - for parser::Command { span : _, name : command_name, - args : command_args } in commands.iter() { - if command_args.len() != 0 { todo!("Command arguments not supported") } - - let command = lexer.span_str(*command_name).to_string(); - if command == "cpu" || command == "gpu" { - let partition = res.get(func_id).unwrap() - .partition_devices.len(); - res.get_mut(func_id).unwrap().partition_devices.push( - if command == "cpu" { Device::CPU } - else { Device::GPU }); - - node_labels.iter() - .filter_map(|(node, label)| - if label_num == *label - || label_info[label_num].1.contains(&label) { - Some(node.idx()) - } else { - None - }) - .for_each(|node| { - let node_part : &mut (usize, PartitionNumber) = - &mut res.get_mut(func_id).unwrap().partitions[node]; - if !label_info[label_num].1.contains(&node_part.0) { - *node_part = (label_num, partition); - }}); - } else if command == "parallel" || command == "vectorize" { - match label_info[label_num].0 { - LabeledStructure::Loop(header) => { - res.get_mut(func_id).unwrap() - .schedules[header.idx()] - .push(if command == "parallel" { - Schedule::ParallelReduce - } else { - Schedule::Vectorize - }); - }, - _ => { - return Err(format!("Cannot parallelize, not a loop")); - }, - } - } else { - return Err(format!("Command {} undefined", command)); - } - } - */ + env.open_scope(); + schedule_codegen( + module, + sched, + strings, + env, + func_info, + output_dir, + module_name, + ) + .map_err(|e| format!("Scheduling Error: {}", e)) +} - func_info.partition_devices.push(partition_devices); - */ - } +pub fn schedule_hercules( + module: Module, + sched_filename: Option<String>, + output_dir: String, + module_name: String, +) -> Result<(), String> { + let sched = process_schedule(sched_filename)?; + + // Prepare the scheduler's string table and environment + // For this, we put all of the Hercules function names into the environment + // and string table + let mut strings = StringTable::new(); + let mut env = Env::new(); + + env.open_scope(); + + for (idx, func) in module.functions.iter().enumerate() { + let func_name = strings.lookup_string(func.name.clone()); + env.insert( + func_name, + Value::HerculesFunction { + func: FunctionID::new(idx), + }, + ); } - Ok(res - .into_iter() - .map(|(f, p)| (f, p.into())) - .collect::<HashMap<_, _>>()) + env.open_scope(); + schedule_codegen( + module, + sched, + strings, + env, + JunoFunctions { func_ids: vec![] }, + output_dir, + module_name, + ) + .map_err(|e| format!("Scheduling Error: {}", e)) } -*/ diff --git a/juno_scheduler/src/pm.rs b/juno_scheduler/src/pm.rs new file mode 100644 index 00000000..e93e7ecb --- /dev/null +++ b/juno_scheduler/src/pm.rs @@ -0,0 +1,1564 @@ +use crate::ir::*; +use crate::labels::*; +use hercules_cg::*; +use hercules_ir::*; +use hercules_opt::FunctionEditor; +use hercules_opt::{ + ccp, collapse_returns, crc, dce, dumb_outline, ensure_between_control_flow, float_collections, + fork_split, gcm, gvn, infer_parallel_fork, infer_parallel_reduce, infer_tight_associative, + infer_vectorizable, inline, interprocedural_sroa, outline, phi_elim, predication, slf, sroa, + unforkify, write_predication, +}; + +use tempfile::TempDir; + +use juno_utils::env::Env; +use juno_utils::stringtab::StringTable; + +use std::cell::RefCell; +use std::collections::{BTreeSet, HashMap, HashSet}; +use std::env::temp_dir; +use std::fmt; +use std::fs::File; +use std::io::Write; +use std::iter::zip; +use std::process::{Command, Stdio}; + +#[derive(Debug, Clone)] +pub enum Value { + Label { labels: Vec<LabelInfo> }, + JunoFunction { func: JunoFunctionID }, + HerculesFunction { func: FunctionID }, + Record { fields: HashMap<String, Value> }, + Everything {}, + Selection { selection: Vec<Value> }, + Integer { val: usize }, + Boolean { val: bool }, +} + +#[derive(Debug, Copy, Clone)] +enum CodeLocation { + Label(LabelInfo), + Function(FunctionID), +} + +impl Value { + fn is_everything(&self) -> bool { + match self { + Value::Everything {} => true, + _ => false, + } + } + + fn as_labels(&self) -> Result<Vec<LabelInfo>, SchedulerError> { + match self { + Value::Label { labels } => Ok(labels.clone()), + Value::Selection { selection } => { + let mut result = vec![]; + for val in selection { + result.extend(val.as_labels()?); + } + Ok(result) + } + Value::JunoFunction { .. } | Value::HerculesFunction { .. } => Err( + SchedulerError::SemanticError("Expected labels, found function".to_string()), + ), + Value::Record { .. } => Err(SchedulerError::SemanticError( + "Expected labels, found record".to_string(), + )), + Value::Everything {} => Err(SchedulerError::SemanticError( + "Expected labels, found everything".to_string(), + )), + Value::Integer { .. } => Err(SchedulerError::SemanticError( + "Expected labels, found integer".to_string(), + )), + Value::Boolean { .. } => Err(SchedulerError::SemanticError( + "Expected labels, found boolean".to_string(), + )), + } + } + + fn as_functions(&self, funcs: &JunoFunctions) -> Result<Vec<FunctionID>, SchedulerError> { + match self { + Value::JunoFunction { func } => Ok(funcs.get_function(*func).clone()), + Value::HerculesFunction { func } => Ok(vec![*func]), + Value::Selection { selection } => { + let mut result = vec![]; + for val in selection { + result.extend(val.as_functions(funcs)?); + } + Ok(result) + } + Value::Label { .. } => Err(SchedulerError::SemanticError( + "Expected functions, found label".to_string(), + )), + Value::Record { .. } => Err(SchedulerError::SemanticError( + "Expected functions, found record".to_string(), + )), + Value::Everything {} => Err(SchedulerError::SemanticError( + "Expected functions, found everything".to_string(), + )), + Value::Integer { .. } => Err(SchedulerError::SemanticError( + "Expected functions, found integer".to_string(), + )), + Value::Boolean { .. } => Err(SchedulerError::SemanticError( + "Expected functions, found boolean".to_string(), + )), + } + } + + fn as_locations(&self, funcs: &JunoFunctions) -> Result<Vec<CodeLocation>, SchedulerError> { + match self { + Value::Label { labels } => Ok(labels.iter().map(|l| CodeLocation::Label(*l)).collect()), + Value::JunoFunction { func } => Ok(funcs + .get_function(*func) + .iter() + .map(|f| CodeLocation::Function(*f)) + .collect()), + Value::HerculesFunction { func } => Ok(vec![CodeLocation::Function(*func)]), + Value::Selection { selection } => { + let mut result = vec![]; + for val in selection { + result.extend(val.as_locations(funcs)?); + } + Ok(result) + } + Value::Record { .. } => Err(SchedulerError::SemanticError( + "Expected code locations, found record".to_string(), + )), + Value::Everything {} => { + panic!("Internal error, check is_everything() before using as_functions()") + } + Value::Integer { .. } => Err(SchedulerError::SemanticError( + "Expected code locations, found integer".to_string(), + )), + Value::Boolean { .. } => Err(SchedulerError::SemanticError( + "Expected code locations, found boolean".to_string(), + )), + } + } +} + +#[derive(Debug, Clone)] +pub enum SchedulerError { + UndefinedVariable(String), + UndefinedField(String), + UndefinedLabel(String), + SemanticError(String), + PassError { pass: String, error: String }, + FixpointFailure(), +} + +impl fmt::Display for SchedulerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SchedulerError::UndefinedVariable(nm) => write!(f, "Undefined variable '{}'", nm), + SchedulerError::UndefinedField(nm) => write!(f, "No field '{}'", nm), + SchedulerError::UndefinedLabel(nm) => write!(f, "No label '{}'", nm), + SchedulerError::SemanticError(msg) => write!(f, "{}", msg), + SchedulerError::PassError { pass, error } => { + write!(f, "Error in pass {}: {}", pass, error) + } + SchedulerError::FixpointFailure() => { + write!(f, "Fixpoint did not converge within limit") + } + } + } +} + +#[derive(Debug)] +struct PassManager { + functions: Vec<Function>, + types: RefCell<Vec<Type>>, + constants: RefCell<Vec<Constant>>, + dynamic_constants: RefCell<Vec<DynamicConstant>>, + labels: RefCell<Vec<String>>, + + // Cached analysis results. + pub def_uses: Option<Vec<ImmutableDefUseMap>>, + pub reverse_postorders: Option<Vec<Vec<NodeID>>>, + pub typing: Option<ModuleTyping>, + pub control_subgraphs: Option<Vec<Subgraph>>, + pub doms: Option<Vec<DomTree>>, + pub postdoms: Option<Vec<DomTree>>, + pub fork_join_maps: Option<Vec<HashMap<NodeID, NodeID>>>, + pub fork_join_nests: Option<Vec<HashMap<NodeID, Vec<NodeID>>>>, + pub loops: Option<Vec<LoopTree>>, + pub reduce_cycles: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, + pub data_nodes_in_fork_joins: Option<Vec<HashMap<NodeID, HashSet<NodeID>>>>, + pub bbs: Option<Vec<BasicBlocks>>, + pub collection_objects: Option<CollectionObjects>, + pub callgraph: Option<CallGraph>, +} + +impl PassManager { + fn new(module: Module) -> Self { + let Module { + functions, + types, + constants, + dynamic_constants, + labels, + } = module; + PassManager { + functions, + types: RefCell::new(types), + constants: RefCell::new(constants), + dynamic_constants: RefCell::new(dynamic_constants), + labels: RefCell::new(labels), + def_uses: None, + reverse_postorders: None, + typing: None, + control_subgraphs: None, + doms: None, + postdoms: None, + fork_join_maps: None, + fork_join_nests: None, + loops: None, + reduce_cycles: None, + data_nodes_in_fork_joins: None, + bbs: None, + collection_objects: None, + callgraph: None, + } + } + + pub fn make_def_uses(&mut self) { + if self.def_uses.is_none() { + self.def_uses = Some(self.functions.iter().map(def_use).collect()); + } + } + + pub fn make_reverse_postorders(&mut self) { + if self.reverse_postorders.is_none() { + self.make_def_uses(); + self.reverse_postorders = Some( + self.def_uses + .as_ref() + .unwrap() + .iter() + .map(reverse_postorder) + .collect(), + ); + } + } + + pub fn make_typing(&mut self) { + if self.typing.is_none() { + self.make_reverse_postorders(); + self.typing = Some( + typecheck( + &self.functions, + &mut self.types.borrow_mut(), + &self.constants.borrow(), + &mut self.dynamic_constants.borrow_mut(), + self.reverse_postorders.as_ref().unwrap(), + ) + .unwrap(), + ); + } + } + + pub fn make_control_subgraphs(&mut self) { + if self.control_subgraphs.is_none() { + self.make_def_uses(); + self.control_subgraphs = Some( + zip(&self.functions, self.def_uses.as_ref().unwrap()) + .map(|(function, def_use)| control_subgraph(function, def_use)) + .collect(), + ); + } + } + + pub fn make_doms(&mut self) { + if self.doms.is_none() { + self.make_control_subgraphs(); + self.doms = Some( + self.control_subgraphs + .as_ref() + .unwrap() + .iter() + .map(|subgraph| dominator(subgraph, NodeID::new(0))) + .collect(), + ); + } + } + + pub fn make_postdoms(&mut self) { + if self.postdoms.is_none() { + self.make_control_subgraphs(); + self.postdoms = Some( + zip( + self.control_subgraphs.as_ref().unwrap().iter(), + self.functions.iter(), + ) + .map(|(subgraph, function)| dominator(subgraph, NodeID::new(function.nodes.len()))) + .collect(), + ); + } + } + + pub fn make_fork_join_maps(&mut self) { + if self.fork_join_maps.is_none() { + self.make_control_subgraphs(); + self.fork_join_maps = Some( + zip( + self.functions.iter(), + self.control_subgraphs.as_ref().unwrap().iter(), + ) + .map(|(function, subgraph)| fork_join_map(function, subgraph)) + .collect(), + ); + } + } + + pub fn make_fork_join_nests(&mut self) { + if self.fork_join_nests.is_none() { + self.make_doms(); + self.make_fork_join_maps(); + self.fork_join_nests = Some( + zip( + self.functions.iter(), + zip( + self.doms.as_ref().unwrap().iter(), + self.fork_join_maps.as_ref().unwrap().iter(), + ), + ) + .map(|(function, (dom, fork_join_map))| { + compute_fork_join_nesting(function, dom, fork_join_map) + }) + .collect(), + ); + } + } + + pub fn make_loops(&mut self) { + if self.loops.is_none() { + self.make_control_subgraphs(); + self.make_doms(); + self.make_fork_join_maps(); + let control_subgraphs = self.control_subgraphs.as_ref().unwrap().iter(); + let doms = self.doms.as_ref().unwrap().iter(); + let fork_join_maps = self.fork_join_maps.as_ref().unwrap().iter(); + self.loops = Some( + zip(control_subgraphs, zip(doms, fork_join_maps)) + .map(|(control_subgraph, (dom, fork_join_map))| { + loops(control_subgraph, NodeID::new(0), dom, fork_join_map) + }) + .collect(), + ); + } + } + + pub fn make_reduce_cycles(&mut self) { + if self.reduce_cycles.is_none() { + self.make_def_uses(); + let def_uses = self.def_uses.as_ref().unwrap().iter(); + self.reduce_cycles = Some( + zip(self.functions.iter(), def_uses) + .map(|(function, def_use)| reduce_cycles(function, def_use)) + .collect(), + ); + } + } + + pub fn make_data_nodes_in_fork_joins(&mut self) { + if self.data_nodes_in_fork_joins.is_none() { + self.make_def_uses(); + self.make_fork_join_maps(); + self.data_nodes_in_fork_joins = Some( + zip( + self.functions.iter(), + zip( + self.def_uses.as_ref().unwrap().iter(), + self.fork_join_maps.as_ref().unwrap().iter(), + ), + ) + .map(|(function, (def_use, fork_join_map))| { + data_nodes_in_fork_joins(function, def_use, fork_join_map) + }) + .collect(), + ); + } + } + + pub fn make_collection_objects(&mut self) { + if self.collection_objects.is_none() { + self.make_reverse_postorders(); + self.make_typing(); + self.make_callgraph(); + let reverse_postorders = self.reverse_postorders.as_ref().unwrap(); + let typing = self.typing.as_ref().unwrap(); + let callgraph = self.callgraph.as_ref().unwrap(); + self.collection_objects = Some(collection_objects( + &self.functions, + &self.types.borrow(), + reverse_postorders, + typing, + callgraph, + )); + } + } + + pub fn make_callgraph(&mut self) { + if self.callgraph.is_none() { + self.callgraph = Some(callgraph(&self.functions)); + } + } + + pub fn delete_gravestones(&mut self) { + for func in self.functions.iter_mut() { + func.delete_gravestones(); + } + } + + fn clear_analyses(&mut self) { + self.def_uses = None; + self.reverse_postorders = None; + self.typing = None; + self.control_subgraphs = None; + self.doms = None; + self.postdoms = None; + self.fork_join_maps = None; + self.fork_join_nests = None; + self.loops = None; + self.reduce_cycles = None; + self.data_nodes_in_fork_joins = None; + self.bbs = None; + self.collection_objects = None; + self.callgraph = None; + } + + fn with_mod<B, F>(&mut self, mut f: F) -> B + where + F: FnMut(&mut Module) -> B, + { + let mut module = Module { + functions: std::mem::take(&mut self.functions), + types: self.types.take(), + constants: self.constants.take(), + dynamic_constants: self.dynamic_constants.take(), + labels: self.labels.take(), + }; + + let res = f(&mut module); + + let Module { + functions, + types, + constants, + dynamic_constants, + labels, + } = module; + self.functions = functions; + self.types.replace(types); + self.constants.replace(constants); + self.dynamic_constants.replace(dynamic_constants); + self.labels.replace(labels); + + res + } + + fn codegen(mut self, output_dir: String, module_name: String) -> Result<(), SchedulerError> { + self.make_typing(); + self.make_control_subgraphs(); + self.make_collection_objects(); + self.make_callgraph(); + + let PassManager { + functions, + types, + constants, + dynamic_constants, + labels, + typing: Some(typing), + control_subgraphs: Some(control_subgraphs), + bbs: Some(bbs), + collection_objects: Some(collection_objects), + callgraph: Some(callgraph), + .. + } = self + else { + return Err(SchedulerError::PassError { + pass: "codegen".to_string(), + error: "Missing basic blocks".to_string(), + }); + }; + + let module = Module { + functions, + types: types.into_inner(), + constants: constants.into_inner(), + dynamic_constants: dynamic_constants.into_inner(), + labels: labels.into_inner(), + }; + + let devices = device_placement(&module.functions, &callgraph); + + let mut rust_rt = String::new(); + let mut llvm_ir = String::new(); + for idx in 0..module.functions.len() { + match devices[idx] { + Device::LLVM => cpu_codegen( + &module.functions[idx], + &module.types, + &module.constants, + &module.dynamic_constants, + &typing[idx], + &control_subgraphs[idx], + &bbs[idx], + &mut llvm_ir, + ) + .map_err(|e| SchedulerError::PassError { + pass: "cpu codegen".to_string(), + error: format!("{}", e), + })?, + Device::AsyncRust => rt_codegen( + FunctionID::new(idx), + &module, + &typing[idx], + &control_subgraphs[idx], + &bbs[idx], + &collection_objects, + &callgraph, + &devices, + &mut rust_rt, + ) + .map_err(|e| SchedulerError::PassError { + pass: "rust codegen".to_string(), + error: format!("{}", e), + })?, + _ => todo!(), + } + } + println!("{}", llvm_ir); + println!("{}", rust_rt); + + // Write the LLVM IR into a temporary file. + let tmp_dir = TempDir::new().unwrap(); + let mut tmp_path = tmp_dir.path().to_path_buf(); + tmp_path.push(format!("{}.ll", module_name)); + println!("{}", tmp_path.display()); + let mut file = File::create(&tmp_path).expect("PANIC: Unable to open output LLVM IR file."); + file.write_all(llvm_ir.as_bytes()) + .expect("PANIC: Unable to write output LLVM IR file contents."); + + // Compile LLVM IR into an ELF object file. + let output_archive = format!("{}/lib{}.a", output_dir, module_name); + let mut clang_process = Command::new("clang") + .arg(&tmp_path) + .arg("--emit-static-lib") + .arg("-O3") + .arg("-march=native") + .arg("-o") + .arg(&output_archive) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .expect("Error running clang. Is it installed?"); + assert!(clang_process.wait().unwrap().success()); + + // Write the Rust runtime into a file. + let output_rt = format!("{}/rt_{}.hrt", output_dir, module_name); + println!("{}", output_rt); + let mut file = + File::create(&output_rt).expect("PANIC: Unable to open output Rust runtime file."); + file.write_all(rust_rt.as_bytes()) + .expect("PANIC: Unable to write output Rust runtime file contents."); + + Ok(()) + } +} + +pub fn schedule_codegen( + mut module: Module, + schedule: ScheduleStmt, + mut stringtab: StringTable, + mut env: Env<usize, Value>, + functions: JunoFunctions, + output_dir: String, + module_name: String, +) -> Result<(), SchedulerError> { + let mut pm = PassManager::new(module); + let _ = schedule_interpret(&mut pm, &schedule, &mut stringtab, &mut env, &functions)?; + pm.codegen(output_dir, module_name) +} + +// Interpreter for statements and expressions returns a bool indicating whether +// any optimization ran and changed the IR. This is used for implementing +// the fixpoint +fn schedule_interpret( + pm: &mut PassManager, + schedule: &ScheduleStmt, + stringtab: &mut StringTable, + env: &mut Env<usize, Value>, + functions: &JunoFunctions, +) -> Result<bool, SchedulerError> { + match schedule { + ScheduleStmt::Fixpoint { body, limit } => { + let mut i = 0; + loop { + // If no change was made, we've reached the fixpoint and are done + if !schedule_interpret(pm, body, stringtab, env, functions)? { + break; + } + // Otherwise, increase the iteration count and check the limit + i += 1; + match limit { + FixpointLimit::NoLimit() => {} + FixpointLimit::PrintIter() => { + println!("Finished Iteration {}", i - 1) + } + FixpointLimit::StopAfter(n) => { + if i >= *n { + break; + } + } + FixpointLimit::PanicAfter(n) => { + if i >= *n { + return Err(SchedulerError::FixpointFailure()); + } + } + } + } + // If we ran just 1 iteration, then no changes were made and otherwise some changes + // were made + Ok(i > 1) + } + ScheduleStmt::Block { body } => { + let mut modified = false; + env.open_scope(); + for command in body { + modified |= schedule_interpret(pm, command, stringtab, env, functions)?; + } + env.close_scope(); + Ok(modified) + } + ScheduleStmt::Let { var, exp } => { + let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?; + let var_id = stringtab.lookup_string(var.clone()); + env.insert(var_id, res); + Ok(modified) + } + ScheduleStmt::Assign { var, exp } => { + let (res, modified) = interp_expr(pm, exp, stringtab, env, functions)?; + let var_id = stringtab.lookup_string(var.clone()); + match env.lookup_mut(&var_id) { + None => { + return Err(SchedulerError::UndefinedVariable(var.clone())); + } + Some(val) => { + *val = res; + } + } + Ok(modified) + } + ScheduleStmt::AddSchedule { sched, on } => match on { + Selector::Everything() => Err(SchedulerError::SemanticError( + "Cannot apply schedule to everything".to_string(), + )), + Selector::Selection(selection) => { + let mut changed = false; + for label in selection { + let (label, modified) = interp_expr(pm, label, stringtab, env, functions)?; + changed |= modified; + add_schedule(pm, sched.clone(), label.as_labels()?); + } + Ok(changed) + } + }, + ScheduleStmt::AddDevice { device, on } => match on { + Selector::Everything() => Err(SchedulerError::SemanticError( + "Cannot apply device to everything".to_string(), + )), + Selector::Selection(selection) => { + let mut changed = false; + for func in selection { + let (func, modified) = interp_expr(pm, func, stringtab, env, functions)?; + changed |= modified; + add_device(pm, device.clone(), func.as_functions(functions)?); + } + Ok(changed) + } + }, + } +} + +fn interp_expr( + pm: &mut PassManager, + expr: &ScheduleExp, + stringtab: &mut StringTable, + env: &mut Env<usize, Value>, + functions: &JunoFunctions, +) -> Result<(Value, bool), SchedulerError> { + match expr { + ScheduleExp::Variable { var } => { + let var_id = stringtab.lookup_string(var.clone()); + match env.lookup(&var_id) { + None => Err(SchedulerError::UndefinedVariable(var.clone())), + Some(v) => Ok((v.clone(), false)), + } + } + ScheduleExp::Integer { val } => Ok((Value::Integer { val: *val }, false)), + ScheduleExp::Boolean { val } => Ok((Value::Boolean { val: *val }, false)), + ScheduleExp::Field { collect, field } => { + let (lhs, changed) = interp_expr(pm, collect, stringtab, env, functions)?; + match lhs { + Value::Label { .. } + | Value::Selection { .. } + | Value::Everything { .. } + | Value::Integer { .. } + | Value::Boolean { .. } => Err(SchedulerError::UndefinedField(field.clone())), + Value::JunoFunction { func } => { + match pm.labels.borrow().iter().position(|s| s == field) { + None => Err(SchedulerError::UndefinedLabel(field.clone())), + Some(label_idx) => Ok(( + Value::Label { + labels: functions + .get_function(func) + .iter() + .map(|f| LabelInfo { + func: *f, + label: LabelID::new(label_idx), + }) + .collect(), + }, + changed, + )), + } + } + Value::HerculesFunction { func } => { + match pm.labels.borrow().iter().position(|s| s == field) { + None => Err(SchedulerError::UndefinedLabel(field.clone())), + Some(label_idx) => Ok(( + Value::Label { + labels: vec![LabelInfo { + func: func, + label: LabelID::new(label_idx), + }], + }, + changed, + )), + } + } + Value::Record { fields } => match fields.get(field) { + None => Err(SchedulerError::UndefinedField(field.clone())), + Some(v) => Ok((v.clone(), changed)), + }, + } + } + ScheduleExp::RunPass { pass, args, on } => { + let mut changed = false; + let mut arg_vals = vec![]; + for arg in args { + let (val, modified) = interp_expr(pm, arg, stringtab, env, functions)?; + arg_vals.push(val); + changed |= modified; + } + + let selection = match on { + Selector::Everything() => None, + Selector::Selection(selection) => { + let mut locs = vec![]; + let mut everything = false; + for loc in selection { + let (val, modified) = interp_expr(pm, loc, stringtab, env, functions)?; + changed |= modified; + if val.is_everything() { + everything = true; + break; + } + locs.extend(val.as_locations(functions)?); + } + if everything { + None + } else { + Some(locs) + } + } + }; + + let (res, modified) = run_pass(pm, *pass, arg_vals, selection)?; + changed |= modified; + Ok((res, changed)) + } + ScheduleExp::Record { fields } => { + let mut result = HashMap::new(); + let mut changed = false; + for (field, val) in fields { + let (val, modified) = interp_expr(pm, val, stringtab, env, functions)?; + result.insert(field.clone(), val); + changed |= modified; + } + Ok((Value::Record { fields: result }, changed)) + } + ScheduleExp::Block { body, res } => { + let mut changed = false; + + env.open_scope(); + for command in body { + changed |= schedule_interpret(pm, command, stringtab, env, functions)?; + } + let (res, modified) = interp_expr(pm, res, stringtab, env, functions)?; + env.close_scope(); + + Ok((res, changed || modified)) + } + ScheduleExp::Selection { selection } => match selection { + Selector::Everything() => Ok((Value::Everything {}, false)), + Selector::Selection(selection) => { + let mut values = vec![]; + let mut changed = false; + for e in selection { + let (val, modified) = interp_expr(pm, e, stringtab, env, functions)?; + values.push(val); + changed |= modified; + } + Ok((Value::Selection { selection: values }, changed)) + } + }, + } +} + +fn add_schedule(pm: &mut PassManager, sched: Schedule, label_ids: Vec<LabelInfo>) { + for LabelInfo { func, label } in label_ids { + let nodes = pm.functions[func.idx()] + .labels + .iter() + .enumerate() + .filter(|(i, ls)| ls.contains(&label)) + .map(|(i, ls)| i) + .collect::<Vec<_>>(); + for node in nodes { + pm.functions[func.idx()].schedules[node].push(sched.clone()); + } + } +} + +fn add_device(pm: &mut PassManager, device: Device, funcs: Vec<FunctionID>) { + for func in funcs { + pm.functions[func.idx()].device = Some(device.clone()); + } +} + +#[derive(Debug, Clone)] +enum FunctionSelection { + Nothing(), + Everything(), + Labels(HashSet<LabelID>), +} + +impl FunctionSelection { + fn add_label(&mut self, label: LabelID) { + match self { + FunctionSelection::Nothing() => { + *self = FunctionSelection::Labels(HashSet::from([label])); + } + FunctionSelection::Everything() => {} + FunctionSelection::Labels(set) => { + set.insert(label); + } + } + } + + fn add_everything(&mut self) { + *self = FunctionSelection::Everything(); + } +} + +fn build_editors<'a>(pm: &'a mut PassManager) -> Vec<FunctionEditor<'a>> { + pm.make_def_uses(); + let def_uses = pm.def_uses.take().unwrap(); + pm.functions + .iter_mut() + .zip(def_uses.iter()) + .enumerate() + .map(|(idx, (func, def_use))| { + FunctionEditor::new( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ) + }) + .collect() +} + +// With a selection, we process it to identify which labels in which functions are to be selected +fn construct_selection(pm: &PassManager, selection: Vec<CodeLocation>) -> Vec<FunctionSelection> { + let mut selected = vec![FunctionSelection::Nothing(); pm.functions.len()]; + for loc in selection { + match loc { + CodeLocation::Label(label) => selected[label.func.idx()].add_label(label.label), + CodeLocation::Function(func) => selected[func.idx()].add_everything(), + } + } + selected +} + +// Given a selection, constructs the set of the nodes selected for a single function, returning the +// function's id +fn selection_as_set( + pm: &PassManager, + selection: Option<Vec<CodeLocation>>, +) -> Option<(BTreeSet<NodeID>, FunctionID)> { + if let Some(selection) = selection { + let selection = construct_selection(pm, selection); + let mut result = None; + + for (idx, (selected, func)) in selection.into_iter().zip(pm.functions.iter()).enumerate() { + match selected { + FunctionSelection::Nothing() => {} + FunctionSelection::Everything() => match result { + Some(_) => { + return None; + } + None => { + result = Some(( + (0..func.nodes.len()).map(|i| NodeID::new(i)).collect(), + FunctionID::new(idx), + )); + } + }, + FunctionSelection::Labels(labels) => match result { + Some(_) => { + return None; + } + None => { + result = Some(( + (0..func.nodes.len()) + .filter(|i| !func.labels[*i].is_disjoint(&labels)) + .map(|i| NodeID::new(i)) + .collect(), + FunctionID::new(idx), + )); + } + }, + } + } + + result + } else { + None + } +} + +fn build_selection<'a>( + pm: &'a mut PassManager, + selection: Option<Vec<CodeLocation>>, +) -> Vec<Option<FunctionEditor<'a>>> { + // Build def uses, which are needed for the editors we'll construct + pm.make_def_uses(); + let def_uses = pm.def_uses.take().unwrap(); + + if let Some(selection) = selection { + let selected = construct_selection(pm, selection); + + pm.functions + .iter_mut() + .zip(selected.iter()) + .zip(def_uses.iter()) + .enumerate() + .map(|(idx, ((func, selected), def_use))| match selected { + FunctionSelection::Nothing() => None, + FunctionSelection::Everything() => Some(FunctionEditor::new( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + )), + FunctionSelection::Labels(labels) => Some(FunctionEditor::new_labeled( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + labels, + )), + }) + .collect() + } else { + build_editors(pm) + .into_iter() + .map(|func| Some(func)) + .collect() + } +} + +fn run_pass( + pm: &mut PassManager, + pass: Pass, + args: Vec<Value>, + selection: Option<Vec<CodeLocation>>, +) -> Result<(Value, bool), SchedulerError> { + let mut result = Value::Record { + fields: HashMap::new(), + }; + let mut changed = false; + + match pass { + Pass::AutoOutline => { + if let Some(_) = selection { + return Err(SchedulerError::PassError { + pass: "autoOutline".to_string(), + error: "must be applied to the entire module".to_string(), + }); + } + + pm.make_def_uses(); + let def_uses = pm.def_uses.take().unwrap(); + let mut editors: Vec<_> = pm + .functions + .iter_mut() + .zip(def_uses.iter()) + .enumerate() + .map(|(idx, (func, def_use))| { + FunctionEditor::new( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ) + }) + .collect(); + for editor in editors.iter_mut() { + collapse_returns(editor); + ensure_between_control_flow(editor); + } + pm.clear_analyses(); + + pm.make_def_uses(); + pm.make_typing(); + pm.make_control_subgraphs(); + pm.make_doms(); + + let def_uses = pm.def_uses.take().unwrap(); + let typing = pm.typing.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + let doms = pm.doms.take().unwrap(); + let old_num_funcs = pm.functions.len(); + + let mut editors: Vec<_> = pm + .functions + .iter_mut() + .zip(def_uses.iter()) + .enumerate() + .map(|(idx, (func, def_use))| { + FunctionEditor::new( + func, + FunctionID::new(idx), + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + def_use, + ) + }) + .collect(); + + let mut new_funcs = vec![]; + for (idx, editor) in editors.iter_mut().enumerate() { + let new_func_id = FunctionID::new(old_num_funcs + new_funcs.len()); + let new_func = dumb_outline( + editor, + &typing[idx], + &control_subgraphs[idx], + &doms[idx], + new_func_id, + ); + if let Some(new_func) = new_func { + let Value::Record { ref mut fields } = result else { + panic!("AutoOutline produces a record"); + }; + fields.insert( + new_func.name.clone(), + Value::HerculesFunction { func: new_func_id }, + ); + new_funcs.push(new_func); + } + } + + for func in pm.functions.iter_mut() { + func.delete_gravestones(); + } + pm.functions.extend(new_funcs); + pm.clear_analyses(); + } + Pass::CCP => { + assert!(args.is_empty()); + pm.make_reverse_postorders(); + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + for (func, reverse_postorder) in build_selection(pm, selection) + .into_iter() + .zip(reverse_postorders.iter()) + { + let Some(mut func) = func else { + continue; + }; + ccp(&mut func, reverse_postorder); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::CRC => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + crc(&mut func); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::DCE => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + dce(&mut func); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::DeleteUncalled => { + todo!("Delete Uncalled changes FunctionIDs, a bunch of bookkeeping is needed for the pass manager to address this") + } + Pass::FloatCollections => { + assert!(args.is_empty()); + if let Some(_) = selection { + return Err(SchedulerError::PassError { + pass: "floatCollections".to_string(), + error: "must be applied to the entire module".to_string(), + }); + } + + pm.make_typing(); + pm.make_callgraph(); + let typing = pm.typing.take().unwrap(); + let callgraph = pm.callgraph.take().unwrap(); + + let devices = device_placement(&pm.functions, &callgraph); + + let mut editors = build_editors(pm); + float_collections(&mut editors, &typing, &callgraph, &devices); + + for func in editors { + changed |= func.modified(); + } + + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::ForkGuardElim => { + todo!("Fork Guard Elim doesn't use editor") + } + Pass::ForkSplit => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_reduce_cycles(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) + { + let Some(mut func) = func else { + continue; + }; + fork_split(&mut func, fork_join_map, reduce_cycles); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Forkify => { + todo!("Forkify doesn't use editor") + } + Pass::GCM => { + assert!(args.is_empty()); + if let Some(_) = selection { + return Err(SchedulerError::PassError { + pass: "gcm".to_string(), + error: "must be applied to the entire module".to_string(), + }); + } + + loop { + pm.make_def_uses(); + pm.make_reverse_postorders(); + pm.make_typing(); + pm.make_control_subgraphs(); + pm.make_doms(); + pm.make_fork_join_maps(); + pm.make_loops(); + pm.make_collection_objects(); + + let def_uses = pm.def_uses.take().unwrap(); + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + let typing = pm.typing.take().unwrap(); + let doms = pm.doms.take().unwrap(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let loops = pm.loops.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + let collection_objects = pm.collection_objects.take().unwrap(); + + let mut bbs = vec![]; + + for ( + ( + ( + ((((mut func, def_use), reverse_postorder), typing), control_subgraph), + doms, + ), + fork_join_map, + ), + loops, + ) in build_editors(pm) + .into_iter() + .zip(def_uses.iter()) + .zip(reverse_postorders.iter()) + .zip(typing.iter()) + .zip(control_subgraphs.iter()) + .zip(doms.iter()) + .zip(fork_join_maps.iter()) + .zip(loops.iter()) + { + if let Some(bb) = gcm( + &mut func, + def_use, + reverse_postorder, + typing, + control_subgraph, + doms, + fork_join_map, + loops, + &collection_objects, + ) { + bbs.push(bb); + } + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + if bbs.len() == pm.functions.len() { + pm.bbs = Some(bbs); + break; + } + } + } + Pass::GVN => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + gvn(&mut func, false); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::InferSchedules => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + pm.make_reduce_cycles(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + let reduce_cycles = pm.reduce_cycles.take().unwrap(); + for ((func, fork_join_map), reduce_cycles) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + .zip(reduce_cycles.iter()) + { + let Some(mut func) = func else { + continue; + }; + infer_parallel_reduce(&mut func, fork_join_map, reduce_cycles); + infer_parallel_fork(&mut func, fork_join_map); + infer_vectorizable(&mut func, fork_join_map); + infer_tight_associative(&mut func, reduce_cycles); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Inline => { + assert!(args.is_empty()); + if let Some(_) = selection { + return Err(SchedulerError::PassError { + pass: "inline".to_string(), + error: "must be applied to the entire module (currently)".to_string(), + }); + } + + pm.make_callgraph(); + let callgraph = pm.callgraph.take().unwrap(); + + let mut editors = build_editors(pm); + inline(&mut editors, &callgraph); + + for func in editors { + changed |= func.modified(); + } + + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::InterproceduralSROA => { + assert!(args.is_empty()); + if let Some(_) = selection { + return Err(SchedulerError::PassError { + pass: "interproceduralSROA".to_string(), + error: "must be applied to the entire module".to_string(), + }); + } + + let mut editors = build_editors(pm); + interprocedural_sroa(&mut editors); + + for func in editors { + changed |= func.modified(); + } + + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Outline => { + let Some((nodes, func)) = selection_as_set(pm, selection) else { + return Err(SchedulerError::PassError { + pass: "outline".to_string(), + error: "must be applied to nodes in a single function".to_string(), + }); + }; + + pm.make_def_uses(); + let def_uses = pm.def_uses.take().unwrap(); + + let mut editor = FunctionEditor::new( + &mut pm.functions[func.idx()], + func, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + &def_uses[func.idx()], + ); + + collapse_returns(&mut editor); + ensure_between_control_flow(&mut editor); + pm.clear_analyses(); + + pm.make_def_uses(); + pm.make_typing(); + pm.make_control_subgraphs(); + pm.make_doms(); + + let def_uses = pm.def_uses.take().unwrap(); + let typing = pm.typing.take().unwrap(); + let control_subgraphs = pm.control_subgraphs.take().unwrap(); + let doms = pm.doms.take().unwrap(); + let new_func_id = FunctionID::new(pm.functions.len()); + + let mut editor = FunctionEditor::new( + &mut pm.functions[func.idx()], + func, + &pm.constants, + &pm.dynamic_constants, + &pm.types, + &pm.labels, + &def_uses[func.idx()], + ); + + let new_func = outline( + &mut editor, + &typing[func.idx()], + &control_subgraphs[func.idx()], + &doms[func.idx()], + &nodes, + new_func_id, + ); + let Some(new_func) = new_func else { + return Err(SchedulerError::PassError { + pass: "outlining".to_string(), + error: "failed to outline".to_string(), + }); + }; + + pm.functions.push(new_func); + changed = true; + pm.functions[func.idx()].delete_gravestones(); + pm.clear_analyses(); + + result = Value::HerculesFunction { func: new_func_id }; + } + Pass::PhiElim => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + phi_elim(&mut func); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Predication => { + assert!(args.is_empty()); + pm.make_typing(); + let typing = pm.typing.take().unwrap(); + + for (func, types) in build_selection(pm, selection) + .into_iter() + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + predication(&mut func, types); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::SLF => { + assert!(args.is_empty()); + pm.make_reverse_postorders(); + pm.make_typing(); + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + let typing = pm.typing.take().unwrap(); + + for ((func, reverse_postorder), types) in build_selection(pm, selection) + .into_iter() + .zip(reverse_postorders.iter()) + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + slf(&mut func, reverse_postorder, types); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::SROA => { + assert!(args.is_empty()); + pm.make_reverse_postorders(); + pm.make_typing(); + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + let typing = pm.typing.take().unwrap(); + + for ((func, reverse_postorder), types) in build_selection(pm, selection) + .into_iter() + .zip(reverse_postorders.iter()) + .zip(typing.iter()) + { + let Some(mut func) = func else { + continue; + }; + sroa(&mut func, reverse_postorder, types); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Unforkify => { + assert!(args.is_empty()); + pm.make_fork_join_maps(); + let fork_join_maps = pm.fork_join_maps.take().unwrap(); + + for (func, fork_join_map) in build_selection(pm, selection) + .into_iter() + .zip(fork_join_maps.iter()) + { + let Some(mut func) = func else { + continue; + }; + unforkify(&mut func, fork_join_map); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::WritePredication => { + assert!(args.is_empty()); + for func in build_selection(pm, selection) { + let Some(mut func) = func else { + continue; + }; + write_predication(&mut func); + changed |= func.modified(); + } + pm.delete_gravestones(); + pm.clear_analyses(); + } + Pass::Verify => { + assert!(args.is_empty()); + let (def_uses, reverse_postorders, typing, subgraphs, doms, postdoms, fork_join_maps) = + pm.with_mod(|module| verify(module)) + .map_err(|msg| SchedulerError::PassError { + pass: "verify".to_string(), + error: format!("failed: {}", msg), + })?; + + // Verification produces a bunch of analysis results that + // may be useful for later passes. + pm.def_uses = Some(def_uses); + pm.reverse_postorders = Some(reverse_postorders); + pm.typing = Some(typing); + pm.control_subgraphs = Some(subgraphs); + pm.doms = Some(doms); + pm.postdoms = Some(postdoms); + pm.fork_join_maps = Some(fork_join_maps); + } + Pass::Xdot => { + assert!(args.len() == 1); + let force_analyses = match args[0] { + Value::Boolean { val } => val, + _ => { + return Err(SchedulerError::PassError { + pass: "xdot".to_string(), + error: "expected boolean argument".to_string(), + }); + } + }; + + pm.make_reverse_postorders(); + if force_analyses { + pm.make_doms(); + pm.make_fork_join_maps(); + } + + let reverse_postorders = pm.reverse_postorders.take().unwrap(); + let doms = pm.doms.take(); + let fork_join_maps = pm.fork_join_maps.take(); + pm.with_mod(|module| { + xdot_module( + module, + &reverse_postorders, + doms.as_ref(), + fork_join_maps.as_ref(), + ) + }); + } + } + + Ok((result, changed)) +} diff --git a/juno_utils/.gitignore b/juno_utils/.gitignore new file mode 100644 index 00000000..ef5f7e55 --- /dev/null +++ b/juno_utils/.gitignore @@ -0,0 +1,4 @@ +*.aux +*.log +*.out +*.pdf diff --git a/juno_utils/Cargo.toml b/juno_utils/Cargo.toml new file mode 100644 index 00000000..8de3b651 --- /dev/null +++ b/juno_utils/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "juno_utils" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[lib] +name = "juno_utils" +path = "src/lib.rs" + +[dependencies] +serde = { version = "*", features = ["derive"] } diff --git a/juno_frontend/src/env.rs b/juno_utils/src/env.rs similarity index 91% rename from juno_frontend/src/env.rs rename to juno_utils/src/env.rs index fb746045..cfa84b78 100644 --- a/juno_frontend/src/env.rs +++ b/juno_utils/src/env.rs @@ -24,6 +24,13 @@ impl<K: Eq + Hash + Copy, V> Env<K, V> { } } + pub fn lookup_mut(&mut self, k: &K) -> Option<&mut V> { + match self.table.get_mut(k) { + None => None, + Some(l) => l.last_mut(), + } + } + pub fn insert(&mut self, k: K, v: V) { if self.scope[self.scope.len() - 1].contains(&k) { match self.table.get_mut(&k) { diff --git a/juno_utils/src/lib.rs b/juno_utils/src/lib.rs new file mode 100644 index 00000000..56b404be --- /dev/null +++ b/juno_utils/src/lib.rs @@ -0,0 +1,2 @@ +pub mod env; +pub mod stringtab; diff --git a/juno_utils/src/stringtab.rs b/juno_utils/src/stringtab.rs new file mode 100644 index 00000000..e151b830 --- /dev/null +++ b/juno_utils/src/stringtab.rs @@ -0,0 +1,48 @@ +extern crate serde; + +use self::serde::{Deserialize, Serialize}; + +use std::collections::HashMap; + +// Map strings to unique identifiers and counts uids +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct StringTable { + count: usize, + string_to_index: HashMap<String, usize>, + index_to_string: HashMap<usize, String>, +} + +impl StringTable { + pub fn new() -> StringTable { + StringTable { + count: 0, + string_to_index: HashMap::new(), + index_to_string: HashMap::new(), + } + } + + // Produce the UID for a string + pub fn lookup_string(&mut self, s: String) -> usize { + match self.string_to_index.get(&s) { + Some(n) => *n, + None => { + let n = self.count; + self.count += 1; + self.string_to_index.insert(s.clone(), n); + self.index_to_string.insert(n, s); + n + } + } + } + + // Identify the string corresponding to a UID + pub fn lookup_id(&self, n: usize) -> Option<String> { + self.index_to_string.get(&n).cloned() + } +} + +impl Default for StringTable { + fn default() -> Self { + StringTable::new() + } +} -- GitLab