diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index bf5f965a9d8dc748854ee63964983aff06eac13b..dec7fcfa0ddc10d21d97618df9405ca678826a82 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -428,7 +428,7 @@ public final class UnsafeExternalSorter {
 
   public UnsafeSorterIterator getSortedIterator() throws IOException {
     assert(inMemSorter != null);
-    final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
+    final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
     int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
     if (spillWriters.isEmpty()) {
       return inMemoryIterator;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 31314653919003c4bc53f9cd774325557b442b73..1e4b8a116e11af37958a412990d5da2199caaab6 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -133,7 +133,7 @@ public final class UnsafeInMemorySorter {
     pointerArrayInsertPosition++;
   }
 
-  private static final class SortedIterator extends UnsafeSorterIterator {
+  public static final class SortedIterator extends UnsafeSorterIterator {
 
     private final TaskMemoryManager memoryManager;
     private final int sortBufferInsertPosition;
@@ -144,7 +144,7 @@ public final class UnsafeInMemorySorter {
     private long keyPrefix;
     private int recordLength;
 
-    SortedIterator(
+    private SortedIterator(
         TaskMemoryManager memoryManager,
         int sortBufferInsertPosition,
         long[] sortBuffer) {
@@ -186,7 +186,7 @@ public final class UnsafeInMemorySorter {
    * Return an iterator over record pointers in sorted order. For efficiency, all calls to
    * {@code next()} will return the same mutable object.
    */
-  public UnsafeSorterIterator getSortedIterator() {
+  public SortedIterator getSortedIterator() {
     sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
     return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
   }
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index f6b017686306b6fb02e2bb95dd5b0b223437f8e7..312ec8ea0dd9dc7681d9b06b2f1991c168dfe037 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -134,7 +134,7 @@ public final class UnsafeKVExternalSorter {
       value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
   }
 
-  public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException {
+  public KVSorterIterator sortedIterator() throws IOException {
     try {
       final UnsafeSorterIterator underlying = sorter.getSortedIterator();
       if (!underlying.hasNext()) {
@@ -142,58 +142,7 @@ public final class UnsafeKVExternalSorter {
         // here in order to prevent memory leaks.
         cleanupResources();
       }
-
-      return new KVIterator<UnsafeRow, UnsafeRow>() {
-        private UnsafeRow key = new UnsafeRow();
-        private UnsafeRow value = new UnsafeRow();
-        private int numKeyFields = keySchema.size();
-        private int numValueFields = valueSchema.size();
-
-        @Override
-        public boolean next() throws IOException {
-          try {
-            if (underlying.hasNext()) {
-              underlying.loadNext();
-
-              Object baseObj = underlying.getBaseObject();
-              long recordOffset = underlying.getBaseOffset();
-              int recordLen = underlying.getRecordLength();
-
-              // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
-              int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
-              int valueLen = recordLen - keyLen - 4;
-
-              key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
-              value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
-
-              return true;
-            } else {
-              key = null;
-              value = null;
-              cleanupResources();
-              return false;
-            }
-          } catch (IOException e) {
-            cleanupResources();
-            throw e;
-          }
-        }
-
-        @Override
-        public UnsafeRow getKey() {
-          return key;
-        }
-
-        @Override
-        public UnsafeRow getValue() {
-          return value;
-        }
-
-        @Override
-        public void close() {
-          cleanupResources();
-        }
-      };
+      return new KVSorterIterator(underlying);
     } catch (IOException e) {
       cleanupResources();
       throw e;
@@ -233,4 +182,61 @@ public final class UnsafeKVExternalSorter {
       return ordering.compare(row1, row2);
     }
   }
+
+  public class KVSorterIterator extends KVIterator<UnsafeRow, UnsafeRow> {
+    private UnsafeRow key = new UnsafeRow();
+    private UnsafeRow value = new UnsafeRow();
+    private final int numKeyFields = keySchema.size();
+    private final int numValueFields = valueSchema.size();
+    private final UnsafeSorterIterator underlying;
+
+    private KVSorterIterator(UnsafeSorterIterator underlying) {
+      this.underlying = underlying;
+    }
+
+    @Override
+    public boolean next() throws IOException {
+      try {
+        if (underlying.hasNext()) {
+          underlying.loadNext();
+
+          Object baseObj = underlying.getBaseObject();
+          long recordOffset = underlying.getBaseOffset();
+          int recordLen = underlying.getRecordLength();
+
+          // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
+          int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
+          int valueLen = recordLen - keyLen - 4;
+
+          key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
+          value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
+
+          return true;
+        } else {
+          key = null;
+          value = null;
+          cleanupResources();
+          return false;
+        }
+      } catch (IOException e) {
+        cleanupResources();
+        throw e;
+      }
+    }
+
+    @Override
+    public UnsafeRow getKey() {
+      return key;
+    }
+
+    @Override
+    public UnsafeRow getValue() {
+      return value;
+    }
+
+    @Override
+    public void close() {
+      cleanupResources();
+    }
+  };
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
index 37d34eb7ccf09def949af717448416a6d8591122..b465787fe8cbd38a73bc1c618cdf04a73557f2be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
@@ -17,12 +17,12 @@
 
 package org.apache.spark.sql.execution.aggregate
 
-import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap}
 import org.apache.spark.unsafe.KVIterator
 import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -230,7 +230,7 @@ class UnsafeHybridAggregationIterator(
     }
 
     // Step 5: Get the sorted iterator from the externalSorter.
-    val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator()
+    val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator()
 
     // Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator.
     // For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator
@@ -368,31 +368,5 @@ object UnsafeHybridAggregationIterator {
       newMutableProjection,
       outputsUnsafeRows)
   }
-
-  def createFromKVIterator(
-      groupingKeyAttributes: Seq[Attribute],
-      valueAttributes: Seq[Attribute],
-      inputKVIterator: KVIterator[UnsafeRow, InternalRow],
-      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-      nonCompleteAggregateAttributes: Seq[Attribute],
-      completeAggregateExpressions: Seq[AggregateExpression2],
-      completeAggregateAttributes: Seq[Attribute],
-      initialInputBufferOffset: Int,
-      resultExpressions: Seq[NamedExpression],
-      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
-      outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
-    new UnsafeHybridAggregationIterator(
-      groupingKeyAttributes,
-      valueAttributes,
-      inputKVIterator,
-      nonCompleteAggregateExpressions,
-      nonCompleteAggregateAttributes,
-      completeAggregateExpressions,
-      completeAggregateAttributes,
-      initialInputBufferOffset,
-      resultExpressions,
-      newMutableProjection,
-      outputsUnsafeRows)
-  }
   // scalastyle:on
 }