Skip to content
Snippets Groups Projects
Commit 2639c3ed authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Xiao Li
Browse files

[SPARK-19910][SQL] `stack` should not reject NULL values due to type mismatch

## What changes were proposed in this pull request?

Since `stack` function generates a table with nullable columns, it should allow mixed null values.

```scala
scala> sql("select stack(3, 1, 2, 3)").printSchema
root
 |-- col0: integer (nullable = true)

scala> sql("select stack(3, 1, 2, null)").printSchema
org.apache.spark.sql.AnalysisException: cannot resolve 'stack(3, 1, 2, NULL)' due to data type mismatch: Argument 1 (IntegerType) != Argument 3 (NullType); line 1 pos 7;
```

## How was this patch tested?

Pass the Jenkins with a new test case.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #17251 from dongjoon-hyun/SPARK-19910.
parent fc0e6944
No related branches found
No related tags found
No related merge requests found
......@@ -54,6 +54,7 @@ object TypeCoercion {
FunctionArgumentConversion ::
CaseWhenCoercion ::
IfCoercion ::
StackCoercion ::
Division ::
PropagateTypes ::
ImplicitTypeCasts ::
......@@ -648,6 +649,22 @@ object TypeCoercion {
}
}
/**
* Coerces NullTypes in the Stack expression to the column types of the corresponding positions.
*/
object StackCoercion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows =>
Stack(children.zipWithIndex.map {
// The first child is the number of rows for stack.
case (e, 0) => e
case (Literal(null, NullType), index: Int) =>
Literal.create(null, s.findDataType(index))
case (e, _) => e
})
}
}
/**
* Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
* to TimeAdd/TimeSub
......
......@@ -138,6 +138,13 @@ case class Stack(children: Seq[Expression]) extends Generator {
private lazy val numRows = children.head.eval().asInstanceOf[Int]
private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
/**
* Return true iff the first child exists and has a foldable IntegerType.
*/
def hasFoldableNumRows: Boolean = {
children.nonEmpty && children.head.dataType == IntegerType && children.head.foldable
}
override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.")
......@@ -156,6 +163,18 @@ case class Stack(children: Seq[Expression]) extends Generator {
}
}
def findDataType(index: Int): DataType = {
// Find the first data type except NullType.
val firstDataIndex = ((index - 1) % numFields) + 1
for (i <- firstDataIndex until children.length by numFields) {
if (children(i).dataType != NullType) {
return children(i).dataType
}
}
// If all values of the column are NullType, use it.
NullType
}
override def elementSchema: StructType =
StructType(children.tail.take(numFields).zipWithIndex.map {
case (e, index) => StructField(s"col$index", e.dataType)
......
......@@ -768,6 +768,63 @@ class TypeCoercionSuite extends PlanTest {
)
}
test("type coercion for Stack") {
val rule = TypeCoercion.StackCoercion
ruleTest(rule,
Stack(Seq(Literal(3), Literal(1), Literal(2), Literal(null))),
Stack(Seq(Literal(3), Literal(1), Literal(2), Literal.create(null, IntegerType))))
ruleTest(rule,
Stack(Seq(Literal(3), Literal(1.0), Literal(null), Literal(3.0))),
Stack(Seq(Literal(3), Literal(1.0), Literal.create(null, DoubleType), Literal(3.0))))
ruleTest(rule,
Stack(Seq(Literal(3), Literal(null), Literal("2"), Literal("3"))),
Stack(Seq(Literal(3), Literal.create(null, StringType), Literal("2"), Literal("3"))))
ruleTest(rule,
Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))),
Stack(Seq(Literal(3), Literal(null), Literal(null), Literal(null))))
ruleTest(rule,
Stack(Seq(Literal(2),
Literal(1), Literal("2"),
Literal(null), Literal(null))),
Stack(Seq(Literal(2),
Literal(1), Literal("2"),
Literal.create(null, IntegerType), Literal.create(null, StringType))))
ruleTest(rule,
Stack(Seq(Literal(2),
Literal(1), Literal(null),
Literal(null), Literal("2"))),
Stack(Seq(Literal(2),
Literal(1), Literal.create(null, StringType),
Literal.create(null, IntegerType), Literal("2"))))
ruleTest(rule,
Stack(Seq(Literal(2),
Literal(null), Literal(1),
Literal("2"), Literal(null))),
Stack(Seq(Literal(2),
Literal.create(null, StringType), Literal(1),
Literal("2"), Literal.create(null, IntegerType))))
ruleTest(rule,
Stack(Seq(Literal(2),
Literal(null), Literal(null),
Literal(1), Literal("2"))),
Stack(Seq(Literal(2),
Literal.create(null, IntegerType), Literal.create(null, StringType),
Literal(1), Literal("2"))))
ruleTest(rule,
Stack(Seq(Subtract(Literal(3), Literal(1)),
Literal(1), Literal("2"),
Literal(null), Literal(null))),
Stack(Seq(Subtract(Literal(3), Literal(1)),
Literal(1), Literal("2"),
Literal.create(null, IntegerType), Literal.create(null, StringType))))
}
test("BooleanEquality type cast") {
val be = TypeCoercion.BooleanEquality
// Use something more than a literal to avoid triggering the simplification rules.
......
......@@ -43,6 +43,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"),
Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil)
// Null values
checkAnswer(df.selectExpr("stack(3, 1, 1.1, null, 2, null, 'b', null, 3.3, 'c')"),
Row(1, 1.1, null) :: Row(2, null, "b") :: Row(null, 3.3, "c") :: Nil)
// Repeat generation at every input row
checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"),
Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment