diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 910519d0e68144b1a38fbe1c97876ef6c7e05665..df0f73049921148575252b8cf1c27d6691c61b15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.{execution, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object EquiJoinSelection extends Strategy with PredicateHelper { - private[this] def makeBroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - left: LogicalPlan, - right: LogicalPlan, - condition: Option[Expression], - side: joins.BuildSide): Seq[SparkPlan] = { - val broadcastHashJoin = execution.joins.BroadcastHashJoin( - leftKeys, rightKeys, side, planLater(left), planLater(right)) - condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- Inner joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) + joins.BroadcastHashJoin( + leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + joins.BroadcastHashJoin( + leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => - val mergeJoin = - joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) - condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + joins.SortMergeJoin( + leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil // --- Outer joins -------------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 0a818cc2c2a279f7b3738e87139b9278f852a9ba..c9ea579b5e809b51607cbac03537fc896c3555ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -39,6 +39,7 @@ case class BroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], buildSide: BuildSide, + condition: Option[Expression], left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7f9d9daa5ab201ce3a367a65d474c597bea38d26..8ef854001f4de063afd4038649fcbd165061451c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.joins +import java.util.NoSuchElementException + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan @@ -29,6 +31,7 @@ trait HashJoin { val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] val buildSide: BuildSide + val condition: Option[Expression] val left: SparkPlan val right: SparkPlan @@ -50,6 +53,12 @@ trait HashJoin { protected def streamSideKeyGenerator: Projection = UnsafeProjection.create(streamedKeys, streamedPlan.output) + @transient private[this] lazy val boundCondition = if (condition.isDefined) { + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + } else { + (r: InternalRow) => true + } + protected def hashJoin( streamIter: Iterator[InternalRow], numStreamRows: LongSQLMetric, @@ -68,44 +77,52 @@ trait HashJoin { private[this] val joinKeys = streamSideKeyGenerator - override final def hasNext: Boolean = - (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || - (streamIter.hasNext && fetchNext()) + override final def hasNext: Boolean = { + while (true) { + // check if it's end of current matches + if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) { + currentHashMatches = null + currentMatchPosition = -1 + } - override final def next(): InternalRow = { - val ret = buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - currentMatchPosition += 1 - numOutputRows += 1 - resultProjection(ret) - } + // find the next match + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + numStreamRows += 1 + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashedRelation.get(key) + if (currentHashMatches != null) { + currentMatchPosition = 0 + } + } + } + if (currentHashMatches == null) { + return false + } - /** - * Searches the streamed iterator for the next row that has at least one match in hashtable. - * - * @return true if the search is successful, and false if the streamed iterator runs out of - * tuples. - */ - private final def fetchNext(): Boolean = { - currentHashMatches = null - currentMatchPosition = -1 - - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - numStreamRows += 1 - val key = joinKeys(currentStreamedRow) - if (!key.anyNull) { - currentHashMatches = hashedRelation.get(key) + // found some matches + buildSide match { + case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) + } + if (boundCondition(joinRow)) { + return true + } else { + currentMatchPosition += 1 } } + false // unreachable + } - if (currentHashMatches == null) { - false + override final def next(): InternalRow = { + // next() could be called without calling hasNext() + if (hasNext) { + currentMatchPosition += 1 + numOutputRows += 1 + resultProjection(joinRow) } else { - currentMatchPosition = 0 - true + throw new NoSuchElementException } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 6d464d6946b78f5919fc9342f19a175d990fa73e..9e614309de129e5d5203118685b67ffeff6d11c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -78,8 +78,11 @@ trait HashOuterJoin { @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - @transient private[this] lazy val boundCondition = + @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + } else { + (row: InternalRow) => true + } // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 812f881d06fb8872fcf69803a329984411af4849..322a954b4f79299ab67bfda7d7b2d8e92b39078b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} case class SortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + condition: Option[Expression], left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -64,6 +65,13 @@ case class SortMergeJoin( val numOutputRows = longMetric("numOutputRows") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + val boundCondition: (InternalRow) => Boolean = { + condition.map { cond => + newPredicate(cond, left.output ++ right.output) + }.getOrElse { + (r: InternalRow) => true + } + } new RowIterator { // The projection used to extract keys from input rows of the left child. private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) @@ -89,26 +97,34 @@ case class SortMergeJoin( private[this] val resultProjection: (InternalRow) => InternalRow = UnsafeProjection.create(schema) + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } + override def advanceNext(): Boolean = { - if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { - if (smjScanner.findNextInnerJoinRows()) { - currentRightMatches = smjScanner.getBufferedMatches - currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 - } else { - currentRightMatches = null - currentLeftRow = null - currentMatchIdx = -1 + while (currentMatchIdx >= 0) { + if (currentMatchIdx == currentRightMatches.length) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + currentMatchIdx = 0 + } else { + currentRightMatches = null + currentLeftRow = null + currentMatchIdx = -1 + return false + } } - } - if (currentLeftRow != null) { joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) currentMatchIdx += 1 - numOutputRows += 1 - true - } else { - false + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } } + false } override def getRow: InternalRow = resultProjection(joinRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 42fadaa8e2215237bb78136a71a2f7ce4bc582ca..ab81b702596af4372e8b2e03b54fdad2d7186e2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner @@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.{DataFrame, Row, SQLConf} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder @@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - val broadcastHashJoin = - execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) - boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) + joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan) } def makeSortMergeJoin( @@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan) = { val sortMergeJoin = - execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) - val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) + joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan) + EnsureRequirements(sqlContext).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") {