Skip to content
Snippets Groups Projects
Commit 323d51f1 authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-12700] [SQL] embed condition into SMJ and BroadcastHashJoin

Currently SortMergeJoin and BroadcastHashJoin do not support condition, the need a followed Filter for that, the result projection to generate UnsafeRow could be very expensive if they generate lots of rows and could be filtered mostly by condition.

This PR brings the support of condition for SortMergeJoin and BroadcastHashJoin, just like other outer joins do.

This could improve the performance of Q72 by 7x (from 120s to 16.5s).

Author: Davies Liu <davies@databricks.com>

Closes #10653 from davies/filter_join.
parent 39ac56fc
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.{execution, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
......@@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object EquiJoinSelection extends Strategy with PredicateHelper {
private[this] def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: LogicalPlan,
right: LogicalPlan,
condition: Option[Expression],
side: joins.BuildSide): Seq[SparkPlan] = {
val broadcastHashJoin = execution.joins.BroadcastHashJoin(
leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// --- Inner joins --------------------------------------------------------------------------
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
joins.BroadcastHashJoin(
leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
joins.BroadcastHashJoin(
leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
val mergeJoin =
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
joins.SortMergeJoin(
leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil
// --- Outer joins --------------------------------------------------------------------------
......
......@@ -39,6 +39,7 @@ case class BroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
extends BinaryNode with HashJoin {
......
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.joins
import java.util.NoSuchElementException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
......@@ -29,6 +31,7 @@ trait HashJoin {
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val buildSide: BuildSide
val condition: Option[Expression]
val left: SparkPlan
val right: SparkPlan
......@@ -50,6 +53,12 @@ trait HashJoin {
protected def streamSideKeyGenerator: Projection =
UnsafeProjection.create(streamedKeys, streamedPlan.output)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
} else {
(r: InternalRow) => true
}
protected def hashJoin(
streamIter: Iterator[InternalRow],
numStreamRows: LongSQLMetric,
......@@ -68,44 +77,52 @@ trait HashJoin {
private[this] val joinKeys = streamSideKeyGenerator
override final def hasNext: Boolean =
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
(streamIter.hasNext && fetchNext())
override final def hasNext: Boolean = {
while (true) {
// check if it's end of current matches
if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
currentHashMatches = null
currentMatchPosition = -1
}
override final def next(): InternalRow = {
val ret = buildSide match {
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
currentMatchPosition += 1
numOutputRows += 1
resultProjection(ret)
}
// find the next match
while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
numStreamRows += 1
val key = joinKeys(currentStreamedRow)
if (!key.anyNull) {
currentHashMatches = hashedRelation.get(key)
if (currentHashMatches != null) {
currentMatchPosition = 0
}
}
}
if (currentHashMatches == null) {
return false
}
/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false if the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatches = null
currentMatchPosition = -1
while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
numStreamRows += 1
val key = joinKeys(currentStreamedRow)
if (!key.anyNull) {
currentHashMatches = hashedRelation.get(key)
// found some matches
buildSide match {
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
if (boundCondition(joinRow)) {
return true
} else {
currentMatchPosition += 1
}
}
false // unreachable
}
if (currentHashMatches == null) {
false
override final def next(): InternalRow = {
// next() could be called without calling hasNext()
if (hasNext) {
currentMatchPosition += 1
numOutputRows += 1
resultProjection(joinRow)
} else {
currentMatchPosition = 0
true
throw new NoSuchElementException
}
}
}
......
......@@ -78,8 +78,11 @@ trait HashOuterJoin {
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@transient private[this] lazy val boundCondition =
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
} else {
(row: InternalRow) => true
}
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
......
......@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
case class SortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
......@@ -64,6 +65,13 @@ case class SortMergeJoin(
val numOutputRows = longMetric("numOutputRows")
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
newPredicate(cond, left.output ++ right.output)
}.getOrElse {
(r: InternalRow) => true
}
}
new RowIterator {
// The projection used to extract keys from input rows of the left child.
private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
......@@ -89,26 +97,34 @@ case class SortMergeJoin(
private[this] val resultProjection: (InternalRow) => InternalRow =
UnsafeProjection.create(schema)
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
currentMatchIdx = 0
}
override def advanceNext(): Boolean = {
if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) {
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
currentMatchIdx = 0
} else {
currentRightMatches = null
currentLeftRow = null
currentMatchIdx = -1
while (currentMatchIdx >= 0) {
if (currentMatchIdx == currentRightMatches.length) {
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
currentMatchIdx = 0
} else {
currentRightMatches = null
currentLeftRow = null
currentMatchIdx = -1
return false
}
}
}
if (currentLeftRow != null) {
joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
currentMatchIdx += 1
numOutputRows += 1
true
} else {
false
if (boundCondition(joinRow)) {
numOutputRows += 1
return true
}
}
false
}
override def getRow: InternalRow = resultProjection(joinRow)
......
......@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
......@@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLConf}
class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.localSeqToDataFrameHolder
......@@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
}
def makeSortMergeJoin(
......@@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
EnsureRequirements(sqlContext).apply(sortMergeJoin)
}
test(s"$testName using BroadcastHashJoin (build=left)") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment