use std::cell::Ref; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; use std::iter::{empty, once, zip, FromIterator}; use bitvec::prelude::*; use either::Either; use union_find::{QuickFindUf, UnionBySize, UnionFind}; use hercules_cg::*; use hercules_ir::*; use crate::*; /* * Top level function to legalize the reference semantics of a Hercules IR * function. Hercules IR is a value semantics representation, meaning that all * program state is in the form of copyable values, and mutation takes place by * making a new value that is a copy of the old value with some modification. * This representation is extremely convenient for optimization, but is not good * for code generation, where we need to generate code with references to get * good performance. Hercules IR can alternatively be interpreted using * reference semantics, where pointers to collection objects are passed around, * read from, and written to. However, the value semantics and reference * semantics interpretation of a Hercules IR function may not be equal - this * pass transforms a Hercules IR function such that its new value semantics is * the same as its old value semantics and that its new reference semantics is * the same as its new value semantics. This pass returns a placement of nodes * into ordered basic blocks, since the reference semantics of a function * depends on the order of execution with respect to anti-dependencies. This * is analogous to global code motion from the original sea of nodes paper. * * Our strategy for handling multiple mutating users of a collection is to treat * the problem similar to register allocation; we perform a liveness analysis, * spill constants into newly allocated constants, and read back the spilled * contents when they are used after the first mutation. It's not obvious how * many spills are needed upfront, and newly spilled constants may affect the * liveness analysis result, so every spill restarts the process of checking for * spills. Once no more spills are found, the process terminates. When a spill * is found, the basic block assignments, and all the other analyses, are not * necessarily valid anymore, so this function is called in a loop in the pass * manager until no more spills are found. * * GCM also generally tries to massage the code to be properly formed for the * device backends in other ways. For example, reduction cycles through the * `init` inputs of an inner reduce and a use of a non-parallel outer reduce in * nested fork-joins is not schedulable, since the outer reduce doesn't dominate * the fork corresponding to the inner reduce. In such cases, the outer fork- * join must be split and unforkified. * * GCM is additionally complicated by the need to generate code that references * objects across multiple devices. In particular, GCM makes sure that every * object lives on exactly one device, so that references to that object always * live on a single device. Additionally, GCM makes sure that the objects that a * node may produce are all on the same device, so that a pointer produced by, * for example, a select node can only refer to memory on a single device. Extra * collection constants and potentially inter-device copies are inserted as * necessary to make sure this is true - an inter-device copy is represented by * a write where the `collect` and `data` inputs are on different devices. This * is only valid in RT functions - it is asserted that this isn't necessary in * device functions. This process "colors" the nodes in the function. * * GCM has one final responsibility - object allocation. Each Hercules function * receives a pointer to a "backing" memory where collection constants live. The * backing memory a function receives is for the constants in that function and * the constants of every called function. Concretely, a function will pass a * sub-regions of its backing memory to a callee, which during the call is that * function's backing memory. Object allocation consists of finding the required * sizes of all collection constants and functions in terms of dynamic constants * (dynamic constant math is expressive enough to represent sizes of types, * which is very convenient) and determining the concrete offsets into the * backing memory where constants and callee sub-regions live. When two users of * backing memory are never live at once, they may share backing memory. This is * done after nodes are given a single device color, since we need to know what * values are on what devices before we can allocate them to backing memory, * since there are separate backing memories per-device. */ pub fn gcm( editor: &mut FunctionEditor, def_use: &ImmutableDefUseMap, reverse_postorder: &Vec<NodeID>, typing: &Vec<TypeID>, control_subgraph: &Subgraph, dom: &DomTree, fork_join_map: &HashMap<NodeID, NodeID>, loops: &LoopTree, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, objects: &CollectionObjects, devices: &Vec<Device>, node_colors: &NodeColors, backing_allocations: &BackingAllocations, ) -> Option<(BasicBlocks, FunctionNodeColors, FunctionBackingAllocation)> { if preliminary_fixups(editor, fork_join_map, loops, reduce_cycles) { return None; } let bbs = basic_blocks( editor.func(), editor.get_types(), editor.func_id(), def_use, reverse_postorder, typing, dom, loops, reduce_cycles, fork_join_map, objects, devices, ); let liveness = liveness_dataflow( editor.func(), editor.func_id(), control_subgraph, objects, &bbs, ); if spill_clones(editor, typing, control_subgraph, objects, &bbs, &liveness) { return None; } let Some(node_colors) = color_nodes(editor, typing, &objects, &devices, node_colors) else { return None; }; let mut alignments = vec![]; Ref::map(editor.get_types(), |types| { for idx in 0..types.len() { if types[idx].is_control() { alignments.push(0); } else { alignments.push(get_type_alignment(types, TypeID::new(idx))); } } &() }); let backing_allocation = object_allocation( editor, typing, &node_colors, &alignments, &liveness, backing_allocations, ); Some((bbs, node_colors, backing_allocation)) } /* * Do misc. fixups on the IR, such as unforkifying sequential outer forks with * problematic reduces. */ fn preliminary_fixups( editor: &mut FunctionEditor, fork_join_map: &HashMap<NodeID, NodeID>, loops: &LoopTree, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, ) -> bool { let nodes = &editor.func().nodes; let schedules = &editor.func().schedules; // Sequentialize non-parallel forks that contain problematic reduce cycles. for (reduce, cycle) in reduce_cycles { if !schedules[reduce.idx()].contains(&Schedule::ParallelReduce) && cycle.into_iter().any(|id| nodes[id.idx()].is_reduce()) { let join = nodes[reduce.idx()].try_reduce().unwrap().0; let fork = fork_join_map .into_iter() .filter(|(_, j)| join == **j) .map(|(f, _)| *f) .next() .unwrap(); let (forks, _) = split_fork(editor, fork, join, reduce_cycles).unwrap(); if forks.len() > 1 { return true; } unforkify(editor, fork, join, loops); return true; } } // Get rid of the backward edge on parallel reduces in fork-joins. for (_, join) in fork_join_map { let parallel_reduces: Vec<_> = editor .get_users(*join) .filter(|id| { nodes[id.idx()].is_reduce() && schedules[id.idx()].contains(&Schedule::ParallelReduce) }) .collect(); for reduce in parallel_reduces { if reduce_cycles[&reduce].is_empty() { continue; } let (_, init, _) = nodes[reduce.idx()].try_reduce().unwrap(); // Replace uses of the reduce in its cycle with the init. let success = editor.edit(|edit| { edit.replace_all_uses_where(reduce, init, |id| reduce_cycles[&reduce].contains(id)) }); assert!(success); return true; } } false } /* * Top level global code motion function. Assigns each data node to one of its * immediate control use / user nodes, forming (unordered) basic blocks. Returns * the control node / basic block each node is in. Takes in a partial * partitioning that must be respected. Based on the schedule-early-schedule- * late method from Cliff Click's PhD thesis. May fail if an anti-dependency * edge can't be satisfied - in this case, a clone that has to be induced is * returned instead. */ fn basic_blocks( function: &Function, types: Ref<Vec<Type>>, func_id: FunctionID, def_use: &ImmutableDefUseMap, reverse_postorder: &Vec<NodeID>, typing: &Vec<TypeID>, dom: &DomTree, loops: &LoopTree, reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>, fork_join_map: &HashMap<NodeID, NodeID>, objects: &CollectionObjects, devices: &Vec<Device>, ) -> BasicBlocks { let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()]; // Step 1: assign the basic block locations of all nodes that must be in a // specific block. This includes control nodes as well as some special data // nodes, such as phis. for idx in 0..function.nodes.len() { match function.nodes[idx] { Node::Phi { control, data: _ } => bbs[idx] = Some(control), Node::ThreadID { control, dimension: _, } => bbs[idx] = Some(control), Node::Reduce { control, init: _, reduct: _, } => bbs[idx] = Some(control), Node::Call { control, function: _, dynamic_constants: _, args: _, } => bbs[idx] = Some(control), Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)), _ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)), _ => {} } } // Step 2: schedule early. Place nodes in the earliest position they could // go - use worklist to iterate nodes. let mut schedule_early = bbs.clone(); let mut worklist = VecDeque::from(reverse_postorder.clone()); while let Some(id) = worklist.pop_front() { if schedule_early[id.idx()].is_some() { continue; } // For every use, check what block is its "schedule early" block. This // node goes in the lowest block amongst those blocks. let use_places: Option<Vec<NodeID>> = get_uses(&function.nodes[id.idx()]) .as_ref() .into_iter() .map(|id| *id) .map(|id| schedule_early[id.idx()]) .collect(); if let Some(use_places) = use_places { // If every use has been placed, we can place this node as the // lowest place in the domtree that dominates all of the use places. let lowest = dom.lowest_amongst(use_places.into_iter()); schedule_early[id.idx()] = Some(lowest); } else { // If not, then just push this node back on the worklist. worklist.push_back(id); } } // Step 3: find anti-dependence edges. An anti-dependence edge needs to be // drawn between a collection reading node and a collection mutating node // when the following conditions are true: // // 1: The reading and mutating nodes may involve the same collection. // 2: The node producing the collection used by the reading node is in a // schedule early block that dominates the schedule early block of the // mutating node. The node producing the collection used by the reading // node may be an originator of a collection, phi or reduce, or mutator, // but not forwarding read - forwarding reads are collapsed, and the // bottom read is treated as reading from the transitive parent of the // forwarding read(s). // 3: If the node producing the collection is a reduce node, then any read // users that aren't in the reduce's cycle shouldn't anti-depend user any // mutators in the reduce cycle. // // Because we do a liveness analysis based spill of collections, anti- // dependencies can be best effort. Thus, when we encounter a read and // mutator where the read doesn't dominate the mutator, but an anti-depdence // edge is derived for the pair, we just don't draw the edge since it would // break the scheduler. let mut antideps = BTreeSet::new(); for id in reverse_postorder.iter() { // Find a terminating read node and the collections it reads. let terminating_reads: BTreeSet<_> = terminating_reads(function, func_id, *id, objects).collect(); if !terminating_reads.is_empty() { // Walk forwarding reads to find anti-dependency roots. let mut workset = terminating_reads.clone(); let mut roots = BTreeSet::new(); while let Some(pop) = workset.pop_first() { let forwarded: BTreeSet<_> = forwarding_reads(function, func_id, pop, objects).collect(); if forwarded.is_empty() { roots.insert(pop); } else { workset.extend(forwarded); } } // For each root, find mutating nodes dominated by the root that // modify an object read on any input of the current node (the // terminating read). // TODO: make this less outrageously inefficient. let func_objects = &objects[&func_id]; for root in roots.iter() { let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles .get(root) .map(|cycle| !cycle.contains(&id)) .unwrap_or(false); let root_early = schedule_early[root.idx()].unwrap(); let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new(); let mut workset = BTreeSet::new(); workset.insert(*root); while let Some(pop) = workset.pop_first() { let users = def_use.get_users(pop).into_iter().filter(|user| { !function.nodes[user.idx()].is_phi() && !function.nodes[user.idx()].is_reduce() && schedule_early[user.idx()].unwrap() == root_early }); workset.extend(users.clone()); root_block_iterated_users.extend(users); } let read_objs: BTreeSet<_> = terminating_reads .iter() .map(|read_use| func_objects.objects(*read_use).into_iter()) .flatten() .map(|id| *id) .collect(); for mutator in reverse_postorder.iter() { let mutator_early = schedule_early[mutator.idx()].unwrap(); if dom.does_dom(root_early, mutator_early) && (root_early != mutator_early || root_block_iterated_users.contains(&mutator)) && mutating_objects(function, func_id, *mutator, objects) .any(|mutated| read_objs.contains(&mutated)) && id != mutator && (!root_is_reduce_and_read_isnt_in_cycle || !reduce_cycles .get(root) .map(|cycle| cycle.contains(mutator)) .unwrap_or(false)) && dom.does_dom(schedule_early[id.idx()].unwrap(), mutator_early) { antideps.insert((*id, *mutator)); } } } } } let mut antideps_uses = vec![vec![]; function.nodes.len()]; let mut antideps_users = vec![vec![]; function.nodes.len()]; for (reader, mutator) in antideps.iter() { antideps_uses[mutator.idx()].push(*reader); antideps_users[reader.idx()].push(*mutator); } // Step 4: schedule late and pick each nodes final position. Since the late // schedule of each node depends on the final positions of its users, these // two steps must be fused. Compute their latest position, then use the // control dependent + shallow loop heuristic to actually place them. A // placement might not necessarily be found due to anti-dependency edges. // These are optional and not necessary to consider, but we do since obeying // them can reduce the number of clones. If the worklist stops making // progress, stop considering the anti-dependency edges. let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map .into_iter() .map(|(fork, join)| (*join, *fork)) .collect(); let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev()); let mut num_skip_iters = 0; let mut consider_antidependencies = true; while let Some(id) = worklist.pop_front() { if num_skip_iters >= worklist.len() { consider_antidependencies = false; } if bbs[id.idx()].is_some() { num_skip_iters = 0; continue; } // Calculate the least common ancestor of user blocks, a.k.a. the "late" // schedule. let calculate_lca = || -> Option<_> { let mut lca = None; // Helper to incrementally update the LCA. let mut update_lca = |a| { if let Some(acc) = lca { lca = Some(dom.least_common_ancestor(acc, a)); } else { lca = Some(a); } }; // For every user, consider where we need to be to directly dominate the // user. for user in def_use .get_users(id) .as_ref() .into_iter() .chain(if consider_antidependencies { Either::Left(antideps_users[id.idx()].iter()) } else { Either::Right(empty()) }) .map(|id| *id) { if let Node::Phi { control, data } = &function.nodes[user.idx()] { // For phis, we need to dominate the block jumping to the phi in // the slot that corresponds to our use. for (control, data) in zip(get_uses(&function.nodes[control.idx()]).as_ref(), data) { if id == *data { update_lca(*control); } } } else if let Node::Reduce { control, init, reduct, } = &function.nodes[user.idx()] { // For reduces, we need to either dominate the block right // before the fork if we're the init input, or we need to // dominate the join if we're the reduct input. if id == *init { let before_fork = function.nodes[join_fork_map[control].idx()] .try_fork() .unwrap() .0; update_lca(before_fork); } else { assert_eq!(id, *reduct); update_lca(*control); } } else { // For everything else, we just need to dominate the user. update_lca(bbs[user.idx()]?); } } Some(lca) }; // Check if all users have been placed. If one of them hasn't, then add // this node back on to the worklist. let Some(lca) = calculate_lca() else { worklist.push_back(id); num_skip_iters += 1; continue; }; // Look between the LCA and the schedule early location to place the // node. let schedule_early = schedule_early[id.idx()].unwrap(); let schedule_late = lca.unwrap_or(schedule_early); let mut chain = dom // If the node has no users, then it doesn't really matter where we // place it - just place it at the early placement. .chain(schedule_late, schedule_early); if let Some(mut location) = chain.next() { while let Some(control_node) = chain.next() { // If the next node further up the dominator tree is in a shallower // loop nest or if we can get out of a reduce loop when we don't // need to be in one, place this data node in a higher-up location. // Only do this is the node isn't a constant or undef - if a // node is a constant or undef, we want its placement to be as // control dependent as possible, even inside loops. In GPU // functions specifically, lift constants that may be returned // outside fork-joins. let is_constant_or_undef = (function.nodes[id.idx()].is_constant() || function.nodes[id.idx()].is_undef()) && !types[typing[id.idx()].idx()].is_primitive(); let is_gpu_returned = devices[func_id.idx()] == Device::CUDA && objects[&func_id] .objects(id) .into_iter() .any(|obj| objects[&func_id].returned_objects().contains(obj)); let old_nest = loops .header_of(location) .map(|header| loops.nesting(header).unwrap()); let new_nest = loops .header_of(control_node) .map(|header| loops.nesting(header).unwrap()); let shallower_nest = if let (Some(old_nest), Some(new_nest)) = (old_nest, new_nest) { old_nest > new_nest } else { // If the new location isn't a loop, it's nesting level should // be considered "shallower" if the current location is in a // loop. old_nest.is_some() }; // This will move all nodes that don't need to be in reduce loops // outside of reduce loops. Nodes that do need to be in a reduce // loop use the reduce node forming the loop, so the dominator chain // will consist of one block, and this loop won't ever iterate. let currently_at_join = function.nodes[location.idx()].is_join() && !function.nodes[control_node.idx()].is_join(); if (!is_constant_or_undef || is_gpu_returned) && (shallower_nest || currently_at_join) { location = control_node; } } bbs[id.idx()] = Some(location); num_skip_iters = 0; } else { // If there is no valid location for this node, then it's a reading // node of a collection that can't be placed above a mutation that // anti-depend uses it. Push the node back on the list, and we'll // stop considering anti-dependencies soon. Don't immediately stop // considering anti-dependencies, as we may be able to eak out some // more use of them. worklist.push_back(id); num_skip_iters += 1; continue; } } let bbs: Vec<_> = bbs.into_iter().map(Option::unwrap).collect(); // Calculate the number of phis and reduces per basic block. We use this to // emit phis and reduces at the top of basic blocks. We want to emit phis // and reduces first into ordered basic blocks for two reasons: // 1. This is useful for liveness analysis. // 2. This is needed for some backends - LLVM expects phis to be at the top // of basic blocks. let mut num_phis_reduces = vec![0; function.nodes.len()]; for (node_idx, bb) in bbs.iter().enumerate() { let node = &function.nodes[node_idx]; if node.is_phi() || node.is_reduce() { num_phis_reduces[bb.idx()] += 1; } } // Step 5: determine the order of nodes inside each block. Use worklist to // add nodes to blocks in order that obeys dependencies. let mut order: Vec<Vec<NodeID>> = vec![vec![]; function.nodes.len()]; let mut worklist = VecDeque::from_iter( reverse_postorder .into_iter() .filter(|id| !function.nodes[id.idx()].is_control()), ); let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()]; let mut num_skip_iters = 0; let mut consider_antidependencies = true; while let Some(id) = worklist.pop_front() { // If the worklist isn't making progress, then there's at least one // reading node of a collection that is in a anti-depend + normal depend // use cycle with a mutating node. See above comment about anti- // dependencies being optional; we just stop considering them here. if num_skip_iters >= worklist.len() { consider_antidependencies = false; } // Phis and reduces always get emitted. Other nodes need to obey // dependency relationships and need to come after phis and reduces. let node = &function.nodes[id.idx()]; let bb = bbs[id.idx()]; if node.is_phi() || node.is_reduce() || (num_phis_reduces[bb.idx()] == 0 && get_uses(node) .as_ref() .into_iter() .chain(if consider_antidependencies { Either::Left(antideps_uses[id.idx()].iter()) } else { Either::Right(empty()) }) .all(|u| { function.nodes[u.idx()].is_control() || bbs[u.idx()] != bbs[id.idx()] || visited[u.idx()] })) { order[bb.idx()].push(*id); visited.set(id.idx(), true); num_skip_iters = 0; if node.is_phi() || node.is_reduce() { num_phis_reduces[bb.idx()] -= 1; } } else { worklist.push_back(id); num_skip_iters += 1; } } (bbs, order) } fn terminating_reads<'a>( function: &'a Function, func_id: FunctionID, reader: NodeID, objects: &'a CollectionObjects, ) -> Box<dyn Iterator<Item = NodeID> + 'a> { match function.nodes[reader.idx()] { Node::Read { collect, indices: _, } if objects[&func_id].objects(reader).is_empty() => Box::new(once(collect)), Node::Write { collect: _, data, indices: _, } if !objects[&func_id].objects(data).is_empty() => Box::new(once(data)), Node::Call { control: _, function: callee, dynamic_constants: _, ref args, } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| { let objects = &objects[&callee]; let returns = objects.returned_objects(); let param_obj = objects.param_to_object(idx)?; if !objects.is_mutated(param_obj) && !returns.contains(¶m_obj) { Some(*arg) } else { None } })), Node::LibraryCall { library_function, ref args, ty: _, device: _, } => match library_function { LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))), }, _ => Box::new(empty()), } } fn forwarding_reads<'a>( function: &'a Function, func_id: FunctionID, reader: NodeID, objects: &'a CollectionObjects, ) -> Box<dyn Iterator<Item = NodeID> + 'a> { match function.nodes[reader.idx()] { Node::Read { collect, indices: _, } if !objects[&func_id].objects(reader).is_empty() => Box::new(once(collect)), Node::Ternary { op: TernaryOperator::Select, first: _, second, third, } if !objects[&func_id].objects(reader).is_empty() => { Box::new(once(second).chain(once(third))) } Node::Call { control: _, function: callee, dynamic_constants: _, ref args, } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| { let objects = &objects[&callee]; let returns = objects.returned_objects(); let param_obj = objects.param_to_object(idx)?; if !objects.is_mutated(param_obj) && returns.contains(¶m_obj) { Some(*arg) } else { None } })), _ => Box::new(empty()), } } fn mutating_objects<'a>( function: &'a Function, func_id: FunctionID, mutator: NodeID, objects: &'a CollectionObjects, ) -> Box<dyn Iterator<Item = CollectionObjectID> + 'a> { match function.nodes[mutator.idx()] { Node::Write { collect, data: _, indices: _, } => Box::new(objects[&func_id].objects(collect).into_iter().map(|id| *id)), Node::Call { control: _, function: callee, dynamic_constants: _, ref args, } => Box::new( args.into_iter() .enumerate() .filter_map(move |(idx, arg)| { let callee_objects = &objects[&callee]; let param_obj = callee_objects.param_to_object(idx)?; if callee_objects.is_mutated(param_obj) { Some(objects[&func_id].objects(*arg).into_iter().map(|id| *id)) } else { None } }) .flatten(), ), Node::LibraryCall { library_function, ref args, ty: _, device: _, } => match library_function { LibraryFunction::GEMM => { Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id)) } }, _ => Box::new(empty()), } } fn mutating_writes<'a>( function: &'a Function, mutator: NodeID, objects: &'a CollectionObjects, ) -> Box<dyn Iterator<Item = NodeID> + 'a> { match function.nodes[mutator.idx()] { Node::Write { collect, data: _, indices: _, } => Box::new(once(collect)), Node::Call { control: _, function: callee, dynamic_constants: _, ref args, } => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| { let callee_objects = &objects[&callee]; let param_obj = callee_objects.param_to_object(idx)?; if callee_objects.is_mutated(param_obj) { Some(*arg) } else { None } })), Node::LibraryCall { library_function, ref args, ty: _, device: _, } => match library_function { LibraryFunction::GEMM => Box::new(once(args[0])), }, _ => Box::new(empty()), } } /* * Top level function to find implicit clones that need to be spilled. Returns * whether a clone was spilled, in which case the whole scheduling process must * be restarted. */ fn spill_clones( editor: &mut FunctionEditor, typing: &Vec<TypeID>, control_subgraph: &Subgraph, objects: &CollectionObjects, bbs: &BasicBlocks, liveness: &Liveness, ) -> bool { // Step 1: compute an interference graph from the liveness result. This // graph contains a vertex per node ID producing a collection value and an // edge per pair of node IDs that interfere. Nodes A and B interfere if node // A is defined right above a point where node B is live and A != B. Extra // edges are drawn for forwarding reads - when there is a node A that is a // forwarding read of a node B, A and B really have the same live range for // the purpose of determining when spills are necessary, since forwarding // reads can be thought of as nothing but pointer math. For this purpose, we // maintain a union-find of nodes that form a forwarding read DAG (notably, // phis and reduces are not considered forwarding reads). The more precise // version of the interference condition is nodes A and B interfere is node // A is defined right above a point where a node C is live where C is in the // same union-find class as B. // Assemble the union-find to group forwarding read DAGs. let mut union_find = QuickFindUf::<UnionBySize>::new(editor.func().nodes.len()); for id in editor.node_ids() { for forwarding_read in forwarding_reads(editor.func(), editor.func_id(), id, objects) { union_find.union(id.idx(), forwarding_read.idx()); } } // Figure out which classes contain which node IDs, since we need to iterate // the disjoint sets. let mut disjoint_sets: BTreeMap<usize, Vec<NodeID>> = BTreeMap::new(); for id in editor.node_ids() { disjoint_sets .entry(union_find.find(id.idx())) .or_default() .push(id); } // Create the graph. let mut edges = vec![]; for (bb, liveness) in liveness { let insts = &bbs.1[bb.idx()]; for (node, live) in zip(insts, liveness.into_iter().skip(1)) { for live_node in live { for live_node in disjoint_sets[&union_find.find(live_node.idx())].iter() { if *node != *live_node { edges.push((*node, *live_node)); } } } } } // Step 2: filter edges (A, B) to just see edges where A uses B and A // mutates B. These are the edges that may require a spill. let mut spill_edges = edges.into_iter().filter(|(a, b)| { mutating_writes(editor.func(), *a, objects).any(|id| id == *b) || (get_uses(&editor.func().nodes[a.idx()]) .as_ref() .into_iter() .any(|u| *u == *b) && (editor.func().nodes[a.idx()].is_phi() || editor.func().nodes[a.idx()].is_reduce()) && !editor.func().nodes[a.idx()] .try_reduce() .map(|(_, init, _)| { init == *b && editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce) }) .unwrap_or(false)) }); // Step 3: if there is a spill edge, spill it and return true. Otherwise, // return false. if let Some((user, obj)) = spill_edges.next() { // Figure out the most immediate dominating region for every basic // block. These are the points where spill slot phis get placed. let nodes = &editor.func().nodes; let mut imm_dom_reg = vec![NodeID::new(0); editor.func().nodes.len()]; for (idx, node) in nodes.into_iter().enumerate() { if node.is_region() { imm_dom_reg[idx] = NodeID::new(idx); } } let rev_po = control_subgraph.rev_po(NodeID::new(0)); for bb in rev_po.iter() { if !nodes[bb.idx()].is_region() && !nodes[bb.idx()].is_start() { imm_dom_reg[bb.idx()] = imm_dom_reg[control_subgraph.preds(*bb).next().unwrap().idx()]; } } let other_obj_users: Vec<_> = editor.get_users(obj).filter(|id| *id != user).collect(); let mut dummy_phis = vec![NodeID::new(0); imm_dom_reg.len()]; let mut success = editor.edit(|mut edit| { // Construct the spill slot. This is just a constant that gets phi- // ed throughout the entire function. let cons_id = edit.add_zero_constant(typing[obj.idx()]); let slot_id = edit.add_node(Node::Constant { id: cons_id }); edit = edit.add_schedule(slot_id, Schedule::NoResetConstant)?; // Allocate IDs for phis that move the spill slot throughout the // function without implicit clones. These are dummy phis, since // there are potentially cycles between them. We will replace them // later. for (idx, reg) in imm_dom_reg.iter().enumerate().skip(1) { if idx == reg.idx() { dummy_phis[idx] = edit.add_node(Node::Phi { control: *reg, data: empty().collect(), }); } } // Spill `obj` before `user` potentially modifies it. let spill_region = imm_dom_reg[bbs.0[obj.idx()].idx()]; let spill_id = edit.add_node(Node::Write { collect: if spill_region == NodeID::new(0) { slot_id } else { dummy_phis[spill_region.idx()] }, data: obj, indices: empty().collect(), }); // Before each other user, unspill `obj`. for other_user in other_obj_users { let other_region = imm_dom_reg[bbs.0[other_user.idx()].idx()]; // If this assert fails, then `obj` is not in the first basic // block, but it has a user that is in the first basic block, // which violates SSA. assert!(other_region == spill_region || other_region != NodeID::new(0)); // If an other user is a phi, we need to be a little careful // about how we insert unspilling code for `obj`. Instead of // inserting an unspill in the same block as the user, we need // to insert one in each predecessor of the phi that corresponds // to a use of `obj`. Since this requires modifying individual // uses in a phi, just rebuild the node entirely. if let Node::Phi { control, data } = edit.get_node(other_user).clone() { assert_eq!(control, other_region); let mut new_data = vec![]; for (pred, data) in zip(control_subgraph.preds(control), data) { let pred = imm_dom_reg[pred.idx()]; if data == obj { let unspill_id = edit.add_node(Node::Write { collect: obj, data: if pred == spill_region { spill_id } else { dummy_phis[pred.idx()] }, indices: empty().collect(), }); new_data.push(unspill_id); } else { new_data.push(data); } } let new_phi = edit.add_node(Node::Phi { control, data: new_data.into_boxed_slice(), }); edit = edit.replace_all_uses(other_user, new_phi)?; edit = edit.delete_node(other_user)?; } else { let unspill_id = edit.add_node(Node::Write { collect: obj, data: if other_region == spill_region { spill_id } else { dummy_phis[other_region.idx()] }, indices: empty().collect(), }); edit = edit.replace_all_uses_where(obj, unspill_id, |id| *id == other_user)?; } } // Create and hook up all the real phis. Phi elimination will clean // this up. let mut real_phis = vec![NodeID::new(0); imm_dom_reg.len()]; for (idx, reg) in imm_dom_reg.iter().enumerate().skip(1) { if idx == reg.idx() { real_phis[idx] = edit.add_node(Node::Phi { control: *reg, data: control_subgraph .preds(*reg) .map(|pred| { let pred = imm_dom_reg[pred.idx()]; if pred == spill_region { spill_id } else if pred == NodeID::new(0) { slot_id } else { dummy_phis[pred.idx()] } }) .collect(), }); } } for (dummy, real) in zip(dummy_phis.iter(), real_phis) { if *dummy != real { edit = edit.replace_all_uses(*dummy, real)?; } } Ok(edit) }); success = success && editor.edit(|mut edit| { for dummy in dummy_phis { if dummy != NodeID::new(0) { edit = edit.delete_node(dummy)?; } } Ok(edit) }); assert!(success, "PANIC: GCM cannot fail to edit a function, as it needs to legalize the reference semantics of every function before code generation."); true } else { false } } type Liveness = BTreeMap<NodeID, Vec<BTreeSet<NodeID>>>; /* * Liveness dataflow analysis on scheduled Hercules IR. Just look at nodes that * involve collections. */ fn liveness_dataflow( function: &Function, func_id: FunctionID, control_subgraph: &Subgraph, objects: &CollectionObjects, bbs: &BasicBlocks, ) -> Liveness { let mut po = control_subgraph.rev_po(NodeID::new(0)); po.reverse(); let mut liveness = Liveness::default(); for (bb_idx, insts) in bbs.1.iter().enumerate() { liveness.insert(NodeID::new(bb_idx), vec![BTreeSet::new(); insts.len() + 1]); } let mut num_phis_reduces = vec![0; function.nodes.len()]; let mut has_phi = vec![false; function.nodes.len()]; let mut has_seq_reduce = vec![false; function.nodes.len()]; for (node_idx, bb) in bbs.0.iter().enumerate() { let node = &function.nodes[node_idx]; if node.is_phi() || node.is_reduce() { num_phis_reduces[bb.idx()] += 1; } has_phi[bb.idx()] = node.is_phi(); has_seq_reduce[bb.idx()] = node.is_reduce() && !function.schedules[node_idx].contains(&Schedule::ParallelReduce); assert!(!node.is_phi() || !node.is_reduce()); } let is_obj = |id: NodeID| !objects[&func_id].objects(id).is_empty(); loop { let mut changed = false; for bb in po.iter() { // First, calculate the liveness set for the bottom of this block. let last_pt = bbs.1[bb.idx()].len(); let old_value = &liveness[&bb][last_pt]; let mut new_value = BTreeSet::new(); for succ in control_subgraph .succs(*bb) .chain(if has_seq_reduce[bb.idx()] { Either::Left(once(*bb)) } else { Either::Right(empty()) }) { // The liveness at the bottom of a basic block is the union of: // 1. The liveness of each succecessor right after its phis and // reduces. // 2. Every data use in a phi or reduce that corresponds to this // block as the predecessor. let after_phis_reduces_pt = num_phis_reduces[succ.idx()]; new_value.extend(&liveness[&succ][after_phis_reduces_pt]); for inst_idx in 0..after_phis_reduces_pt { let id = bbs.1[succ.idx()][inst_idx]; new_value.remove(&id); match function.nodes[id.idx()] { Node::Phi { control, ref data } if is_obj(data[0]) => { assert_eq!(control, succ); new_value.extend( zip(control_subgraph.preds(succ), data) .filter(|(pred, _)| *pred == *bb) .map(|(_, data)| *data), ); } Node::Reduce { control, init, reduct, } if is_obj(init) => { assert_eq!(control, succ); if succ == *bb { new_value.insert(reduct); } else if !function.schedules[id.idx()] .contains(&Schedule::ParallelReduce) { new_value.insert(init); } } _ => {} } } } changed |= *old_value != new_value; liveness.get_mut(&bb).unwrap()[last_pt] = new_value; // Second, calculate the liveness set above each instruction in this block. for pt in (0..last_pt).rev() { let old_value = &liveness[&bb][pt]; let mut new_value = liveness[&bb][pt + 1].clone(); let id = bbs.1[bb.idx()][pt]; let uses = get_uses(&function.nodes[id.idx()]); let is_obj = |id: &NodeID| is_obj(*id); new_value.remove(&id); new_value.extend( if let Node::Write { collect: _, data, ref indices, } = function.nodes[id.idx()] && indices.is_empty() { // If this write is a cloning write, the `collect` input // isn't actually live, because its value doesn't // matter. Either::Left(once(data).filter(is_obj)) } else if let Node::Reduce { control: _, init: _, reduct, } = function.nodes[id.idx()] && function.schedules[id.idx()].contains(&Schedule::ParallelReduce) { // If this reduce is a parallel reduce, the `init` input // isn't actually live. Either::Left(once(reduct).filter(is_obj)) } else { Either::Right(uses.as_ref().into_iter().map(|id| *id).filter(is_obj)) }, ); changed |= *old_value != new_value; liveness.get_mut(&bb).unwrap()[pt] = new_value; } } if !changed { return liveness; } } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum UTerm { Node(NodeID), Device(Device), } fn unify( mut equations: VecDeque<(UTerm, UTerm)>, ) -> Result<BTreeMap<NodeID, Device>, BTreeMap<NodeID, Device>> { let mut theta = BTreeMap::new(); let mut no_progress_iters = 0; while no_progress_iters <= equations.len() && let Some((l, r)) = equations.pop_front() { match (l, r) { (UTerm::Node(_), UTerm::Node(_)) => { if l != r { equations.push_back((l, r)); } no_progress_iters += 1; } (UTerm::Node(n), UTerm::Device(d)) | (UTerm::Device(d), UTerm::Node(n)) => { theta.insert(n, d); for (l, r) in equations.iter_mut() { if *l == UTerm::Node(n) { *l = UTerm::Device(d); } if *r == UTerm::Node(n) { *r = UTerm::Device(d); } } no_progress_iters = 0; } (UTerm::Device(d1), UTerm::Device(d2)) if d1 == d2 => { no_progress_iters = 0; } _ => { return Err(theta); } } } Ok(theta) } /* * Determine what device each node produces a collection onto. Insert inter- * device clones when a single node may potentially be on different devices. */ fn color_nodes( editor: &mut FunctionEditor, typing: &Vec<TypeID>, objects: &CollectionObjects, devices: &Vec<Device>, node_colors: &NodeColors, ) -> Option<FunctionNodeColors> { let nodes = &editor.func().nodes; let func_id = editor.func_id(); let func_device = devices[func_id.idx()]; let mut func_colors = ( BTreeMap::new(), vec![None; editor.func().param_types.len()], None, ); // Assigning nodes to devices is tricky due to function calls. Technically, // all the information is there to decide what device to place nodes on, but // coherently expressing the constraints and deriving the devices is not // obvious. Express this as a unification problem, where we need to assign // types (devices) to each node. Each unification term is either a node ID // or a concrete device. Assemble a list of unification equations to solve. let mut equations = vec![]; for id in editor.node_ids() { match nodes[id.idx()] { Node::Phi { control: _, ref data, } if !editor.get_type(typing[id.idx()]).is_primitive() => { // Every input to a phi needs to be on the same device. The // phi itself is also on this device. for (l, r) in zip(data.into_iter(), data.into_iter().skip(1).chain(once(&id))) { equations.push((UTerm::Node(*l), UTerm::Node(*r))); } } Node::Reduce { control: _, init: first, reduct: second, } | Node::Ternary { op: TernaryOperator::Select, first: _, second: first, third: second, } if !editor.get_type(typing[id.idx()]).is_primitive() => { // Every input to the reduce, and the reduce itself, are on // the same device. equations.push((UTerm::Node(first), UTerm::Node(second))); equations.push((UTerm::Node(second), UTerm::Node(id))); } Node::Constant { id: _ } if !editor.get_type(typing[id.idx()]).is_primitive() && func_device != Device::AsyncRust => { // Constants inside device functions are allocated on that // device. equations.push((UTerm::Node(id), UTerm::Device(func_device))); } Node::Read { collect, indices: _, } => { if editor.get_type(typing[id.idx()]).is_primitive() { // If this reads a primitive, then the collection needs to // be on the device of this function. equations.push(( UTerm::Node(collect), UTerm::Device(backing_device(func_device)), )); } else { // If this read just reads a sub-collection, then `collect` // and the read itself need to be on the same device. equations.push((UTerm::Node(collect), UTerm::Node(id))); } } Node::Write { collect, data, indices: _, } => { if func_device != Device::AsyncRust || editor.get_type(typing[data.idx()]).is_primitive() { // If this writes a primitive or this is in a device // function, then the collection needs to be on the backing // device of this function. equations.push(( UTerm::Node(collect), UTerm::Device(backing_device(func_device)), )); if func_device != Device::AsyncRust && !editor.get_type(typing[data.idx()]).is_primitive() { // We can only do inter-device copies in AsyncRust // functions. equations.push((UTerm::Node(collect), UTerm::Node(data))); } } equations.push((UTerm::Node(collect), UTerm::Node(id))); } Node::Call { control: _, function: callee, dynamic_constants: _, ref args, } => { // If the callee has a definite device for a parameter, add an // equation for the corresponding argument. for (idx, arg) in args.into_iter().enumerate() { if let Some(device) = node_colors[&callee].1[idx] { equations.push((UTerm::Node(*arg), UTerm::Device(device))); } } // If the callee has a definite device for the returned value, // add an equation for the call node itself. if let Some(device) = node_colors[&callee].2 { equations.push((UTerm::Node(id), UTerm::Device(device))); } // For any object that may be returned by the callee that // originates as a parameter in the callee, the device of the // corresponding argument and call node itself must be equal. for ret in objects[&callee].returned_objects() { if let Some(idx) = objects[&callee].origin(*ret).try_parameter() { equations.push((UTerm::Node(args[idx]), UTerm::Node(id))); } } } Node::LibraryCall { library_function: _, ref args, ty: _, device, } => { for arg in args { equations.push((UTerm::Node(*arg), UTerm::Device(device))); } equations.push((UTerm::Node(id), UTerm::Device(device))); } _ => {} } } // Solve the unification problem. I couldn't find a simple enough crate for // this, and the problems are usually pretty small, so just use a hand- // rolled implementation for now. match unify(VecDeque::from(equations)) { Ok(solve) => { func_colors.0 = solve; // Look at parameter and return nodes to get the device signature of // the function. for id in editor.node_ids() { if let Node::Parameter { index } = nodes[id.idx()] && let Some(device) = func_colors.0.get(&id) { assert!(func_colors.1[index].is_none(), "PANIC: Found multiple parameter nodes for the same index in GCM. Please just run GVN first."); func_colors.1[index] = Some(*device); } else if let Node::Return { control: _, data } = nodes[id.idx()] && let Some(device) = func_colors.0.get(&data) { assert!(func_colors.2.is_none(), "PANIC: Found multiple return nodes in GCM. Contact Russel if you see this, it's an easy fix."); func_colors.2 = Some(*device); } } Some(func_colors) } Err(progress) => { // If unification failed, then there's some node using a node in // `progress` that's expecting a different type than what it got. // Pick one and add potentially inter-device copies on each def-use // edge. We'll clean these up later. let (id, _) = progress.into_iter().next().unwrap(); let users: Vec<_> = editor.get_users(id).collect(); let success = editor.edit(|mut edit| { let cons = edit.add_zero_constant(typing[id.idx()]); for user in users { let cons = edit.add_node(Node::Constant { id: cons }); edit = edit.add_schedule(cons, Schedule::NoResetConstant)?; let copy = edit.add_node(Node::Write { collect: cons, data: id, indices: Box::new([]), }); edit = edit.replace_all_uses_where(id, copy, |id| *id == user)?; } Ok(edit) }); assert!(success); None } } } fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> DynamicConstantID { assert_ne!(align, 0); if align != 1 { let align_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align)); let align_m1_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align - 1)); acc = edit.add_dynamic_constant(DynamicConstant::add(acc, align_m1_dc)); acc = edit.add_dynamic_constant(DynamicConstant::div(acc, align_dc)); acc = edit.add_dynamic_constant(DynamicConstant::mul(acc, align_dc)); } acc } /* * Determine the size of a type in terms of dynamic constants. */ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> DynamicConstantID { let ty = edit.get_type(ty_id).clone(); let size = match ty { Type::Control => panic!(), Type::Boolean | Type::Integer8 | Type::UnsignedInteger8 | Type::Float8 => { edit.add_dynamic_constant(DynamicConstant::Constant(1)) } Type::Integer16 | Type::UnsignedInteger16 | Type::BFloat16 => { edit.add_dynamic_constant(DynamicConstant::Constant(2)) } Type::Integer32 | Type::UnsignedInteger32 | Type::Float32 => { edit.add_dynamic_constant(DynamicConstant::Constant(4)) } Type::Integer64 | Type::UnsignedInteger64 | Type::Float64 => { edit.add_dynamic_constant(DynamicConstant::Constant(8)) } Type::Product(fields) => { // The layout of product types is like the C-style layout. let mut acc_size = edit.add_dynamic_constant(DynamicConstant::Constant(0)); for field in fields { // Round up to the alignment of the field, then add the size of // the field. let field_size = type_size(edit, field, alignments); acc_size = align(edit, acc_size, alignments[field.idx()]); acc_size = edit.add_dynamic_constant(DynamicConstant::add(acc_size, field_size)); } // Finally, round up to the alignment of the whole product, since // the size needs to be a multiple of the alignment. acc_size = align(edit, acc_size, alignments[ty_id.idx()]); acc_size } Type::Summation(variants) => { // A summation holds every variant in the same memory. let mut acc_size = edit.add_dynamic_constant(DynamicConstant::Constant(0)); for variant in variants { // Pick the size of the largest variant, since that's the most // memory we would need. let variant_size = type_size(edit, variant, alignments); acc_size = edit.add_dynamic_constant(DynamicConstant::max(acc_size, variant_size)); } // Add one byte for the discriminant and align the whole summation. let one = edit.add_dynamic_constant(DynamicConstant::Constant(1)); acc_size = edit.add_dynamic_constant(DynamicConstant::add(acc_size, one)); acc_size = align(edit, acc_size, alignments[ty_id.idx()]); acc_size } Type::Array(elem, bounds) => { // The layout of an array is row-major linear in memory. let mut acc_size = type_size(edit, elem, alignments); for bound in bounds { acc_size = edit.add_dynamic_constant(DynamicConstant::mul(acc_size, bound)); } acc_size } }; size } /* * Allocate objects in a function. Relies on the allocations of all called * functions. */ fn object_allocation( editor: &mut FunctionEditor, typing: &Vec<TypeID>, node_colors: &FunctionNodeColors, alignments: &Vec<usize>, _liveness: &Liveness, backing_allocations: &BackingAllocations, ) -> FunctionBackingAllocation { let mut fba = BTreeMap::new(); let node_ids = editor.node_ids(); editor.edit(|mut edit| { // For now, just allocate each object to its own slot. let zero = edit.add_dynamic_constant(DynamicConstant::Constant(0)); for id in node_ids { match *edit.get_node(id) { Node::Constant { id: _ } => { if !edit.get_type(typing[id.idx()]).is_primitive() { let device = node_colors.0[&id]; let (total, offsets) = fba.entry(device).or_insert_with(|| (zero, BTreeMap::new())); *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]); offsets.insert(id, *total); let type_size = type_size(&mut edit, typing[id.idx()], alignments); *total = edit.add_dynamic_constant(DynamicConstant::add(*total, type_size)); } } Node::Call { control: _, function: callee, ref dynamic_constants, args: _, } => { let dynamic_constants = dynamic_constants.to_vec(); let dc_args = (0..dynamic_constants.len()) .map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i))); let substs = dc_args .zip(dynamic_constants.into_iter()) .collect::<HashMap<_, _>>(); for device in BACKED_DEVICES { if let Some(mut callee_backing_size) = backing_allocations[&callee] .get(&device) .map(|(callee_total, _)| *callee_total) { let (total, offsets) = fba.entry(device).or_insert_with(|| (zero, BTreeMap::new())); // We don't know the alignment requirement of the memory // in the callee, so just assume the largest alignment. *total = align(&mut edit, *total, LARGEST_ALIGNMENT); offsets.insert(id, *total); // Substitute the dynamic constant parameters in the // callee's backing size. callee_backing_size = substitute_dynamic_constants( &substs, callee_backing_size, &mut edit, ); *total = edit.add_dynamic_constant(DynamicConstant::add( *total, callee_backing_size, )); } } } _ => {} } } Ok(edit) }); fba }