Skip to content
Snippets Groups Projects
Commit c3576ffc authored by Aaron Davidson's avatar Aaron Davidson Committed by Reynold Xin
Browse files

[SQL] Minor: Introduce SchemaRDD#aggregate() for simple aggregations

```scala
rdd.aggregate(Sum('val))
```
is just shorthand for

```scala
rdd.groupBy()(Sum('val))
```

but seems be more natural than doing a groupBy with no grouping expressions when you really just want an aggregation over all rows.

Did not add a JavaSchemaRDD or Python API, as these seem to be lacking several other methods like groupBy() already -- leaving that cleanup for future patches.

Author: Aaron Davidson <aaron@databricks.com>

Closes #874 from aarondav/schemardd and squashes the following commits:

e9e68ee [Aaron Davidson] Add comment
db6afe2 [Aaron Davidson] Introduce SchemaRDD#aggregate() for simple aggregations
parent 06595296
No related branches found
No related tags found
No related merge requests found
...@@ -59,7 +59,7 @@ import java.util.{Map => JMap} ...@@ -59,7 +59,7 @@ import java.util.{Map => JMap}
* // Importing the SQL context gives access to all the SQL functions and implicit conversions. * // Importing the SQL context gives access to all the SQL functions and implicit conversions.
* import sqlContext._ * import sqlContext._
* *
* val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_\$i"))) * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
* // Any RDD containing case classes can be registered as a table. The schema of the table is * // Any RDD containing case classes can be registered as a table. The schema of the table is
* // automatically inferred using scala reflection. * // automatically inferred using scala reflection.
* rdd.registerAsTable("records") * rdd.registerAsTable("records")
...@@ -204,6 +204,20 @@ class SchemaRDD( ...@@ -204,6 +204,20 @@ class SchemaRDD(
new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan)) new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan))
} }
/**
* Performs an aggregation over all Rows in this RDD.
* This is equivalent to a groupBy with no grouping expressions.
*
* {{{
* schemaRDD.aggregate(Sum('sales) as 'totalSales)
* }}}
*
* @group Query
*/
def aggregate(aggregateExprs: Expression*): SchemaRDD = {
groupBy()(aggregateExprs: _*)
}
/** /**
* Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes
* with the same name, for example, when performing self-joins. * with the same name, for example, when performing self-joins.
...@@ -281,7 +295,7 @@ class SchemaRDD( ...@@ -281,7 +295,7 @@ class SchemaRDD(
* supports features such as filter pushdown. * supports features such as filter pushdown.
*/ */
@Experimental @Experimental
override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0) override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0)
/** /**
* :: Experimental :: * :: Experimental ::
......
...@@ -39,6 +39,14 @@ class DslQuerySuite extends QueryTest { ...@@ -39,6 +39,14 @@ class DslQuerySuite extends QueryTest {
testData2.groupBy('a)('a, Sum('b)), testData2.groupBy('a)('a, Sum('b)),
Seq((1,3),(2,3),(3,3)) Seq((1,3),(2,3),(3,3))
) )
checkAnswer(
testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
9
)
checkAnswer(
testData2.aggregate(Sum('b)),
9
)
} }
test("select *") { test("select *") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment