From 5eea332307cbed5fc44427959f070afc16a12c02 Mon Sep 17 00:00:00 2001
From: Takeshi YAMAMURO <linguin.m.s@gmail.com>
Date: Wed, 1 Jun 2016 22:23:00 -0700
Subject: [PATCH] [SPARK-13484][SQL] Prevent illegal NULL propagation when
 filtering outer-join results

## What changes were proposed in this pull request?
This PR add a rule at the end of analyzer to correct nullable fields of attributes in a logical plan by using nullable fields of the corresponding attributes in its children logical plans (these plans generate the input rows).

This is another approach for addressing SPARK-13484 (the first approach is https://github.com/apache/spark/pull/11371).

Close #113711

Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>
Author: Yin Huai <yhuai@databricks.com>

Closes #13290 from yhuai/SPARK-13484.
---
 .../sql/catalyst/analysis/Analyzer.scala      | 37 ++++++++++++++++++-
 .../analysis/ResolveNaturalJoinSuite.scala    |  2 +-
 .../apache/spark/sql/DataFrameJoinSuite.scala | 21 +++++++++++
 3 files changed, 58 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index eb46c0e72e..02966796af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -113,6 +113,8 @@ class Analyzer(
       PullOutNondeterministic),
     Batch("UDF", Once,
       HandleNullInputsForUDF),
+    Batch("FixNullability", Once,
+      FixNullability),
     Batch("Cleanup", fixedPoint,
       CleanupAliases)
   )
@@ -1451,6 +1453,40 @@ class Analyzer(
     }
   }
 
+  /**
+   * Fixes nullability of Attributes in a resolved LogicalPlan by using the nullability of
+   * corresponding Attributes of its children output Attributes. This step is needed because
+   * users can use a resolved AttributeReference in the Dataset API and outer joins
+   * can change the nullability of an AttribtueReference. Without the fix, a nullable column's
+   * nullable field can be actually set as non-nullable, which cause illegal optimization
+   * (e.g., NULL propagation) and wrong answers.
+   * See SPARK-13484 and SPARK-13801 for the concrete queries of this case.
+   */
+  object FixNullability extends Rule[LogicalPlan] {
+
+    def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+      case p if !p.resolved => p // Skip unresolved nodes.
+      case p: LogicalPlan if p.resolved =>
+        val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap {
+          case (exprId, attributes) =>
+            // If there are multiple Attributes having the same ExprId, we need to resolve
+            // the conflict of nullable field. We do not really expect this happen.
+            val nullable = attributes.exists(_.nullable)
+            attributes.map(attr => attr.withNullability(nullable))
+        }.toSeq
+        // At here, we create an AttributeMap that only compare the exprId for the lookup
+        // operation. So, we can find the corresponding input attribute's nullability.
+        val attributeMap = AttributeMap[Attribute](childrenOutput.map(attr => attr -> attr))
+        // For an Attribute used by the current LogicalPlan, if it is from its children,
+        // we fix the nullable field by using the nullability setting of the corresponding
+        // output Attribute from the children.
+        p.transformExpressions {
+          case attr: Attribute if attributeMap.contains(attr) =>
+            attr.withNullability(attributeMap(attr).nullable)
+        }
+    }
+  }
+
   /**
    * Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and
    * aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]]
@@ -2133,4 +2169,3 @@ object TimeWindowing extends Rule[LogicalPlan] {
       }
   }
 }
-
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
index 1423a8705a..748579df41 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
@@ -100,7 +100,7 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
     val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None)
     val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None)
     val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select(
-      Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
+      Alias(Coalesce(Seq(b, b)), "b")(), a, c)
     checkAnalysis(naturalPlan, expected)
     checkAnalysis(usingPlan, expected)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 031e66b57c..4342c039ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -204,4 +204,25 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
       leftJoin2Inner,
       Row(1, 2, "1", 1, 3, "1") :: Nil)
   }
+
+  test("process outer join results using the non-nullable columns in the join input") {
+    // Filter data using a non-nullable column from a right table
+    val df1 = Seq((0, 0), (1, 0), (2, 0), (3, 0), (4, 0)).toDF("id", "count")
+    val df2 = Seq(Tuple1(0), Tuple1(1)).toDF("id").groupBy("id").count
+    checkAnswer(
+      df1.join(df2, df1("id") === df2("id"), "left_outer").filter(df2("count").isNull),
+      Row(2, 0, null, null) ::
+      Row(3, 0, null, null) ::
+      Row(4, 0, null, null) :: Nil
+    )
+
+    // Coalesce data using non-nullable columns in input tables
+    val df3 = Seq((1, 1)).toDF("a", "b")
+    val df4 = Seq((2, 2)).toDF("a", "b")
+    checkAnswer(
+      df3.join(df4, df3("a") === df4("a"), "outer")
+        .select(coalesce(df3("a"), df3("b")), coalesce(df4("a"), df4("b"))),
+      Row(1, null) :: Row(null, 2) :: Nil
+    )
+  }
 }
-- 
GitLab