diff --git a/core/pom.xml b/core/pom.xml
index 262a3320db106e7ebc57e7c0fa5c5f22873a1397..bfa49d0d6dc2533228db91f0f448788ab3373a23 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -361,6 +361,16 @@
       <artifactId>junit</artifactId>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.hamcrest</groupId>
+      <artifactId>hamcrest-core</artifactId>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.hamcrest</groupId>
+      <artifactId>hamcrest-library</artifactId>
+      <scope>test</scope>
+    </dependency>
     <dependency>
       <groupId>com.novocode</groupId>
       <artifactId>junit-interface</artifactId>
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
new file mode 100644
index 0000000000000000000000000000000000000000..3f746b886bc9b0c3dde792a50e1a5fcd2c70d007
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+
+import scala.reflect.ClassTag;
+
+import org.apache.spark.serializer.DeserializationStream;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ * Our shuffle write path doesn't actually use this serializer (since we end up calling the
+ * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ * around this, we pass a dummy no-op serializer.
+ */
+final class DummySerializerInstance extends SerializerInstance {
+
+  public static final DummySerializerInstance INSTANCE = new DummySerializerInstance();
+
+  private DummySerializerInstance() { }
+
+  @Override
+  public SerializationStream serializeStream(final OutputStream s) {
+    return new SerializationStream() {
+      @Override
+      public void flush() {
+        // Need to implement this because DiskObjectWriter uses it to flush the compression stream
+        try {
+          s.flush();
+        } catch (IOException e) {
+          PlatformDependent.throwException(e);
+        }
+      }
+
+      @Override
+      public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public void close() {
+        // Need to implement this because DiskObjectWriter uses it to close the compression stream
+        try {
+          s.close();
+        } catch (IOException e) {
+          PlatformDependent.throwException(e);
+        }
+      }
+    };
+  }
+
+  @Override
+  public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public DeserializationStream deserializeStream(InputStream s) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public <T> T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag<T> ev1) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) {
+    throw new UnsupportedOperationException();
+  }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
new file mode 100644
index 0000000000000000000000000000000000000000..4ee6a82c0423ebadb0dbad8fa5cb990cb2baf8b8
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+/**
+ * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
+ * <p>
+ * Within the long, the data is laid out as follows:
+ * <pre>
+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * </pre>
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
+ * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
+ * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
+ * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
+ * <p>
+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
+ * optimization to future work as it will require more careful design to ensure that addresses are
+ * properly aligned (e.g. by padding records).
+ */
+final class PackedRecordPointer {
+
+  static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27;  // 128 megabytes
+
+  /**
+   * The maximum partition identifier that can be encoded. Note that partition ids start from 0.
+   */
+  static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1;  // 16777215
+
+  /** Bit mask for the lower 40 bits of a long. */
+  private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
+
+  /** Bit mask for the upper 24 bits of a long */
+  private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS;
+
+  /** Bit mask for the lower 27 bits of a long. */
+  private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1;
+
+  /** Bit mask for the lower 51 bits of a long. */
+  private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1;
+
+  /** Bit mask for the upper 13 bits of a long */
+  private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+  /**
+   * Pack a record address and partition id into a single word.
+   *
+   * @param recordPointer a record pointer encoded by TaskMemoryManager.
+   * @param partitionId a shuffle partition id (maximum value of 2^24).
+   * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class.
+   */
+  public static long packPointer(long recordPointer, int partitionId) {
+    assert (partitionId <= MAXIMUM_PARTITION_ID);
+    // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
+    // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
+    final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
+    final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
+    return (((long) partitionId) << 40) | compressedAddress;
+  }
+
+  private long packedRecordPointer;
+
+  public void set(long packedRecordPointer) {
+    this.packedRecordPointer = packedRecordPointer;
+  }
+
+  public int getPartitionId() {
+    return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
+  }
+
+  public long getRecordPointer() {
+    final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
+    final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
+    return pageNumber | offsetInPage;
+  }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
new file mode 100644
index 0000000000000000000000000000000000000000..7bac0dc0bbeb6da15eb1cc6f7718fe6af08986a3
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+
+import org.apache.spark.storage.TempShuffleBlockId;
+
+/**
+ * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
+ */
+final class SpillInfo {
+  final long[] partitionLengths;
+  final File file;
+  final TempShuffleBlockId blockId;
+
+  public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
+    this.partitionLengths = new long[numPartitions];
+    this.file = file;
+    this.blockId = blockId;
+  }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
new file mode 100644
index 0000000000000000000000000000000000000000..9e9ed94b7890cd5f7acb7ad1a60ca7b045c011a6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.LinkedList;
+
+import scala.Tuple2;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * An external sorter that is specialized for sort-based shuffle.
+ * <p>
+ * Incoming records are appended to data pages. When all records have been inserted (or when the
+ * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
+ * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
+ * written to a single output file (or multiple files, if we've spilled). The format of the output
+ * files is the same as the format of the final output file written by
+ * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
+ * written as a single serialized, compressed stream that can be read with a new decompression and
+ * deserialization stream.
+ * <p>
+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
+ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
+ * specialized merge procedure that avoids extra serialization/deserialization.
+ */
+final class UnsafeShuffleExternalSorter {
+
+  private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
+
+  private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
+  @VisibleForTesting
+  static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+  @VisibleForTesting
+  static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
+
+  private final int initialSize;
+  private final int numPartitions;
+  private final TaskMemoryManager memoryManager;
+  private final ShuffleMemoryManager shuffleMemoryManager;
+  private final BlockManager blockManager;
+  private final TaskContext taskContext;
+  private final ShuffleWriteMetrics writeMetrics;
+
+  /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+  private final int fileBufferSizeBytes;
+
+  /**
+   * Memory pages that hold the records being sorted. The pages in this list are freed when
+   * spilling, although in principle we could recycle these pages across spills (on the other hand,
+   * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+   * itself).
+   */
+  private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
+
+  private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
+
+  // These variables are reset after spilling:
+  private UnsafeShuffleInMemorySorter sorter;
+  private MemoryBlock currentPage = null;
+  private long currentPagePosition = -1;
+  private long freeSpaceInCurrentPage = 0;
+
+  public UnsafeShuffleExternalSorter(
+      TaskMemoryManager memoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      BlockManager blockManager,
+      TaskContext taskContext,
+      int initialSize,
+      int numPartitions,
+      SparkConf conf,
+      ShuffleWriteMetrics writeMetrics) throws IOException {
+    this.memoryManager = memoryManager;
+    this.shuffleMemoryManager = shuffleMemoryManager;
+    this.blockManager = blockManager;
+    this.taskContext = taskContext;
+    this.initialSize = initialSize;
+    this.numPartitions = numPartitions;
+    // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+    this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+
+    this.writeMetrics = writeMetrics;
+    initializeForWriting();
+  }
+
+  /**
+   * Allocates new sort data structures. Called when creating the sorter and after each spill.
+   */
+  private void initializeForWriting() throws IOException {
+    // TODO: move this sizing calculation logic into a static method of sorter:
+    final long memoryRequested = initialSize * 8L;
+    final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+    if (memoryAcquired != memoryRequested) {
+      shuffleMemoryManager.release(memoryAcquired);
+      throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+    }
+
+    this.sorter = new UnsafeShuffleInMemorySorter(initialSize);
+  }
+
+  /**
+   * Sorts the in-memory records and writes the sorted records to an on-disk file.
+   * This method does not free the sort data structures.
+   *
+   * @param isLastFile if true, this indicates that we're writing the final output file and that the
+   *                   bytes written should be counted towards shuffle spill metrics rather than
+   *                   shuffle write metrics.
+   */
+  private void writeSortedFile(boolean isLastFile) throws IOException {
+
+    final ShuffleWriteMetrics writeMetricsToUse;
+
+    if (isLastFile) {
+      // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
+      writeMetricsToUse = writeMetrics;
+    } else {
+      // We're spilling, so bytes written should be counted towards spill rather than write.
+      // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
+      // them towards shuffle bytes written.
+      writeMetricsToUse = new ShuffleWriteMetrics();
+    }
+
+    // This call performs the actual sort.
+    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
+      sorter.getSortedIterator();
+
+    // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
+    // after SPARK-5581 is fixed.
+    BlockObjectWriter writer;
+
+    // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+    // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+    // data through a byte array. This array does not need to be large enough to hold a single
+    // record;
+    final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+    // Because this output will be read during shuffle, its compression codec must be controlled by
+    // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+    // createTempShuffleBlock here; see SPARK-3426 for more details.
+    final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
+      blockManager.diskBlockManager().createTempShuffleBlock();
+    final File file = spilledFileInfo._2();
+    final TempShuffleBlockId blockId = spilledFileInfo._1();
+    final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
+
+    // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+    // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+    // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+    // around this, we pass a dummy no-op serializer.
+    final SerializerInstance ser = DummySerializerInstance.INSTANCE;
+
+    writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+
+    int currentPartition = -1;
+    while (sortedRecords.hasNext()) {
+      sortedRecords.loadNext();
+      final int partition = sortedRecords.packedRecordPointer.getPartitionId();
+      assert (partition >= currentPartition);
+      if (partition != currentPartition) {
+        // Switch to the new partition
+        if (currentPartition != -1) {
+          writer.commitAndClose();
+          spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+        }
+        currentPartition = partition;
+        writer =
+          blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+      }
+
+      final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
+      final Object recordPage = memoryManager.getPage(recordPointer);
+      final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer);
+      int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage);
+      long recordReadPosition = recordOffsetInPage + 4; // skip over record length
+      while (dataRemaining > 0) {
+        final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
+        PlatformDependent.copyMemory(
+          recordPage,
+          recordReadPosition,
+          writeBuffer,
+          PlatformDependent.BYTE_ARRAY_OFFSET,
+          toTransfer);
+        writer.write(writeBuffer, 0, toTransfer);
+        recordReadPosition += toTransfer;
+        dataRemaining -= toTransfer;
+      }
+      writer.recordWritten();
+    }
+
+    if (writer != null) {
+      writer.commitAndClose();
+      // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
+      // then the file might be empty. Note that it might be better to avoid calling
+      // writeSortedFile() in that case.
+      if (currentPartition != -1) {
+        spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+        spills.add(spillInfo);
+      }
+    }
+
+    if (!isLastFile) {  // i.e. this is a spill file
+      // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
+      // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
+      // relies on its `recordWritten()` method being called in order to trigger periodic updates to
+      // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
+      // counter at a higher-level, then the in-progress metrics for records written and bytes
+      // written would get out of sync.
+      //
+      // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
+      // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
+      // metrics to the true write metrics here. The reason for performing this copying is so that
+      // we can avoid reporting spilled bytes as shuffle write bytes.
+      //
+      // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
+      // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
+      // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
+      writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
+      taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
+    }
+  }
+
+  /**
+   * Sort and spill the current records in response to memory pressure.
+   */
+  @VisibleForTesting
+  void spill() throws IOException {
+    logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+      Thread.currentThread().getId(),
+      Utils.bytesToString(getMemoryUsage()),
+      spills.size(),
+      spills.size() > 1 ? " times" : " time");
+
+    writeSortedFile(false);
+    final long sorterMemoryUsage = sorter.getMemoryUsage();
+    sorter = null;
+    shuffleMemoryManager.release(sorterMemoryUsage);
+    final long spillSize = freeMemory();
+    taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
+    initializeForWriting();
+  }
+
+  private long getMemoryUsage() {
+    return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+  }
+
+  private long freeMemory() {
+    long memoryFreed = 0;
+    for (MemoryBlock block : allocatedPages) {
+      memoryManager.freePage(block);
+      shuffleMemoryManager.release(block.size());
+      memoryFreed += block.size();
+    }
+    allocatedPages.clear();
+    currentPage = null;
+    currentPagePosition = -1;
+    freeSpaceInCurrentPage = 0;
+    return memoryFreed;
+  }
+
+  /**
+   * Force all memory and spill files to be deleted; called by shuffle error-handling code.
+   */
+  public void cleanupAfterError() {
+    freeMemory();
+    for (SpillInfo spill : spills) {
+      if (spill.file.exists() && !spill.file.delete()) {
+        logger.error("Unable to delete spill file {}", spill.file.getPath());
+      }
+    }
+    if (sorter != null) {
+      shuffleMemoryManager.release(sorter.getMemoryUsage());
+      sorter = null;
+    }
+  }
+
+  /**
+   * Checks whether there is enough space to insert a new record into the sorter.
+   *
+   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   *                      the record size.
+
+   * @return true if the record can be inserted without requiring more allocations, false otherwise.
+   */
+  private boolean haveSpaceForRecord(int requiredSpace) {
+    assert (requiredSpace > 0);
+    return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+  }
+
+  /**
+   * Allocates more memory in order to insert an additional record. This will request additional
+   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+   * obtained.
+   *
+   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   *                      the record size.
+   */
+  private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+    if (!sorter.hasSpaceForAnotherRecord()) {
+      logger.debug("Attempting to expand sort pointer array");
+      final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+      final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+      final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+      if (memoryAcquired < memoryToGrowPointerArray) {
+        shuffleMemoryManager.release(memoryAcquired);
+        spill();
+      } else {
+        sorter.expandPointerArray();
+        shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+      }
+    }
+    if (requiredSpace > freeSpaceInCurrentPage) {
+      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+        freeSpaceInCurrentPage);
+      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+      // without using the free space at the end of the current page. We should also do this for
+      // BytesToBytesMap.
+      if (requiredSpace > PAGE_SIZE) {
+        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+          PAGE_SIZE + ")");
+      } else {
+        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+        if (memoryAcquired < PAGE_SIZE) {
+          shuffleMemoryManager.release(memoryAcquired);
+          spill();
+          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+          if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+            shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+            throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+          }
+        }
+        currentPage = memoryManager.allocatePage(PAGE_SIZE);
+        currentPagePosition = currentPage.getBaseOffset();
+        freeSpaceInCurrentPage = PAGE_SIZE;
+        allocatedPages.add(currentPage);
+      }
+    }
+  }
+
+  /**
+   * Write a record to the shuffle sorter.
+   */
+  public void insertRecord(
+      Object recordBaseObject,
+      long recordBaseOffset,
+      int lengthInBytes,
+      int partitionId) throws IOException {
+    // Need 4 bytes to store the record length.
+    final int totalSpaceRequired = lengthInBytes + 4;
+    if (!haveSpaceForRecord(totalSpaceRequired)) {
+      allocateSpaceForRecord(totalSpaceRequired);
+    }
+
+    final long recordAddress =
+      memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+    final Object dataPageBaseObject = currentPage.getBaseObject();
+    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
+    currentPagePosition += 4;
+    freeSpaceInCurrentPage -= 4;
+    PlatformDependent.copyMemory(
+      recordBaseObject,
+      recordBaseOffset,
+      dataPageBaseObject,
+      currentPagePosition,
+      lengthInBytes);
+    currentPagePosition += lengthInBytes;
+    freeSpaceInCurrentPage -= lengthInBytes;
+    sorter.insertRecord(recordAddress, partitionId);
+  }
+
+  /**
+   * Close the sorter, causing any buffered data to be sorted and written out to disk.
+   *
+   * @return metadata for the spill files written by this sorter. If no records were ever inserted
+   *         into this sorter, then this will return an empty array.
+   * @throws IOException
+   */
+  public SpillInfo[] closeAndGetSpills() throws IOException {
+    try {
+      if (sorter != null) {
+        // Do not count the final file towards the spill count.
+        writeSortedFile(true);
+        freeMemory();
+      }
+      return spills.toArray(new SpillInfo[spills.size()]);
+    } catch (IOException e) {
+      cleanupAfterError();
+      throw e;
+    }
+  }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
new file mode 100644
index 0000000000000000000000000000000000000000..5bab501da9364430b2b647e9397931e3aeaba9fc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Comparator;
+
+import org.apache.spark.util.collection.Sorter;
+
+final class UnsafeShuffleInMemorySorter {
+
+  private final Sorter<PackedRecordPointer, long[]> sorter;
+  private static final class SortComparator implements Comparator<PackedRecordPointer> {
+    @Override
+    public int compare(PackedRecordPointer left, PackedRecordPointer right) {
+      return left.getPartitionId() - right.getPartitionId();
+    }
+  }
+  private static final SortComparator SORT_COMPARATOR = new SortComparator();
+
+  /**
+   * An array of record pointers and partition ids that have been encoded by
+   * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
+   * records.
+   */
+  private long[] pointerArray;
+
+  /**
+   * The position in the pointer array where new records can be inserted.
+   */
+  private int pointerArrayInsertPosition = 0;
+
+  public UnsafeShuffleInMemorySorter(int initialSize) {
+    assert (initialSize > 0);
+    this.pointerArray = new long[initialSize];
+    this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
+  }
+
+  public void expandPointerArray() {
+    final long[] oldArray = pointerArray;
+    // Guard against overflow:
+    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+    pointerArray = new long[newLength];
+    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+  }
+
+  public boolean hasSpaceForAnotherRecord() {
+    return pointerArrayInsertPosition + 1 < pointerArray.length;
+  }
+
+  public long getMemoryUsage() {
+    return pointerArray.length * 8L;
+  }
+
+  /**
+   * Inserts a record to be sorted.
+   *
+   * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to
+   *                      certain pointer compression techniques used by the sorter, the sort can
+   *                      only operate on pointers that point to locations in the first
+   *                      {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page.
+   * @param partitionId the partition id, which must be less than or equal to
+   *                    {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}.
+   */
+  public void insertRecord(long recordPointer, int partitionId) {
+    if (!hasSpaceForAnotherRecord()) {
+      if (pointerArray.length == Integer.MAX_VALUE) {
+        throw new IllegalStateException("Sort pointer array has reached maximum size");
+      } else {
+        expandPointerArray();
+      }
+    }
+    pointerArray[pointerArrayInsertPosition] =
+        PackedRecordPointer.packPointer(recordPointer, partitionId);
+    pointerArrayInsertPosition++;
+  }
+
+  /**
+   * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
+   */
+  public static final class UnsafeShuffleSorterIterator {
+
+    private final long[] pointerArray;
+    private final int numRecords;
+    final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
+    private int position = 0;
+
+    public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
+      this.numRecords = numRecords;
+      this.pointerArray = pointerArray;
+    }
+
+    public boolean hasNext() {
+      return position < numRecords;
+    }
+
+    public void loadNext() {
+      packedRecordPointer.set(pointerArray[position]);
+      position++;
+    }
+  }
+
+  /**
+   * Return an iterator over record pointers in sorted order.
+   */
+  public UnsafeShuffleSorterIterator getSortedIterator() {
+    sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
+    return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+  }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
new file mode 100644
index 0000000000000000000000000000000000000000..a66d74ee44782ccab0d6a239a01acccfbf42b015
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
+
+  public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
+
+  private UnsafeShuffleSortDataFormat() { }
+
+  @Override
+  public PackedRecordPointer getKey(long[] data, int pos) {
+    // Since we re-use keys, this method shouldn't be called.
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public PackedRecordPointer newKey() {
+    return new PackedRecordPointer();
+  }
+
+  @Override
+  public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
+    reuse.set(data[pos]);
+    return reuse;
+  }
+
+  @Override
+  public void swap(long[] data, int pos0, int pos1) {
+    final long temp = data[pos0];
+    data[pos0] = data[pos1];
+    data[pos1] = temp;
+  }
+
+  @Override
+  public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+    dst[dstPos] = src[srcPos];
+  }
+
+  @Override
+  public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+    System.arraycopy(src, srcPos, dst, dstPos, length);
+  }
+
+  @Override
+  public long[] allocate(int length) {
+    return new long[length];
+  }
+
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
new file mode 100644
index 0000000000000000000000000000000000000000..ad7eb04afcd8c44bc7295f8c6b8e44618d391072
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.channels.FileChannel;
+import java.util.Iterator;
+import javax.annotation.Nullable;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.JavaConversions;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.*;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.TimeTrackingOutputStream;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+@Private
+public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
+
+  private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
+
+  private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
+
+  @VisibleForTesting
+  static final int INITIAL_SORT_BUFFER_SIZE = 4096;
+
+  private final BlockManager blockManager;
+  private final IndexShuffleBlockResolver shuffleBlockResolver;
+  private final TaskMemoryManager memoryManager;
+  private final ShuffleMemoryManager shuffleMemoryManager;
+  private final SerializerInstance serializer;
+  private final Partitioner partitioner;
+  private final ShuffleWriteMetrics writeMetrics;
+  private final int shuffleId;
+  private final int mapId;
+  private final TaskContext taskContext;
+  private final SparkConf sparkConf;
+  private final boolean transferToEnabled;
+
+  private MapStatus mapStatus = null;
+  private UnsafeShuffleExternalSorter sorter = null;
+
+  /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
+  private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
+    public MyByteArrayOutputStream(int size) { super(size); }
+    public byte[] getBuf() { return buf; }
+  }
+
+  private MyByteArrayOutputStream serBuffer;
+  private SerializationStream serOutputStream;
+
+  /**
+   * Are we in the process of stopping? Because map tasks can call stop() with success = true
+   * and then call stop() with success = false if they get an exception, we want to make sure
+   * we don't try deleting files, etc twice.
+   */
+  private boolean stopping = false;
+
+  public UnsafeShuffleWriter(
+      BlockManager blockManager,
+      IndexShuffleBlockResolver shuffleBlockResolver,
+      TaskMemoryManager memoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      UnsafeShuffleHandle<K, V> handle,
+      int mapId,
+      TaskContext taskContext,
+      SparkConf sparkConf) throws IOException {
+    final int numPartitions = handle.dependency().partitioner().numPartitions();
+    if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
+      throw new IllegalArgumentException(
+        "UnsafeShuffleWriter can only be used for shuffles with at most " +
+          UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions");
+    }
+    this.blockManager = blockManager;
+    this.shuffleBlockResolver = shuffleBlockResolver;
+    this.memoryManager = memoryManager;
+    this.shuffleMemoryManager = shuffleMemoryManager;
+    this.mapId = mapId;
+    final ShuffleDependency<K, V, V> dep = handle.dependency();
+    this.shuffleId = dep.shuffleId();
+    this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
+    this.partitioner = dep.partitioner();
+    this.writeMetrics = new ShuffleWriteMetrics();
+    taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+    this.taskContext = taskContext;
+    this.sparkConf = sparkConf;
+    this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
+    open();
+  }
+
+  /**
+   * This convenience method should only be called in test code.
+   */
+  @VisibleForTesting
+  public void write(Iterator<Product2<K, V>> records) throws IOException {
+    write(JavaConversions.asScalaIterator(records));
+  }
+
+  @Override
+  public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
+    boolean success = false;
+    try {
+      while (records.hasNext()) {
+        insertRecordIntoSorter(records.next());
+      }
+      closeAndWriteOutput();
+      success = true;
+    } finally {
+      if (!success) {
+        sorter.cleanupAfterError();
+      }
+    }
+  }
+
+  private void open() throws IOException {
+    assert (sorter == null);
+    sorter = new UnsafeShuffleExternalSorter(
+      memoryManager,
+      shuffleMemoryManager,
+      blockManager,
+      taskContext,
+      INITIAL_SORT_BUFFER_SIZE,
+      partitioner.numPartitions(),
+      sparkConf,
+      writeMetrics);
+    serBuffer = new MyByteArrayOutputStream(1024 * 1024);
+    serOutputStream = serializer.serializeStream(serBuffer);
+  }
+
+  @VisibleForTesting
+  void closeAndWriteOutput() throws IOException {
+    serBuffer = null;
+    serOutputStream = null;
+    final SpillInfo[] spills = sorter.closeAndGetSpills();
+    sorter = null;
+    final long[] partitionLengths;
+    try {
+      partitionLengths = mergeSpills(spills);
+    } finally {
+      for (SpillInfo spill : spills) {
+        if (spill.file.exists() && ! spill.file.delete()) {
+          logger.error("Error while deleting spill file {}", spill.file.getPath());
+        }
+      }
+    }
+    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+  }
+
+  @VisibleForTesting
+  void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
+    final K key = record._1();
+    final int partitionId = partitioner.getPartition(key);
+    serBuffer.reset();
+    serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
+    serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
+    serOutputStream.flush();
+
+    final int serializedRecordSize = serBuffer.size();
+    assert (serializedRecordSize > 0);
+
+    sorter.insertRecord(
+      serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
+  }
+
+  @VisibleForTesting
+  void forceSorterToSpill() throws IOException {
+    assert (sorter != null);
+    sorter.spill();
+  }
+
+  /**
+   * Merge zero or more spill files together, choosing the fastest merging strategy based on the
+   * number of spills and the IO compression codec.
+   *
+   * @return the partition lengths in the merged file.
+   */
+  private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+    final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+    final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
+    final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
+    final boolean fastMergeEnabled =
+      sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
+    final boolean fastMergeIsSupported =
+      !compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
+    try {
+      if (spills.length == 0) {
+        new FileOutputStream(outputFile).close(); // Create an empty file
+        return new long[partitioner.numPartitions()];
+      } else if (spills.length == 1) {
+        // Here, we don't need to perform any metrics updates because the bytes written to this
+        // output file would have already been counted as shuffle bytes written.
+        Files.move(spills[0].file, outputFile);
+        return spills[0].partitionLengths;
+      } else {
+        final long[] partitionLengths;
+        // There are multiple spills to merge, so none of these spill files' lengths were counted
+        // towards our shuffle write count or shuffle write time. If we use the slow merge path,
+        // then the final output file's size won't necessarily be equal to the sum of the spill
+        // files' sizes. To guard against this case, we look at the output file's actual size when
+        // computing shuffle bytes written.
+        //
+        // We allow the individual merge methods to report their own IO times since different merge
+        // strategies use different IO techniques.  We count IO during merge towards the shuffle
+        // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
+        // branch in ExternalSorter.
+        if (fastMergeEnabled && fastMergeIsSupported) {
+          // Compression is disabled or we are using an IO compression codec that supports
+          // decompression of concatenated compressed streams, so we can perform a fast spill merge
+          // that doesn't need to interpret the spilled bytes.
+          if (transferToEnabled) {
+            logger.debug("Using transferTo-based fast merge");
+            partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+          } else {
+            logger.debug("Using fileStream-based fast merge");
+            partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+          }
+        } else {
+          logger.debug("Using slow merge");
+          partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+        }
+        // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
+        // in-memory records, we write out the in-memory records to a file but do not count that
+        // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
+        // to be counted as shuffle write, but this will lead to double-counting of the final
+        // SpillInfo's bytes.
+        writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length());
+        writeMetrics.incShuffleBytesWritten(outputFile.length());
+        return partitionLengths;
+      }
+    } catch (IOException e) {
+      if (outputFile.exists() && !outputFile.delete()) {
+        logger.error("Unable to delete output file {}", outputFile.getPath());
+      }
+      throw e;
+    }
+  }
+
+  /**
+   * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
+   * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
+   * cases where the IO compression codec does not support concatenation of compressed data, or in
+   * cases where users have explicitly disabled use of {@code transferTo} in order to work around
+   * kernel bugs.
+   *
+   * @param spills the spills to merge.
+   * @param outputFile the file to write the merged data to.
+   * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
+   * @return the partition lengths in the merged file.
+   */
+  private long[] mergeSpillsWithFileStream(
+      SpillInfo[] spills,
+      File outputFile,
+      @Nullable CompressionCodec compressionCodec) throws IOException {
+    assert (spills.length >= 2);
+    final int numPartitions = partitioner.numPartitions();
+    final long[] partitionLengths = new long[numPartitions];
+    final InputStream[] spillInputStreams = new FileInputStream[spills.length];
+    OutputStream mergedFileOutputStream = null;
+
+    boolean threwException = true;
+    try {
+      for (int i = 0; i < spills.length; i++) {
+        spillInputStreams[i] = new FileInputStream(spills[i].file);
+      }
+      for (int partition = 0; partition < numPartitions; partition++) {
+        final long initialFileLength = outputFile.length();
+        mergedFileOutputStream =
+          new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
+        if (compressionCodec != null) {
+          mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
+        }
+
+        for (int i = 0; i < spills.length; i++) {
+          final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+          if (partitionLengthInSpill > 0) {
+            InputStream partitionInputStream =
+              new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
+            if (compressionCodec != null) {
+              partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+            }
+            ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+          }
+        }
+        mergedFileOutputStream.flush();
+        mergedFileOutputStream.close();
+        partitionLengths[partition] = (outputFile.length() - initialFileLength);
+      }
+      threwException = false;
+    } finally {
+      // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+      // throw exceptions during cleanup if threwException == false.
+      for (InputStream stream : spillInputStreams) {
+        Closeables.close(stream, threwException);
+      }
+      Closeables.close(mergedFileOutputStream, threwException);
+    }
+    return partitionLengths;
+  }
+
+  /**
+   * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
+   * This is only safe when the IO compression codec and serializer support concatenation of
+   * serialized streams.
+   *
+   * @return the partition lengths in the merged file.
+   */
+  private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
+    assert (spills.length >= 2);
+    final int numPartitions = partitioner.numPartitions();
+    final long[] partitionLengths = new long[numPartitions];
+    final FileChannel[] spillInputChannels = new FileChannel[spills.length];
+    final long[] spillInputChannelPositions = new long[spills.length];
+    FileChannel mergedFileOutputChannel = null;
+
+    boolean threwException = true;
+    try {
+      for (int i = 0; i < spills.length; i++) {
+        spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
+      }
+      // This file needs to opened in append mode in order to work around a Linux kernel bug that
+      // affects transferTo; see SPARK-3948 for more details.
+      mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
+
+      long bytesWrittenToMergedFile = 0;
+      for (int partition = 0; partition < numPartitions; partition++) {
+        for (int i = 0; i < spills.length; i++) {
+          final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+          long bytesToTransfer = partitionLengthInSpill;
+          final FileChannel spillInputChannel = spillInputChannels[i];
+          final long writeStartTime = System.nanoTime();
+          while (bytesToTransfer > 0) {
+            final long actualBytesTransferred = spillInputChannel.transferTo(
+              spillInputChannelPositions[i],
+              bytesToTransfer,
+              mergedFileOutputChannel);
+            spillInputChannelPositions[i] += actualBytesTransferred;
+            bytesToTransfer -= actualBytesTransferred;
+          }
+          writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+          bytesWrittenToMergedFile += partitionLengthInSpill;
+          partitionLengths[partition] += partitionLengthInSpill;
+        }
+      }
+      // Check the position after transferTo loop to see if it is in the right position and raise an
+      // exception if it is incorrect. The position will not be increased to the expected length
+      // after calling transferTo in kernel version 2.6.32. This issue is described at
+      // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
+      if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
+        throw new IOException(
+          "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
+            "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
+            " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
+            "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
+            "to disable this NIO feature."
+        );
+      }
+      threwException = false;
+    } finally {
+      // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+      // throw exceptions during cleanup if threwException == false.
+      for (int i = 0; i < spills.length; i++) {
+        assert(spillInputChannelPositions[i] == spills[i].file.length());
+        Closeables.close(spillInputChannels[i], threwException);
+      }
+      Closeables.close(mergedFileOutputChannel, threwException);
+    }
+    return partitionLengths;
+  }
+
+  @Override
+  public Option<MapStatus> stop(boolean success) {
+    try {
+      if (stopping) {
+        return Option.apply(null);
+      } else {
+        stopping = true;
+        if (success) {
+          if (mapStatus == null) {
+            throw new IllegalStateException("Cannot call stop(true) without having called write()");
+          }
+          return Option.apply(mapStatus);
+        } else {
+          // The map task failed, so delete our output data.
+          shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+          return Option.apply(null);
+        }
+      }
+    } finally {
+      if (sorter != null) {
+        // If sorter is non-null, then this implies that we called stop() in response to an error,
+        // so we need to clean up memory and spill files created by the sorter
+        sorter.cleanupAfterError();
+      }
+    }
+  }
+}
diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
new file mode 100644
index 0000000000000000000000000000000000000000..dc2aa30466cc69c01a452a6741177af4990e960d
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+
+/**
+ * Intercepts write calls and tracks total time spent writing in order to update shuffle write
+ * metrics. Not thread safe.
+ */
+@Private
+public final class TimeTrackingOutputStream extends OutputStream {
+
+  private final ShuffleWriteMetrics writeMetrics;
+  private final OutputStream outputStream;
+
+  public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) {
+    this.writeMetrics = writeMetrics;
+    this.outputStream = outputStream;
+  }
+
+  @Override
+  public void write(int b) throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.write(b);
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+
+  @Override
+  public void write(byte[] b) throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.write(b);
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+
+  @Override
+  public void write(byte[] b, int off, int len) throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.write(b, off, len);
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+
+  @Override
+  public void flush() throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.flush();
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+
+  @Override
+  public void close() throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.close();
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0c4d28f786edda1599b2d32e68fbab10ba0b28be..a5d831c7e68ad504309b5429b05ce6bdddb2f27a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -313,7 +313,8 @@ object SparkEnv extends Logging {
     // Let the user specify short names for shuffle managers
     val shortShuffleMgrNames = Map(
       "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
-      "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
+      "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager",
+      "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager")
     val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
     val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
     val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index dfbde7c8a1b0d99cb356e56c64183655004ae660..698d1384d580d5b335f1c6fdc6158afdb11d3373 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -121,6 +121,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
   private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
   private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true)
 
