From 86664338f25f58b2f59db93b68cd57de671a4c0b Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Mon, 24 Jul 2017 10:18:28 -0700
Subject: [PATCH] [SPARK-17528][SQL][FOLLOWUP] remove unnecessary data copy in
 object hash aggregate

## What changes were proposed in this pull request?

In #18483 , we fixed the data copy bug when saving into `InternalRow`, and removed all workarounds for this bug in the aggregate code path. However, the object hash aggregate was missed, this PR fixes it.

This patch is also a requirement for #17419 , which shows that DataFrame version is slower than RDD version because of this issue.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #18712 from cloud-fan/minor.
---
 .../aggregate/ObjectAggregationIterator.scala | 20 ++++---------------
 1 file changed, 4 insertions(+), 16 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
index 6e47f9d611..eef2c4e843 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -70,10 +70,6 @@ class ObjectAggregationIterator(
     generateProcessRow(newExpressions, newFunctions, newInputAttributes)
   }
 
-  // A safe projection used to do deep clone of input rows to prevent false sharing.
-  private[this] val safeProjection: Projection =
-    FromUnsafeProjection(outputAttributes.map(_.dataType))
-
   /**
    * Start processing input rows.
    */
@@ -151,12 +147,11 @@ class ObjectAggregationIterator(
       val groupingKey = groupingProjection.apply(null)
       val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
       while (inputRows.hasNext) {
-        val newInput = safeProjection(inputRows.next())
-        processRow(buffer, newInput)
+        processRow(buffer, inputRows.next())
       }
     } else {
       while (inputRows.hasNext && !sortBased) {
-        val newInput = safeProjection(inputRows.next())
+        val newInput = inputRows.next()
         val groupingKey = groupingProjection.apply(newInput)
         val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
         processRow(buffer, newInput)
@@ -266,9 +261,7 @@ class SortBasedAggregator(
           // Firstly, update the aggregation buffer with input rows.
           while (hasNextInput &&
             groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
-            // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
-            // overwritten when `inputIterator` steps forward, we need to do a deep copy here.
-            processRow(result.aggregationBuffer, inputIterator.getValue.copy())
+            processRow(result.aggregationBuffer, inputIterator.getValue)
             hasNextInput = inputIterator.next()
           }
 
@@ -277,12 +270,7 @@ class SortBasedAggregator(
           // be called after calling processRow.
           while (hasNextAggBuffer &&
             groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
-            mergeAggregationBuffers(
-              result.aggregationBuffer,
-              // Since `inputIterator.getValue` is an `UnsafeRow` whose underlying buffer will be
-              // overwritten when `inputIterator` steps forward, we need to do a deep copy here.
-              initialAggBufferIterator.getValue.copy()
-            )
+            mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue)
             hasNextAggBuffer = initialAggBufferIterator.next()
           }
 
-- 
GitLab