From c3576ffcd7910e38928f233a824dd9e037cde05f Mon Sep 17 00:00:00 2001
From: Aaron Davidson <aaron@databricks.com>
Date: Sun, 25 May 2014 18:37:44 -0700
Subject: [PATCH] [SQL] Minor: Introduce SchemaRDD#aggregate() for simple
 aggregations

```scala
rdd.aggregate(Sum('val))
```
is just shorthand for

```scala
rdd.groupBy()(Sum('val))
```

but seems be more natural than doing a groupBy with no grouping expressions when you really just want an aggregation over all rows.

Did not add a JavaSchemaRDD or Python API, as these seem to be lacking several other methods like groupBy() already -- leaving that cleanup for future patches.

Author: Aaron Davidson <aaron@databricks.com>

Closes #874 from aarondav/schemardd and squashes the following commits:

e9e68ee [Aaron Davidson] Add comment
db6afe2 [Aaron Davidson] Introduce SchemaRDD#aggregate() for simple aggregations
---
 .../scala/org/apache/spark/sql/SchemaRDD.scala | 18 ++++++++++++++++--
 .../org/apache/spark/sql/DslQuerySuite.scala   |  8 ++++++++
 2 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 9883ebc0b3..e855f36256 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -59,7 +59,7 @@ import java.util.{Map => JMap}
  *  // Importing the SQL context gives access to all the SQL functions and implicit conversions.
  *  import sqlContext._
  *
- *  val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_\$i")))
+ *  val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
  *  // Any RDD containing case classes can be registered as a table.  The schema of the table is
  *  // automatically inferred using scala reflection.
  *  rdd.registerAsTable("records")
@@ -204,6 +204,20 @@ class SchemaRDD(
     new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan))
   }
 
+  /**
+   * Performs an aggregation over all Rows in this RDD.
+   * This is equivalent to a groupBy with no grouping expressions.
+   *
+   * {{{
+   *   schemaRDD.aggregate(Sum('sales) as 'totalSales)
+   * }}}
+   *
+   * @group Query
+   */
+  def aggregate(aggregateExprs: Expression*): SchemaRDD = {
+    groupBy()(aggregateExprs: _*)
+  }
+
   /**
    * Applies a qualifier to the attributes of this relation.  Can be used to disambiguate attributes
    * with the same name, for example, when performing self-joins.
@@ -281,7 +295,7 @@ class SchemaRDD(
    * supports features such as filter pushdown.
    */
   @Experimental
-  override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0)
+  override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0)
 
   /**
    * :: Experimental ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 94ba13b14b..692569a73f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -39,6 +39,14 @@ class DslQuerySuite extends QueryTest {
       testData2.groupBy('a)('a, Sum('b)),
       Seq((1,3),(2,3),(3,3))
     )
+    checkAnswer(
+      testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
+      9
+    )
+    checkAnswer(
+      testData2.aggregate(Sum('b)),
+      9
+    )
   }
 
   test("select *") {
-- 
GitLab