diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index dcded0774180eb806674aa4fd848c00a7abb59a6..420303408451f4e8dda5c1f6060888b6b1d4d285 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -81,8 +81,12 @@ trait CaseConversionExpression { def dataType: DataType = StringType override def eval(input: Row): Any = { - val converted = child.eval(input) - convert(converted.toString) + val evaluated = child.eval(input) + if (evaluated == null) { + null + } else { + convert(evaluated.toString) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 95860e6683f67348ea583b3b32e59d30080dcdce..e2ad3915d31343a328f7726918e18c63dbcc3b89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -322,6 +322,13 @@ class SQLQuerySuite extends QueryTest { (2, "B"), (3, "C"), (4, "D"))) + + checkAnswer( + sql("SELECT n, UPPER(s) FROM nullStrings"), + Seq( + (1, "ABC"), + (2, "ABC"), + (3, null))) } test("system function lower()") { @@ -334,6 +341,13 @@ class SQLQuerySuite extends QueryTest { (4, "d"), (5, "e"), (6, "f"))) + + checkAnswer( + sql("SELECT n, LOWER(s) FROM nullStrings"), + Seq( + (1, "abc"), + (2, "abc"), + (3, null))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 944f520e43515ea47d8a6c79250fa2d5c47f73e7..876bd1636aab3ac63955e1b5da51efd404af6a51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -106,4 +106,12 @@ object TestData { NullInts(null) :: Nil ) nullInts.registerAsTable("nullInts") + + case class NullStrings(n: Int, s: String) + val nullStrings = + TestSQLContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil) + nullStrings.registerAsTable("nullStrings") }