diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index e3e79471154df896afc6307c5f60251b53b38526..1bc924d424c02718a8dd1aa032b21d4de419606b 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -51,6 +51,6 @@ public class MemoryBlock extends MemoryLocation {
    * Creates a memory block pointing to the memory used by the long array.
    */
   public static MemoryBlock fromLongArray(final long[] array) {
-    return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8);
+    return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L);
   }
 }
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index 75a0e807d76f5c7a7e4b5be5c3a2b78fef7ec2d0..dc36809d8911fc570d95cd93e80464f048ed3a7a 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -22,12 +22,12 @@ import java.util.Comparator;
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.Sorter;
 import org.apache.spark.util.collection.unsafe.sort.RadixSort;
 
 final class ShuffleInMemorySorter {
 
-  private final Sorter<PackedRecordPointer, LongArray> sorter;
   private static final class SortComparator implements Comparator<PackedRecordPointer> {
     @Override
     public int compare(PackedRecordPointer left, PackedRecordPointer right) {
@@ -44,6 +44,9 @@ final class ShuffleInMemorySorter {
    * An array of record pointers and partition ids that have been encoded by
    * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
    * records.
+   *
+   * Only part of the array will be used to store the pointers, the rest part is preserved as
+   * temporary buffer for sorting.
    */
   private LongArray array;
 
@@ -54,14 +57,14 @@ final class ShuffleInMemorySorter {
   private final boolean useRadixSort;
 
   /**
-   * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
+   * The position in the pointer array where new records can be inserted.
    */
-  private final int memoryAllocationFactor;
+  private int pos = 0;
 
   /**
-   * The position in the pointer array where new records can be inserted.
+   * How many records could be inserted, because part of the array should be left for sorting.
    */
-  private int pos = 0;
+  private int usableCapacity = 0;
 
   private int initialSize;
 
@@ -70,9 +73,14 @@ final class ShuffleInMemorySorter {
     assert (initialSize > 0);
     this.initialSize = initialSize;
     this.useRadixSort = useRadixSort;
-    this.memoryAllocationFactor = useRadixSort ? 2 : 1;
     this.array = consumer.allocateArray(initialSize);
-    this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
+    this.usableCapacity = getUsableCapacity();
+  }
+
+  private int getUsableCapacity() {
+    // Radix sort requires same amount of used memory as buffer, Tim sort requires
+    // half of the used memory as buffer.
+    return (int) (array.size() / (useRadixSort ? 2 : 1.5));
   }
 
   public void free() {
@@ -89,7 +97,8 @@ final class ShuffleInMemorySorter {
   public void reset() {
     if (consumer != null) {
       consumer.freeArray(array);
-      this.array = consumer.allocateArray(initialSize);
+      array = consumer.allocateArray(initialSize);
+      usableCapacity = getUsableCapacity();
     }
     pos = 0;
   }
@@ -101,14 +110,15 @@ final class ShuffleInMemorySorter {
       array.getBaseOffset(),
       newArray.getBaseObject(),
       newArray.getBaseOffset(),
-      array.size() * (8 / memoryAllocationFactor)
+      pos * 8L
     );
     consumer.freeArray(array);
     array = newArray;
+    usableCapacity = getUsableCapacity();
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pos < array.size() / memoryAllocationFactor;
+    return pos < usableCapacity;
   }
 
   public long getMemoryUsage() {
@@ -170,6 +180,14 @@ final class ShuffleInMemorySorter {
         PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
         PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
     } else {
+      MemoryBlock unused = new MemoryBlock(
+        array.getBaseObject(),
+        array.getBaseOffset() + pos * 8L,
+        (array.size() - pos) * 8L);
+      LongArray buffer = new LongArray(unused);
+      Sorter<PackedRecordPointer, LongArray> sorter =
+        new Sorter<>(new ShuffleSortDataFormat(buffer));
+
       sorter.sort(array, 0, pos, SORT_COMPARATOR);
     }
     return new ShuffleSorterIterator(pos, array, offset);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
index 1e924d2aec442fb1058f1adb0d9b072c019aa6ef..717bdd79d47ef2a64d586575de72ce97536b9417 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
@@ -19,14 +19,15 @@ package org.apache.spark.shuffle.sort;
 
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
-import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.SortDataFormat;
 
 final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, LongArray> {
 
-  public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
+  private final LongArray buffer;
 
-  private ShuffleSortDataFormat() { }
+  ShuffleSortDataFormat(LongArray buffer) {
+    this.buffer = buffer;
+  }
 
   @Override
   public PackedRecordPointer getKey(LongArray data, int pos) {
@@ -70,8 +71,8 @@ final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, Lo
 
   @Override
   public LongArray allocate(int length) {
-    // This buffer is used temporary (usually small), so it's fine to allocated from JVM heap.
-    return new LongArray(MemoryBlock.fromLongArray(new long[length]));
+    assert (length <= buffer.size()) :
+      "the buffer is smaller than required: " + buffer.size() + " < " + length;
+    return buffer;
   }
-
 }
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 6c00608302c4e2e056a53d1c24ec723a88cf1f81..dc04025692909632a95bfb872d0139f88183f01b 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -221,7 +221,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
       SparkEnv.get() != null ? SparkEnv.get().blockManager() :  null,
       SparkEnv.get() != null ? SparkEnv.get().serializerManager() :  null,
       initialCapacity,
-      0.70,
+      // In order to re-use the longArray for sorting, the load factor cannot be larger than 0.5.
+      0.5,
       pageSizeBytes,
       enablePerfMetrics);
   }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 0cce792f33d34bd36a996ea1a7f05e422b3f8276..c7b070f519f88766cb2eb268decb6a5e0d38497a 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -25,6 +25,7 @@ import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.Sorter;
 
 /**
@@ -69,8 +70,6 @@ public final class UnsafeInMemorySorter {
   private final MemoryConsumer consumer;
   private final TaskMemoryManager memoryManager;
   @Nullable
-  private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
-  @Nullable
   private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
 
   /**
@@ -79,14 +78,12 @@ public final class UnsafeInMemorySorter {
   @Nullable
   private final PrefixComparators.RadixSortSupport radixSortSupport;
 
-  /**
-   * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
-   */
-  private final int memoryAllocationFactor;
-
   /**
    * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
    * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+   *
+   * Only part of the array will be used to store the pointers, the rest part is preserved as
+   * temporary buffer for sorting.
    */
   private LongArray array;
 
@@ -95,6 +92,11 @@ public final class UnsafeInMemorySorter {
    */
   private int pos = 0;
 
+  /**
+   * How many records could be inserted, because part of the array should be left for sorting.
+   */
+  private int usableCapacity = 0;
+
   private long initialSize;
 
   private long totalSortTimeNanos = 0L;
@@ -121,7 +123,6 @@ public final class UnsafeInMemorySorter {
     this.memoryManager = memoryManager;
     this.initialSize = array.size();
     if (recordComparator != null) {
-      this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
       this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
       if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) {
         this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator;
@@ -129,12 +130,17 @@ public final class UnsafeInMemorySorter {
         this.radixSortSupport = null;
       }
     } else {
-      this.sorter = null;
       this.sortComparator = null;
       this.radixSortSupport = null;
     }
-    this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1;
     this.array = array;
+    this.usableCapacity = getUsableCapacity();
+  }
+
+  private int getUsableCapacity() {
+    // Radix sort requires same amount of used memory as buffer, Tim sort requires
+    // half of the used memory as buffer.
+    return (int) (array.size() / (radixSortSupport != null ? 2 : 1.5));
   }
 
   /**
@@ -150,7 +156,8 @@ public final class UnsafeInMemorySorter {
   public void reset() {
     if (consumer != null) {
       consumer.freeArray(array);
-      this.array = consumer.allocateArray(initialSize);
+      array = consumer.allocateArray(initialSize);
+      usableCapacity = getUsableCapacity();
     }
     pos = 0;
   }
@@ -174,7 +181,7 @@ public final class UnsafeInMemorySorter {
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pos + 1 < (array.size() / memoryAllocationFactor);
+    return pos + 1 < usableCapacity;
   }
 
   public void expandPointerArray(LongArray newArray) {
@@ -186,9 +193,10 @@ public final class UnsafeInMemorySorter {
       array.getBaseOffset(),
       newArray.getBaseObject(),
       newArray.getBaseOffset(),
-      array.size() * (8 / memoryAllocationFactor));
+      pos * 8L);
     consumer.freeArray(array);
     array = newArray;
+    usableCapacity = getUsableCapacity();
   }
 
   /**
@@ -275,13 +283,20 @@ public final class UnsafeInMemorySorter {
   public SortedIterator getSortedIterator() {
     int offset = 0;
     long start = System.nanoTime();
-    if (sorter != null) {
+    if (sortComparator != null) {
       if (this.radixSortSupport != null) {
         // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they
         // force a full-width sort (and we cannot radix-sort nullable long fields at all).
         offset = RadixSort.sortKeyPrefixArray(
           array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
       } else {
+        MemoryBlock unused = new MemoryBlock(
+          array.getBaseObject(),
+          array.getBaseOffset() + pos * 8L,
+          (array.size() - pos) * 8L);
+        LongArray buffer = new LongArray(unused);
+        Sorter<RecordPointerAndKeyPrefix, LongArray> sorter =
+          new Sorter<>(new UnsafeSortDataFormat(buffer));
         sorter.sort(array, 0, pos / 2, sortComparator);
       }
     }
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
index 7bda76907f4c352534311fac32256a88ad94f0e1..430bf677edbdf58277a08084bdde0c5007318e35 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -19,7 +19,6 @@ package org.apache.spark.util.collection.unsafe.sort;
 
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
-import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.SortDataFormat;
 
 /**
@@ -32,9 +31,11 @@ import org.apache.spark.util.collection.SortDataFormat;
 public final class UnsafeSortDataFormat
   extends SortDataFormat<RecordPointerAndKeyPrefix, LongArray> {
 
-  public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
+  private final LongArray buffer;
 
-  private UnsafeSortDataFormat() { }
+  public UnsafeSortDataFormat(LongArray buffer) {
+    this.buffer = buffer;
+  }
 
   @Override
   public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
@@ -83,9 +84,9 @@ public final class UnsafeSortDataFormat
 
   @Override
   public LongArray allocate(int length) {
-    assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
-    // This is used as temporary buffer, it's fine to allocate from JVM heap.
-    return new LongArray(MemoryBlock.fromLongArray(new long[length * 2]));
+    assert (length * 2 <= buffer.size()) :
+      "the buffer is smaller than required: " + buffer.size() + " < " + (length * 2);
+    return buffer;
   }
 
 }
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index f9dc20d8b751bae20bec5d65491d4b9fa1b103e1..7dd61f85abefd40a64597c6587c545a5164deb57 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -21,12 +21,15 @@ import java.io.*;
 import java.nio.ByteBuffer;
 import java.util.*;
 
-import scala.*;
+import scala.Option;
+import scala.Product2;
+import scala.Tuple2;
+import scala.Tuple2$;
 import scala.collection.Iterator;
 import scala.runtime.AbstractFunction1;
 
-import com.google.common.collect.Iterators;
 import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Iterators;
 import com.google.common.io.ByteStreams;
 import org.junit.After;
 import org.junit.Before;
@@ -35,29 +38,33 @@ import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.lessThan;
-import static org.junit.Assert.*;
-import static org.mockito.Answers.RETURNS_SMART_NULLS;
-import static org.mockito.Mockito.*;
 
-import org.apache.spark.*;
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.io.CompressionCodec$;
 import org.apache.spark.io.LZ4CompressionCodec;
 import org.apache.spark.io.LZFCompressionCodec;
 import org.apache.spark.io.SnappyCompressionCodec;
-import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
-import org.apache.spark.serializer.*;
 import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.serializer.*;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.storage.*;
-import org.apache.spark.memory.TestMemoryManager;
-import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.Utils;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
 public class UnsafeShuffleWriterSuite {
 
   static final int NUM_PARTITITONS = 4;
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 84b82f5a4742cdce097e1e08d34d141ee6cbe559..fc127f07c8d690a418fb4aebcb5e5e807e7539b1 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -589,7 +589,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Test
   public void multipleValuesForSameKey() {
     BytesToBytesMap map =
-      new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false);
+      new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024, false);
     try {
       int i;
       for (i = 0; i < 1024; i++) {
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 699f7fa1f272755041da0e27402d8eebd056b933..6bcc601e13ecc1b34393b3aae02e586d14f0bb1b 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -106,8 +106,10 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi()
     val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i }
     val buf = new LongArray(MemoryBlock.fromLongArray(ref))
+    val tmp = new Array[Long](size/2)
+    val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp))
 
-    new Sorter(UnsafeSortDataFormat.INSTANCE).sort(
+    new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort(
       buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] {
         override def compare(
             r1: RecordPointerAndKeyPrefix,
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
index def0752b46f6af4618688ebe4e1481d4feda46c0..1d26d4a8307cfa80b1a8f7e14da41e809149d878 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -93,7 +93,8 @@ class RadixSortSuite extends SparkFunSuite with Logging {
   }
 
   private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
-    new Sorter(UnsafeSortDataFormat.INSTANCE).sort(
+    val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
+    new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
       buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
         override def compare(
             r1: RecordPointerAndKeyPrefix,
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 38dbfef76caeee19147ac596d5a9fef69772d89f..bb823cd07be5eeaeaeb82a5c1cd0ff9232970c79 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -73,6 +73,8 @@ public final class UnsafeKVExternalSorter {
     PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema);
     BaseOrdering ordering = GenerateOrdering.create(keySchema);
     KVComparator recordComparator = new KVComparator(ordering, keySchema.length());
+    boolean canUseRadixSort = keySchema.length() == 1 &&
+      SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0));
 
     TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager();
 
@@ -86,14 +88,16 @@ public final class UnsafeKVExternalSorter {
         prefixComparator,
         /* initialSize */ 4096,
         pageSizeBytes,
-        keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)));
+        canUseRadixSort);
     } else {
+      // The array will be used to do in-place sort, which require half of the space to be empty.
+      assert(map.numKeys() <= map.getArray().size() / 2);
       // During spilling, the array in map will not be used, so we can borrow that and use it
       // as the underline array for in-memory sorter (it's always large enough).
       // Since we will not grow the array, it's fine to pass `null` as consumer.
       final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
         null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(),
-        false /* TODO(ekl) we can only radix sort if the BytesToBytes load factor is <= 0.5 */);
+        canUseRadixSort);
 
       // We cannot use the destructive iterator here because we are reusing the existing memory
       // pages in BytesToBytesMap to hold records during sorting.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cd6b97a855412fbfd4df9c709ea8868657442c91..412e8c54ca308028519d59c37811c898442c8f20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -540,7 +540,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
       Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
         cursor - Platform.LONG_ARRAY_OFFSET)
       page = newPage
