diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5f1fe99f75c9d327665ce4a97cf3351719726423..d57b6eaf40b09bd0a80aa9b9c85f7300f25dbbb2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -155,8 +155,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
   object BroadcastNestedLoopJoin extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.Join(left, right, joinType, condition) =>
+        val buildSide =
+          if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) BuildRight else BuildLeft
         execution.BroadcastNestedLoopJoin(
-          planLater(left), planLater(right), joinType, condition) :: Nil
+          planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
       case _ => Nil
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 2750ddbce896f7e5b0a1d6bd9e93236b383a82ad..b068579db75cddfcbaa8865d74ff51e6adec0d8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -314,10 +314,19 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
  */
 @DeveloperApi
 case class BroadcastNestedLoopJoin(
-    streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
-  extends BinaryNode {
+    left: SparkPlan,
+    right: SparkPlan,
+    buildSide: BuildSide,
+    joinType: JoinType,
+    condition: Option[Expression]) extends BinaryNode {
   // TODO: Override requiredChildDistribution.
 
+  /** BuildRight means the right relation <=> the broadcast relation. */
+  val (streamed, broadcast) = buildSide match {
+    case BuildRight => (left, right)
+    case BuildLeft => (right, left)
+  }
+
   override def outputPartitioning: Partitioning = streamed.outputPartitioning
 
   override def output = {
@@ -333,11 +342,6 @@ case class BroadcastNestedLoopJoin(
     }
   }
 
-  /** The Streamed Relation */
-  def left = streamed
-  /** The Broadcast relation */
-  def right = broadcast
-
   @transient lazy val boundCondition =
     InterpretedPredicate(
       condition
@@ -348,57 +352,78 @@ case class BroadcastNestedLoopJoin(
     val broadcastedRelation =
       sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
 
-    val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
+    /** All rows that either match both-way, or rows from streamed joined with nulls. */
+    val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
       val matchedRows = new ArrayBuffer[Row]
       // TODO: Use Spark's BitSet.
-      val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
+      val includedBroadcastTuples =
+        new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
       val joinedRow = new JoinedRow
+      val leftNulls = new GenericMutableRow(left.output.size)
       val rightNulls = new GenericMutableRow(right.output.size)
 
       streamedIter.foreach { streamedRow =>
         var i = 0
-        var matched = false
+        var streamRowMatched = false
 
         while (i < broadcastedRelation.value.size) {
           // TODO: One bitset per partition instead of per row.
           val broadcastedRow = broadcastedRelation.value(i)
-          if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
-            matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
-            matched = true
-            includedBroadcastTuples += i
+          buildSide match {
+            case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
+              matchedRows += joinedRow(streamedRow, broadcastedRow).copy()
+              streamRowMatched = true
+              includedBroadcastTuples += i
+            case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
+              matchedRows += joinedRow(broadcastedRow, streamedRow).copy()
+              streamRowMatched = true
+              includedBroadcastTuples += i
+            case _ =>
           }
           i += 1
         }
 
-        if (!matched && (joinType == LeftOuter || joinType == FullOuter)) {
-          matchedRows += joinedRow(streamedRow, rightNulls).copy()
+        (streamRowMatched, joinType, buildSide) match {
+          case (false, LeftOuter | FullOuter, BuildRight) =>
+            matchedRows += joinedRow(streamedRow, rightNulls).copy()
+          case (false, RightOuter | FullOuter, BuildLeft) =>
+            matchedRows += joinedRow(leftNulls, streamedRow).copy()
+          case _ =>
         }
       }
       Iterator((matchedRows, includedBroadcastTuples))
     }
 
-    val includedBroadcastTuples = streamedPlusMatches.map(_._2)
+    val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
     val allIncludedBroadcastTuples =
       if (includedBroadcastTuples.count == 0) {
         new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
       } else {
-        streamedPlusMatches.map(_._2).reduce(_ ++ _)
+        includedBroadcastTuples.reduce(_ ++ _)
       }
 
     val leftNulls = new GenericMutableRow(left.output.size)
-    val rightOuterMatches: Seq[Row] =
-      if (joinType == RightOuter || joinType == FullOuter) {
-        broadcastedRelation.value.zipWithIndex.filter {
-          case (row, i) => !allIncludedBroadcastTuples.contains(i)
-        }.map {
-          case (row, _) => new JoinedRow(leftNulls, row)
+    val rightNulls = new GenericMutableRow(right.output.size)
+    /** Rows from broadcasted joined with nulls. */
+    val broadcastRowsWithNulls: Seq[Row] = {
+      val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer()
+      var i = 0
+      val rel = broadcastedRelation.value
+      while (i < rel.length) {
+        if (!allIncludedBroadcastTuples.contains(i)) {
+          (joinType, buildSide) match {
+            case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i))
+            case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls)
+            case _ =>
+          }
         }
-      } else {
-        Vector()
+        i += 1
       }
+      arrBuf.toSeq
+    }
 
     // TODO: Breaks lineage.
     sparkContext.union(
-      streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches))
+      matchesOrStreamedRowsWithNulls.flatMap(_._1), sparkContext.makeRDD(broadcastRowsWithNulls))
   }
 }