diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 5d10a1f84493cd3e852f1e6be3e46e57874b1258..1f7d2dc838ebce62173f6b80ab8f55f2735018fb 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -286,30 +286,32 @@ class ExternalAppendOnlyMap[K, V, C](
     private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
 
     inputStreams.foreach { it =>
-      val kcPairs = getMorePairs(it)
+      val kcPairs = new ArrayBuffer[(K, C)]
+      readNextHashCode(it, kcPairs)
       if (kcPairs.length > 0) {
         mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
       }
     }
 
     /**
-     * Fetch from the given iterator until a key of different hash is retrieved.
+     * Fill a buffer with the next set of keys with the same hash code from a given iterator. We
+     * read streams one hash code at a time to ensure we don't miss elements when they are merged.
+     *
+     * Assumes the given iterator is in sorted order of hash code.
      *
-     * In the event of key hash collisions, this ensures no pairs are hidden from being merged.
-     * Assume the given iterator is in sorted order.
+     * @param it iterator to read from
+     * @param buf buffer to write the results into
      */
-    private def getMorePairs(it: BufferedIterator[(K, C)]): ArrayBuffer[(K, C)] = {
-      val kcPairs = new ArrayBuffer[(K, C)]
+    private def readNextHashCode(it: BufferedIterator[(K, C)], buf: ArrayBuffer[(K, C)]): Unit = {
       if (it.hasNext) {
         var kc = it.next()
-        kcPairs += kc
+        buf += kc
         val minHash = hashKey(kc)
         while (it.hasNext && it.head._1.hashCode() == minHash) {
           kc = it.next()
-          kcPairs += kc
+          buf += kc
         }
       }
-      kcPairs
     }
 
     /**
@@ -321,7 +323,9 @@ class ExternalAppendOnlyMap[K, V, C](
       while (i < buffer.pairs.length) {
         val pair = buffer.pairs(i)
         if (pair._1 == key) {
-          buffer.pairs.remove(i)
+          // Note that there's at most one pair in the buffer with a given key, since we always
+          // merge stuff in a map before spilling, so it's safe to return after the first we find
+          removeFromBuffer(buffer.pairs, i)
           return mergeCombiners(baseCombiner, pair._2)
         }
         i += 1
@@ -329,6 +333,19 @@ class ExternalAppendOnlyMap[K, V, C](
       baseCombiner
     }
 
+    /**
+     * Remove the index'th element from an ArrayBuffer in constant time, swapping another element
+     * into its place. This is more efficient than the ArrayBuffer.remove method because it does
+     * not have to shift all the elements in the array over. It works for our array buffers because
+     * we don't care about the order of elements inside, we just want to search them for a key.
+     */
+    private def removeFromBuffer[T](buffer: ArrayBuffer[T], index: Int): T = {
+      val elem = buffer(index)
+      buffer(index) = buffer(buffer.size - 1)  // This also works if index == buffer.size - 1
+      buffer.reduceToSize(buffer.size - 1)
+      elem
+    }
+
     /**
      * Return true if there exists an input stream that still has unvisited pairs.
      */
@@ -346,7 +363,7 @@ class ExternalAppendOnlyMap[K, V, C](
       val minBuffer = mergeHeap.dequeue()
       val minPairs = minBuffer.pairs
       val minHash = minBuffer.minKeyHash
-      val minPair = minPairs.remove(0)
+      val minPair = removeFromBuffer(minPairs, 0)
       val minKey = minPair._1
       var minCombiner = minPair._2
       assert(hashKey(minPair) == minHash)
@@ -363,7 +380,7 @@ class ExternalAppendOnlyMap[K, V, C](
       // Repopulate each visited stream buffer and add it back to the queue if it is non-empty
       mergedBuffers.foreach { buffer =>
         if (buffer.isEmpty) {
-          buffer.pairs ++= getMorePairs(buffer.iterator)
+          readNextHashCode(buffer.iterator, buffer.pairs)
         }
         if (!buffer.isEmpty) {
           mergeHeap.enqueue(buffer)
@@ -375,10 +392,13 @@ class ExternalAppendOnlyMap[K, V, C](
 
     /**
      * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash.
-     * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to
-     * hash collisions, it is possible for multiple keys to be "tied" for being the lowest.
+     * Each buffer maintains all of the key-value pairs with what is currently the lowest hash
+     * code among keys in the stream. There may be multiple keys if there are hash collisions.
+     * Note that because when we spill data out, we only spill one value for each key, there is
+     * at most one element for each key.
      *
-     * StreamBuffers are ordered by the minimum key hash found across all of their own pairs.
+     * StreamBuffers are ordered by the minimum key hash currently available in their stream so
+     * that we can put them into a heap and sort that.
      */
     private class StreamBuffer(
         val iterator: BufferedIterator[(K, C)],