From 2639c3ed03075d37f07042a03d93a4237366c6a5 Mon Sep 17 00:00:00 2001
From: Dongjoon Hyun <dongjoon@apache.org>
Date: Mon, 12 Jun 2017 21:18:43 -0700
Subject: [PATCH] [SPARK-19910][SQL] `stack` should not reject NULL values due
 to type mismatch

## What changes were proposed in this pull request?

Since `stack` function generates a table with nullable columns, it should allow mixed null values.

```scala
scala> sql("select stack(3, 1, 2, 3)").printSchema
root
 |-- col0: integer (nullable = true)

scala> sql("select stack(3, 1, 2, null)").printSchema
org.apache.spark.sql.AnalysisException: cannot resolve 'stack(3, 1, 2, NULL)' due to data type mismatch: Argument 1 (IntegerType) != Argument 3 (NullType); line 1 pos 7;
```

## How was this patch tested?

Pass the Jenkins with a new test case.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #17251 from dongjoon-hyun/SPARK-19910.
---
 .../sql/catalyst/analysis/TypeCoercion.scala  | 17 ++++++
 .../sql/catalyst/expressions/generators.scala | 19 +++++++
 .../catalyst/analysis/TypeCoercionSuite.scala | 57 +++++++++++++++++++
 .../spark/sql/GeneratorFunctionSuite.scala    |  4 ++
 4 files changed, 97 insertions(+)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index e1dd010d37..1f21739051 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -54,6 +54,7 @@ object TypeCoercion {
       FunctionArgumentConversion ::
       CaseWhenCoercion ::
       IfCoercion ::
+      StackCoercion ::
       Division ::
       PropagateTypes ::
       ImplicitTypeCasts ::
@@ -648,6 +649,22 @@ object TypeCoercion {
     }
   }
 
+  /**
+   * Coerces NullTypes in the Stack expression to the column types of the corresponding positions.
+   */
+  object StackCoercion extends Rule[LogicalPlan] {
+    def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+      case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows =>
+        Stack(children.zipWithIndex.map {
+          // The first child is the number of rows for stack.
+          case (e, 0) => e
+          case (Literal(null, NullType), index: Int) =>
+            Literal.create(null, s.findDataType(index))
+          case (e, _) => e
+        })
+    }
+  }
+
   /**
    * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
    * to TimeAdd/TimeSub
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index e84796f2ed..e023f0567e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -138,6 +138,13 @@ case class Stack(children: Seq[Expression]) extends Generator {
   private lazy val numRows = children.head.eval().asInstanceOf[Int]
   private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
 
+  /**
+   * Return true iff the first child exists and has a foldable IntegerType.
+   */
+  def hasFoldableNumRows: Boolean = {
+    children.nonEmpty && children.head.dataType == IntegerType && children.head.foldable
+  }
+
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.length <= 1) {
       TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.")
@@ -156,6 +163,18 @@ case class Stack(children: Seq[Expression]) extends Generator {
     }
   }
 
+  def findDataType(index: Int): DataType = {
+    // Find the first data type except NullType.
+    val firstDataIndex = ((index - 1) % numFields) + 1
+    for (i <- firstDataIndex until children.length by numFields) {
+      if (children(i).dataType != NullType) {
+        return children(i).dataType
+      }
+    }
+    // If all values of the column are NullType, use it.
+    NullType
+  }
+
   override def elementSchema: StructType =
     StructType(children.tail.take(numFields).zipWithIndex.map {
       case (e, index) => StructField(s"col$index", e.dataType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 2ac11598e6..7358f401ed 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -768,6 +768,63 @@ class TypeCoercionSuite extends PlanTest {
     )
   }
 
+  test("type coercion for Stack") {
+    val rule = TypeCoercion.StackCoercion
+
+    ruleTest(rule,
+      Stack(Seq(Literal(3), Literal(1), Literal(2), Literal(null))),
+      Stack(Seq(Literal(3), Literal(1), Literal(2), Literal.create(null, IntegerType))))
+    ruleTest(rule,
+      Stack(Seq(Literal(3), Literal(1.0), Literal(null), Literal(3.0))),
+      Stack(Seq(Literal(3), Literal(1.0), Literal.create(null, DoubleType), Literal(3.0))))
+    ruleTest(rule,
+      Stack(Seq(Literal(3), Literal(null), Literal("2"), Literal("3"))),
+      Stack(Seq(Literal(3), Literal.create(null, StringType), Literal("2"), Literal("3"))))
+    ruleTest(rule,
+      Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))),
+      Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))))
+
+    ruleTest(rule,
+      Stack(Seq(Literal(2),
+        Literal(1), Literal("2"),
+        Literal(null), Literal(null))),
+      Stack(Seq(Literal(2),
+        Literal(1), Literal("2"),
+        Literal.create(null, IntegerType), Literal.create(null, StringType))))
+
+    ruleTest(rule,
+      Stack(Seq(Literal(2),
+        Literal(1), Literal(null),
+        Literal(null), Literal("2"))),
+      Stack(Seq(Literal(2),
+        Literal(1), Literal.create(null, StringType),
+        Literal.create(null, IntegerType), Literal("2"))))
+
+    ruleTest(rule,
+      Stack(Seq(Literal(2),
+        Literal(null), Literal(1),
+        Literal("2"), Literal(null))),
+      Stack(Seq(Literal(2),
+        Literal.create(null, StringType), Literal(1),
+        Literal("2"), Literal.create(null, IntegerType))))
+
+    ruleTest(rule,
+      Stack(Seq(Literal(2),
+        Literal(null), Literal(null),
+        Literal(1), Literal("2"))),
+      Stack(Seq(Literal(2),
+        Literal.create(null, IntegerType), Literal.create(null, StringType),
+        Literal(1), Literal("2"))))
+
+    ruleTest(rule,
+      Stack(Seq(Subtract(Literal(3), Literal(1)),
+        Literal(1), Literal("2"),
+        Literal(null), Literal(null))),
+      Stack(Seq(Subtract(Literal(3), Literal(1)),
+        Literal(1), Literal("2"),
+        Literal.create(null, IntegerType), Literal.create(null, StringType))))
+  }
+
   test("BooleanEquality type cast") {
     val be = TypeCoercion.BooleanEquality
     // Use something more than a literal to avoid triggering the simplification rules.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index 539c63d3cb..6b98209fd4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -43,6 +43,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
     checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"),
       Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil)
 
+    // Null values
+    checkAnswer(df.selectExpr("stack(3, 1, 1.1, null, 2, null, 'b', null, 3.3, 'c')"),
+      Row(1, 1.1, null) :: Row(2, null, "b") :: Row(null, 3.3, "c") :: Nil)
+
     // Repeat generation at every input row
     checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"),
       Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil)
-- 
GitLab