diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2c269478ee7eff4db157b5632dd5fe62ba54fe79..9a92330f75f6f45ca6174a42004ef7ec4aa1f2a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -177,14 +177,16 @@ class Analyzer( private def assignAliases(exprs: Seq[NamedExpression]) = { exprs.zipWithIndex.map { case (expr, i) => - expr.transformUp { case u @ UnresolvedAlias(child, optionalAliasName) => + expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => child match { case ne: NamedExpression => ne case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() case e: ExtractValue => Alias(e, toPrettySQL(e))() - case e => Alias(e, optionalAliasName.getOrElse(toPrettySQL(e)))() + case e if optGenAliasFunc.isDefined => + Alias(child, optGenAliasFunc.get.apply(e))() + case e => Alias(e, toPrettySQL(e))() } } }.asInstanceOf[Seq[NamedExpression]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 1f1897dc36df212d23a2f7dd1f3dfab035d09424..e953eda7843c96974d7763665a5ef997acc84df0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -325,10 +325,13 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) * Holds the expression that has yet to be aliased. * * @param child The computation that is needs to be resolved during analysis. - * @param aliasName The name if specified to be associated with the result of computing [[child]] + * @param aliasFunc The function if specified to be called to generate an alias to associate + * with the result of computing [[child]] * */ -case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) +case class UnresolvedAlias( + child: Expression, + aliasFunc: Option[Expression => String] = None) extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 9b8334d334e4d1e19959a8f2502d3a2ab48e4531..204af719b2c59c9fe2b7245b05d1dfb900606044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -37,6 +38,14 @@ private[sql] object Column { def apply(expr: Expression): Column = new Column(expr) def unapply(col: Column): Option[Expression] = Some(col.expr) + + private[sql] def generateAlias(e: Expression): String = { + e match { + case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => + a.aggregateFunction.toString + case expr => usePrettyExpression(expr).sql + } + } } /** @@ -145,7 +154,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { case jt: JsonTuple => MultiAlias(jt, Nil) - case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql)) + case func: UnresolvedFunction => UnresolvedAlias(func, Some(Column.generateAlias)) // If we have a top level Cast, there is a chance to give it a better alias, if there is a // NamedExpression under this Cast. @@ -156,9 +165,14 @@ class Column(protected[sql] val expr: Expression) extends Logging { case other => Alias(expr, usePrettyExpression(expr).sql)() } + case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => + UnresolvedAlias(a, Some(Column.generateAlias)) + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } + + override def toString: String = usePrettyExpression(expr).sql override def equals(that: Any): Boolean = that match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 4f5bf633fab2eb7f87d67e86ed08d0b8b0c46e3f..b0e48a6553a45733aa0b2ef0a11978c685f5d1a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.NumericType @@ -73,6 +74,8 @@ class RelationalGroupedDataset protected[sql]( private[this] def alias(expr: Expression): NamedExpression = expr match { case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr + case a: AggregateExpression if (a.aggregateFunction.isInstanceOf[TypedAggregateExpression]) => + UnresolvedAlias(a, Some(Column.generateAlias)) case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index f1585ca3ff3180e48876af414a06ba34ad1a3488..ead7bd9642ecabde6aad28ee48377d8a5119dc98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -240,4 +240,16 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val df2 = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") checkAnswer(df2.agg(RowAgg.toColumn as "b").select("b"), Row(6) :: Nil) } + + test("spark-15114 shorter system generated alias names") { + val ds = Seq(1, 3, 2, 5).toDS() + assert(ds.select(typed.sum((i: Int) => i)).columns.head === "TypedSumDouble(int)") + val ds2 = ds.select(typed.sum((i: Int) => i), typed.avg((i: Int) => i)) + assert(ds2.columns.head === "TypedSumDouble(int)") + assert(ds2.columns.last === "TypedAverage(int)") + val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + assert(df.groupBy($"j").agg(RowAgg.toColumn).columns.last == + "RowAgg(org.apache.spark.sql.Row)") + assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1") + } }