diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 9e151fc7a9141a2435e502beff4114ddccba364f..2fd17267ac42727baa0c7e709e261af8eb266350 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -65,6 +65,11 @@
       <artifactId>junit-interface</artifactId>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.mockito</groupId>
+      <artifactId>mockito-all</artifactId>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
   <build>
     <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 19d6a169fd2ad786a531b4de80c2178baf9b7ad8..bd4ca74cc7764dc269d7dd63cd566db4f83cabd1 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -23,6 +23,8 @@ import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 
+import com.google.common.annotations.VisibleForTesting;
+
 import org.apache.spark.unsafe.*;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.array.LongArray;
@@ -36,9 +38,8 @@ import org.apache.spark.unsafe.memory.*;
  * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers,
  * which is guaranteed to exhaust the space.
  * <p>
- * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is
- * higher than this, you should probably be using sorting instead of hashing for better cache
- * locality.
+ * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should
+ * probably be using sorting instead of hashing for better cache locality.
  * <p>
  * This class is not thread safe.
  */
@@ -48,6 +49,11 @@ public final class BytesToBytesMap {
 
   private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
 
+  /**
+   * Special record length that is placed after the last record in a data page.
+   */
+  private static final int END_OF_PAGE_MARKER = -1;
+
   private final TaskMemoryManager memoryManager;
 
   /**
@@ -64,7 +70,7 @@ public final class BytesToBytesMap {
 
   /**
    * Offset into `currentDataPage` that points to the location where new data can be inserted into
-   * the page.
+   * the page. This does not incorporate the page's base offset.
    */
   private long pageCursor = 0;
 
@@ -74,6 +80,15 @@ public final class BytesToBytesMap {
    */
   private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
 
+  /**
+   * The maximum number of keys that BytesToBytesMap supports. The hash table has to be
+   * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since
+   * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array
+   * entries per key, giving us a maximum capacity of (1 << 29).
+   */
+  @VisibleForTesting
+  static final int MAX_CAPACITY = (1 << 29);
+
   // This choice of page table size and page size means that we can address up to 500 gigabytes
   // of memory.
 
@@ -143,6 +158,13 @@ public final class BytesToBytesMap {
     this.loadFactor = loadFactor;
     this.loc = new Location();
     this.enablePerfMetrics = enablePerfMetrics;
+    if (initialCapacity <= 0) {
+      throw new IllegalArgumentException("Initial capacity must be greater than 0");
+    }
+    if (initialCapacity > MAX_CAPACITY) {
+      throw new IllegalArgumentException(
+        "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY);
+    }
     allocate(initialCapacity);
   }
 
@@ -162,6 +184,55 @@ public final class BytesToBytesMap {
    */
   public int size() { return size; }
 
+  private static final class BytesToBytesMapIterator implements Iterator<Location> {
+
+    private final int numRecords;
+    private final Iterator<MemoryBlock> dataPagesIterator;
+    private final Location loc;
+
+    private int currentRecordNumber = 0;
+    private Object pageBaseObject;
+    private long offsetInPage;
+
+    BytesToBytesMapIterator(int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc) {
+      this.numRecords = numRecords;
+      this.dataPagesIterator = dataPagesIterator;
+      this.loc = loc;
+      if (dataPagesIterator.hasNext()) {
+        advanceToNextPage();
+      }
+    }
+
+    private void advanceToNextPage() {
+      final MemoryBlock currentPage = dataPagesIterator.next();
+      pageBaseObject = currentPage.getBaseObject();
+      offsetInPage = currentPage.getBaseOffset();
+    }
+
+    @Override
+    public boolean hasNext() {
+      return currentRecordNumber != numRecords;
+    }
+
+    @Override
+    public Location next() {
+      int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+      if (keyLength == END_OF_PAGE_MARKER) {
+        advanceToNextPage();
+        keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+      }
+      loc.with(pageBaseObject, offsetInPage);
+      offsetInPage += 8 + 8 + keyLength + loc.getValueLength();
+      currentRecordNumber++;
+      return loc;
+    }
+
+    @Override
+    public void remove() {
+      throw new UnsupportedOperationException();
+    }
+  }
+
   /**
    * Returns an iterator for iterating over the entries of this map.
    *
@@ -171,27 +242,7 @@ public final class BytesToBytesMap {
    * `lookup()`, the behavior of the returned iterator is undefined.
    */
   public Iterator<Location> iterator() {
-    return new Iterator<Location>() {
-
-      private int nextPos = bitset.nextSetBit(0);
-
-      @Override
-      public boolean hasNext() {
-        return nextPos != -1;
-      }
-
-      @Override
-      public Location next() {
-        final int pos = nextPos;
-        nextPos = bitset.nextSetBit(nextPos + 1);
-        return loc.with(pos, 0, true);
-      }
-
-      @Override
-      public void remove() {
-        throw new UnsupportedOperationException();
-      }
-    };
+    return new BytesToBytesMapIterator(size, dataPages.iterator(), loc);
   }
 
   /**
@@ -268,8 +319,11 @@ public final class BytesToBytesMap {
     private int valueLength;
 
     private void updateAddressesAndSizes(long fullKeyAddress) {
-        final Object page = memoryManager.getPage(fullKeyAddress);
-        final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress);
+      updateAddressesAndSizes(
+        memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress));
+    }
+
+    private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
         long position = keyOffsetInPage;
         keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
         position += 8; // word used to store the key size
@@ -291,6 +345,12 @@ public final class BytesToBytesMap {
       return this;
     }
 
+    Location with(Object page, long keyOffsetInPage) {
+      this.isDefined = true;
+      updateAddressesAndSizes(page, keyOffsetInPage);
+      return this;
+    }
+
     /**
      * Returns true if the key is defined at this position, and false otherwise.
      */
@@ -345,6 +405,8 @@ public final class BytesToBytesMap {
      * <p>
      * It is only valid to call this method immediately after calling `lookup()` using the same key.
      * <p>
+     * The key and value must be word-aligned (that is, their sizes must multiples of 8).
+     * <p>
      * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length`
      * will return information on the data stored by this `putNewKey` call.
      * <p>
@@ -370,17 +432,27 @@ public final class BytesToBytesMap {
       isDefined = true;
       assert (keyLengthBytes % 8 == 0);
       assert (valueLengthBytes % 8 == 0);
+      if (size == MAX_CAPACITY) {
+        throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+      }
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
       // (8 byte key length) (key) (8 byte value length) (value)
       final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
-      assert(requiredSize <= PAGE_SIZE_BYTES);
+      assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker.
       size++;
       bitset.set(pos);
 
-      // If there's not enough space in the current page, allocate a new page:
-      if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) {
+      // If there's not enough space in the current page, allocate a new page (8 bytes are reserved
+      // for the end-of-page marker).
+      if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) {
+        if (currentDataPage != null) {
+          // There wasn't enough space in the current page, so write an end-of-page marker:
+          final Object pageBaseObject = currentDataPage.getBaseObject();
+          final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
+          PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
+        }
         MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES);
         dataPages.add(newPage);
         pageCursor = 0;
@@ -414,7 +486,7 @@ public final class BytesToBytesMap {
       longArray.set(pos * 2 + 1, keyHashcode);
       updateAddressesAndSizes(storedKeyAddress);
       isDefined = true;
-      if (size > growthThreshold) {
+      if (size > growthThreshold && longArray.size() < MAX_CAPACITY) {
         growAndRehash();
       }
     }
@@ -427,8 +499,11 @@ public final class BytesToBytesMap {
    * @param capacity the new map capacity
    */
   private void allocate(int capacity) {
-    capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64);
-    longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2));
+    assert (capacity >= 0);
+    // The capacity needs to be divisible by 64 so that our bit set can be sized properly
+    capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64);
+    assert (capacity <= MAX_CAPACITY);
+    longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2));
     bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
 
     this.growthThreshold = (int) (capacity * loadFactor);
@@ -494,10 +569,16 @@ public final class BytesToBytesMap {
     return numHashCollisions;
   }
 
+  @VisibleForTesting
+  int getNumDataPages() {
+    return dataPages.size();
+  }
+
   /**
    * Grows the size of the hash table and re-hash everything.
    */
-  private void growAndRehash() {
+  @VisibleForTesting
+  void growAndRehash() {
     long resizeStartTime = -1;
     if (enablePerfMetrics) {
       resizeStartTime = System.nanoTime();
@@ -508,7 +589,7 @@ public final class BytesToBytesMap {
     final int oldCapacity = (int) oldBitSet.capacity();
 
     // Allocate the new data structures
-    allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity)));
+    allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
 
     // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
     for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
index 7c321baffe82d5cd31a23fa968e646b44b497c79..20654e4eeaa0288c4dc77e031b33e010ebde4e9a 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
@@ -32,7 +32,9 @@ public interface HashMapGrowthStrategy {
   class Doubling implements HashMapGrowthStrategy {
     @Override
     public int nextCapacity(int currentCapacity) {
-      return currentCapacity * 2;
+      assert (currentCapacity > 0);
+      // Guard against overflow
+      return (currentCapacity * 2 > 0) ? (currentCapacity * 2) : Integer.MAX_VALUE;
     }
   }
 
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 2906ac8abad1a00b34ce2aa02fe711aa306737ab..10881969dbc784d3cebafd52988efa125e33c338 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -44,7 +44,7 @@ import org.slf4j.LoggerFactory;
  * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is
  * approximately 35 terabytes of memory.
  */
-public final class TaskMemoryManager {
+public class TaskMemoryManager {
 
   private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
 
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 7a5c0622d1ffbabd08cd61722548d319e98d1ba7..81315f7c946450557547fb34f712100a9b814259 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -25,24 +25,40 @@ import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.mockito.AdditionalMatchers.geq;
+import static org.mockito.Mockito.*;
 
 import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.*;
 import org.apache.spark.unsafe.PlatformDependent;
 import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET;
+
 
 public abstract class AbstractBytesToBytesMapSuite {
 
   private final Random rand = new Random(42);
 
   private TaskMemoryManager memoryManager;
+  private TaskMemoryManager sizeLimitedMemoryManager;
 
   @Before
   public void setup() {
     memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
+    // Mocked memory manager for tests that check the maximum array size, since actually allocating
+    // such large arrays will cause us to run out of memory in our tests.
+    sizeLimitedMemoryManager = spy(memoryManager);
+    when(sizeLimitedMemoryManager.allocate(geq(1L << 20))).thenAnswer(new Answer<MemoryBlock>() {
+      @Override
+      public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
+        if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
+          throw new OutOfMemoryError("Requested array size exceeds VM limit");
+        }
+        return memoryManager.allocate(1L << 20);
+      }
+    });
   }
 
   @After
@@ -101,6 +117,7 @@ public abstract class AbstractBytesToBytesMapSuite {
       final int keyLengthInBytes = keyLengthInWords * 8;
       final byte[] key = getRandomByteArray(keyLengthInWords);
       Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined());
+      Assert.assertFalse(map.iterator().hasNext());
     } finally {
       map.free();
     }
@@ -159,7 +176,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void iteratorTest() throws Exception {
-    final int size = 128;
+    final int size = 4096;
     BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2);
     try {
       for (long i = 0; i < size; i++) {
@@ -167,14 +184,26 @@ public abstract class AbstractBytesToBytesMapSuite {
         final BytesToBytesMap.Location loc =
           map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8);
         Assert.assertFalse(loc.isDefined());
-        loc.putNewKey(
-          value,
-          PlatformDependent.LONG_ARRAY_OFFSET,
-          8,
-          value,
-          PlatformDependent.LONG_ARRAY_OFFSET,
-          8
-        );
+        // Ensure that we store some zero-length keys
+        if (i % 5 == 0) {
+          loc.putNewKey(
+            null,
+            PlatformDependent.LONG_ARRAY_OFFSET,
+            0,
+            value,
+            PlatformDependent.LONG_ARRAY_OFFSET,
+            8
+          );
+        } else {
+          loc.putNewKey(
+            value,
+            PlatformDependent.LONG_ARRAY_OFFSET,
+            8,
+            value,
+            PlatformDependent.LONG_ARRAY_OFFSET,
+            8
+          );
+        }
       }
       final java.util.BitSet valuesSeen = new java.util.BitSet(size);
       final Iterator<BytesToBytesMap.Location> iter = map.iterator();
@@ -183,11 +212,16 @@ public abstract class AbstractBytesToBytesMapSuite {
         Assert.assertTrue(loc.isDefined());
         final MemoryLocation keyAddress = loc.getKeyAddress();
         final MemoryLocation valueAddress = loc.getValueAddress();
-        final long key =  PlatformDependent.UNSAFE.getLong(
-          keyAddress.getBaseObject(), keyAddress.getBaseOffset());
         final long value = PlatformDependent.UNSAFE.getLong(
           valueAddress.getBaseObject(), valueAddress.getBaseOffset());
-        Assert.assertEquals(key, value);
+        final long keyLength = loc.getKeyLength();
+        if (keyLength == 0) {
+          Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
+        } else {
+        final long key = PlatformDependent.UNSAFE.getLong(
+          keyAddress.getBaseObject(), keyAddress.getBaseOffset());
+          Assert.assertEquals(value, key);
+        }
         valuesSeen.set((int) value);
       }
       Assert.assertEquals(size, valuesSeen.cardinality());
@@ -196,6 +230,74 @@ public abstract class AbstractBytesToBytesMapSuite {
     }
   }
 
+  @Test
+  public void iteratingOverDataPagesWithWastedSpace() throws Exception {
+    final int NUM_ENTRIES = 1000 * 1000;
+    final int KEY_LENGTH = 16;
+    final int VALUE_LENGTH = 40;
+    final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES);
+    // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
+    // pages won't be evenly-divisible by records of this size, which will cause us to waste some
+    // space at the end of the page. This is necessary in order for us to take the end-of-record
+    // handling branch in iterator().
+    try {
+      for (int i = 0; i < NUM_ENTRIES; i++) {
+        final long[] key = new long[] { i, i };  // 2 * 8 = 16 bytes
+        final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes
+        final BytesToBytesMap.Location loc = map.lookup(
+          key,
+          LONG_ARRAY_OFFSET,
+          KEY_LENGTH
+        );
+        Assert.assertFalse(loc.isDefined());
+        loc.putNewKey(
+          key,
+          LONG_ARRAY_OFFSET,
+          KEY_LENGTH,
+          value,
+          LONG_ARRAY_OFFSET,
+          VALUE_LENGTH
+        );
+      }
+      Assert.assertEquals(2, map.getNumDataPages());
+
+      final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES);
+      final Iterator<BytesToBytesMap.Location> iter = map.iterator();
+      final long key[] = new long[KEY_LENGTH / 8];
+      final long value[] = new long[VALUE_LENGTH / 8];
+      while (iter.hasNext()) {
+        final BytesToBytesMap.Location loc = iter.next();
+        Assert.assertTrue(loc.isDefined());
+        Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
+        Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
+        PlatformDependent.copyMemory(
+          loc.getKeyAddress().getBaseObject(),
+          loc.getKeyAddress().getBaseOffset(),
+          key,
+          LONG_ARRAY_OFFSET,
+          KEY_LENGTH
+        );
+        PlatformDependent.copyMemory(
+          loc.getValueAddress().getBaseObject(),
+          loc.getValueAddress().getBaseOffset(),
+          value,
+          LONG_ARRAY_OFFSET,
+          VALUE_LENGTH
+        );
+        for (long j : key) {
+          Assert.assertEquals(key[0], j);
+        }
+        for (long j : value) {
+          Assert.assertEquals(key[0], j);
+        }
+        valuesSeen.set((int) key[0]);
+      }
+      Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality());
+    } finally {
+      map.free();
+    }
+  }
+
   @Test
   public void randomizedStressTest() {
     final int size = 65536;
@@ -247,4 +349,35 @@ public abstract class AbstractBytesToBytesMapSuite {
       map.free();
     }
   }
+
+  @Test
+  public void initialCapacityBoundsChecking() {
+    try {
+      new BytesToBytesMap(sizeLimitedMemoryManager, 0);
+      Assert.fail("Expected IllegalArgumentException to be thrown");
+    } catch (IllegalArgumentException e) {
+      // expected exception
+    }
+
+    try {
+      new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1);
+      Assert.fail("Expected IllegalArgumentException to be thrown");
+    } catch (IllegalArgumentException e) {
+      // expected exception
+    }
+
+   // Can allocate _at_ the max capacity
+    BytesToBytesMap map =
+      new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY);
+    map.free();
+  }
+
+  @Test
+  public void resizingLargeMap() {
+    // As long as a map's capacity is below the max, we should be able to resize up to the max
+    BytesToBytesMap map =
+      new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64);
+    map.growAndRehash();
+    map.free();
+  }
 }