diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 837b8525fed55be6af0a5d2d488af2f4bf053d05..c96ed6ef4101604da4429a76f4228133dff27867 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -76,20 +76,23 @@ case class BroadcastHashJoinExec( streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() } - override def needCopyResult: Boolean = joinType match { + private def multipleOutputForOneInput: Boolean = joinType match { case _: InnerLike | LeftOuter | RightOuter => // For inner and outer joins, one row from the streamed side may produce multiple result rows, - // if the build side has duplicated keys. Then we need to copy the result rows before putting - // them in a buffer, because these result rows share one UnsafeRow instance. Note that here - // we wait for the broadcast to be finished, which is a no-op because it's already finished - // when we wait it in `doProduce`. + // if the build side has duplicated keys. Note that here we wait for the broadcast to be + // finished, which is a no-op because it's already finished when we wait it in `doProduce`. !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique // Other joins types(semi, anti, existence) can at most produce one result row for one input - // row from the streamed side, so no need to copy the result rows. + // row from the streamed side. case _ => false } + // If the streaming side needs to copy result, this join plan needs to copy too. Otherwise, + // this join plan only needs to copy result if it may output multiple rows for one input. + override def needCopyResult: Boolean = + streamedPlan.asInstanceOf[CodegenSupport].needCopyResult || multipleOutputForOneInput + override def doProduce(ctx: CodegenContext): String = { streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 226cc3028b13571de5c89e451b9db7bbba867183..771e1186e63abc114a818f8559f577fd4ffde69c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} -import org.apache.spark.sql.execution.SortExec +import org.apache.spark.sql.execution.{BinaryExecNode, SortExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -857,4 +857,29 @@ class JoinSuite extends QueryTest with SharedSQLContext { joinQueries.foreach(assertJoinOrdering) } + + test("SPARK-22445 Respect stream-side child's needCopyResult in BroadcastHashJoin") { + val df1 = Seq((2, 3), (2, 5), (2, 2), (3, 8), (2, 1)).toDF("k", "v1") + val df2 = Seq((2, 8), (3, 7), (3, 4), (1, 2)).toDF("k", "v2") + val df3 = Seq((1, 1), (3, 2), (4, 3), (5, 1)).toDF("k", "v3") + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.JOIN_REORDER_ENABLED.key -> "false") { + val df = df1.join(df2, "k").join(functions.broadcast(df3), "k") + val plan = df.queryExecution.sparkPlan + + // Check if `needCopyResult` in `BroadcastHashJoin` is correct when smj->bhj + val joins = new collection.mutable.ArrayBuffer[BinaryExecNode]() + plan.foreachUp { + case j: BroadcastHashJoinExec => joins += j + case j: SortMergeJoinExec => joins += j + case _ => + } + assert(joins.size == 2) + assert(joins(0).isInstanceOf[SortMergeJoinExec]) + assert(joins(1).isInstanceOf[BroadcastHashJoinExec]) + checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) + } + } }