+  protected def this() = this(new SparkConf())  // For deserialization only
+
   override def newInstance(): SerializerInstance = {
     val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
     new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6ad427bcac7f983b79dfc192435aa5bc762569ae..6c3b3080d26057c2f076127b85348ee3f53597f6 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -76,7 +76,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
   private val consolidateShuffleFiles =
     conf.getBoolean("spark.shuffle.consolidateFiles", false)
 
-  // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided 
+  // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
index f6e6fe5defe09a6aee972b99dcf4494a66878920..4cc4ef5f1886e4b42cc32e505d3cee3d8cc8a6fc 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -17,14 +17,17 @@
 
 package org.apache.spark.shuffle
 
+import java.io.IOException
+
 import org.apache.spark.scheduler.MapStatus
 
 /**
  * Obtained inside a map task to write out records to the shuffle system.
  */
-private[spark] trait ShuffleWriter[K, V] {
+private[spark] abstract class ShuffleWriter[K, V] {
   /** Write a sequence of records to this task's output */
-  def write(records: Iterator[_ <: Product2[K, V]]): Unit
+  @throws[IOException]
+  def write(records: Iterator[Product2[K, V]]): Unit
 
   /** Close this writer, passing along whether the map completed */
   def stop(success: Boolean): Option[MapStatus]
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 897f0a5dc5bcce714d89570c4511470782bc21b6..eb87cee15903c9c98ddcea3f878f492fcdb99336 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -49,7 +49,7 @@ private[spark] class HashShuffleWriter[K, V](
     writeMetrics)
 
   /** Write a bunch of records to this task's output */
-  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+  override def write(records: Iterator[Product2[K, V]]): Unit = {
     val iter = if (dep.aggregator.isDefined) {
       if (dep.mapSideCombine) {
         dep.aggregator.get.combineValuesByKey(records, context)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 15842941daaab0b7c0baefde1e7f091f61c5cd0d..d7fab351ca3b88d29deaca22873739a7242f5a89 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
     true
   }
 
-  override def shuffleBlockResolver: IndexShuffleBlockResolver = {
+  override val shuffleBlockResolver: IndexShuffleBlockResolver = {
     indexShuffleBlockResolver
   }
 
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index add2656294ca2d22cff6e2945ebec8902a21ff5c..c9dd6bfc4c2199a6cebd13ec5d162d94cf5f9b14 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -48,7 +48,7 @@ private[spark] class SortShuffleWriter[K, V, C](
   context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)
 
   /** Write a bunch of records to this task's output */
-  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+  override def write(records: Iterator[Product2[K, V]]): Unit = {
     if (dep.mapSideCombine) {
       require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
       sorter = new ExternalSorter[K, V, C](
diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
new file mode 100644
index 0000000000000000000000000000000000000000..f2bfef376d3ca3addac69648126dc2ee12cc53f4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.util.Collections
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.sort.SortShuffleManager
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
+ */
+private[spark] class UnsafeShuffleHandle[K, V](
+    shuffleId: Int,
+    numMaps: Int,
+    dependency: ShuffleDependency[K, V, V])
+  extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
+
+private[spark] object UnsafeShuffleManager extends Logging {
+
+  /**
+   * The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
+   */
+  val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
+
+  /**
+   * Helper method for determining whether a shuffle should use the optimized unsafe shuffle
+   * path or whether it should fall back to the original sort-based shuffle.
+   */
+  def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
+    val shufId = dependency.shuffleId
+    val serializer = Serializer.getSerializer(dependency.serializer)
+    if (!serializer.supportsRelocationOfSerializedObjects) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
+        s"${serializer.getClass.getName}, does not support object relocation")
+      false
+    } else if (dependency.aggregator.isDefined) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
+      false
+    } else if (dependency.keyOrdering.isDefined) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined")
+      false
+    } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
+        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
+      false
+    } else {
+      log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
+      true
+    }
+  }
+}
+
+/**
+ * A shuffle implementation that uses directly-managed memory to implement several performance
+ * optimizations for certain types of shuffles. In cases where the new performance optimizations
+ * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
+ * shuffles.
+ *
+ * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
+ *
+ *  - The shuffle dependency specifies no aggregation or output ordering.
+ *  - The shuffle serializer supports relocation of serialized values (this is currently supported
+ *    by KryoSerializer and Spark SQL's custom serializers).
+ *  - The shuffle produces fewer than 16777216 output partitions.
+ *  - No individual record is larger than 128 MB when serialized.
+ *
+ * In addition, extra spill-merging optimizations are automatically applied when the shuffle
+ * compression codec supports concatenation of serialized streams. This is currently supported by
+ * Spark's LZF serializer.
+ *
+ * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
+ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
+ * written to a single map output file. Reducers fetch contiguous regions of this file in order to
+ * read their portion of the map output. In cases where the map output data is too large to fit in
+ * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
+ * to produce the final output file.
+ *
+ * UnsafeShuffleManager optimizes this process in several ways:
+ *
+ *  - Its sort operates on serialized binary data rather than Java objects, which reduces memory
+ *    consumption and GC overheads. This optimization requires the record serializer to have certain
+ *    properties to allow serialized records to be re-ordered without requiring deserialization.
+ *    See SPARK-4550, where this optimization was first proposed and implemented, for more details.
+ *
+ *  - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
+ *    arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
+ *    record in the sorting array, this fits more of the array into cache.
+ *
+ *  - The spill merging procedure operates on blocks of serialized records that belong to the same
+ *    partition and does not need to deserialize records during the merge.
+ *
+ *  - When the spill compression codec supports concatenation of compressed data, the spill merge
+ *    simply concatenates the serialized and compressed spill partitions to produce the final output
+ *    partition.  This allows efficient data copying methods, like NIO's `transferTo`, to be used
+ *    and avoids the need to allocate decompression or copying buffers during the merge.
+ *
+ * For more details on UnsafeShuffleManager's design, see SPARK-7081.
+ */
+private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+
+  if (!conf.getBoolean("spark.shuffle.spill", true)) {
+    logWarning(
+      "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
+      "manager; its optimized shuffles will continue to spill to disk when necessary.")
+  }
+
+  private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
+  private[this] val shufflesThatFellBackToSortShuffle =
+    Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
+  private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
+
+  /**
+   * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+   */
+  override def registerShuffle[K, V, C](
+      shuffleId: Int,
+      numMaps: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+    if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
+      new UnsafeShuffleHandle[K, V](
+        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+    } else {
+      new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    }
+  }
+
+  /**
+   * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+   * Called on executors by reduce tasks.
+   */
+  override def getReader[K, C](
+      handle: ShuffleHandle,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext): ShuffleReader[K, C] = {
+    sortShuffleManager.getReader(handle, startPartition, endPartition, context)
+  }
+
+  /** Get a writer for a given partition. Called on executors by map tasks. */
+  override def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Int,
+      context: TaskContext): ShuffleWriter[K, V] = {
+    handle match {
+      case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] =>
+        numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
+        val env = SparkEnv.get
+        new UnsafeShuffleWriter(
+          env.blockManager,
+          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+          context.taskMemoryManager(),
+          env.shuffleMemoryManager,
+          unsafeShuffleHandle,
+          mapId,
+          context,
+          env.conf)
+      case other =>
+        shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
+        sortShuffleManager.getWriter(handle, mapId, context)
+    }
+  }
+
+  /** Remove a shuffle's metadata from the ShuffleManager. */
+  override def unregisterShuffle(shuffleId: Int): Boolean = {
+    if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
+      sortShuffleManager.unregisterShuffle(shuffleId)
+    } else {
+      Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
+        (0 until numMaps).foreach { mapId =>
+          shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+        }
+      }
+      true
+    }
+  }
+
+  override val shuffleBlockResolver: IndexShuffleBlockResolver = {
+    sortShuffleManager.shuffleBlockResolver
+  }
+
+  /** Shut down this ShuffleManager. */
+  override def stop(): Unit = {
+    sortShuffleManager.stop()
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 8bc4e205bc3c6609fed2e59bf8a8206a902a0815..a33f22ef52687bddea0e430ed489a9f49af79e4d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -86,16 +86,6 @@ private[spark] class DiskBlockObjectWriter(
   extends BlockObjectWriter(blockId)
   with Logging
 {
-  /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
-  private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
-    override def write(i: Int): Unit = callWithTiming(out.write(i))
-    override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b))
-    override def write(b: Array[Byte], off: Int, len: Int): Unit = {
-      callWithTiming(out.write(b, off, len))
-    }
-    override def close(): Unit = out.close()
-    override def flush(): Unit = out.flush()
-  }
 
   /** The file channel, used for repositioning / truncating the file. */
   private var channel: FileChannel = null
