Skip to content
Snippets Groups Projects
Commit f6f2eeb1 authored by Cheng Hao's avatar Cheng Hao Committed by Reynold Xin
Browse files

[SPARK-7322][SQL] Window functions in DataFrame

This closes #6104.

Author: Cheng Hao <hao.cheng@intel.com>
Author: Reynold Xin <rxin@databricks.com>

Closes #6343 from rxin/window-df and squashes the following commits:

026d587 [Reynold Xin] Address code review feedback.
dc448fe [Reynold Xin] Fixed Hive tests.
9794d9d [Reynold Xin] Moved Java test package.
9331605 [Reynold Xin] Refactored API.
3313e2a [Reynold Xin] Merge pull request #6104 from chenghao-intel/df_window
d625a64 [Cheng Hao] Update the dataframe window API as suggsted
c141fb1 [Cheng Hao] hide all of properties of the WindowFunctionDefinition
3b1865f [Cheng Hao] scaladoc typos
f3fd2d0 [Cheng Hao] polish the unit test
6847825 [Cheng Hao] Add additional analystcs functions
57e3bc0 [Cheng Hao] typos
24a08ec [Cheng Hao] scaladoc
28222ed [Cheng Hao] fix bug of range/row Frame
1d91865 [Cheng Hao] style issue
53f89f2 [Cheng Hao] remove the over from the functions.scala
964c013 [Cheng Hao] add more unit tests and window functions
64e18a7 [Cheng Hao] Add Window Function support for DataFrame
parent 2728c3df
No related branches found
No related tags found
No related merge requests found
Showing
with 807 additions and 7 deletions
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
package org.apache.spark.sql package org.apache.spark.sql
import scala.language.implicitConversions import scala.language.implicitConversions
import scala.collection.JavaConversions._
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
...@@ -889,6 +889,22 @@ class Column(protected[sql] val expr: Expression) extends Logging { ...@@ -889,6 +889,22 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*/ */
def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr)
/**
* Define a windowing column.
*
* {{{
* val w = Window.partitionBy("name").orderBy("id")
* df.select(
* sum("price").over(w.rangeBetween(Long.MinValue, 2)),
* avg("price").over(w.rowsBetween(0, 4))
* )
* }}}
*
* @group expr_ops
* @since 1.4.0
*/
def over(window: expressions.WindowSpec): Column = window.withAggregate(this)
} }
......
...@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, Unresol ...@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, Unresol
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.json.JacksonGenerator import org.apache.spark.sql.json.JacksonGenerator
import org.apache.spark.sql.sources.CreateTableUsingAsSelect import org.apache.spark.sql.sources.CreateTableUsingAsSelect
...@@ -411,7 +411,7 @@ class DataFrame private[sql]( ...@@ -411,7 +411,7 @@ class DataFrame private[sql](
joined.left, joined.left,
joined.right, joined.right,
joinType = Inner, joinType = Inner,
Some(expressions.EqualTo( Some(catalyst.expressions.EqualTo(
joined.left.resolve(usingColumn), joined.left.resolve(usingColumn),
joined.right.resolve(usingColumn)))) joined.right.resolve(usingColumn))))
) )
...@@ -480,8 +480,9 @@ class DataFrame private[sql]( ...@@ -480,8 +480,9 @@ class DataFrame private[sql](
// By the time we get here, since we have already run analysis, all attributes should've been // By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference. // resolved and become AttributeReference.
val cond = plan.condition.map { _.transform { val cond = plan.condition.map { _.transform {
case expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference)
expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) if a.sameRef(b) =>
catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name))
}} }}
plan.copy(condition = cond) plan.copy(condition = cond)
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.expressions
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions._
/**
* :: Experimental ::
* Utility functions for defining window in DataFrames.
*
* {{{
* // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
* Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0)
*
* // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING
* Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3)
* }}}
*
* @since 1.4.0
*/
@Experimental
object Window {
/**
* Creates a [[WindowSpec]] with the partitioning defined.
* @since 1.4.0
*/
@scala.annotation.varargs
def partitionBy(colName: String, colNames: String*): WindowSpec = {
spec.partitionBy(colName, colNames : _*)
}
/**
* Creates a [[WindowSpec]] with the partitioning defined.
* @since 1.4.0
*/
@scala.annotation.varargs
def partitionBy(cols: Column*): WindowSpec = {
spec.partitionBy(cols : _*)
}
/**
* Creates a [[WindowSpec]] with the ordering defined.
* @since 1.4.0
*/
@scala.annotation.varargs
def orderBy(colName: String, colNames: String*): WindowSpec = {
spec.orderBy(colName, colNames : _*)
}
/**
* Creates a [[WindowSpec]] with the ordering defined.
* @since 1.4.0
*/
@scala.annotation.varargs
def orderBy(cols: Column*): WindowSpec = {
spec.orderBy(cols : _*)
}
private def spec: WindowSpec = {
new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.expressions
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.{Column, catalyst}
import org.apache.spark.sql.catalyst.expressions._
/**
* :: Experimental ::
* A window specification that defines the partitioning, ordering, and frame boundaries.
*
* Use the static methods in [[Window]] to create a [[WindowSpec]].
*
* @since 1.4.0
*/
@Experimental
class WindowSpec private[sql](
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
frame: catalyst.expressions.WindowFrame) {
/**
* Defines the partitioning columns in a [[WindowSpec]].
* @since 1.4.0
*/
@scala.annotation.varargs
def partitionBy(colName: String, colNames: String*): WindowSpec = {
partitionBy((colName +: colNames).map(Column(_)): _*)
}
/**
* Defines the partitioning columns in a [[WindowSpec]].
* @since 1.4.0
*/
@scala.annotation.varargs
def partitionBy(cols: Column*): WindowSpec = {
new WindowSpec(cols.map(_.expr), orderSpec, frame)
}
/**
* Defines the ordering columns in a [[WindowSpec]].
* @since 1.4.0
*/
@scala.annotation.varargs
def orderBy(colName: String, colNames: String*): WindowSpec = {
orderBy((colName +: colNames).map(Column(_)): _*)
}
/**
* Defines the ordering columns in a [[WindowSpec]].
* @since 1.4.0
*/
@scala.annotation.varargs
def orderBy(cols: Column*): WindowSpec = {
val sortOrder: Seq[SortOrder] = cols.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
new WindowSpec(partitionSpec, sortOrder, frame)
}
/**
* Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
*
* Both `start` and `end` are relative positions from the current row. For example, "0" means
* "current row", while "-1" means the row before the current row, and "5" means the fifth row
* after the current row.
*
* @param start boundary start, inclusive.
* The frame is unbounded if this is the minimum long value.
* @param end boundary end, inclusive.
* The frame is unbounded if this is the maximum long value.
* @since 1.4.0
*/
def rowsBetween(start: Long, end: Long): WindowSpec = {
between(RowFrame, start, end)
}
/**
* Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
*
* Both `start` and `end` are relative from the current row. For example, "0" means "current row",
* while "-1" means one off before the current row, and "5" means the five off after the
* current row.
*
* @param start boundary start, inclusive.
* The frame is unbounded if this is the minimum long value.
* @param end boundary end, inclusive.
* The frame is unbounded if this is the maximum long value.
* @since 1.4.0
*/
def rangeBetween(start: Long, end: Long): WindowSpec = {
between(RangeFrame, start, end)
}
private def between(typ: FrameType, start: Long, end: Long): WindowSpec = {
val boundaryStart = start match {
case 0 => CurrentRow
case Long.MinValue => UnboundedPreceding
case x if x < 0 => ValuePreceding(-start.toInt)
case x if x > 0 => ValueFollowing(start.toInt)
}
val boundaryEnd = end match {
case 0 => CurrentRow
case Long.MaxValue => UnboundedFollowing
case x if x < 0 => ValuePreceding(-end.toInt)
case x if x > 0 => ValueFollowing(end.toInt)
}
new WindowSpec(
partitionSpec,
orderSpec,
SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd))
}
/**
* Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression.
*/
private[sql] def withAggregate(aggregate: Column): Column = {
val windowExpr = aggregate.expr match {
case Average(child) => WindowExpression(
UnresolvedWindowFunction("avg", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case Sum(child) => WindowExpression(
UnresolvedWindowFunction("sum", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case Count(child) => WindowExpression(
UnresolvedWindowFunction("count", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case First(child) => WindowExpression(
// TODO this is a hack for Hive UDAF first_value
UnresolvedWindowFunction("first_value", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case Last(child) => WindowExpression(
// TODO this is a hack for Hive UDAF last_value
UnresolvedWindowFunction("last_value", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case Min(child) => WindowExpression(
UnresolvedWindowFunction("min", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case Max(child) => WindowExpression(
UnresolvedWindowFunction("max", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case wf: WindowFunction => WindowExpression(
wf,
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case x =>
throw new UnsupportedOperationException(s"$x is not supported in window operation.")
}
new Column(windowExpr)
}
}
...@@ -37,6 +37,7 @@ import org.apache.spark.util.Utils ...@@ -37,6 +37,7 @@ import org.apache.spark.util.Utils
* @groupname sort_funcs Sorting functions * @groupname sort_funcs Sorting functions
* @groupname normal_funcs Non-aggregate functions * @groupname normal_funcs Non-aggregate functions
* @groupname math_funcs Math functions * @groupname math_funcs Math functions
* @groupname window_funcs Window functions
* @groupname Ungrouped Support functions for DataFrames. * @groupname Ungrouped Support functions for DataFrames.
* @since 1.3.0 * @since 1.3.0
*/ */
...@@ -320,6 +321,233 @@ object functions { ...@@ -320,6 +321,233 @@ object functions {
*/ */
def max(columnName: String): Column = max(Column(columnName)) def max(columnName: String): Column = max(Column(columnName))
//////////////////////////////////////////////////////////////////////////////////////////////
// Window functions
//////////////////////////////////////////////////////////////////////////////////////////////
/**
* Window function: returns the lag value of current row of the expression,
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(columnName: String): Column = {
lag(columnName, 1)
}
/**
* Window function: returns the lag value of current row of the column,
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(e: Column): Column = {
lag(e, 1)
}
/**
* Window function: returns the lag values of current row of the expression,
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(e: Column, count: Int): Column = {
lag(e, count, null)
}
/**
* Window function: returns the lag values of current row of the column,
* null when the current row extends before the beginning of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(columnName: String, count: Int): Column = {
lag(columnName, count, null)
}
/**
* Window function: returns the lag values of current row of the column,
* given default value when the current row extends before the beginning
* of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(columnName: String, count: Int, defaultValue: Any): Column = {
lag(Column(columnName), count, defaultValue)
}
/**
* Window function: returns the lag values of current row of the expression,
* given default value when the current row extends before the beginning
* of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lag(e: Column, count: Int, defaultValue: Any): Column = {
UnresolvedWindowFunction("lag", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil)
}
/**
* Window function: returns the lead value of current row of the column,
* null when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(columnName: String): Column = {
lead(columnName, 1)
}
/**
* Window function: returns the lead value of current row of the expression,
* null when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(e: Column): Column = {
lead(e, 1)
}
/**
* Window function: returns the lead values of current row of the column,
* null when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(columnName: String, count: Int): Column = {
lead(columnName, count, null)
}
/**
* Window function: returns the lead values of current row of the expression,
* null when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(e: Column, count: Int): Column = {
lead(e, count, null)
}
/**
* Window function: returns the lead values of current row of the column,
* given default value when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(columnName: String, count: Int, defaultValue: Any): Column = {
lead(Column(columnName), count, defaultValue)
}
/**
* Window function: returns the lead values of current row of the expression,
* given default value when the current row extends before the end of the window.
*
* @group window_funcs
* @since 1.4.0
*/
def lead(e: Column, count: Int, defaultValue: Any): Column = {
UnresolvedWindowFunction("lead", e.expr :: Literal(count) :: Literal(defaultValue) :: Nil)
}
/**
* NTILE for specified expression.
* NTILE allows easy calculation of tertiles, quartiles, deciles and other
* common summary statistics. This function divides an ordered partition into a specified
* number of groups called buckets and assigns a bucket number to each row in the partition.
*
* @group window_funcs
* @since 1.4.0
*/
def ntile(e: Column): Column = {
UnresolvedWindowFunction("ntile", e.expr :: Nil)
}
/**
* NTILE for specified column.
* NTILE allows easy calculation of tertiles, quartiles, deciles and other
* common summary statistics. This function divides an ordered partition into a specified
* number of groups called buckets and assigns a bucket number to each row in the partition.
*
* @group window_funcs
* @since 1.4.0
*/
def ntile(columnName: String): Column = {
ntile(Column(columnName))
}
/**
* Assigns a unique number (sequentially, starting from 1, as defined by ORDER BY) to each
* row within the partition.
*
* @group window_funcs
* @since 1.4.0
*/
def rowNumber(): Column = {
UnresolvedWindowFunction("row_number", Nil)
}
/**
* The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking
* sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK
* and had three people tie for second place, you would say that all three were in second
* place and that the next person came in third.
*
* @group window_funcs
* @since 1.4.0
*/
def denseRank(): Column = {
UnresolvedWindowFunction("dense_rank", Nil)
}
/**
* The difference between RANK and DENSE_RANK is that DENSE_RANK leaves no gaps in ranking
* sequence when there are ties. That is, if you were ranking a competition using DENSE_RANK
* and had three people tie for second place, you would say that all three were in second
* place and that the next person came in third.
*
* @group window_funcs
* @since 1.4.0
*/
def rank(): Column = {
UnresolvedWindowFunction("rank", Nil)
}
/**
* CUME_DIST (defined as the inverse of percentile in some statistical books) computes
* the position of a specified value relative to a set of values.
* To compute the CUME_DIST of a value x in a set S of size N, you use the formula:
* CUME_DIST(x) = number of values in S coming before and including x in the specified order / N
*
* @group window_funcs
* @since 1.4.0
*/
def cumeDist(): Column = {
UnresolvedWindowFunction("cume_dist", Nil)
}
/**
* PERCENT_RANK is similar to CUME_DIST, but it uses rank values rather than row counts
* in its numerator.
* The formula:
* (rank of row in its partition - 1) / (number of rows in the partition - 1)
*
* @group window_funcs
* @since 1.4.0
*/
def percentRank(): Column = {
UnresolvedWindowFunction("percent_rank", Nil)
}
////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////
// Non-aggregate functions // Non-aggregate functions
////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql.hive;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.hive.test.TestHive$;
public class JavaDataFrameSuite {
private transient JavaSparkContext sc;
private transient HiveContext hc;
DataFrame df;
private void checkAnswer(DataFrame actual, List<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
if (errorMessage != null) {
Assert.fail(errorMessage);
}
}
@Before
public void setUp() throws IOException {
hc = TestHive$.MODULE$;
sc = new JavaSparkContext(hc.sparkContext());
List<String> jsonObjects = new ArrayList<String>(10);
for (int i = 0; i < 10; i++) {
jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}");
}
df = hc.jsonRDD(sc.parallelize(jsonObjects));
df.registerTempTable("window_table");
}
@After
public void tearDown() throws IOException {
// Clean up tables.
hc.sql("DROP TABLE IF EXISTS window_table");
}
@Test
public void saveTableAndQueryIt() {
checkAnswer(
df.select(functions.avg("key").over(
Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))),
hc.sql("SELECT avg(key) " +
"OVER (PARTITION BY value " +
" ORDER BY key " +
" ROWS BETWEEN 1 preceding and 1 following) " +
"FROM window_table").collectAsList());
}
}
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.sql.hive;
package test.org.apache.spark.sql.hive;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
...@@ -36,6 +37,7 @@ import org.apache.spark.api.java.JavaSparkContext; ...@@ -36,6 +37,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.QueryTest$;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.hive.HiveContext;
import org.apache.spark.sql.hive.test.TestHive$; import org.apache.spark.sql.hive.test.TestHive$;
import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructField;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive
import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
class HiveDataFrameWindowSuite extends QueryTest {
test("reuse window partitionBy") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
val w = Window.partitionBy("key").orderBy("value")
checkAnswer(
df.select(
lead("key").over(w),
lead("value").over(w)),
Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
}
test("reuse window orderBy") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
val w = Window.orderBy("value").partitionBy("key")
checkAnswer(
df.select(
lead("key").over(w),
lead("value").over(w)),
Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
}
test("lead") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
lead("value").over(Window.partitionBy($"key").orderBy($"value"))),
sql(
"""SELECT
| lead(value) OVER (PARTITION BY key ORDER BY value)
| FROM window_table""".stripMargin).collect())
}
test("lag") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
lag("value").over(
Window.partitionBy($"key")
.orderBy($"value"))),
sql(
"""SELECT
| lag(value) OVER (PARTITION BY key ORDER BY value)
| FROM window_table""".stripMargin).collect())
}
test("lead with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))),
sql(
"""SELECT
| lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value)
| FROM window_table""".stripMargin).collect())
}
test("lag with default value") {
val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"),
(2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))),
sql(
"""SELECT
| lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value)
| FROM window_table""".stripMargin).collect())
}
test("rank functions in unspecific window") {
val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
max("key").over(Window.partitionBy("value").orderBy("key")),
min("key").over(Window.partitionBy("value").orderBy("key")),
mean("key").over(Window.partitionBy("value").orderBy("key")),
count("key").over(Window.partitionBy("value").orderBy("key")),
sum("key").over(Window.partitionBy("value").orderBy("key")),
ntile("key").over(Window.partitionBy("value").orderBy("key")),
ntile($"key").over(Window.partitionBy("value").orderBy("key")),
rowNumber().over(Window.partitionBy("value").orderBy("key")),
denseRank().over(Window.partitionBy("value").orderBy("key")),
rank().over(Window.partitionBy("value").orderBy("key")),
cumeDist().over(Window.partitionBy("value").orderBy("key")),
percentRank().over(Window.partitionBy("value").orderBy("key"))),
sql(
s"""SELECT
|key,
|max(key) over (partition by value order by key),
|min(key) over (partition by value order by key),
|avg(key) over (partition by value order by key),
|count(key) over (partition by value order by key),
|sum(key) over (partition by value order by key),
|ntile(key) over (partition by value order by key),
|ntile(key) over (partition by value order by key),
|row_number() over (partition by value order by key),
|dense_rank() over (partition by value order by key),
|rank() over (partition by value order by key),
|cume_dist() over (partition by value order by key),
|percent_rank() over (partition by value order by key)
|FROM window_table""".stripMargin).collect())
}
test("aggregation and rows between") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))),
sql(
"""SELECT
| avg(key) OVER
| (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following)
| FROM window_table""".stripMargin).collect())
}
test("aggregation and range betweens") {
val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))),
sql(
"""SELECT
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following)
| FROM window_table""".stripMargin).collect())
}
test("aggregation and rows betweens with unbounded") {
val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
last("value").over(
Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)),
last("value").over(
Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)),
last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))),
sql(
"""SELECT
| key,
| last_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS between current row and unbounded following),
| last_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row),
| last_value(value) OVER
| (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following)
| FROM window_table""".stripMargin).collect())
}
test("aggregation and range betweens with unbounded") {
val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
df.registerTempTable("window_table")
checkAnswer(
df.select(
$"key",
last("value").over(
Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue))
.equalTo("2")
.as("last_v"),
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1))
.as("avg_key1"),
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue))
.as("avg_key2"),
avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0))
.as("avg_key3")
),
sql(
"""SELECT
| key,
| last_value(value) OVER
| (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2",
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following),
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following),
| avg(key) OVER
| (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row)
| FROM window_table""".stripMargin).collect())
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment