diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 5d10a1f84493cd3e852f1e6be3e46e57874b1258..1f7d2dc838ebce62173f6b80ab8f55f2735018fb 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -286,30 +286,32 @@ class ExternalAppendOnlyMap[K, V, C]( private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => - val kcPairs = getMorePairs(it) + val kcPairs = new ArrayBuffer[(K, C)] + readNextHashCode(it, kcPairs) if (kcPairs.length > 0) { mergeHeap.enqueue(new StreamBuffer(it, kcPairs)) } } /** - * Fetch from the given iterator until a key of different hash is retrieved. + * Fill a buffer with the next set of keys with the same hash code from a given iterator. We + * read streams one hash code at a time to ensure we don't miss elements when they are merged. + * + * Assumes the given iterator is in sorted order of hash code. * - * In the event of key hash collisions, this ensures no pairs are hidden from being merged. - * Assume the given iterator is in sorted order. + * @param it iterator to read from + * @param buf buffer to write the results into */ - private def getMorePairs(it: BufferedIterator[(K, C)]): ArrayBuffer[(K, C)] = { - val kcPairs = new ArrayBuffer[(K, C)] + private def readNextHashCode(it: BufferedIterator[(K, C)], buf: ArrayBuffer[(K, C)]): Unit = { if (it.hasNext) { var kc = it.next() - kcPairs += kc + buf += kc val minHash = hashKey(kc) while (it.hasNext && it.head._1.hashCode() == minHash) { kc = it.next() - kcPairs += kc + buf += kc } } - kcPairs } /** @@ -321,7 +323,9 @@ class ExternalAppendOnlyMap[K, V, C]( while (i < buffer.pairs.length) { val pair = buffer.pairs(i) if (pair._1 == key) { - buffer.pairs.remove(i) + // Note that there's at most one pair in the buffer with a given key, since we always + // merge stuff in a map before spilling, so it's safe to return after the first we find + removeFromBuffer(buffer.pairs, i) return mergeCombiners(baseCombiner, pair._2) } i += 1 @@ -329,6 +333,19 @@ class ExternalAppendOnlyMap[K, V, C]( baseCombiner } + /** + * Remove the index'th element from an ArrayBuffer in constant time, swapping another element + * into its place. This is more efficient than the ArrayBuffer.remove method because it does + * not have to shift all the elements in the array over. It works for our array buffers because + * we don't care about the order of elements inside, we just want to search them for a key. + */ + private def removeFromBuffer[T](buffer: ArrayBuffer[T], index: Int): T = { + val elem = buffer(index) + buffer(index) = buffer(buffer.size - 1) // This also works if index == buffer.size - 1 + buffer.reduceToSize(buffer.size - 1) + elem + } + /** * Return true if there exists an input stream that still has unvisited pairs. */ @@ -346,7 +363,7 @@ class ExternalAppendOnlyMap[K, V, C]( val minBuffer = mergeHeap.dequeue() val minPairs = minBuffer.pairs val minHash = minBuffer.minKeyHash - val minPair = minPairs.remove(0) + val minPair = removeFromBuffer(minPairs, 0) val minKey = minPair._1 var minCombiner = minPair._2 assert(hashKey(minPair) == minHash) @@ -363,7 +380,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Repopulate each visited stream buffer and add it back to the queue if it is non-empty mergedBuffers.foreach { buffer => if (buffer.isEmpty) { - buffer.pairs ++= getMorePairs(buffer.iterator) + readNextHashCode(buffer.iterator, buffer.pairs) } if (!buffer.isEmpty) { mergeHeap.enqueue(buffer) @@ -375,10 +392,13 @@ class ExternalAppendOnlyMap[K, V, C]( /** * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash. - * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to - * hash collisions, it is possible for multiple keys to be "tied" for being the lowest. + * Each buffer maintains all of the key-value pairs with what is currently the lowest hash + * code among keys in the stream. There may be multiple keys if there are hash collisions. + * Note that because when we spill data out, we only spill one value for each key, there is + * at most one element for each key. * - * StreamBuffers are ordered by the minimum key hash found across all of their own pairs. + * StreamBuffers are ordered by the minimum key hash currently available in their stream so + * that we can put them into a heap and sort that. */ private class StreamBuffer( val iterator: BufferedIterator[(K, C)],