diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 837b8525fed55be6af0a5d2d488af2f4bf053d05..c96ed6ef4101604da4429a76f4228133dff27867 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -76,20 +76,23 @@ case class BroadcastHashJoinExec(
     streamedPlan.asInstanceOf[CodegenSupport].inputRDDs()
   }
 
-  override def needCopyResult: Boolean = joinType match {
+  private def multipleOutputForOneInput: Boolean = joinType match {
     case _: InnerLike | LeftOuter | RightOuter =>
       // For inner and outer joins, one row from the streamed side may produce multiple result rows,
-      // if the build side has duplicated keys. Then we need to copy the result rows before putting
-      // them in a buffer, because these result rows share one UnsafeRow instance. Note that here
-      // we wait for the broadcast to be finished, which is a no-op because it's already finished
-      // when we wait it in `doProduce`.
+      // if the build side has duplicated keys. Note that here we wait for the broadcast to be
+      // finished, which is a no-op because it's already finished when we wait it in `doProduce`.
       !buildPlan.executeBroadcast[HashedRelation]().value.keyIsUnique
 
     // Other joins types(semi, anti, existence) can at most produce one result row for one input
-    // row from the streamed side, so no need to copy the result rows.
+    // row from the streamed side.
     case _ => false
   }
 
+  // If the streaming side needs to copy result, this join plan needs to copy too. Otherwise,
+  // this join plan only needs to copy result if it may output multiple rows for one input.
+  override def needCopyResult: Boolean =
+    streamedPlan.asInstanceOf[CodegenSupport].needCopyResult || multipleOutputForOneInput
+
   override def doProduce(ctx: CodegenContext): String = {
     streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 226cc3028b13571de5c89e451b9db7bbba867183..771e1186e63abc114a818f8559f577fd4ffde69c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
-import org.apache.spark.sql.execution.SortExec
+import org.apache.spark.sql.execution.{BinaryExecNode, SortExec}
 import org.apache.spark.sql.execution.joins._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
@@ -857,4 +857,29 @@ class JoinSuite extends QueryTest with SharedSQLContext {
 
     joinQueries.foreach(assertJoinOrdering)
   }
+
+  test("SPARK-22445 Respect stream-side child's needCopyResult in BroadcastHashJoin") {
+    val df1 = Seq((2, 3), (2, 5), (2, 2), (3, 8), (2, 1)).toDF("k", "v1")
+    val df2 = Seq((2, 8), (3, 7), (3, 4), (1, 2)).toDF("k", "v2")
+    val df3 = Seq((1, 1), (3, 2), (4, 3), (5, 1)).toDF("k", "v3")
+
+    withSQLConf(
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+        SQLConf.JOIN_REORDER_ENABLED.key -> "false") {
+      val df = df1.join(df2, "k").join(functions.broadcast(df3), "k")
+      val plan = df.queryExecution.sparkPlan
+
+      // Check if `needCopyResult` in `BroadcastHashJoin` is correct when smj->bhj
+      val joins = new collection.mutable.ArrayBuffer[BinaryExecNode]()
+      plan.foreachUp {
+        case j: BroadcastHashJoinExec => joins += j
+        case j: SortMergeJoinExec => joins += j
+        case _ =>
+      }
+      assert(joins.size == 2)
+      assert(joins(0).isInstanceOf[SortMergeJoinExec])
+      assert(joins(1).isInstanceOf[BroadcastHashJoinExec])
+      checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil)
+    }
+  }
 }