From e4c11eef2f64df0b6a432f40b669486d91ca6352 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN <ueshin@happy-camper.st> Date: Thu, 5 Jun 2014 12:00:31 -0700 Subject: [PATCH] [SPARK-2036] [SQL] CaseConversionExpression should check if the evaluated value is null. `CaseConversionExpression` should check if the evaluated value is `null`. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #982 from ueshin/issues/SPARK-2036 and squashes the following commits: 61e1c54 [Takuya UESHIN] Add check if the evaluated value is null. --- .../catalyst/expressions/stringOperations.scala | 8 ++++++-- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 14 ++++++++++++++ .../test/scala/org/apache/spark/sql/TestData.scala | 8 ++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index dcded07741..4203034084 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -81,8 +81,12 @@ trait CaseConversionExpression { def dataType: DataType = StringType override def eval(input: Row): Any = { - val converted = child.eval(input) - convert(converted.toString) + val evaluated = child.eval(input) + if (evaluated == null) { + null + } else { + convert(evaluated.toString) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 95860e6683..e2ad3915d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -322,6 +322,13 @@ class SQLQuerySuite extends QueryTest { (2, "B"), (3, "C"), (4, "D"))) + + checkAnswer( + sql("SELECT n, UPPER(s) FROM nullStrings"), + Seq( + (1, "ABC"), + (2, "ABC"), + (3, null))) } test("system function lower()") { @@ -334,6 +341,13 @@ class SQLQuerySuite extends QueryTest { (4, "d"), (5, "e"), (6, "f"))) + + checkAnswer( + sql("SELECT n, LOWER(s) FROM nullStrings"), + Seq( + (1, "abc"), + (2, "abc"), + (3, null))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 944f520e43..876bd1636a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -106,4 +106,12 @@ object TestData { NullInts(null) :: Nil ) nullInts.registerAsTable("nullInts") + + case class NullStrings(n: Int, s: String) + val nullStrings = + TestSQLContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil) + nullStrings.registerAsTable("nullStrings") } -- GitLab