diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 3b6200e74f1e11d42eafcea8f11bca68873bcfe7..610ace30f8a62b722276a3b7654a0f36de4ce5ab 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -258,6 +258,11 @@ public final class BytesToBytesMap extends MemoryConsumer { this.destructive = destructive; if (destructive) { destructiveIterator = this; + // longArray will not be used anymore if destructive is true, release it now. + if (longArray != null) { + freeArray(longArray); + longArray = null; + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 50d8e3024598dd8c713ac7b9ec4089b183587af6..d194f58cd1cdd27dc7d0309047528e160443903f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -127,9 +127,10 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES ) val groupKey = InternalRow(UTF8String.fromString("cats")) + val row = map.getAggregationBuffer(groupKey) // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) - assert(map.getAggregationBuffer(groupKey) != null) + assert(row != null) val iter = map.iterator() assert(iter.next()) iter.getKey.getString(0) should be ("cats") @@ -138,7 +139,7 @@ class UnsafeFixedWidthAggregationMapSuite // Modifications to rows retrieved from the map should update the values in the map iter.getValue.setInt(0, 42) - map.getAggregationBuffer(groupKey).getInt(0) should be (42) + row.getInt(0) should be (42) map.free() }