diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index e10ab9790d7671c2d0fd51a55d9dbdffad4cd064..d5ac01500b15156a8fd7d0de69448b6b26cca790 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -23,6 +23,7 @@ private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean def orderByOrdinal: Boolean + def groupByOrdinal: Boolean /** * Returns the [[Resolver]] for the current configuration, which can be used to determin if two @@ -48,11 +49,16 @@ object EmptyConf extends CatalystConf { override def orderByOrdinal: Boolean = { throw new UnsupportedOperationException } + override def groupByOrdinal: Boolean = { + throw new UnsupportedOperationException + } } /** A CatalystConf that can be used for local testing. */ case class SimpleCatalystConf( caseSensitiveAnalysis: Boolean, - orderByOrdinal: Boolean = true) + orderByOrdinal: Boolean = true, + groupByOrdinal: Boolean = true) + extends CatalystConf { } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 07b0f5ee705b1c7bdf4ade93f64f02adb3457178..d0a31e7620bb01876a2b2ddd60647ac7e682e04c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -85,6 +85,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveUpCast :: + ResolveOrdinalInOrderByAndGroupBy :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -385,7 +386,13 @@ class Analyzer( p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) + if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) { + failAnalysis( + "Group by position: star is not allowed to use in the select list " + + "when using ordinals in group by") + } else { + a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) + } // If the script transformation input contains Stars, expand it. case t: ScriptTransformation if containsStar(t.input) => t.copy( @@ -634,21 +641,23 @@ class Analyzer( } } - /** - * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT - * clause. This rule detects such queries and adds the required attributes to the original - * projection, so that they will be available during sorting. Another projection is added to - * remove these attributes after sorting. - * - * This rule also resolves the position number in sort references. This support is introduced - * in Spark 2.0. Before Spark 2.0, the integers in Order By has no effect on output sorting. - * - When the sort references are not integer but foldable expressions, ignore them. - * - When spark.sql.orderByOrdinal is set to false, ignore the position numbers too. - */ - object ResolveSortReferences extends Rule[LogicalPlan] { + /** + * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by + * clauses. This rule is to convert ordinal positions to the corresponding expressions in the + * select list. This support is introduced in Spark 2.0. + * + * - When the sort references or group by expressions are not integer but foldable expressions, + * just ignore them. + * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position + * numbers too. + * + * Before the release of Spark 2.0, the literals in order/sort by and group by clauses + * have no effect on the results. + */ + object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case s: Sort if !s.child.resolved => s - // Replace the index with the related attribute for ORDER BY + case p if !p.childrenResolved => p + // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. case s @ Sort(orders, global, child) if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) => @@ -665,10 +674,41 @@ class Analyzer( } Sort(newOrders, global, child) + // Replace the index with the corresponding expression in aggregateExpressions. The index is + // a 1-base position of aggregateExpressions, which is output columns (select expression) + case a @ Aggregate(groups, aggs, child) + if conf.groupByOrdinal && aggs.forall(_.resolved) && + groups.exists(IntegerIndex.unapply(_).nonEmpty) => + val newGroups = groups.map { + case IntegerIndex(index) if index > 0 && index <= aggs.size => + aggs(index - 1) match { + case e if ResolveAggregateFunctions.containsAggregate(e) => + throw new UnresolvedException(a, + s"Group by position: the '$index'th column in the select contains an " + + s"aggregate function: ${e.sql}. Aggregate functions are not allowed in GROUP BY") + case o => o + } + case IntegerIndex(index) => + throw new UnresolvedException(a, + s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.") + case o => o + } + Aggregate(newGroups, aggs, child) + } + } + + /** + * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT + * clause. This rule detects such queries and adds the required attributes to the original + * projection, so that they will be available during sorting. Another projection is added to + * remove these attributes after sorting. + */ + object ResolveSortReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved => + case s @ Sort(order, _, child) if !s.resolved && child.resolved => try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index ada842477116cc291433f71143a24414be078595..9c927077d0ec8837263d728fb578a5c0f3eeb817 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -210,7 +210,8 @@ object Unions { object IntegerIndex { def unapply(a: Any): Option[Int] = a match { case Literal(a: Int, IntegerType) => Some(a) - // When resolving ordinal in Sort, negative values are extracted for issuing error messages. + // When resolving ordinal in Sort and Group By, negative values are extracted + // for issuing error messages. case UnaryMinus(IntegerLiteral(v)) => Some(-v) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 863a876afe9c3ddaf2455bd9d0a000e91f632c1a..77af0e000b5ce53fb0b5fbea865b96043321517f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -445,6 +445,11 @@ object SQLConf { doc = "When true, the ordinal numbers are treated as the position in the select list. " + "When false, the ordinal numbers in order/sort By clause are ignored.") + val GROUP_BY_ORDINAL = booleanConf("spark.sql.groupByOrdinal", + defaultValue = Some(true), + doc = "When true, the ordinal numbers in group by clauses are treated as the position " + + "in the select list. When false, the ordinal numbers are ignored.") + // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. // @@ -668,6 +673,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) + override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index eb486a135f00a80dda9d00359637ea5453fc6461..61358fda7665d2f3980c9ef7348d9db740969ffc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -459,25 +460,103 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } - test("literal in agg grouping expressions") { + test("Group By Ordinal - basic") { checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"), + sql("SELECT a, sum(b) FROM testData2 GROUP BY a")) + // duplicate group-by columns checkAnswer( sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + } + + test("Group By Ordinal - non aggregate expressions") { + checkAnswer( + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) + + checkAnswer( + sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) + } + + test("Group By Ordinal - non-foldable constant expression") { + checkAnswer( + sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"), + sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) + checkAnswer( sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + } + + test("Group By Ordinal - alias") { + checkAnswer( + sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"), + sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) + + checkAnswer( + sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) + } + + test("Group By Ordinal - constants") { checkAnswer( sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), sql("SELECT 1, 2, sum(b) FROM testData2")) } + test("Group By Ordinal - negative cases") { + intercept[UnresolvedException[Aggregate]] { + sql("SELECT a, b FROM testData2 GROUP BY -1") + } + + intercept[UnresolvedException[Aggregate]] { + sql("SELECT a, b FROM testData2 GROUP BY 3") + } + + var e = intercept[UnresolvedException[Aggregate]]( + sql("SELECT SUM(a) FROM testData2 GROUP BY 1")) + assert(e.getMessage contains + "Invalid call to Group by position: the '1'th column in the select contains " + + "an aggregate function") + + e = intercept[UnresolvedException[Aggregate]]( + sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1")) + assert(e.getMessage contains + "Invalid call to Group by position: the '1'th column in the select contains " + + "an aggregate function") + + var ae = intercept[AnalysisException]( + sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2")) + assert(ae.getMessage contains + "nondeterministic expression rand(0) should not appear in grouping expression") + + ae = intercept[AnalysisException]( + sql("SELECT * FROM testData2 GROUP BY a, b, 1")) + assert(ae.getMessage contains + "Group by position: star is not allowed to use in the select list " + + "when using ordinals in group by") + } + + test("Group By Ordinal: spark.sql.groupByOrdinal=false") { + withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") { + // If spark.sql.groupByOrdinal=false, ignore the position number. + intercept[AnalysisException] { + sql("SELECT a, sum(b) FROM testData2 GROUP BY 1") + } + // '*' is not allowed to use in the select list when users specify ordinals in group by + checkAnswer( + sql("SELECT * FROM testData2 GROUP BY a, b, 1"), + sql("SELECT * FROM testData2 GROUP BY a, b")) + } + } + test("aggregates with nulls") { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + @@ -2174,7 +2253,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer( sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"), sql("SELECT * FROM testData2 ORDER BY b ASC")) - checkAnswer( sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"))