From d6395d86f90d1c47c5b6ad17c618b56e00b7fc85 Mon Sep 17 00:00:00 2001
From: Takuya UESHIN <ueshin@happy-camper.st>
Date: Mon, 26 May 2014 00:17:20 -0700
Subject: [PATCH] [SPARK-1914] [SQL] Simplify CountFunction not to traverse to
 evaluate all child expressions.

`CountFunction` should count up only if the child's evaluated value is not null.

Because it traverses to evaluate all child expressions, even if the child is null, it counts up if one of the all children is not null.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #861 from ueshin/issues/SPARK-1914 and squashes the following commits:

3b37315 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-1914
2afa238 [Takuya UESHIN] Simplify CountFunction not to traverse to evaluate all child expressions.
---
 .../apache/spark/sql/catalyst/expressions/aggregates.scala   | 4 ++--
 .../src/test/scala/org/apache/spark/sql/DslQuerySuite.scala  | 5 +++++
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 1bcd4e2276..79937b129a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -298,8 +298,8 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
   var count: Long = _
 
   override def update(input: Row): Unit = {
-    val evaluatedExpr = expr.map(_.eval(input))
-    if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
+    val evaluatedExpr = expr.eval(input)
+    if (evaluatedExpr != null) {
       count += 1L
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 692569a73f..8197e8a18d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -125,6 +125,11 @@ class DslQuerySuite extends QueryTest {
       Seq((1,0), (2, 1))
     )
 
+    checkAnswer(
+      testData3.groupBy('a)('a, Count('a + 'b)),
+      Seq((1,0), (2, 1))
+    )
+
     checkAnswer(
       testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
       (2, 1, 2, 2, 1) :: Nil
-- 
GitLab