From dafcb05c2ef8e09f45edfb7eabf58116c23975a0 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal <sameer@databricks.com> Date: Sun, 22 May 2016 23:32:39 -0700 Subject: [PATCH] [SPARK-15425][SQL] Disallow cross joins by default ## What changes were proposed in this pull request? In order to prevent users from inadvertently writing queries with cartesian joins, this patch introduces a new conf `spark.sql.crossJoin.enabled` (set to `false` by default) that if not set, results in a `SparkException` if the query contains one or more cartesian products. ## How was this patch tested? Added a test to verify the new behavior in `JoinSuite`. Additionally, `SQLQuerySuite` and `SQLMetricsSuite` were modified to explicitly enable cartesian products. Author: Sameer Agarwal <sameer@databricks.com> Closes #13209 from sameeragarwal/disallow-cartesian. --- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../joins/BroadcastNestedLoopJoinExec.scala | 14 +++++- .../joins/CartesianProductExec.scala | 11 +++++ .../apache/spark/sql/internal/SQLConf.scala | 9 +++- .../org/apache/spark/sql/JoinSuite.scala | 31 +++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 32 +++++++------ .../sql/execution/joins/InnerJoinSuite.scala | 3 +- .../execution/metric/SQLMetricsSuite.scala | 46 ++++++++++--------- .../execution/HiveCompatibilitySuite.scala | 4 ++ .../sql/hive/execution/HiveQuerySuite.scala | 6 +++ 10 files changed, 113 insertions(+), 46 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 555a2f4c01..c46cecc71f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -190,7 +190,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition, + withinBroadcastThreshold = false) :: Nil // --- Cases where this strategy does not apply --------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 2a250ecce6..4d43765f8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class BroadcastNestedLoopJoinExec( @@ -32,7 +34,8 @@ case class BroadcastNestedLoopJoinExec( right: SparkPlan, buildSide: BuildSide, joinType: JoinType, - condition: Option[Expression]) extends BinaryExecNode { + condition: Option[Expression], + withinBroadcastThreshold: Boolean = true) extends BinaryExecNode { override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -337,6 +340,15 @@ case class BroadcastNestedLoopJoinExec( ) } + protected override def doPrepare(): Unit = { + if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) { + throw new AnalysisException("Both sides of this join are outside the broadcasting " + + "threshold and computing it could be prohibitively expensive. To explicitly enable it, " + + s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true") + } + super.doPrepare() + } + protected override def doExecute(): RDD[InternalRow] = { val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 8d7ecc442a..88f78a7a73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter @@ -88,6 +90,15 @@ case class CartesianProductExec( override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + protected override def doPrepare(): Unit = { + if (!sqlContext.conf.crossJoinEnabled) { + throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " + + "disabled by default. To explicitly enable them, please set " + + s"${SQLConf.CROSS_JOINS_ENABLED.key} = true") + } + super.doPrepare() + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 35d67ca2d8..f3064eb6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -338,9 +338,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled") + .doc("When false, we will throw an error if a query contains a cross join") + .booleanConf + .createWithDefault(false) + val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal") .doc("When true, the ordinal numbers are treated as the position in the select list. " + - "When false, the ordinal numbers in order/sort By clause are ignored.") + "When false, the ordinal numbers in order/sort by clause are ignored.") .booleanConf .createWithDefault(true) @@ -622,6 +627,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a5d8cb19ea..5583673708 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -62,7 +62,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("join operator selection") { spark.cacheManager.clearCache() - withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[SortMergeJoinExec]), @@ -204,13 +205,27 @@ class JoinSuite extends QueryTest with SharedSQLContext { testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - test("cartisian product join") { - checkAnswer( - testData3.join(testData3), - Row(1, null, 1, null) :: - Row(1, null, 2, 2) :: - Row(2, 2, 1, null) :: - Row(2, 2, 2, 2) :: Nil) + test("cartesian product join") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + checkAnswer( + testData3.join(testData3), + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) + } + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val e = intercept[Exception] { + checkAnswer( + testData3.join(testData3), + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) + } + assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " + + "disabled by default")) + } } test("left outer join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 460e34a5ff..b1f848fdc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -104,9 +104,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ).toDF("a", "b", "c").createOrReplaceTempView("cachedData") spark.catalog.cacheTable("cachedData") - checkAnswer( - sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), - Row(0) :: Row(81) :: Nil) + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + checkAnswer( + sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), + Row(0) :: Row(81) :: Nil) + } } test("self join with aliases") { @@ -435,10 +437,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("left semi greater than predicate") { - checkAnswer( - sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq(Row(3, 1), Row(3, 2)) - ) + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), + Seq(Row(3, 1), Row(3, 2)) + ) + } } test("left semi greater than predicate and equal operator") { @@ -824,12 +828,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("cartesian product join") { - checkAnswer( - testData3.join(testData3), - Row(1, null, 1, null) :: - Row(1, null, 2, 2) :: - Row(2, 2, 1, null) :: - Row(2, 2, 2, 2) :: Nil) + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + checkAnswer( + testData3.join(testData3), + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) + } } test("left outer join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 7caeb3be54..27f6abcd95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -187,7 +187,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using CartesianProduct") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => CartesianProductExec(left, right, Some(condition())), expectedAnswer.map(Row.fromTuple), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 7a89b484eb..12940c86fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{AccumulatorContext, JsonProtocol, Utils} @@ -237,16 +238,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("BroadcastNestedLoopJoin metrics") { val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = spark.sql( - "SELECT * FROM testData2 left JOIN testDataForJoin ON " + - "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") - testSparkPlanMetrics(df, 3, Map( - 1L -> ("BroadcastNestedLoopJoin", Map( - "number of output rows" -> 12L))) - ) + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = spark.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of output rows" -> 12L))) + ) + } } } @@ -263,17 +266,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } test("CartesianProduct metrics") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.createOrReplaceTempView("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = spark.sql( - "SELECT * FROM testData2 JOIN testDataForJoin") - testSparkPlanMetrics(df, 1, Map( - 0L -> ("CartesianProduct", Map( - "number of output rows" -> 12L))) - ) + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.createOrReplaceTempView("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = spark.sql( + "SELECT * FROM testData2 JOIN testDataForJoin") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("CartesianProduct", Map("number of output rows" -> 12L))) + ) + } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 54fb440b33..a8645f7cd3 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc + private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled def testCases: Seq[(String, File)] = { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) @@ -61,6 +62,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Ensures that the plans generation use metastore relation and not OrcRelation // Was done because SqlBuilder does not work with plans having logical relation TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false) + // Ensures that cross joins are enabled so that we can test them + TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) RuleExecutor.resetTime() } @@ -72,6 +75,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc) + TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) TestHive.sessionState.functionRegistry.restore() // For debugging dump some statistics about how much time was spent in various optimizer rules diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2aaaaadb6a..e179021491 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.internal.SQLConf case class TestData(a: Int, b: String) @@ -48,6 +49,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { import org.apache.spark.sql.hive.test.TestHive.implicits._ + private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + override def beforeAll() { super.beforeAll() TestHive.setCacheTables(true) @@ -55,6 +58,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) + // Ensures that cross joins are enabled so that we can test them + TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) } override def afterAll() { @@ -63,6 +68,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2") + TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) } finally { super.afterAll() } -- GitLab