From dafcb05c2ef8e09f45edfb7eabf58116c23975a0 Mon Sep 17 00:00:00 2001
From: Sameer Agarwal <sameer@databricks.com>
Date: Sun, 22 May 2016 23:32:39 -0700
Subject: [PATCH] [SPARK-15425][SQL] Disallow cross joins by default

## What changes were proposed in this pull request?

In order to prevent users from inadvertently writing queries with cartesian joins, this patch introduces a new conf `spark.sql.crossJoin.enabled` (set to `false` by default) that if not set, results in a `SparkException` if the query contains one or more cartesian products.

## How was this patch tested?

Added a test to verify the new behavior in `JoinSuite`. Additionally, `SQLQuerySuite` and `SQLMetricsSuite` were modified to explicitly enable cartesian products.

Author: Sameer Agarwal <sameer@databricks.com>

Closes #13209 from sameeragarwal/disallow-cartesian.
---
 .../spark/sql/execution/SparkStrategies.scala |  3 +-
 .../joins/BroadcastNestedLoopJoinExec.scala   | 14 +++++-
 .../joins/CartesianProductExec.scala          | 11 +++++
 .../apache/spark/sql/internal/SQLConf.scala   |  9 +++-
 .../org/apache/spark/sql/JoinSuite.scala      | 31 +++++++++----
 .../org/apache/spark/sql/SQLQuerySuite.scala  | 32 +++++++------
 .../sql/execution/joins/InnerJoinSuite.scala  |  3 +-
 .../execution/metric/SQLMetricsSuite.scala    | 46 ++++++++++---------
 .../execution/HiveCompatibilitySuite.scala    |  4 ++
 .../sql/hive/execution/HiveQuerySuite.scala   |  6 +++
 10 files changed, 113 insertions(+), 46 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 555a2f4c01..c46cecc71f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -190,7 +190,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           }
         // This join could be very slow or OOM
         joins.BroadcastNestedLoopJoinExec(
-          planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+          planLater(left), planLater(right), buildSide, joinType, condition,
+          withinBroadcastThreshold = false) :: Nil
 
       // --- Cases where this strategy does not apply ---------------------------------------------
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 2a250ecce6..4d43765f8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.collection.{BitSet, CompactBuffer}
 
 case class BroadcastNestedLoopJoinExec(
@@ -32,7 +34,8 @@ case class BroadcastNestedLoopJoinExec(
     right: SparkPlan,
     buildSide: BuildSide,
     joinType: JoinType,
-    condition: Option[Expression]) extends BinaryExecNode {
+    condition: Option[Expression],
+    withinBroadcastThreshold: Boolean = true) extends BinaryExecNode {
 
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -337,6 +340,15 @@ case class BroadcastNestedLoopJoinExec(
     )
   }
 
+  protected override def doPrepare(): Unit = {
+    if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) {
+      throw new AnalysisException("Both sides of this join are outside the broadcasting " +
+        "threshold and computing it could be prohibitively expensive. To explicitly enable it, " +
+        s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
+    }
+    super.doPrepare()
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 8d7ecc442a..88f78a7a73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark._
 import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
 import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
@@ -88,6 +90,15 @@ case class CartesianProductExec(
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
+  protected override def doPrepare(): Unit = {
+    if (!sqlContext.conf.crossJoinEnabled) {
+      throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " +
+        "disabled by default. To explicitly enable them, please set " +
+        s"${SQLConf.CROSS_JOINS_ENABLED.key} = true")
+    }
+    super.doPrepare()
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 35d67ca2d8..f3064eb6ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -338,9 +338,14 @@ object SQLConf {
     .booleanConf
     .createWithDefault(true)
 
+  val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled")
+    .doc("When false, we will throw an error if a query contains a cross join")
+    .booleanConf
+    .createWithDefault(false)
+
   val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal")
     .doc("When true, the ordinal numbers are treated as the position in the select list. " +
-         "When false, the ordinal numbers in order/sort By clause are ignored.")
+         "When false, the ordinal numbers in order/sort by clause are ignored.")
     .booleanConf
     .createWithDefault(true)
 
@@ -622,6 +627,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
 
   def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED)
 
+  def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
+
   // Do not use a value larger than 4000 as the default value of this property.
   // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
   def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index a5d8cb19ea..5583673708 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -62,7 +62,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
   test("join operator selection") {
     spark.cacheManager.clearCache()
 
-    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
+    withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0",
+      SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
       Seq(
         ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
           classOf[SortMergeJoinExec]),
@@ -204,13 +205,27 @@ class JoinSuite extends QueryTest with SharedSQLContext {
       testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
   }
 
-  test("cartisian product join") {
-    checkAnswer(
-      testData3.join(testData3),
-      Row(1, null, 1, null) ::
-        Row(1, null, 2, 2) ::
-        Row(2, 2, 1, null) ::
-        Row(2, 2, 2, 2) :: Nil)
+  test("cartesian product join") {
+    withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      checkAnswer(
+        testData3.join(testData3),
+        Row(1, null, 1, null) ::
+          Row(1, null, 2, 2) ::
+          Row(2, 2, 1, null) ::
+          Row(2, 2, 2, 2) :: Nil)
+    }
+    withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
+      val e = intercept[Exception] {
+        checkAnswer(
+          testData3.join(testData3),
+          Row(1, null, 1, null) ::
+            Row(1, null, 2, 2) ::
+            Row(2, 2, 1, null) ::
+            Row(2, 2, 2, 2) :: Nil)
+      }
+      assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " +
+        "disabled by default"))
+    }
   }
 
   test("left outer join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 460e34a5ff..b1f848fdc8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -104,9 +104,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     ).toDF("a", "b", "c").createOrReplaceTempView("cachedData")
 
     spark.catalog.cacheTable("cachedData")
-    checkAnswer(
-      sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
-      Row(0) :: Row(81) :: Nil)
+    withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      checkAnswer(
+        sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
+        Row(0) :: Row(81) :: Nil)
+    }
   }
 
   test("self join with aliases") {
@@ -435,10 +437,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
   }
 
   test("left semi greater than predicate") {
-    checkAnswer(
-      sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
-      Seq(Row(3, 1), Row(3, 2))
-    )
+    withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      checkAnswer(
+        sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
+        Seq(Row(3, 1), Row(3, 2))
+      )
+    }
   }
 
   test("left semi greater than predicate and equal operator") {
@@ -824,12 +828,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
   }
 
   test("cartesian product join") {
-    checkAnswer(
-      testData3.join(testData3),
-      Row(1, null, 1, null) ::
-      Row(1, null, 2, 2) ::
-      Row(2, 2, 1, null) ::
-      Row(2, 2, 2, 2) :: Nil)
+    withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      checkAnswer(
+        testData3.join(testData3),
+        Row(1, null, 1, null) ::
+          Row(1, null, 2, 2) ::
+          Row(2, 2, 1, null) ::
+          Row(2, 2, 2, 2) :: Nil)
+    }
   }
 
   test("left outer join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 7caeb3be54..27f6abcd95 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -187,7 +187,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
     }
 
     test(s"$testName using CartesianProduct") {
-      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1",
+        SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
         checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
           CartesianProductExec(left, right, Some(condition())),
           expectedAnswer.map(Row.fromTuple),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 7a89b484eb..12940c86fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.execution.SparkPlanInfo
 import org.apache.spark.sql.execution.ui.SparkPlanGraph
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.util.{AccumulatorContext, JsonProtocol, Utils}
 
@@ -237,16 +238,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
   test("BroadcastNestedLoopJoin metrics") {
     val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
     testDataForJoin.createOrReplaceTempView("testDataForJoin")
-    withTempTable("testDataForJoin") {
-      // Assume the execution plan is
-      // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
-      val df = spark.sql(
-        "SELECT * FROM testData2 left JOIN testDataForJoin ON " +
-          "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
-      testSparkPlanMetrics(df, 3, Map(
-        1L -> ("BroadcastNestedLoopJoin", Map(
-          "number of output rows" -> 12L)))
-      )
+    withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      withTempTable("testDataForJoin") {
+        // Assume the execution plan is
+        // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
+        val df = spark.sql(
+          "SELECT * FROM testData2 left JOIN testDataForJoin ON " +
+            "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
+        testSparkPlanMetrics(df, 3, Map(
+          1L -> ("BroadcastNestedLoopJoin", Map(
+            "number of output rows" -> 12L)))
+        )
+      }
     }
   }
 
@@ -263,17 +266,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
   }
 
   test("CartesianProduct metrics") {
-    val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
-    testDataForJoin.createOrReplaceTempView("testDataForJoin")
-    withTempTable("testDataForJoin") {
-      // Assume the execution plan is
-      // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
-      val df = spark.sql(
-        "SELECT * FROM testData2 JOIN testDataForJoin")
-      testSparkPlanMetrics(df, 1, Map(
-        0L -> ("CartesianProduct", Map(
-          "number of output rows" -> 12L)))
-      )
+    withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+      testDataForJoin.createOrReplaceTempView("testDataForJoin")
+      withTempTable("testDataForJoin") {
+        // Assume the execution plan is
+        // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
+        val df = spark.sql(
+          "SELECT * FROM testData2 JOIN testDataForJoin")
+        testSparkPlanMetrics(df, 1, Map(
+          0L -> ("CartesianProduct", Map("number of output rows" -> 12L)))
+        )
+      }
     }
   }
 
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 54fb440b33..a8645f7cd3 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
   private val originalColumnBatchSize = TestHive.conf.columnBatchSize
   private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
   private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc
+  private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
 
   def testCases: Seq[(String, File)] = {
     hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
@@ -61,6 +62,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     // Ensures that the plans generation use metastore relation and not OrcRelation
     // Was done because SqlBuilder does not work with plans having logical relation
     TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false)
+    // Ensures that cross joins are enabled so that we can test them
+    TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
     RuleExecutor.resetTime()
   }
 
@@ -72,6 +75,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
       TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
       TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
       TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc)
+      TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
       TestHive.sessionState.functionRegistry.restore()
 
       // For debugging dump some statistics about how much time was spent in various optimizer rules
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 2aaaaadb6a..e179021491 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
 import org.apache.spark.sql.hive._
 import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
 import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.internal.SQLConf
 
 case class TestData(a: Int, b: String)
 
@@ -48,6 +49,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
 
   import org.apache.spark.sql.hive.test.TestHive.implicits._
 
+  private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
+
   override def beforeAll() {
     super.beforeAll()
     TestHive.setCacheTables(true)
@@ -55,6 +58,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
     TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
     // Add Locale setting
     Locale.setDefault(Locale.US)
+    // Ensures that cross joins are enabled so that we can test them
+    TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
   }
 
   override def afterAll() {
@@ -63,6 +68,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
       TimeZone.setDefault(originalTimeZone)
       Locale.setDefault(originalLocale)
       sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2")
+      TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
     } finally {
       super.afterAll()
     }
-- 
GitLab