diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ac05dd3d0ef9e7c2cef2d4dd4eeb85b82701de52..c459fe587859e027f492c56f0116bebd8dbc57fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -252,6 +252,8 @@ object FunctionRegistry { expression[VarianceSamp]("variance"), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), + expression[CollectList]("collect_list"), + expression[CollectSet]("collect_set"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala new file mode 100644 index 0000000000000000000000000000000000000000..1f4ff9c4b184ead52f447db53087d6841a4236bd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import scala.collection.generic.Growable +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** + * The Collect aggregate function collects all seen expression values into a list of values. + * + * The operator is bound to the slower sort based aggregation path because the number of + * elements (and their memory usage) can not be determined in advance. This also means that the + * collected elements are stored on heap, and that too many elements can cause GC pauses and + * eventually Out of Memory Errors. + */ +abstract class Collect extends ImperativeAggregate { + + val child: Expression + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = ArrayType(child.dataType) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + override def supportsPartial: Boolean = false + + override def aggBufferAttributes: Seq[AttributeReference] = Nil + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override def inputAggBufferAttributes: Seq[AttributeReference] = Nil + + protected[this] val buffer: Growable[Any] with Iterable[Any] + + override def initialize(b: MutableRow): Unit = { + buffer.clear() + } + + override def update(b: MutableRow, input: InternalRow): Unit = { + buffer += child.eval(input) + } + + override def merge(buffer: MutableRow, input: InternalRow): Unit = { + sys.error("Collect cannot be used in partial aggregations.") + } + + override def eval(input: InternalRow): Any = { + new GenericArrayData(buffer.toArray) + } +} + +/** + * Collect a list of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.") +case class CollectList( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends Collect { + + def this(child: Expression) = this(child, 0, 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "collect_list" + + override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty +} + +/** + * Collect a list of unique elements. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Collects and returns a set of unique elements.") +case class CollectSet( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends Collect { + + def this(child: Expression) = this(child, 0, 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "collect_set" + + override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3e295c20b6d9fd413df14256db6d28876f5286b2..07f55042eeb404f94faff526a6b3c85d3676e948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -195,18 +195,14 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * - * For now this is an alias for the collect_list Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ - def collect_list(e: Column): Column = callUDF("collect_list", e) + def collect_list(e: Column): Column = withAggregateFunction { CollectList(e.expr) } /** * Aggregate function: returns a list of objects with duplicates. * - * For now this is an alias for the collect_list Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ @@ -215,18 +211,14 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * - * For now this is an alias for the collect_set Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ - def collect_set(e: Column): Column = callUDF("collect_set", e) + def collect_set(e: Column): Column = withAggregateFunction { CollectSet(e.expr) } /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * - * For now this is an alias for the collect_set Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8a99866a33c75a9d72346936dcc66687605858a2..69a990789bcfd16e670f921b0f10f393b50afabf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -431,6 +431,32 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, null, null, null)) } + test("collect functions") { + val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) + ) + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 4))) + ) + } + + test("collect functions structs") { + val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1)) + .toDF("a", "x", "y") + .select($"a", struct($"x", $"y").as("b")) + checkAnswer( + df.select(collect_list($"a"), sort_array(collect_list($"b"))), + Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(2, 2), Row(4, 1)))) + ) + checkAnswer( + df.select(collect_set($"a"), sort_array(collect_set($"b"))), + Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(4, 1)))) + ) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 75a252ccba569d0be36b278852e4718cfc6d0a90..4f8aac8c2fcdd62e855257447057b572e6850409 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -222,20 +222,4 @@ private[sql] class HiveSessionCatalog( } } } - - // Pre-load a few commonly used Hive built-in functions. - HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach { - case (functionName, clazz) => - val builder = makeFunctionBuilder(functionName, clazz) - val info = new ExpressionInfo(clazz.getCanonicalName, functionName) - createTempFunction(functionName, info, builder, ignoreIfExists = false) - } -} - -private[sql] object HiveSessionCatalog { - // This is the list of Hive's built-in functions that are commonly used and we want to - // pre-load when we create the FunctionRegistry. - val preloadedHiveBuiltinFunctions = - ("collect_set", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) :: - ("collect_list", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 57f96e725a044a4e0623e4a8b3ddbb4778381ff0..cc41c04c71e16235401c182b2ddcda19fa8c9c74 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -58,17 +58,6 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with ) } - test("collect functions") { - checkAnswer( - testData.select(collect_list($"a"), collect_list($"b")), - Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) - ) - checkAnswer( - testData.select(collect_set($"a"), collect_set($"b")), - Seq(Row(Seq(1, 2, 3), Seq(2, 4))) - ) - } - test("cube") { checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),