Skip to content
Snippets Groups Projects
Commit 2c0fe818 authored by Takeshi Yamamuro's avatar Takeshi Yamamuro Committed by Wenchen Fan
Browse files

[SPARK-22445][SQL][FOLLOW-UP] Respect stream-side child's needCopyResult in BroadcastHashJoin

## What changes were proposed in this pull request?
I found #19656 causes some bugs, for example, it changed the result set of `q6` in tpcds (I keep tracking TPCDS results daily [here](https://github.com/maropu/spark-tpcds-datagen/tree/master/reports/tests)):
- w/o pr19658
```
+-----+---+
|state|cnt|
+-----+---+
|   MA| 10|
|   AK| 10|
|   AZ| 11|
|   ME| 13|
|   VT| 14|
|   NV| 15|
|   NH| 16|
|   UT| 17|
|   NJ| 21|
|   MD| 22|
|   WY| 25|
|   NM| 26|
|   OR| 31|
|   WA| 36|
|   ND| 38|
|   ID| 39|
|   SC| 45|
|   WV| 50|
|   FL| 51|
|   OK| 53|
|   MT| 53|
|   CO| 57|
|   AR| 58|
|   NY| 58|
|   PA| 62|
|   AL| 63|
|   LA| 63|
|   SD| 70|
|   WI| 80|
| null| 81|
|   MI| 82|
|   NC| 82|
|   MS| 83|
|   CA| 84|
|   MN| 85|
|   MO| 88|
|   IL| 95|
|   IA|102|
|   TN|102|
|   IN|103|
|   KY|104|
|   NE|113|
|   OH|114|
|   VA|130|
|   KS|139|
|   GA|168|
|   TX|216|
+-----+---+
```
- w/   pr19658
```
+-----+---+
|state|cnt|
+-----+---+
|   RI| 14|
|   AK| 16|
|   FL| 20|
|   NJ| 21|
|   NM| 21|
|   NV| 22|
|   MA| 22|
|   MD| 22|
|   UT| 22|
|   AZ| 25|
|   SC| 28|
|   AL| 36|
|   MT| 36|
|   WA| 39|
|   ND| 41|
|   MI| 44|
|   AR| 45|
|   OR| 47|
|   OK| 52|
|   PA| 53|
|   LA| 55|
|   CO| 55|
|   NY| 64|
|   WV| 66|
|   SD| 72|
|   MS| 73|
|   NC| 79|
|   IN| 82|
| null| 85|
|   ID| 88|
|   MN| 91|
|   WI| 95|
|   IL| 96|
|   MO| 97|
|   CA|109|
|   CA|109|
|   TN|114|
|   NE|115|
|   KY|128|
|   OH|131|
|   IA|156|
|   TX|160|
|   VA|182|
|   KS|211|
|   GA|230|
+-----+---+
```
This pr is to keep the original logic of `CodegenContext.copyResult` in `BroadcastHashJoinExec`.

## How was this patch tested?
Existing tests

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #19781 from maropu/SPARK-22445-bugfix.
parent e0d7665c
No related branches found
No related tags found
No related merge requests found
...@@ -76,20 +76,23 @@ case class BroadcastHashJoinExec( ...@@ -76,20 +76,23 @@ case class BroadcastHashJoinExec(
streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() streamedPlan.asInstanceOf[CodegenSupport].inputRDDs()
} }
override def needCopyResult: Boolean = joinType match { private def multipleOutputForOneInput: Boolean = joinType match {
case _: InnerLike | LeftOuter | RightOuter => case _: InnerLike | LeftOuter | RightOuter =>
// For inner and outer joins, one row from the streamed side may produce multiple result rows, // For inner and outer joins, one row from the streamed side may produce multiple result rows,
// if the build side has duplicated keys. Then we need to copy the result rows before putting // if the build side has duplicated keys. Note that here we wait for the broadcast to be
// them in a buffer, because these result rows share one UnsafeRow instance. Note that here // finished, which is a no-op because it's already finished when we wait it in `doProduce`.
// we wait for the broadcast to be finished, which is a no-op because it's already finished
// when we wait it in `doProduce`.
!buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique
// Other joins types(semi, anti, existence) can at most produce one result row for one input // Other joins types(semi, anti, existence) can at most produce one result row for one input
// row from the streamed side, so no need to copy the result rows. // row from the streamed side.
case _ => false case _ => false
} }
// If the streaming side needs to copy result, this join plan needs to copy too. Otherwise,
// this join plan only needs to copy result if it may output multiple rows for one input.
override def needCopyResult: Boolean =
streamedPlan.asInstanceOf[CodegenSupport].needCopyResult || multipleOutputForOneInput
override def doProduce(ctx: CodegenContext): String = { override def doProduce(ctx: CodegenContext): String = {
streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
} }
......
...@@ -25,7 +25,7 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} ...@@ -25,7 +25,7 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
import org.apache.spark.sql.execution.SortExec import org.apache.spark.sql.execution.{BinaryExecNode, SortExec}
import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SharedSQLContext
...@@ -857,4 +857,29 @@ class JoinSuite extends QueryTest with SharedSQLContext { ...@@ -857,4 +857,29 @@ class JoinSuite extends QueryTest with SharedSQLContext {
joinQueries.foreach(assertJoinOrdering) joinQueries.foreach(assertJoinOrdering)
} }
test("SPARK-22445 Respect stream-side child's needCopyResult in BroadcastHashJoin") {
val df1 = Seq((2, 3), (2, 5), (2, 2), (3, 8), (2, 1)).toDF("k", "v1")
val df2 = Seq((2, 8), (3, 7), (3, 4), (1, 2)).toDF("k", "v2")
val df3 = Seq((1, 1), (3, 2), (4, 3), (5, 1)).toDF("k", "v3")
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.JOIN_REORDER_ENABLED.key -> "false") {
val df = df1.join(df2, "k").join(functions.broadcast(df3), "k")
val plan = df.queryExecution.sparkPlan
// Check if `needCopyResult` in `BroadcastHashJoin` is correct when smj->bhj
val joins = new collection.mutable.ArrayBuffer[BinaryExecNode]()
plan.foreachUp {
case j: BroadcastHashJoinExec => joins += j
case j: SortMergeJoinExec => joins += j
case _ =>
}
assert(joins.size == 2)
assert(joins(0).isInstanceOf[SortMergeJoinExec])
assert(joins(1).isInstanceOf[BroadcastHashJoinExec])
checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil)
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment