diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 1bcd4e22766a9cfd9dfc20c5de8e56075295213d..79937b129aeae8018faafb9c4c00ad3ee7e16555 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -298,8 +298,8 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
   var count: Long = _
 
   override def update(input: Row): Unit = {
-    val evaluatedExpr = expr.map(_.eval(input))
-    if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
+    val evaluatedExpr = expr.eval(input)
+    if (evaluatedExpr != null) {
       count += 1L
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 692569a73ffcfd264a70e48f349bc84b9da92e76..8197e8a18d447c4b1959cd0b0ce827898254403e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -125,6 +125,11 @@ class DslQuerySuite extends QueryTest {
       Seq((1,0), (2, 1))
     )
 
+    checkAnswer(
+      testData3.groupBy('a)('a, Count('a + 'b)),
+      Seq((1,0), (2, 1))
+    )
+
     checkAnswer(
       testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
       (2, 1, 2, 2, 1) :: Nil