diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 3b6200e74f1e11d42eafcea8f11bca68873bcfe7..610ace30f8a62b722276a3b7654a0f36de4ce5ab 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -258,6 +258,11 @@ public final class BytesToBytesMap extends MemoryConsumer {
       this.destructive = destructive;
       if (destructive) {
         destructiveIterator = this;
+        // longArray will not be used anymore if destructive is true, release it now.
+        if (longArray != null) {
+          freeArray(longArray);
+          longArray = null;
+        }
       }
     }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 50d8e3024598dd8c713ac7b9ec4089b183587af6..d194f58cd1cdd27dc7d0309047528e160443903f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -127,9 +127,10 @@ class UnsafeFixedWidthAggregationMapSuite
       PAGE_SIZE_BYTES
     )
     val groupKey = InternalRow(UTF8String.fromString("cats"))
+    val row = map.getAggregationBuffer(groupKey)
 
     // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
-    assert(map.getAggregationBuffer(groupKey) != null)
+    assert(row != null)
     val iter = map.iterator()
     assert(iter.next())
     iter.getKey.getString(0) should be ("cats")
@@ -138,7 +139,7 @@ class UnsafeFixedWidthAggregationMapSuite
 
     // Modifications to rows retrieved from the map should update the values in the map
     iter.getValue.setInt(0, 42)
-    map.getAggregationBuffer(groupKey).getInt(0) should be (42)
+    row.getInt(0) should be (42)
 
     map.free()
   }