diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 910519d0e68144b1a38fbe1c97876ef6c7e05665..df0f73049921148575252b8cf1c27d6691c61b15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
 import org.apache.spark.sql.{execution, Strategy}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
    */
   object EquiJoinSelection extends Strategy with PredicateHelper {
 
-    private[this] def makeBroadcastHashJoin(
-        leftKeys: Seq[Expression],
-        rightKeys: Seq[Expression],
-        left: LogicalPlan,
-        right: LogicalPlan,
-        condition: Option[Expression],
-        side: joins.BuildSide): Seq[SparkPlan] = {
-      val broadcastHashJoin = execution.joins.BroadcastHashJoin(
-        leftKeys, rightKeys, side, planLater(left), planLater(right))
-      condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
-    }
-
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
 
       // --- Inner joins --------------------------------------------------------------------------
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
-        makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
+        joins.BroadcastHashJoin(
+          leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
-        makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
+        joins.BroadcastHashJoin(
+          leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
         if RowOrdering.isOrderable(leftKeys) =>
-        val mergeJoin =
-          joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
-        condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
+        joins.SortMergeJoin(
+          leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil
 
       // --- Outer joins --------------------------------------------------------------------------
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 0a818cc2c2a279f7b3738e87139b9278f852a9ba..c9ea579b5e809b51607cbac03537fc896c3555ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -39,6 +39,7 @@ case class BroadcastHashJoin(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     buildSide: BuildSide,
+    condition: Option[Expression],
     left: SparkPlan,
     right: SparkPlan)
   extends BinaryNode with HashJoin {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 7f9d9daa5ab201ce3a367a65d474c597bea38d26..8ef854001f4de063afd4038649fcbd165061451c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.joins
 
+import java.util.NoSuchElementException
+
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
@@ -29,6 +31,7 @@ trait HashJoin {
   val leftKeys: Seq[Expression]
   val rightKeys: Seq[Expression]
   val buildSide: BuildSide
+  val condition: Option[Expression]
   val left: SparkPlan
   val right: SparkPlan
 
@@ -50,6 +53,12 @@ trait HashJoin {
   protected def streamSideKeyGenerator: Projection =
     UnsafeProjection.create(streamedKeys, streamedPlan.output)
 
+  @transient private[this] lazy val boundCondition = if (condition.isDefined) {
+    newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+  } else {
+    (r: InternalRow) => true
+  }
+
   protected def hashJoin(
       streamIter: Iterator[InternalRow],
       numStreamRows: LongSQLMetric,
@@ -68,44 +77,52 @@ trait HashJoin {
 
       private[this] val joinKeys = streamSideKeyGenerator
 
-      override final def hasNext: Boolean =
-        (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
-          (streamIter.hasNext && fetchNext())
+      override final def hasNext: Boolean = {
+        while (true) {
+          // check if it's end of current matches
+          if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
+            currentHashMatches = null
+            currentMatchPosition = -1
+          }
 
-      override final def next(): InternalRow = {
-        val ret = buildSide match {
-          case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
-          case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
-        }
-        currentMatchPosition += 1
-        numOutputRows += 1
-        resultProjection(ret)
-      }
+          // find the next match
+          while (currentHashMatches == null && streamIter.hasNext) {
+            currentStreamedRow = streamIter.next()
+            numStreamRows += 1
+            val key = joinKeys(currentStreamedRow)
+            if (!key.anyNull) {
+              currentHashMatches = hashedRelation.get(key)
+              if (currentHashMatches != null) {
+                currentMatchPosition = 0
+              }
+            }
+          }
+          if (currentHashMatches == null) {
+            return false
+          }
 
-      /**
-       * Searches the streamed iterator for the next row that has at least one match in hashtable.
-       *
-       * @return true if the search is successful, and false if the streamed iterator runs out of
-       *         tuples.
-       */
-      private final def fetchNext(): Boolean = {
-        currentHashMatches = null
-        currentMatchPosition = -1
-
-        while (currentHashMatches == null && streamIter.hasNext) {
-          currentStreamedRow = streamIter.next()
-          numStreamRows += 1
-          val key = joinKeys(currentStreamedRow)
-          if (!key.anyNull) {
-            currentHashMatches = hashedRelation.get(key)
+          // found some matches
+          buildSide match {
+            case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
+            case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
+          }
+          if (boundCondition(joinRow)) {
+            return true
+          } else {
+            currentMatchPosition += 1
           }
         }
+        false  // unreachable
+      }
 
-        if (currentHashMatches == null) {
-          false
+      override final def next(): InternalRow = {
+        // next() could be called without calling hasNext()
+        if (hasNext) {
+          currentMatchPosition += 1
+          numOutputRows += 1
+          resultProjection(joinRow)
         } else {
-          currentMatchPosition = 0
-          true
+          throw new NoSuchElementException
         }
       }
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 6d464d6946b78f5919fc9342f19a175d990fa73e..9e614309de129e5d5203118685b67ffeff6d11c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -78,8 +78,11 @@ trait HashOuterJoin {
 
   @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
   @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
-  @transient private[this] lazy val boundCondition =
+  @transient private[this] lazy val boundCondition = if (condition.isDefined) {
     newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+  } else {
+    (row: InternalRow) => true
+  }
 
   // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
   // iterator for performance purpose.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 812f881d06fb8872fcf69803a329984411af4849..322a954b4f79299ab67bfda7d7b2d8e92b39078b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
 case class SortMergeJoin(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
+    condition: Option[Expression],
     left: SparkPlan,
     right: SparkPlan) extends BinaryNode {
 
@@ -64,6 +65,13 @@ case class SortMergeJoin(
     val numOutputRows = longMetric("numOutputRows")
 
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+      val boundCondition: (InternalRow) => Boolean = {
+        condition.map { cond =>
+          newPredicate(cond, left.output ++ right.output)
+        }.getOrElse {
+          (r: InternalRow) => true
+        }
+      }
       new RowIterator {
         // The projection used to extract keys from input rows of the left child.
         private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
@@ -89,26 +97,34 @@ case class SortMergeJoin(
         private[this] val resultProjection: (InternalRow) => InternalRow =
           UnsafeProjection.create(schema)
 
+        if (smjScanner.findNextInnerJoinRows()) {
+          currentRightMatches = smjScanner.getBufferedMatches
+          currentLeftRow = smjScanner.getStreamedRow
+          currentMatchIdx = 0
+        }
+
         override def advanceNext(): Boolean = {
-          if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) {
-            if (smjScanner.findNextInnerJoinRows()) {
-              currentRightMatches = smjScanner.getBufferedMatches
-              currentLeftRow = smjScanner.getStreamedRow
-              currentMatchIdx = 0
-            } else {
-              currentRightMatches = null
-              currentLeftRow = null
-              currentMatchIdx = -1
+          while (currentMatchIdx >= 0) {
+            if (currentMatchIdx == currentRightMatches.length) {
+              if (smjScanner.findNextInnerJoinRows()) {
+                currentRightMatches = smjScanner.getBufferedMatches
+                currentLeftRow = smjScanner.getStreamedRow
+                currentMatchIdx = 0
+              } else {
+                currentRightMatches = null
+                currentLeftRow = null
+                currentMatchIdx = -1
+                return false
+              }
             }
-          }
-          if (currentLeftRow != null) {
             joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
             currentMatchIdx += 1
-            numOutputRows += 1
-            true
-          } else {
-            false
+            if (boundCondition(joinRow)) {
+              numOutputRows += 1
+              return true
+            }
           }
+          false
         }
 
         override def getRow: InternalRow = resultProjection(joinRow)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 42fadaa8e2215237bb78136a71a2f7ce4bc582ca..ab81b702596af4372e8b2e03b54fdad2d7186e2c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.execution.joins
 
-import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf}
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.Inner
@@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Join
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.apache.spark.sql.{DataFrame, Row, SQLConf}
 
 class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
   import testImplicits.localSeqToDataFrameHolder
@@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
         leftPlan: SparkPlan,
         rightPlan: SparkPlan,
         side: BuildSide) = {
-      val broadcastHashJoin =
-        execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
-      boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
+      joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
     }
 
     def makeSortMergeJoin(
@@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
         leftPlan: SparkPlan,
         rightPlan: SparkPlan) = {
       val sortMergeJoin =
-        execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
-      val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
-      EnsureRequirements(sqlContext).apply(filteredJoin)
+        joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
+      EnsureRequirements(sqlContext).apply(sortMergeJoin)
     }
 
     test(s"$testName using BroadcastHashJoin (build=left)") {