-      freeMemory(used * 8)
+      freeMemory(used * 8L)
     }
 
     // copy the bytes of UnsafeRow
@@ -599,7 +599,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
       i += 2
     }
     old_array = null  // release the reference to old array
-    freeMemory(n * 8)
+    freeMemory(n * 8L)
   }
 
   /**
@@ -610,7 +610,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
     // Convert to dense mode if it does not require more memory or could fit within L1 cache
     if (range < array.length || range < 1024) {
       try {
-        ensureAcquireMemory((range + 1) * 8)
+        ensureAcquireMemory((range + 1) * 8L)
       } catch {
         case e: SparkException =>
           // there is no enough memory to convert
@@ -628,7 +628,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
       val old_length = array.length
       array = denseArray
       isDense = true
-      freeMemory(old_length * 8)
+      freeMemory(old_length * 8L)
     }
   }
 
@@ -637,11 +637,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
    */
   def free(): Unit = {
     if (page != null) {
-      freeMemory(page.length * 8)
+      freeMemory(page.length * 8L)
       page = null
     }
     if (array != null) {
-      freeMemory(array.length * 8)
+      freeMemory(array.length * 8L)
       array = null
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
index 0e1868dd665655b06703bde7af6d24a418a993f1..9964b7373fc20891431c2b0c9a1ac66ea5b7f376 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
@@ -36,7 +36,8 @@ import org.apache.spark.util.random.XORShiftRandom
 class SortBenchmark extends BenchmarkBase {
 
   private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
-    new Sorter(UnsafeSortDataFormat.INSTANCE).sort(
+    val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
+    new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
       buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
         override def compare(
           r1: RecordPointerAndKeyPrefix,