diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
index 8542541fe6d6950a98220e65790717e0a6f24edd..8bb4ee3bfa22e3ad1233d778e59feedc892115f7 100644
--- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
@@ -35,6 +35,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
   private var capacity = nextPowerOf2(initialCapacity)
   private var mask = capacity - 1
   private var curSize = 0
+  private var growThreshold = LOAD_FACTOR * capacity
 
   // Holds keys and values in the same array for memory locality; specifically, the order of
   // elements is key0, value0, key1, value1, key2, value2, etc.
@@ -80,9 +81,23 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
       haveNullValue = true
       return
     }
-    val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
-    if (isNewEntry) {
-      incrementSize()
+    var pos = rehash(key.hashCode) & mask
+    var i = 1
+    while (true) {
+      val curKey = data(2 * pos)
+      if (curKey.eq(null)) {
+        data(2 * pos) = k
+        data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+        incrementSize()  // Since we added a new key
+        return
+      } else if (k.eq(curKey) || k.equals(curKey)) {
+        data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+        return
+      } else {
+        val delta = i
+        pos = (pos + delta) & mask
+        i += 1
+      }
     }
   }
 
@@ -161,7 +176,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
   /** Increase table size by 1, rehashing if necessary */
   private def incrementSize() {
     curSize += 1
-    if (curSize > LOAD_FACTOR * capacity) {
+    if (curSize > growThreshold) {
       growTable()
     }
   }
@@ -174,33 +189,6 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
     it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
   }
 
-  /**
-   * Put an entry into a table represented by data, returning true if
-   * this increases the size of the table or false otherwise. Assumes
-   * that "data" has at least one empty slot.
-   */
-  private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
-    val mask = (data.length / 2) - 1
-    var pos = rehash(key.hashCode) & mask
-    var i = 1
-    while (true) {
-      val curKey = data(2 * pos)
-      if (curKey.eq(null)) {
-        data(2 * pos) = key
-        data(2 * pos + 1) = value.asInstanceOf[AnyRef]
-        return true
-      } else if (curKey.eq(key) || curKey.equals(key)) {
-        data(2 * pos + 1) = value.asInstanceOf[AnyRef]
-        return false
-      } else {
-        val delta = i
-        pos = (pos + delta) & mask
-        i += 1
-      }
-    }
-    return false  // Never reached but needed to keep compiler happy
-  }
-
   /** Double the table's size and re-hash everything */
   private def growTable() {
     val newCapacity = capacity * 2
@@ -210,16 +198,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
       throw new Exception("Can't make capacity bigger than 2^29 elements")
     }
     val newData = new Array[AnyRef](2 * newCapacity)
-    var pos = 0
-    while (pos < capacity) {
-      if (!data(2 * pos).eq(null)) {
-        putInto(newData, data(2 * pos), data(2 * pos + 1))
+    val newMask = newCapacity - 1
+    // Insert all our old values into the new array. Note that because our old keys are
+    // unique, there's no need to check for equality here when we insert.
+    var oldPos = 0
+    while (oldPos < capacity) {
+      if (!data(2 * oldPos).eq(null)) {
+        val key = data(2 * oldPos)
+        val value = data(2 * oldPos + 1)
+        var newPos = rehash(key.hashCode) & newMask
+        var i = 1
+        var keepGoing = true
+        while (keepGoing) {
+          val curKey = newData(2 * newPos)
+          if (curKey.eq(null)) {
+            newData(2 * newPos) = key
+            newData(2 * newPos + 1) = value
+            keepGoing = false
+          } else {
+            val delta = i
+            newPos = (newPos + delta) & newMask
+            i += 1
+          }
+        }
       }
-      pos += 1
+      oldPos += 1
     }
     data = newData
     capacity = newCapacity
-    mask = newCapacity - 1
+    mask = newMask
+    growThreshold = LOAD_FACTOR * newCapacity
   }
 
   private def nextPowerOf2(n: Int): Int = {