diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 6e47f9d6111999b1c2d27b77b7453a884745a2b9..eef2c4e843f35bb92d512aefe8494f4c89a7a9c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -70,10 +70,6 @@ class ObjectAggregationIterator( generateProcessRow(newExpressions, newFunctions, newInputAttributes) } - // A safe projection used to do deep clone of input rows to prevent false sharing. - private[this] val safeProjection: Projection = - FromUnsafeProjection(outputAttributes.map(_.dataType)) - /** * Start processing input rows. */ @@ -151,12 +147,11 @@ class ObjectAggregationIterator( val groupingKey = groupingProjection.apply(null) val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey) while (inputRows.hasNext) { - val newInput = safeProjection(inputRows.next()) - processRow(buffer, newInput) + processRow(buffer, inputRows.next()) } } else { while (inputRows.hasNext && !sortBased) { - val newInput = safeProjection(inputRows.next()) + val newInput = inputRows.next() val groupingKey = groupingProjection.apply(newInput) val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey) processRow(buffer, newInput) @@ -266,9 +261,7 @@ class SortBasedAggregator( // Firstly, update the aggregation buffer with input rows. while (hasNextInput && groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) { - // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be - // overwritten when `inputIterator` steps forward, we need to do a deep copy here. - processRow(result.aggregationBuffer, inputIterator.getValue.copy()) + processRow(result.aggregationBuffer, inputIterator.getValue) hasNextInput = inputIterator.next() } @@ -277,12 +270,7 @@ class SortBasedAggregator( // be called after calling processRow. while (hasNextAggBuffer && groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) { - mergeAggregationBuffers( - result.aggregationBuffer, - // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be - // overwritten when `inputIterator` steps forward, we need to do a deep copy here. - initialAggBufferIterator.getValue.copy() - ) + mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue) hasNextAggBuffer = initialAggBufferIterator.next() }