Newer
Older
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
use union_find::{QuickFindUf, UnionBySize, UnionFind};
use crate::*;
/*
* Top level function to legalize the reference semantics of a Hercules IR
* function. Hercules IR is a value semantics representation, meaning that all
* program state is in the form of copyable values, and mutation takes place by
* making a new value that is a copy of the old value with some modification.
* This representation is extremely convenient for optimization, but is not good
* for code generation, where we need to generate code with references to get
* good performance. Hercules IR can alternatively be interpreted using
* reference semantics, where pointers to collection objects are passed around,
* read from, and written to. However, the value semantics and reference
* semantics interpretation of a Hercules IR function may not be equal - this
* pass transforms a Hercules IR function such that its new value semantics is
* the same as its old value semantics and that its new reference semantics is
* the same as its new value semantics. This pass returns a placement of nodes
* into ordered basic blocks, since the reference semantics of a function
* depends on the order of execution with respect to anti-dependencies. This
* is analogous to global code motion from the original sea of nodes paper.
*
* Our strategy for handling multiple mutating users of a collection is to treat
* the problem similar to register allocation; we perform a liveness analysis,
* spill constants into newly allocated constants, and read back the spilled
* contents when they are used after the first mutation. It's not obvious how
* many spills are needed upfront, and newly spilled constants may affect the
* liveness analysis result, so every spill restarts the process of checking for
* spills. Once no more spills are found, the process terminates. When a spill
* is found, the basic block assignments, and all the other analyses, are not
* necessarily valid anymore, so this function is called in a loop in the pass
* manager until no more spills are found.
*
* GCM also generally tries to massage the code to be properly formed for the
* device backends in other ways. For example, reduction cycles through the
* `init` inputs of an inner reduce and a use of a non-parallel outer reduce in
* nested fork-joins is not schedulable, since the outer reduce doesn't dominate
* the fork corresponding to the inner reduce. In such cases, the outer fork-
* join must be split and unforkified.
*
* GCM is additionally complicated by the need to generate code that references
* objects across multiple devices. In particular, GCM makes sure that every
* object lives on exactly one device, so that references to that object always
* live on a single device. Additionally, GCM makes sure that the objects that a
* node may produce are all on the same device, so that a pointer produced by,
* for example, a select node can only refer to memory on a single device. Extra
* collection constants and potentially inter-device copies are inserted as
* necessary to make sure this is true - an inter-device copy is represented by
* a write where the `collect` and `data` inputs are on different devices. This
* is only valid in RT functions - it is asserted that this isn't necessary in
* device functions. This process "colors" the nodes in the function.
*
* GCM has one final responsibility - object allocation. Each Hercules function
* receives a pointer to a "backing" memory where collection constants live. The
* backing memory a function receives is for the constants in that function and
* the constants of every called function. Concretely, a function will pass a
* sub-regions of its backing memory to a callee, which during the call is that
* function's backing memory. Object allocation consists of finding the required
* sizes of all collection constants and functions in terms of dynamic constants
* (dynamic constant math is expressive enough to represent sizes of types,
* which is very convenient) and determining the concrete offsets into the
* backing memory where constants and callee sub-regions live. When two users of
* backing memory are never live at once, they may share backing memory. This is
* done after nodes are given a single device color, since we need to know what
* values are on what devices before we can allocate them to backing memory,
* since there are separate backing memories per-device.
*/
pub fn gcm(
editor: &mut FunctionEditor,
def_use: &ImmutableDefUseMap,
reverse_postorder: &Vec<NodeID>,
typing: &Vec<TypeID>,
control_subgraph: &Subgraph,
dom: &DomTree,
fork_join_map: &HashMap<NodeID, NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
backing_allocations: &BackingAllocations,
) -> Option<(BasicBlocks, FunctionNodeColors, FunctionBackingAllocation)> {
if preliminary_fixups(editor, fork_join_map, loops, reduce_cycles) {
return None;
}
editor.func_id(),
def_use,
reverse_postorder,
let liveness = liveness_dataflow(
editor.func(),
editor.func_id(),
control_subgraph,
objects,
&bbs,
);
if spill_clones(editor, typing, control_subgraph, objects, &bbs, &liveness) {
return None;
if add_extra_collection_dims(
editor,
typing,
fork_join_map,
fork_join_nest,
objects,
devices,
&bbs,
) {
return None;
}
let Some(node_colors) = color_nodes(editor, typing, &objects, &devices, node_colors) else {
return None;
};
let mut alignments = vec![];
Ref::map(editor.get_types(), |types| {
for idx in 0..types.len() {
if types[idx].is_control() || types[idx].is_multireturn() {
alignments.push(0);
} else {
alignments.push(get_type_alignment(types, TypeID::new(idx)));
}
}
&()
});
let backing_allocation = object_allocation(
editor,
typing,
&node_colors,
&alignments,
&liveness,
backing_allocations,
);
Some((bbs, node_colors, backing_allocation))
/*
* Do misc. fixups on the IR, such as unforkifying sequential outer forks with
* problematic reduces.
*/
fn preliminary_fixups(
editor: &mut FunctionEditor,
fork_join_map: &HashMap<NodeID, NodeID>,
loops: &LoopTree,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
) -> bool {
let nodes = &editor.func().nodes;
let schedules = &editor.func().schedules;
// Sequentialize non-parallel forks that contain problematic reduce cycles.
for (reduce, cycle) in reduce_cycles {
if !schedules[reduce.idx()].contains(&Schedule::ParallelReduce)
&& cycle.into_iter().any(|id| nodes[id.idx()].is_reduce())
{
let join = nodes[reduce.idx()].try_reduce().unwrap().0;
let fork = fork_join_map
.into_iter()
.filter(|(_, j)| join == **j)
.map(|(f, _)| *f)
.next()
.unwrap();
let (forks, _) = split_fork(editor, fork, join, reduce_cycles).unwrap();
if forks.len() > 1 {
return true;
}
unforkify(editor, fork, join, loops);
return true;
}
}
// Get rid of the backward edge on parallel reduces in fork-joins.
for (_, join) in fork_join_map {
let parallel_reduces: Vec<_> = editor
.get_users(*join)
.filter(|id| {
nodes[id.idx()].is_reduce()
&& schedules[id.idx()].contains(&Schedule::ParallelReduce)
})
.collect();
for reduce in parallel_reduces {
if reduce_cycles[&reduce].is_empty() {
continue;
}
let (_, init, _) = nodes[reduce.idx()].try_reduce().unwrap();
// Replace uses of the reduce in its cycle with the init.
let success = editor.edit(|edit| {
edit.replace_all_uses_where(reduce, init, |id| reduce_cycles[&reduce].contains(id))
});
assert!(success);
return true;
}
}
/*
* Top level global code motion function. Assigns each data node to one of its
* immediate control use / user nodes, forming (unordered) basic blocks. Returns
* the control node / basic block each node is in. Takes in a partial
* partitioning that must be respected. Based on the schedule-early-schedule-
* late method from Cliff Click's PhD thesis. May fail if an anti-dependency
* edge can't be satisfied - in this case, a clone that has to be induced is
* returned instead.
*/
fn basic_blocks(
function: &Function,
func_id: FunctionID,
def_use: &ImmutableDefUseMap,
reverse_postorder: &Vec<NodeID>,
reduce_cycles: &HashMap<NodeID, HashSet<NodeID>>,
fork_join_map: &HashMap<NodeID, NodeID>,
objects: &CollectionObjects,
devices: &Vec<Device>,
) -> BasicBlocks {
let mut bbs: Vec<Option<NodeID>> = vec![None; function.nodes.len()];
// Step 1: assign the basic block locations of all nodes that must be in a
// specific block. This includes control nodes as well as some special data
// nodes, such as phis.
for idx in 0..function.nodes.len() {
match function.nodes[idx] {
Node::Phi { control, data: _ } => bbs[idx] = Some(control),
Node::ThreadID {
control,
dimension: _,
} => bbs[idx] = Some(control),
Node::Reduce {
control,
init: _,
reduct: _,
} => bbs[idx] = Some(control),
Node::Call {
control,
function: _,
dynamic_constants: _,
args: _,
} => bbs[idx] = Some(control),
Node::DataProjection { data, selection: _ } => {
let Node::Call { control, .. } = function.nodes[data.idx()] else {
panic!();
};
bbs[idx] = Some(control);
}
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
Node::Parameter { index: _ } => bbs[idx] = Some(NodeID::new(0)),
_ if function.nodes[idx].is_control() => bbs[idx] = Some(NodeID::new(idx)),
_ => {}
}
}
// Step 2: schedule early. Place nodes in the earliest position they could
// go - use worklist to iterate nodes.
let mut schedule_early = bbs.clone();
let mut worklist = VecDeque::from(reverse_postorder.clone());
while let Some(id) = worklist.pop_front() {
if schedule_early[id.idx()].is_some() {
continue;
}
// For every use, check what block is its "schedule early" block. This
// node goes in the lowest block amongst those blocks.
let use_places: Option<Vec<NodeID>> = get_uses(&function.nodes[id.idx()])
.as_ref()
.into_iter()
.map(|id| *id)
.map(|id| schedule_early[id.idx()])
.collect();
if let Some(use_places) = use_places {
// If every use has been placed, we can place this node as the
// lowest place in the domtree that dominates all of the use places.
let lowest = dom.lowest_amongst(use_places.into_iter());
schedule_early[id.idx()] = Some(lowest);
} else {
// If not, then just push this node back on the worklist.
worklist.push_back(id);
}
}
// Step 3: find anti-dependence edges. An anti-dependence edge needs to be
// drawn between a collection reading node and a collection mutating node
// when the following conditions are true:
//
// 1: The reading and mutating nodes may involve the same collection.
// 2: The node producing the collection used by the reading node is in a
// schedule early block that dominates the schedule early block of the
// mutating node. The node producing the collection used by the reading
// node may be an originator of a collection, phi or reduce, or mutator,
// but not forwarding read - forwarding reads are collapsed, and the
// bottom read is treated as reading from the transitive parent of the
// forwarding read(s).
// 3: If the node producing the collection is a reduce node, then any read
// users that aren't in the reduce's cycle shouldn't anti-depend user any
// mutators in the reduce cycle.
//
// Because we do a liveness analysis based spill of collections, anti-
// dependencies can be best effort. Thus, when we encounter a read and
// mutator where the read doesn't dominate the mutator, but an anti-depdence
// edge is derived for the pair, we just don't draw the edge since it would
// break the scheduler.
let mut antideps = BTreeSet::new();
for id in reverse_postorder.iter() {
// Find a terminating read node and the collections it reads.
let terminating_reads: BTreeSet<_> =
terminating_reads(function, func_id, *id, objects).collect();
if !terminating_reads.is_empty() {
// Walk forwarding reads to find anti-dependency roots.
let mut workset = terminating_reads.clone();
let mut roots = BTreeSet::new();
while let Some(pop) = workset.pop_first() {
let forwarded: BTreeSet<_> =
forwarding_reads(function, func_id, pop, objects).collect();
if forwarded.is_empty() {
roots.insert(pop);
} else {
workset.extend(forwarded);
}
}
// For each root, find mutating nodes dominated by the root that
// modify an object read on any input of the current node (the
// terminating read).
// TODO: make this less outrageously inefficient.
let func_objects = &objects[&func_id];
for root in roots.iter() {
let root_is_reduce_and_read_isnt_in_cycle = reduce_cycles
.get(root)
.map(|cycle| !cycle.contains(&id))
.unwrap_or(false);
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
let root_early = schedule_early[root.idx()].unwrap();
let mut root_block_iterated_users: BTreeSet<NodeID> = BTreeSet::new();
let mut workset = BTreeSet::new();
workset.insert(*root);
while let Some(pop) = workset.pop_first() {
let users = def_use.get_users(pop).into_iter().filter(|user| {
!function.nodes[user.idx()].is_phi()
&& !function.nodes[user.idx()].is_reduce()
&& schedule_early[user.idx()].unwrap() == root_early
});
workset.extend(users.clone());
root_block_iterated_users.extend(users);
}
let read_objs: BTreeSet<_> = terminating_reads
.iter()
.map(|read_use| func_objects.objects(*read_use).into_iter())
.flatten()
.map(|id| *id)
.collect();
for mutator in reverse_postorder.iter() {
let mutator_early = schedule_early[mutator.idx()].unwrap();
if dom.does_dom(root_early, mutator_early)
&& (root_early != mutator_early
|| root_block_iterated_users.contains(&mutator))
&& mutating_objects(function, func_id, *mutator, objects)
.any(|mutated| read_objs.contains(&mutated))
&& id != mutator
&& (!root_is_reduce_and_read_isnt_in_cycle
|| !reduce_cycles
.get(root)
.map(|cycle| cycle.contains(mutator))
.unwrap_or(false))
&& dom.does_dom(schedule_early[id.idx()].unwrap(), mutator_early)
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
{
antideps.insert((*id, *mutator));
}
}
}
}
}
let mut antideps_uses = vec![vec![]; function.nodes.len()];
let mut antideps_users = vec![vec![]; function.nodes.len()];
for (reader, mutator) in antideps.iter() {
antideps_uses[mutator.idx()].push(*reader);
antideps_users[reader.idx()].push(*mutator);
}
// Step 4: schedule late and pick each nodes final position. Since the late
// schedule of each node depends on the final positions of its users, these
// two steps must be fused. Compute their latest position, then use the
// control dependent + shallow loop heuristic to actually place them. A
// placement might not necessarily be found due to anti-dependency edges.
// These are optional and not necessary to consider, but we do since obeying
// them can reduce the number of clones. If the worklist stops making
// progress, stop considering the anti-dependency edges.
let join_fork_map: HashMap<NodeID, NodeID> = fork_join_map
.into_iter()
.map(|(fork, join)| (*join, *fork))
.collect();
let mut worklist = VecDeque::from_iter(reverse_postorder.into_iter().map(|id| *id).rev());
let mut num_skip_iters = 0;
let mut consider_antidependencies = true;
while let Some(id) = worklist.pop_front() {
if num_skip_iters >= worklist.len() {
consider_antidependencies = false;
}
if bbs[id.idx()].is_some() {
num_skip_iters = 0;
continue;
}
// Calculate the least common ancestor of user blocks, a.k.a. the "late"
// schedule.
let calculate_lca = || -> Option<_> {
let mut lca = None;
// Helper to incrementally update the LCA.
let mut update_lca = |a| {
if let Some(acc) = lca {
lca = Some(dom.least_common_ancestor(acc, a));
} else {
lca = Some(a);
}
};
// For every user, consider where we need to be to directly dominate the
// user.
for user in def_use
.get_users(id)
.as_ref()
.into_iter()
.chain(if consider_antidependencies {
Either::Left(antideps_users[id.idx()].iter())
} else {
Either::Right(empty())
})
.map(|id| *id)
{
if let Node::Phi { control, data } = &function.nodes[user.idx()] {
// For phis, we need to dominate the block jumping to the phi in
// the slot that corresponds to our use.
for (control, data) in
zip(get_uses(&function.nodes[control.idx()]).as_ref(), data)
{
if id == *data {
update_lca(*control);
}
}
} else if let Node::Reduce {
control,
init,
reduct,
} = &function.nodes[user.idx()]
{
// For reduces, we need to either dominate the block right
// before the fork if we're the init input, or we need to
// dominate the join if we're the reduct input.
if id == *init {
let before_fork = function.nodes[join_fork_map[control].idx()]
.try_fork()
.unwrap()
.0;
update_lca(before_fork);
} else {
assert_eq!(id, *reduct);
update_lca(*control);
}
} else {
// For everything else, we just need to dominate the user.
update_lca(bbs[user.idx()]?);
}
}
Some(lca)
};
// Check if all users have been placed. If one of them hasn't, then add
// this node back on to the worklist.
let Some(lca) = calculate_lca() else {
worklist.push_back(id);
num_skip_iters += 1;
continue;
};
// Look between the LCA and the schedule early location to place the
// node.
let schedule_early = schedule_early[id.idx()].unwrap();
// If the node has no users, then it doesn't really matter where we
// place it - just place it at the early placement.
let mut chain = dom.chain(schedule_late, schedule_early);
if let Some(mut location) = chain.next() {
while let Some(control_node) = chain.next() {
// If the next node further up the dominator tree is in a shallower
// loop nest or if we can get out of a reduce loop when we don't
// need to be in one, place this data node in a higher-up location.
// Only do this is the node isn't a constant or undef - if a
// node is a constant or undef, we want its placement to be as
// control dependent as possible, even inside loops. In GPU
// functions specifically, lift constants that may be returned
// outside fork-joins.
let is_constant_or_undef = (function.nodes[id.idx()].is_constant()
|| function.nodes[id.idx()].is_undef())
&& !types[typing[id.idx()].idx()].is_primitive();
let is_gpu_returned = devices[func_id.idx()] == Device::CUDA
&& objects[&func_id].objects(id).into_iter().any(|obj| {
objects[&func_id]
.all_returned_objects()
.any(|ret| ret == *obj)
});
let old_nest = loops
.header_of(location)
.map(|header| loops.nesting(header).unwrap());
let new_nest = loops
.header_of(control_node)
.map(|header| loops.nesting(header).unwrap());
let shallower_nest = if let (Some(old_nest), Some(new_nest)) = (old_nest, new_nest)
{
old_nest > new_nest
} else {
// If the new location isn't a loop, it's nesting level should
// be considered "shallower" if the current location is in a
// loop.
old_nest.is_some()
};
// This will move all nodes that don't need to be in reduce loops
// outside of reduce loops. Nodes that do need to be in a reduce
// loop use the reduce node forming the loop, so the dominator chain
// will consist of one block, and this loop won't ever iterate.
let currently_at_join = function.nodes[location.idx()].is_join()
&& !function.nodes[control_node.idx()].is_join();
if (!is_constant_or_undef || is_gpu_returned)
&& (shallower_nest || currently_at_join)
{
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
location = control_node;
}
}
bbs[id.idx()] = Some(location);
num_skip_iters = 0;
} else {
// If there is no valid location for this node, then it's a reading
// node of a collection that can't be placed above a mutation that
// anti-depend uses it. Push the node back on the list, and we'll
// stop considering anti-dependencies soon. Don't immediately stop
// considering anti-dependencies, as we may be able to eak out some
// more use of them.
worklist.push_back(id);
num_skip_iters += 1;
continue;
}
}
let bbs: Vec<_> = bbs.into_iter().map(Option::unwrap).collect();
// Calculate the number of phis and reduces per basic block. We use this to
// emit phis and reduces at the top of basic blocks. We want to emit phis
// and reduces first into ordered basic blocks for two reasons:
// 1. This is useful for liveness analysis.
// 2. This is needed for some backends - LLVM expects phis to be at the top
// of basic blocks.
let mut num_phis_reduces = vec![0; function.nodes.len()];
for (node_idx, bb) in bbs.iter().enumerate() {
let node = &function.nodes[node_idx];
if node.is_phi() || node.is_reduce() {
num_phis_reduces[bb.idx()] += 1;
}
}
// Step 5: determine the order of nodes inside each block. Use worklist to
// add nodes to blocks in order that obeys dependencies.
let mut order: Vec<Vec<NodeID>> = vec![vec![]; function.nodes.len()];
let mut worklist = VecDeque::from_iter(
reverse_postorder
.into_iter()
.filter(|id| !function.nodes[id.idx()].is_control()),
);
let mut visited = bitvec![u8, Lsb0; 0; function.nodes.len()];
let mut num_skip_iters = 0;
let mut consider_antidependencies = true;
while let Some(id) = worklist.pop_front() {
// If the worklist isn't making progress, then there's at least one
// reading node of a collection that is in a anti-depend + normal depend
// use cycle with a mutating node. See above comment about anti-
// dependencies being optional; we just stop considering them here.
if num_skip_iters >= worklist.len() {
consider_antidependencies = false;
}
// Phis and reduces always get emitted. Other nodes need to obey
// dependency relationships and need to come after phis and reduces.
let node = &function.nodes[id.idx()];
let bb = bbs[id.idx()];
if node.is_phi()
|| node.is_reduce()
|| (num_phis_reduces[bb.idx()] == 0
&& get_uses(node)
.as_ref()
.into_iter()
.chain(if consider_antidependencies {
Either::Left(antideps_uses[id.idx()].iter())
} else {
Either::Right(empty())
})
.all(|u| {
function.nodes[u.idx()].is_control()
|| bbs[u.idx()] != bbs[id.idx()]
|| visited[u.idx()]
}))
{
order[bb.idx()].push(*id);
visited.set(id.idx(), true);
num_skip_iters = 0;
if node.is_phi() || node.is_reduce() {
num_phis_reduces[bb.idx()] -= 1;
}
} else {
worklist.push_back(id);
num_skip_iters += 1;
}
}
(bbs, order)
}
fn terminating_reads<'a>(
function: &'a Function,
func_id: FunctionID,
reader: NodeID,
objects: &'a CollectionObjects,
) -> Box<dyn Iterator<Item = NodeID> + 'a> {
match function.nodes[reader.idx()] {
Node::Read {
collect,
indices: _,
} if objects[&func_id].objects(reader).is_empty() => Box::new(once(collect)),
Node::Write {
collect: _,
data,
indices: _,
} if !objects[&func_id].objects(data).is_empty() => Box::new(once(data)),
Node::Call {
control: _,
function: callee,
dynamic_constants: _,
ref args,
} => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
let objects = &objects[&callee];
if !objects.is_mutated(param_obj) && !returns.any(|ret| ret == param_obj) {
Node::LibraryCall {
library_function,
ref args,
ty: _,
device: _,
} => match library_function {
LibraryFunction::GEMM => Box::new(once(args[1]).chain(once(args[2]))),
},
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
_ => Box::new(empty()),
}
}
fn forwarding_reads<'a>(
function: &'a Function,
func_id: FunctionID,
reader: NodeID,
objects: &'a CollectionObjects,
) -> Box<dyn Iterator<Item = NodeID> + 'a> {
match function.nodes[reader.idx()] {
Node::Read {
collect,
indices: _,
} if !objects[&func_id].objects(reader).is_empty() => Box::new(once(collect)),
Node::Ternary {
op: TernaryOperator::Select,
first: _,
second,
third,
} if !objects[&func_id].objects(reader).is_empty() => {
Box::new(once(second).chain(once(third)))
}
Node::Call {
control: _,
function: callee,
dynamic_constants: _,
ref args,
} => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
let objects = &objects[&callee];
if !objects.is_mutated(param_obj) && returns.any(|ret| ret == param_obj) {
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
Some(*arg)
} else {
None
}
})),
_ => Box::new(empty()),
}
}
fn mutating_objects<'a>(
function: &'a Function,
func_id: FunctionID,
mutator: NodeID,
objects: &'a CollectionObjects,
) -> Box<dyn Iterator<Item = CollectionObjectID> + 'a> {
match function.nodes[mutator.idx()] {
Node::Write {
collect,
data: _,
indices: _,
} => Box::new(objects[&func_id].objects(collect).into_iter().map(|id| *id)),
Node::Call {
control: _,
function: callee,
dynamic_constants: _,
ref args,
} => Box::new(
args.into_iter()
.enumerate()
.filter_map(move |(idx, arg)| {
let callee_objects = &objects[&callee];
let param_obj = callee_objects.param_to_object(idx)?;
if callee_objects.is_mutated(param_obj) {
Some(objects[&func_id].objects(*arg).into_iter().map(|id| *id))
} else {
None
}
})
.flatten(),
),
Node::LibraryCall {
library_function,
ref args,
ty: _,
device: _,
} => match library_function {
LibraryFunction::GEMM => {
Box::new(objects[&func_id].objects(args[0]).into_iter().map(|id| *id))
}
},
fn mutating_writes<'a>(
function: &'a Function,
mutator: NodeID,
objects: &'a CollectionObjects,
) -> Box<dyn Iterator<Item = NodeID> + 'a> {
match function.nodes[mutator.idx()] {
Node::Write {
collect,
data: _,
indices: _,
} => Box::new(once(collect)),
Node::Call {
control: _,
function: callee,
dynamic_constants: _,
ref args,
} => Box::new(args.into_iter().enumerate().filter_map(move |(idx, arg)| {
let callee_objects = &objects[&callee];
let param_obj = callee_objects.param_to_object(idx)?;
if callee_objects.is_mutated(param_obj) {
Some(*arg)
} else {
None
}
})),
Node::LibraryCall {
library_function,
ref args,
ty: _,
device: _,
} => match library_function {
LibraryFunction::GEMM => Box::new(once(args[0])),
},
_ => Box::new(empty()),
}
}
/*
* Top level function to find implicit clones that need to be spilled. Returns
* whether a clone was spilled, in which case the whole scheduling process must
* be restarted.
*/
fn spill_clones(
editor: &mut FunctionEditor,
typing: &Vec<TypeID>,
control_subgraph: &Subgraph,
objects: &CollectionObjects,
bbs: &BasicBlocks,
// Step 1: compute an interference graph from the liveness result. This
// graph contains a vertex per node ID producing a collection value and an
// edge per pair of node IDs that interfere. Nodes A and B interfere if node
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
// A is defined right above a point where node B is live and A != B. Extra
// edges are drawn for forwarding reads - when there is a node A that is a
// forwarding read of a node B, A and B really have the same live range for
// the purpose of determining when spills are necessary, since forwarding
// reads can be thought of as nothing but pointer math. For this purpose, we
// maintain a union-find of nodes that form a forwarding read DAG (notably,
// phis and reduces are not considered forwarding reads). The more precise
// version of the interference condition is nodes A and B interfere is node
// A is defined right above a point where a node C is live where C is in the
// same union-find class as B.
// Assemble the union-find to group forwarding read DAGs.
let mut union_find = QuickFindUf::<UnionBySize>::new(editor.func().nodes.len());
for id in editor.node_ids() {
for forwarding_read in forwarding_reads(editor.func(), editor.func_id(), id, objects) {
union_find.union(id.idx(), forwarding_read.idx());
}
}
// Figure out which classes contain which node IDs, since we need to iterate
// the disjoint sets.
let mut disjoint_sets: BTreeMap<usize, Vec<NodeID>> = BTreeMap::new();
for id in editor.node_ids() {
disjoint_sets
.entry(union_find.find(id.idx()))
.or_default()
.push(id);
}
// Create the graph.
let mut edges = vec![];
for (bb, liveness) in liveness {
let insts = &bbs.1[bb.idx()];
for (node, live) in zip(insts, liveness.into_iter().skip(1)) {
for live_node in live {
for live_node in disjoint_sets[&union_find.find(live_node.idx())].iter() {
if *node != *live_node {
edges.push((*node, *live_node));
}
// Step 2: filter edges (A, B) to just see edges where A uses B and A
// mutates B. These are the edges that may require a spill.
let mut spill_edges = edges.into_iter().filter(|(a, b)| {
mutating_writes(editor.func(), *a, objects).any(|id| id == *b)
|| (get_uses(&editor.func().nodes[a.idx()])
.as_ref()
.into_iter()
.any(|u| *u == *b)
&& (editor.func().nodes[a.idx()].is_phi()
|| editor.func().nodes[a.idx()].is_reduce())
&& !editor.func().nodes[a.idx()]
.try_reduce()
.map(|(_, init, reduct)| {
(init == *b || reduct == *b)
&& editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
})
.unwrap_or(false)
&& !editor.func().nodes[a.idx()]
.try_phi()
.map(|(_, data)| {
data.contains(b)
&& editor.func().schedules[a.idx()].contains(&Schedule::ParallelReduce)
})
.unwrap_or(false))
// Step 3: if there is a spill edge, spill it and return true. Otherwise,
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
// return false.
if let Some((user, obj)) = spill_edges.next() {
// Figure out the most immediate dominating region for every basic
// block. These are the points where spill slot phis get placed.
let nodes = &editor.func().nodes;
let mut imm_dom_reg = vec![NodeID::new(0); editor.func().nodes.len()];
for (idx, node) in nodes.into_iter().enumerate() {
if node.is_region() {
imm_dom_reg[idx] = NodeID::new(idx);
}
}
let rev_po = control_subgraph.rev_po(NodeID::new(0));
for bb in rev_po.iter() {
if !nodes[bb.idx()].is_region() && !nodes[bb.idx()].is_start() {
imm_dom_reg[bb.idx()] =
imm_dom_reg[control_subgraph.preds(*bb).next().unwrap().idx()];
}
}
let other_obj_users: Vec<_> = editor.get_users(obj).filter(|id| *id != user).collect();
let mut dummy_phis = vec![NodeID::new(0); imm_dom_reg.len()];
let mut success = editor.edit(|mut edit| {
// Construct the spill slot. This is just a constant that gets phi-
// ed throughout the entire function.
let cons_id = edit.add_zero_constant(typing[obj.idx()]);
let slot_id = edit.add_node(Node::Constant { id: cons_id });
edit = edit.add_schedule(slot_id, Schedule::NoResetConstant)?;
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Allocate IDs for phis that move the spill slot throughout the
// function without implicit clones. These are dummy phis, since
// there are potentially cycles between them. We will replace them
// later.
for (idx, reg) in imm_dom_reg.iter().enumerate().skip(1) {
if idx == reg.idx() {
dummy_phis[idx] = edit.add_node(Node::Phi {
control: *reg,
data: empty().collect(),
});
}
}
// Spill `obj` before `user` potentially modifies it.
let spill_region = imm_dom_reg[bbs.0[obj.idx()].idx()];
let spill_id = edit.add_node(Node::Write {
collect: if spill_region == NodeID::new(0) {
slot_id
} else {
dummy_phis[spill_region.idx()]
},
data: obj,
indices: empty().collect(),
});
// Before each other user, unspill `obj`.
for other_user in other_obj_users {
let other_region = imm_dom_reg[bbs.0[other_user.idx()].idx()];
// If this assert fails, then `obj` is not in the first basic
// block, but it has a user that is in the first basic block,
// which violates SSA.
assert!(other_region == spill_region || other_region != NodeID::new(0));
// If an other user is a phi, we need to be a little careful
// about how we insert unspilling code for `obj`. Instead of
// inserting an unspill in the same block as the user, we need
// to insert one in each predecessor of the phi that corresponds
// to a use of `obj`. Since this requires modifying individual
// uses in a phi, just rebuild the node entirely.
if let Node::Phi { control, data } = edit.get_node(other_user).clone() {
assert_eq!(control, other_region);
let mut new_data = vec![];
for (pred, data) in zip(control_subgraph.preds(control), data) {
let pred = imm_dom_reg[pred.idx()];
if data == obj {
let unspill_id = edit.add_node(Node::Write {
collect: obj,
data: if pred == spill_region {
spill_id
} else {
dummy_phis[pred.idx()]
},
indices: empty().collect(),
});
new_data.push(unspill_id);
} else {
new_data.push(data);
}
}
let new_phi = edit.add_node(Node::Phi {
control,
data: new_data.into_boxed_slice(),
});
edit = edit.replace_all_uses(other_user, new_phi)?;
edit = edit.delete_node(other_user)?;
} else {
let unspill_id = edit.add_node(Node::Write {
collect: obj,
data: if other_region == spill_region {
spill_id
} else {
dummy_phis[other_region.idx()]
},
indices: empty().collect(),
});
edit = edit.replace_all_uses_where(obj, unspill_id, |id| *id == other_user)?;