From 2639c3ed03075d37f07042a03d93a4237366c6a5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun <dongjoon@apache.org> Date: Mon, 12 Jun 2017 21:18:43 -0700 Subject: [PATCH] [SPARK-19910][SQL] `stack` should not reject NULL values due to type mismatch ## What changes were proposed in this pull request? Since `stack` function generates a table with nullable columns, it should allow mixed null values. ```scala scala> sql("select stack(3, 1, 2, 3)").printSchema root |-- col0: integer (nullable = true) scala> sql("select stack(3, 1, 2, null)").printSchema org.apache.spark.sql.AnalysisException: cannot resolve 'stack(3, 1, 2, NULL)' due to data type mismatch: Argument 1 (IntegerType) != Argument 3 (NullType); line 1 pos 7; ``` ## How was this patch tested? Pass the Jenkins with a new test case. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #17251 from dongjoon-hyun/SPARK-19910. --- .../sql/catalyst/analysis/TypeCoercion.scala | 17 ++++++ .../sql/catalyst/expressions/generators.scala | 19 +++++++ .../catalyst/analysis/TypeCoercionSuite.scala | 57 +++++++++++++++++++ .../spark/sql/GeneratorFunctionSuite.scala | 4 ++ 4 files changed, 97 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e1dd010d37..1f21739051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { FunctionArgumentConversion :: CaseWhenCoercion :: IfCoercion :: + StackCoercion :: Division :: PropagateTypes :: ImplicitTypeCasts :: @@ -648,6 +649,22 @@ object TypeCoercion { } } + /** + * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. + */ + object StackCoercion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => + Stack(children.zipWithIndex.map { + // The first child is the number of rows for stack. + case (e, 0) => e + case (Literal(null, NullType), index: Int) => + Literal.create(null, s.findDataType(index)) + case (e, _) => e + }) + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e84796f2ed..e023f0567e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -138,6 +138,13 @@ case class Stack(children: Seq[Expression]) extends Generator { private lazy val numRows = children.head.eval().asInstanceOf[Int] private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt + /** + * Return true iff the first child exists and has a foldable IntegerType. + */ + def hasFoldableNumRows: Boolean = { + children.nonEmpty && children.head.dataType == IntegerType && children.head.foldable + } + override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") @@ -156,6 +163,18 @@ case class Stack(children: Seq[Expression]) extends Generator { } } + def findDataType(index: Int): DataType = { + // Find the first data type except NullType. + val firstDataIndex = ((index - 1) % numFields) + 1 + for (i <- firstDataIndex until children.length by numFields) { + if (children(i).dataType != NullType) { + return children(i).dataType + } + } + // If all values of the column are NullType, use it. + NullType + } + override def elementSchema: StructType = StructType(children.tail.take(numFields).zipWithIndex.map { case (e, index) => StructField(s"col$index", e.dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2ac11598e6..7358f401ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -768,6 +768,63 @@ class TypeCoercionSuite extends PlanTest { ) } + test("type coercion for Stack") { + val rule = TypeCoercion.StackCoercion + + ruleTest(rule, + Stack(Seq(Literal(3), Literal(1), Literal(2), Literal(null))), + Stack(Seq(Literal(3), Literal(1), Literal(2), Literal.create(null, IntegerType)))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(1.0), Literal(null), Literal(3.0))), + Stack(Seq(Literal(3), Literal(1.0), Literal.create(null, DoubleType), Literal(3.0)))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(null), Literal("2"), Literal("3"))), + Stack(Seq(Literal(3), Literal.create(null, StringType), Literal("2"), Literal("3")))) + ruleTest(rule, + Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))), + Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(1), Literal("2"), + Literal(null), Literal(null))), + Stack(Seq(Literal(2), + Literal(1), Literal("2"), + Literal.create(null, IntegerType), Literal.create(null, StringType)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(1), Literal(null), + Literal(null), Literal("2"))), + Stack(Seq(Literal(2), + Literal(1), Literal.create(null, StringType), + Literal.create(null, IntegerType), Literal("2")))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(null), Literal(1), + Literal("2"), Literal(null))), + Stack(Seq(Literal(2), + Literal.create(null, StringType), Literal(1), + Literal("2"), Literal.create(null, IntegerType)))) + + ruleTest(rule, + Stack(Seq(Literal(2), + Literal(null), Literal(null), + Literal(1), Literal("2"))), + Stack(Seq(Literal(2), + Literal.create(null, IntegerType), Literal.create(null, StringType), + Literal(1), Literal("2")))) + + ruleTest(rule, + Stack(Seq(Subtract(Literal(3), Literal(1)), + Literal(1), Literal("2"), + Literal(null), Literal(null))), + Stack(Seq(Subtract(Literal(3), Literal(1)), + Literal(1), Literal("2"), + Literal.create(null, IntegerType), Literal.create(null, StringType)))) + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 539c63d3cb..6b98209fd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -43,6 +43,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"), Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil) + // Null values + checkAnswer(df.selectExpr("stack(3, 1, 1.1, null, 2, null, 'b', null, 3.3, 'c')"), + Row(1, 1.1, null) :: Row(2, null, "b") :: Row(null, 3.3, "c") :: Nil) + // Repeat generation at every input row checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"), Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil) -- GitLab