diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 63f86ad09412cf46fe89b4017f5cbaa410ea8cc9..6e6cc6962c0077636b047fccefe476fae86ddd92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -31,5 +31,6 @@ package org.apache.spark.sql.catalyst.plans.logical * * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it * defaults to the product of children's `sizeInBytes`. + * @param isBroadcastable If true, output is small enough to be used in a broadcast join. */ -private[sql] case class Statistics(sizeInBytes: BigInt, isBroadcastable: Boolean = false) +case class Statistics(sizeInBytes: BigInt, isBroadcastable: Boolean = false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index e681b88685b58336eb016aa04eeae64103e85af0..be4bf5b447565353e5f77db5104148caf4d1fabc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.joins import scala.reflect.ClassTag -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.execution.exchange.EnsureRequirements @@ -68,10 +66,11 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } - private def testBroadcastJoin[T: ClassTag](joinType: String, - forceBroadcast: Boolean = false): SparkPlan = { + private def testBroadcastJoin[T: ClassTag]( + joinType: String, + forceBroadcast: Boolean = false): SparkPlan = { val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - var df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") @@ -80,11 +79,9 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } else { df1.join(df2, joinExpression, joinType) } - val plan = - EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan) + val plan = EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) - - return plan + plan } test("unsafe broadcast hash join updates peak execution memory") {