From b1436c7496161da1f4fa6950a06de62c9007c40c Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Mon, 12 Jun 2017 14:07:51 -0700
Subject: [PATCH] [SPARK-21059][SQL] LikeSimplification can NPE on null pattern

## What changes were proposed in this pull request?
This patch fixes a bug that can cause NullPointerException in LikeSimplification, when the pattern for like is null.

## How was this patch tested?
Added a new unit test case in LikeSimplificationSuite.

Author: Reynold Xin <rxin@databricks.com>

Closes #18273 from rxin/SPARK-21059.
---
 .../sql/catalyst/optimizer/expressions.scala  | 37 +++++++++++--------
 .../optimizer/LikeSimplificationSuite.scala   |  8 +++-
 2 files changed, 28 insertions(+), 17 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 51f749a8bf..66b8ca62e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -383,22 +383,27 @@ object LikeSimplification extends Rule[LogicalPlan] {
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
     case Like(input, Literal(pattern, StringType)) =>
-      pattern.toString match {
-        case startsWith(prefix) if !prefix.endsWith("\\") =>
-          StartsWith(input, Literal(prefix))
-        case endsWith(postfix) =>
-          EndsWith(input, Literal(postfix))
-        // 'a%a' pattern is basically same with 'a%' && '%a'.
-        // However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
-        case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
-          And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)),
-            And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
-        case contains(infix) if !infix.endsWith("\\") =>
-          Contains(input, Literal(infix))
-        case equalTo(str) =>
-          EqualTo(input, Literal(str))
-        case _ =>
-          Like(input, Literal.create(pattern, StringType))
+      if (pattern == null) {
+        // If pattern is null, return null value directly, since "col like null" == null.
+        Literal(null, BooleanType)
+      } else {
+        pattern.toString match {
+          case startsWith(prefix) if !prefix.endsWith("\\") =>
+            StartsWith(input, Literal(prefix))
+          case endsWith(postfix) =>
+            EndsWith(input, Literal(postfix))
+          // 'a%a' pattern is basically same with 'a%' && '%a'.
+          // However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
+          case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
+            And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)),
+              And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
+          case contains(infix) if !infix.endsWith("\\") =>
+            Contains(input, Literal(infix))
+          case equalTo(str) =>
+            EqualTo(input, Literal(str))
+          case _ =>
+            Like(input, Literal.create(pattern, StringType))
+        }
       }
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
index fdde89d079..50398788c6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
@@ -17,13 +17,13 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-/* Implicit conversions */
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.{BooleanType, StringType}
 
 class LikeSimplificationSuite extends PlanTest {
 
@@ -100,4 +100,10 @@ class LikeSimplificationSuite extends PlanTest {
 
     comparePlans(optimized, correctAnswer)
   }
+
+  test("null pattern") {
+    val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze
+    val optimized = Optimize.execute(originalQuery)
+    comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze)
+  }
 }
-- 
GitLab