@@ -136,7 +126,7 @@ private[spark] class DiskBlockObjectWriter(
       throw new IllegalStateException("Writer already closed. Cannot be reopened.")
     }
     fos = new FileOutputStream(file, true)
-    ts = new TimeTrackingOutputStream(fos)
+    ts = new TimeTrackingOutputStream(writeMetrics, fos)
     channel = fos.getChannel()
     bs = compressStream(new BufferedOutputStream(ts, bufferSize))
     objOut = serializerInstance.serializeStream(bs)
@@ -150,9 +140,9 @@ private[spark] class DiskBlockObjectWriter(
         if (syncWrites) {
           // Force outstanding writes to disk and track how long it takes
           objOut.flush()
-          callWithTiming {
-            fos.getFD.sync()
-          }
+          val start = System.nanoTime()
+          fos.getFD.sync()
+          writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
         }
       } {
         objOut.close()
@@ -251,12 +241,6 @@ private[spark] class DiskBlockObjectWriter(
     reportedPosition = pos
   }
 
-  private def callWithTiming(f: => Unit) = {
-    val start = System.nanoTime()
-    f
-    writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
-  }
-
   // For testing
   private[spark] override def flush() {
     objOut.flush()
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index b8509731450777c931b6f2bd58593e436a38ae42..df2d6ad3b41a4a4f59f39ed2b4ef40ab8bee99bd 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -90,7 +90,7 @@ class ExternalAppendOnlyMap[K, V, C](
   // Number of bytes spilled in total
   private var _diskBytesSpilled = 0L
   
-  // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+  // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize = 
     sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
 
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 7d5cf7b61e56a40890c816a26599c16212f9458c..3b9d14f9372b618daf2b7d15de148e33fe9e1049 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -110,7 +110,7 @@ private[spark] class ExternalSorter[K, V, C](
   private val conf = SparkEnv.get.conf
   private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
   
-  // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+  // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
   private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
 
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
new file mode 100644
index 0000000000000000000000000000000000000000..db9e82759090af957f7151fb3f0dffe93252de98
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
+
+public class PackedRecordPointerSuite {
+
+  @Test
+  public void heap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock page0 = memoryManager.allocatePage(100);
+    final MemoryBlock page1 = memoryManager.allocatePage(100);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void offHeap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+    final MemoryBlock page0 = memoryManager.allocatePage(100);
+    final MemoryBlock page1 = memoryManager.allocatePage(100);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void maximumPartitionIdCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
+    assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
+  }
+
+  @Test
+  public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    try {
+      // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
+      packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
+      assertFalse(MAXIMUM_PARTITION_ID  + 1 == packedPointer.getPartitionId());
+    } catch (AssertionError e ) {
+      // pass
+    }
+  }
+
+  @Test
+  public void maximumOffsetInPageCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(address, packedPointer.getRecordPointer());
+  }
+
+  @Test
+  public void offsetsPastMaxOffsetInPageWillOverflow() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(0, packedPointer.getRecordPointer());
+  }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
new file mode 100644
index 0000000000000000000000000000000000000000..8fa72597db24d21f69ff09ad7cbd71f064d31254
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeShuffleInMemorySorterSuite {
+
+  private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
+    final byte[] strBytes = new byte[strLength];
+    PlatformDependent.copyMemory(
+      baseObject,
+      baseOffset,
+      strBytes,
+      PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
+    return new String(strBytes);
+  }
+
+  @Test
+  public void testSortingEmptyInput() {
+    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
+    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    assert(!iter.hasNext());
+  }
+
+  @Test
+  public void testBasicSorting() throws Exception {
+    final String[] dataToSort = new String[] {
+      "Boba",
+      "Pearls",
+      "Tapioca",
+      "Taho",
+      "Condensed Milk",
+      "Jasmine",
+      "Milk Tea",
+      "Lychee",
+      "Mango"
+    };
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+    final Object baseObject = dataPage.getBaseObject();
+    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+    final HashPartitioner hashPartitioner = new HashPartitioner(4);
+
+    // Write the records into the data page and store pointers into the sorter
+    long position = dataPage.getBaseOffset();
+    for (String str : dataToSort) {
+      final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
+      final byte[] strBytes = str.getBytes("utf-8");
+      PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+      position += 4;
+      PlatformDependent.copyMemory(
+        strBytes,
+        PlatformDependent.BYTE_ARRAY_OFFSET,
+        baseObject,
+        position,
+        strBytes.length);
+      position += strBytes.length;
+      sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
+    }
+
+    // Sort the records
+    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    int prevPartitionId = -1;
+    Arrays.sort(dataToSort);
+    for (int i = 0; i < dataToSort.length; i++) {
+      Assert.assertTrue(iter.hasNext());
+      iter.loadNext();
+      final int partitionId = iter.packedRecordPointer.getPartitionId();
+      Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
+      Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
+        partitionId >= prevPartitionId);
+      final long recordAddress = iter.packedRecordPointer.getRecordPointer();
+      final int recordLength = PlatformDependent.UNSAFE.getInt(
+        memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
+      final String str = getStringFromDataPage(
+        memoryManager.getPage(recordAddress),
+        memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
+        recordLength);
+      Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
+    }
+    Assert.assertFalse(iter.hasNext());
+  }
+
+  @Test
+  public void testSortingManyNumbers() throws Exception {
+    UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+    int[] numbersToSort = new int[128000];
+    Random random = new Random(16);
+    for (int i = 0; i < numbersToSort.length; i++) {
+      numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
+      sorter.insertRecord(0, numbersToSort[i]);
+    }
+    Arrays.sort(numbersToSort);
+    int[] sorterResult = new int[numbersToSort.length];
+    UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    int j = 0;
+    while (iter.hasNext()) {
+      iter.loadNext();
+      sorterResult[j] = iter.packedRecordPointer.getPartitionId();
+      j += 1;
+    }
+    Assert.assertArrayEquals(numbersToSort, sorterResult);
+  }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
new file mode 100644
index 0000000000000000000000000000000000000000..730d265c87f882b048caa3c82e34305b2155078c
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import scala.*;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+import scala.runtime.AbstractFunction1;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.io.ByteStreams;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.*;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZ4CompressionCodec;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.io.SnappyCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.serializer.*;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeShuffleWriterSuite {
+
+  static final int NUM_PARTITITONS = 4;
+  final TaskMemoryManager taskMemoryManager =
+    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+  final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
+  File mergedOutputFile;
+  File tempDir;
+  long[] partitionSizesInMergedFile;
+  final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+  SparkConf conf;
+  final Serializer serializer = new KryoSerializer(new SparkConf());
+  TaskMetrics taskMetrics;
+
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
+  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
+
+  private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+    @Override
+    public OutputStream apply(OutputStream stream) {
+      if (conf.getBoolean("spark.shuffle.compress", true)) {
+        return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
+      } else {
+        return stream;
+      }
+    }
+  }
+
+  @After
+  public void tearDown() {
+    Utils.deleteRecursively(tempDir);
+    final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+    if (leakedMemory != 0) {
+      fail("Test leaked " + leakedMemory + " bytes of managed memory");
+    }
+  }
+
+  @Before
+  @SuppressWarnings("unchecked")
+  public void setUp() throws IOException {
+    MockitoAnnotations.initMocks(this);
+    tempDir = Utils.createTempDir("test", "test");
+    mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
+    partitionSizesInMergedFile = null;
+    spillFilesCreated.clear();
+    conf = new SparkConf();
+    taskMetrics = new TaskMetrics();
+
+    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+
+    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+    when(blockManager.getDiskWriter(
+      any(BlockId.class),
+      any(File.class),
+      any(SerializerInstance.class),
+      anyInt(),
+      any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+      @Override
+      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+        Object[] args = invocationOnMock.getArguments();
+
+        return new DiskBlockObjectWriter(
+          (BlockId) args[0],
+          (File) args[1],
+          (SerializerInstance) args[2],
+          (Integer) args[3],
+          new CompressStream(),
+          false,
+          (ShuffleWriteMetrics) args[4]
+        );
+      }
+    });
+    when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
+      new Answer<InputStream>() {
+        @Override
+        public InputStream answer(InvocationOnMock invocation) throws Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          InputStream is = (InputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
+          } else {
+            return is;
+          }
+        }
+      }
+    );
+
+    when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
+      new Answer<OutputStream>() {
+        @Override
+        public OutputStream answer(InvocationOnMock invocation) throws Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          OutputStream os = (OutputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
+          } else {
+            return os;
+          }
+        }
+      }
+    );
+
+    when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
+    doAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+        partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+        return null;
+      }
+    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+
+    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+      new Answer<Tuple2<TempShuffleBlockId, File>>() {
+        @Override
+        public Tuple2<TempShuffleBlockId, File> answer(
+          InvocationOnMock invocationOnMock) throws Throwable {
+          TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
+          File file = File.createTempFile("spillFile", ".spill", tempDir);
+          spillFilesCreated.add(file);
+          return Tuple2$.MODULE$.apply(blockId, file);
+        }
+      });
+
+    when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+
+    when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
+    when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+  }
+
+  private UnsafeShuffleWriter<Object, Object> createWriter(
+      boolean transferToEnabled) throws IOException {
+    conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
+    return new UnsafeShuffleWriter<Object, Object>(
+      blockManager,
+      shuffleBlockResolver,
+      taskMemoryManager,
+      shuffleMemoryManager,
+      new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+      0, // map id
+      taskContext,
+      conf
+    );
+  }
+
+  private void assertSpillFilesWereCleanedUp() {
+    for (File spillFile : spillFilesCreated) {
+      assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+        spillFile.exists());
+    }
+  }
+
+  private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
+    final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
+    long startOffset = 0;
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      final long partitionSize = partitionSizesInMergedFile[i];
+      if (partitionSize > 0) {
+        InputStream in = new FileInputStream(mergedOutputFile);
+        ByteStreams.skipFully(in, startOffset);
+        in = new LimitedInputStream(in, partitionSize);
+        if (conf.getBoolean("spark.shuffle.compress", true)) {
+          in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
+        }
+        DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
+        Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
+        while (records.hasNext()) {
+          Tuple2<Object, Object> record = records.next();
+          assertEquals(i, hashPartitioner.getPartition(record._1()));
+          recordsList.add(record);
+        }
+        recordsStream.close();
+        startOffset += partitionSize;
+      }
+    }
+    return recordsList;
+  }
+
+  @Test(expected=IllegalStateException.class)
+  public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+    createWriter(false).stop(true);
+  }
+
+  @Test
+  public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+    createWriter(false).stop(false);
+  }
+
+  @Test
+  public void writeEmptyIterator() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(Collections.<Product2<Object, Object>>emptyIterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+  }
+
+  @Test
+  public void writeWithoutSpilling() throws Exception {
+    // In this example, each partition should have exactly one record:
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(dataToWrite.iterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      // All partitions should be the same size:
+      assertEquals(partitionSizesInMergedFile[0], size);
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  private void testMergingSpills(
+      boolean transferToEnabled,
+      String compressionCodecName) throws IOException {
+    if (compressionCodecName != null) {
+      conf.set("spark.shuffle.compress", "true");
+      conf.set("spark.io.compression.codec", compressionCodecName);
+    } else {
+      conf.set("spark.shuffle.compress", "false");
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.insertRecordIntoSorter(dataToWrite.get(0));
+    writer.insertRecordIntoSorter(dataToWrite.get(1));
+    writer.insertRecordIntoSorter(dataToWrite.get(2));
+    writer.insertRecordIntoSorter(dataToWrite.get(3));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(dataToWrite.get(4));
+    writer.insertRecordIntoSorter(dataToWrite.get(5));
+    writer.closeAndWriteOutput();
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertEquals(2, spillFilesCreated.size());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
+
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZF() throws Exception {
+    testMergingSpills(true, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZF() throws Exception {
+    testMergingSpills(false, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZ4() throws Exception {
+    testMergingSpills(true, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
+    testMergingSpills(false, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndSnappy() throws Exception {
+    testMergingSpills(true, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
+    testMergingSpills(false, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
+    testMergingSpills(true, null);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
+    testMergingSpills(false, null);
+  }
+
+  @Test
+  public void writeEnoughDataToTriggerSpill() throws Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to allocate new data page
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
+    for (int i = 0; i < 128 + 1; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to grow sort buffer
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+    new Random(42).nextBytes(bytes);
+    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
+    writer.write(dataToWrite.iterator());
+    writer.stop(true);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
+    // Use a custom serializer so that we have exact control over the size of serialized data.
+    final Serializer byteArraySerializer = new Serializer() {
+      @Override
+      public SerializerInstance newInstance() {
+        return new SerializerInstance() {
+          @Override
+          public SerializationStream serializeStream(final OutputStream s) {
+            return new SerializationStream() {
+              @Override
+              public void flush() { }
+
+              @Override
+              public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
+                byte[] bytes = (byte[]) t;
+                try {
+                  s.write(bytes);
+                } catch (IOException e) {
+                  throw new RuntimeException(e);
+                }
+                return this;
+              }
+
+              @Override
+              public void close() { }
+            };
+          }
+          public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) { return null; }
+          public DeserializationStream deserializeStream(InputStream s) { return null; }
+          public <T> T deserialize(ByteBuffer b, ClassLoader l, ClassTag<T> ev1) { return null; }
+          public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) { return null; }
+        };
+      }
+    };
+    when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(byteArraySerializer));
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    // Insert a record and force a spill so that there's something to clean up:
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[1], new byte[1]));
+    writer.forceSorterToSpill();
+    // We should be able to write a record that's right _at_ the max record size
+    final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE];
+    new Random(42).nextBytes(atMaxRecordSize);
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[0], atMaxRecordSize));
+    writer.forceSorterToSpill();
+    // Inserting a record that's larger than the max record size should fail:
+    final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1];
+    new Random(42).nextBytes(exceedsMaxRecordSize);
+    Product2<Object, Object> hugeRecord =
+      new Tuple2<Object, Object>(new byte[0], exceedsMaxRecordSize);
+    try {
+      // Here, we write through the public `write()` interface instead of the test-only
+      // `insertRecordIntoSorter` interface:
+      writer.write(Collections.singletonList(hugeRecord).iterator());
+      fail("Expected exception to be thrown");
+    } catch (IOException e) {
+      // Pass
+    }
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.stop(false);
+    assertSpillFilesWereCleanedUp();
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
index 8c6035fb367fe965aba7fa8e3113ef32741b599c..cf6a143537889ef93b1d72188176a9943e55443f 100644
--- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.io
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
 
+import com.google.common.io.ByteStreams
 import org.scalatest.FunSuite
 
 import org.apache.spark.SparkConf
@@ -62,6 +63,14 @@ class CompressionCodecSuite extends FunSuite {
     testCodec(codec)
   }
 
+  test("lz4 does not support concatenation of serialized streams") {
+    val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName)
+    assert(codec.getClass === classOf[LZ4CompressionCodec])
+    intercept[Exception] {
+      testConcatenationOfSerializedStreams(codec)
+    }
+  }
+
   test("lzf compression codec") {
     val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
     assert(codec.getClass === classOf[LZFCompressionCodec])
@@ -74,6 +83,12 @@ class CompressionCodecSuite extends FunSuite {
     testCodec(codec)
   }
 
+  test("lzf supports concatenation of serialized streams") {
+    val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
+    assert(codec.getClass === classOf[LZFCompressionCodec])
+    testConcatenationOfSerializedStreams(codec)
+  }
+
   test("snappy compression codec") {
     val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
     assert(codec.getClass === classOf[SnappyCompressionCodec])
@@ -86,9 +101,38 @@ class CompressionCodecSuite extends FunSuite {
     testCodec(codec)
   }
 
+  test("snappy does not support concatenation of serialized streams") {
+    val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
+    assert(codec.getClass === classOf[SnappyCompressionCodec])
+    intercept[Exception] {
+      testConcatenationOfSerializedStreams(codec)
+    }
+  }
+
   test("bad compression codec") {
     intercept[IllegalArgumentException] {
       CompressionCodec.createCodec(conf, "foobar")
     }
   }
+
+  private def testConcatenationOfSerializedStreams(codec: CompressionCodec): Unit = {
+    val bytes1: Array[Byte] = {
+      val baos = new ByteArrayOutputStream()
+      val out = codec.compressedOutputStream(baos)
+      (0 to 64).foreach(out.write)
+      out.close()
+      baos.toByteArray
+    }
+    val bytes2: Array[Byte] = {
+      val baos = new ByteArrayOutputStream()
+      val out = codec.compressedOutputStream(baos)
+      (65 to 127).foreach(out.write)
+      out.close()
+      baos.toByteArray
+    }
+    val concatenatedBytes = codec.compressedInputStream(new ByteArrayInputStream(bytes1 ++ bytes2))
+    val decompressed: Array[Byte] = new Array[Byte](128)
+    ByteStreams.readFully(concatenatedBytes, decompressed)
+    assert(decompressed.toSeq === (0 to 127))
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ed4d8ce632e162a8a2f43d82e15c30e1551f3b74
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.apache.spark.SparkConf
+import org.scalatest.FunSuite
+
+class JavaSerializerSuite extends FunSuite {
+  test("JavaSerializer instances are serializable") {
+    val serializer = new JavaSerializer(new SparkConf())
+    val instance = serializer.newInstance()
+    instance.deserialize[JavaSerializer](instance.serialize(serializer))
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..49a04a2a45280a896751775d4b4f2be37ede17cd
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{FunSuite, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
+
+/**
+ * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
+ * performed in other suites.
+ */
+class UnsafeShuffleManagerSuite extends FunSuite with Matchers {
+
+  import UnsafeShuffleManager.canUseUnsafeShuffle
+
+  private class RuntimeExceptionAnswer extends Answer[Object] {
+    override def answer(invocation: InvocationOnMock): Object = {
+      throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName)
+    }
+  }
+
+  private def shuffleDep(
+      partitioner: Partitioner,
+      serializer: Option[Serializer],
+      keyOrdering: Option[Ordering[Any]],
+      aggregator: Option[Aggregator[Any, Any, Any]],
+      mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
+    val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer())
+    doReturn(0).when(dep).shuffleId
+    doReturn(partitioner).when(dep).partitioner
+    doReturn(serializer).when(dep).serializer
+    doReturn(keyOrdering).when(dep).keyOrdering
+    doReturn(aggregator).when(dep).aggregator
+    doReturn(mapSideCombine).when(dep).mapSideCombine
+    dep
+  }
+
+  test("supported shuffle dependencies") {
+    val kryo = Some(new KryoSerializer(new SparkConf()))
+
+    assert(canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
+    when(rangePartitioner.numPartitions).thenReturn(2)
+    assert(canUseUnsafeShuffle(shuffleDep(
+      partitioner = rangePartitioner,
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+  }
+
+  test("unsupported shuffle dependencies") {
+    val kryo = Some(new KryoSerializer(new SparkConf()))
+    val java = Some(new JavaSerializer(new SparkConf()))
+
+    // We only support serializers that support object relocation
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = java,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    // We do not support shuffles with more than 16 million output partitions
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    // We do not support shuffles that perform any kind of aggregation or sorting of keys
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = Some(mock(classOf[Ordering[Any]])),
+      aggregator = None,
+      mapSideCombine = false
+    )))
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+      mapSideCombine = false
+    )))
+    // We do not support shuffles that perform any kind of aggregation or sorting of keys
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = Some(mock(classOf[Ordering[Any]])),
+      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+      mapSideCombine = true
+    )))
+  }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6351539e91e973f5268c170d5aa1513aada1cfbf
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
+class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
+
+  // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
+
+  override def beforeAll() {
+    conf.set("spark.shuffle.manager", "tungsten-sort")
+    // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort
+    // shuffle records.
+    conf.set("spark.shuffle.memoryFraction", "0.5")
+  }
+
+  test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
+    val tmpDir = Utils.createTempDir()
+    try {
+      val myConf = conf.clone()
+        .set("spark.local.dir", tmpDir.getAbsolutePath)
+      sc = new SparkContext("local", "test", myConf)
+      // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
+      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+        .setSerializer(new KryoSerializer(myConf))
+      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+      assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+      def getAllFiles: Set[File] =
+        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+      val filesBeforeShuffle = getAllFiles
+      // Force the shuffle to be performed
+      shuffledRdd.count()
+      // Ensure that the shuffle actually created files that will need to be cleaned up
+      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+      filesCreatedByShuffle.map(_.getName) should be
+        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+      // Check that the cleanup actually removes the files
+      sc.env.blockManager.master.removeShuffle(0, blocking = true)
+      for (file <- filesCreatedByShuffle) {
+        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+      }
+    } finally {
+      Utils.deleteRecursively(tmpDir)
+    }
+  }
+
+  test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
+    val tmpDir = Utils.createTempDir()
+    try {
+      val myConf = conf.clone()
+        .set("spark.local.dir", tmpDir.getAbsolutePath)
+      sc = new SparkContext("local", "test", myConf)
+      // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
+      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+        .setSerializer(new JavaSerializer(myConf))
+      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+      assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+      def getAllFiles: Set[File] =
+        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+      val filesBeforeShuffle = getAllFiles
+      // Force the shuffle to be performed
+      shuffledRdd.count()
+      // Ensure that the shuffle actually created files that will need to be cleaned up
+      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+      filesCreatedByShuffle.map(_.getName) should be
+        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+      // Check that the cleanup actually removes the files
+      sc.env.blockManager.master.removeShuffle(0, blocking = true)
+      for (file <- filesCreatedByShuffle) {
+        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+      }
+    } finally {
+      Utils.deleteRecursively(tmpDir)
+    }
+  }
+}
diff --git a/pom.xml b/pom.xml
index cf9279ea5a2a611e894664ab9f63e9aa3a949559..564a443466e5a1aea7bbdf260ed02b5db17a62c9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -669,7 +669,7 @@
       <dependency>
         <groupId>org.mockito</groupId>
         <artifactId>mockito-all</artifactId>
-        <version>1.9.0</version>
+        <version>1.9.5</version>
         <scope>test</scope>
       </dependency>
       <dependency>
@@ -684,6 +684,18 @@
         <version>4.10</version>
         <scope>test</scope>
       </dependency>
+      <dependency>
+        <groupId>org.hamcrest</groupId>
+        <artifactId>hamcrest-core</artifactId>
+        <version>1.3</version>
+        <scope>test</scope>
+      </dependency>
+      <dependency>
+        <groupId>org.hamcrest</groupId>
+        <artifactId>hamcrest-library</artifactId>
+        <version>1.3</version>
+        <scope>test</scope>
+      </dependency>
       <dependency>
         <groupId>com.novocode</groupId>
         <artifactId>junit-interface</artifactId>
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index fba7290dcb0b54a4972295e8747f0d86eb33ce27..487062a31f77f310363146128343b7b20f082c77 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -131,6 +131,12 @@ object MimaExcludes {
             // SPARK-7530 Added StreamingContext.getState()
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.streaming.StreamingContext.state_=")
+          ) ++ Seq(
+            // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some
+            // unnecessary type bounds in order to fix some compiler warnings that occurred when
+            // implementing this interface in Java. Note that ShuffleWriter is private[spark].
+            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+              "org.apache.spark.shuffle.ShuffleWriter")
           )
 
         case v if v.startsWith("1.3") =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index c3d2c7019a54aa19fa0267ef88cb1be74a81cf1f..3e46596ecf6ace5efcfe98467799fa992615d002 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -17,17 +17,18 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.{RDD, ShuffledRDD}
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
 import org.apache.spark.sql.catalyst.errors.attachTree
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.{SQLContext, Row}
 import org.apache.spark.util.MutablePair
 
 object Exchange {
@@ -85,7 +86,9 @@ case class Exchange(
     // corner-cases where a partitioner constructed with `numPartitions` partitions may output
     // fewer partitions (like RangePartitioner, for example).
     val conf = child.sqlContext.sparkContext.conf
-    val sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+    val shuffleManager = SparkEnv.get.shuffleManager
+    val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] ||
+      shuffleManager.isInstanceOf[UnsafeShuffleManager]
     val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
     val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
     if (newOrdering.nonEmpty) {
@@ -93,11 +96,11 @@ case class Exchange(
       // which requires a defensive copy.
       true
     } else if (sortBasedShuffleOn) {
-      // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
-      // However, there are two special cases where we can avoid the copy, described below:
-      if (partitioner.numPartitions <= bypassMergeThreshold) {
-        // If the number of output partitions is sufficiently small, then Spark will fall back to
-        // the old hash-based shuffle write path which doesn't buffer deserialized records.
+      val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+      if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
+        // If we're using the original SortShuffleManager and the number of output partitions is
+        // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which
+        // doesn't buffer deserialized records.
         // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
         false
       } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
@@ -105,9 +108,14 @@ case class Exchange(
         // them. This optimization is guarded by a feature-flag and is only applied in cases where
         // shuffle dependency does not specify an ordering and the record serializer has certain
         // properties. If this optimization is enabled, we can safely avoid the copy.
+        //
+        // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
         false
       } else {
-        // None of the special cases held, so we must copy.
+        // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code
+        // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls
+        // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In
+        // both cases, we must copy.
         true
       }
     } else {
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 5b0733206b2bc5de85b57cc45771cc55b8170cb3..9e151fc7a9141a2435e502beff4114ddccba364f 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -42,6 +42,10 @@
       <groupId>com.google.code.findbugs</groupId>
       <artifactId>jsr305</artifactId>
     </dependency>
+    <dependency>
+      <groupId>com.google.guava</groupId>
+      <artifactId>guava</artifactId>
+    </dependency>
 
     <!-- Provided dependencies -->
     <dependency>
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 9224988e6ad69cd0561a27618e7cd92979c6d101..2906ac8abad1a00b34ce2aa02fe711aa306737ab 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -19,6 +19,7 @@ package org.apache.spark.unsafe.memory;
 
 import java.util.*;
 
+import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -47,10 +48,18 @@ public final class TaskMemoryManager {
 
   private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
 
-  /**
-   * The number of entries in the page table.
-   */
-  private static final int PAGE_TABLE_SIZE = 1 << 13;
+  /** The number of bits used to address the page table. */
+  private static final int PAGE_NUMBER_BITS = 13;
+
+  /** The number of bits used to encode offsets in data pages. */
+  @VisibleForTesting
+  static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS;  // 51
+
+  /** The number of entries in the page table. */
+  private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
+
+  /** Maximum supported data page size */
+  private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS);
 
   /** Bit mask for the lower 51 bits of a long. */
   private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
@@ -101,11 +110,9 @@ public final class TaskMemoryManager {
    * intended for allocating large blocks of memory that will be shared between operators.
    */
   public MemoryBlock allocatePage(long size) {
-    if (logger.isTraceEnabled()) {
-      logger.trace("Allocating {} byte page", size);
-    }
-    if (size >= (1L << 51)) {
-      throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes");
+    if (size > MAXIMUM_PAGE_SIZE) {
+      throw new IllegalArgumentException(
+        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes");
     }
 
     final int pageNumber;
@@ -120,8 +127,8 @@ public final class TaskMemoryManager {
     final MemoryBlock page = executorMemoryManager.allocate(size);
     page.pageNumber = pageNumber;
     pageTable[pageNumber] = page;
-    if (logger.isDebugEnabled()) {
-      logger.debug("Allocate page number {} ({} bytes)", pageNumber, size);
+    if (logger.isTraceEnabled()) {
+      logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
     }
     return page;
   }
@@ -130,9 +137,6 @@ public final class TaskMemoryManager {
    * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
    */
   public void freePage(MemoryBlock page) {
-    if (logger.isTraceEnabled()) {
-      logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size());
-    }
     assert (page.pageNumber != -1) :
       "Called freePage() on memory that wasn't allocated with allocatePage()";
     executorMemoryManager.free(page);
@@ -140,8 +144,8 @@ public final class TaskMemoryManager {
       allocatedPages.clear(page.pageNumber);
     }
     pageTable[page.pageNumber] = null;
-    if (logger.isDebugEnabled()) {
-      logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size());
+    if (logger.isTraceEnabled()) {
+      logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
     }
   }
 
@@ -173,14 +177,36 @@ public final class TaskMemoryManager {
   /**
    * Given a memory page and offset within that page, encode this address into a 64-bit long.
    * This address will remain valid as long as the corresponding page has not been freed.
+   *
+   * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}.
+   * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
+   *                     this should be the value that you would pass as the base offset into an
+   *                     UNSAFE call (e.g. page.baseOffset() + something).
+   * @return an encoded page address.
    */
   public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
-    if (inHeap) {
-      assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
-      return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
-    } else {
-      return offsetInPage;
+    if (!inHeap) {
+      // In off-heap mode, an offset is an absolute address that may require a full 64 bits to
+      // encode. Due to our page size limitation, though, we can convert this into an offset that's
+      // relative to the page's base offset; this relative offset will fit in 51 bits.
+      offsetInPage -= page.getBaseOffset();
     }
+    return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
+  }
+
+  @VisibleForTesting
+  public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
+    assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+    return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+  }
+
+  @VisibleForTesting
+  public static int decodePageNumber(long pagePlusOffsetAddress) {
+    return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS);
+  }
+
+  private static long decodeOffset(long pagePlusOffsetAddress) {
+    return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
   }
 
   /**
@@ -189,7 +215,7 @@ public final class TaskMemoryManager {
    */
   public Object getPage(long pagePlusOffsetAddress) {
     if (inHeap) {
-      final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51);
+      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
       assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
       final Object page = pageTable[pageNumber].getBaseObject();
       assert (page != null);
@@ -204,10 +230,15 @@ public final class TaskMemoryManager {
    * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
    */
   public long getOffsetInPage(long pagePlusOffsetAddress) {
+    final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
     if (inHeap) {
-      return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
+      return offsetInPage;
     } else {
-      return pagePlusOffsetAddress;
+      // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we
+      // converted the absolute address into a relative address. Here, we invert that operation:
+      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+      assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+      return pageTable[pageNumber].getBaseOffset() + offsetInPage;
     }
   }
 
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
index 932882f1ca24817147f060b9e1ae14ae31c7d408..06fb0811836593383f9b68c934a59451ec3af1aa 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
@@ -38,4 +38,27 @@ public class TaskMemoryManagerSuite {
     Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
   }
 
+  @Test
+  public void encodePageNumberAndOffsetOffHeap() {
+    final TaskMemoryManager manager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
+    // encode. This test exercises that corner-case:
+    final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset);
+    Assert.assertEquals(null, manager.getPage(encodedAddress));
+    Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress));
+  }
+
+  @Test
+  public void encodePageNumberAndOffsetOnHeap() {
+    final TaskMemoryManager manager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
+    Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
+    Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
+  }
+
 }