Skip to content
Snippets Groups Projects
Commit 4bf3de71 authored by Matei Zaharia's avatar Matei Zaharia Committed by Michael Armbrust
Browse files

[SPARK-3085] [SQL] Use compact data structures in SQL joins

This reuses the CompactBuffer from Spark Core to save memory and pointer
dereferences. I also tried AppendOnlyMap instead of java.util.HashMap
but unfortunately that slows things down because it seems to do more
equals() calls and the equals on GenericRow, and especially JoinedRow,
is pretty expensive.

Author: Matei Zaharia <matei@databricks.com>

Closes #1993 from mateiz/spark-3085 and squashes the following commits:

188221e [Matei Zaharia] Remove unneeded import
5f903ee [Matei Zaharia] [SPARK-3085] [SQL] Use compact data structures in SQL joins
parent 6a13dca1
No related branches found
No related tags found
No related merge requests found
...@@ -19,16 +19,15 @@ package org.apache.spark.sql.execution ...@@ -19,16 +19,15 @@ package org.apache.spark.sql.execution
import java.util.{HashMap => JavaHashMap} import java.util.{HashMap => JavaHashMap}
import scala.collection.mutable.{ArrayBuffer, BitSet}
import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent._ import scala.concurrent._
import scala.concurrent.duration._ import scala.concurrent.duration._
import org.apache.spark.annotation.DeveloperApi import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.collection.CompactBuffer
@DeveloperApi @DeveloperApi
sealed abstract class BuildSide sealed abstract class BuildSide
...@@ -67,7 +66,7 @@ trait HashJoin { ...@@ -67,7 +66,7 @@ trait HashJoin {
def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
// TODO: Use Spark's HashMap implementation. // TODO: Use Spark's HashMap implementation.
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]()
var currentRow: Row = null var currentRow: Row = null
// Create a mapping of buildKeys -> rows // Create a mapping of buildKeys -> rows
...@@ -77,7 +76,7 @@ trait HashJoin { ...@@ -77,7 +76,7 @@ trait HashJoin {
if (!rowKey.anyNull) { if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey) val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) { val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]() val newMatchList = new CompactBuffer[Row]()
hashTable.put(rowKey, newMatchList) hashTable.put(rowKey, newMatchList)
newMatchList newMatchList
} else { } else {
...@@ -89,7 +88,7 @@ trait HashJoin { ...@@ -89,7 +88,7 @@ trait HashJoin {
new Iterator[Row] { new Iterator[Row] {
private[this] var currentStreamedRow: Row = _ private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatches: ArrayBuffer[Row] = _ private[this] var currentHashMatches: CompactBuffer[Row] = _
private[this] var currentMatchPosition: Int = -1 private[this] var currentMatchPosition: Int = -1
// Mutable per row objects. // Mutable per row objects.
...@@ -140,7 +139,7 @@ trait HashJoin { ...@@ -140,7 +139,7 @@ trait HashJoin {
/** /**
* :: DeveloperApi :: * :: DeveloperApi ::
* Performs a hash based outer join for two child relations by shuffling the data using * Performs a hash based outer join for two child relations by shuffling the data using
* the join keys. This operator requires loading the associated partition in both side into memory. * the join keys. This operator requires loading the associated partition in both side into memory.
*/ */
@DeveloperApi @DeveloperApi
...@@ -179,26 +178,26 @@ case class HashOuterJoin( ...@@ -179,26 +178,26 @@ case class HashOuterJoin(
@transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row]
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose. // iterator for performance purpose.
private[this] def leftOuterIterator( private[this] def leftOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow() val joinedRow = new JoinedRow()
val rightNullRow = new GenericRow(right.output.length) val rightNullRow = new GenericRow(right.output.length)
val boundCondition = val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
leftIter.iterator.flatMap { l => leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l) joinedRow.withLeft(l)
var matched = false var matched = false
(if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
matched = true matched = true
joinedRow.copy joinedRow.copy
} else { } else {
Nil Nil
}) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all of the // as we don't know whether we need to append it until finish iterating all of the
// records in right side. // records in right side.
// If we didn't get any proper row, then append a single row with empty right // If we didn't get any proper row, then append a single row with empty right
joinedRow.withRight(rightNullRow).copy joinedRow.withRight(rightNullRow).copy
...@@ -210,20 +209,20 @@ case class HashOuterJoin( ...@@ -210,20 +209,20 @@ case class HashOuterJoin(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow() val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length) val leftNullRow = new GenericRow(left.output.length)
val boundCondition = val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
rightIter.iterator.flatMap { r => rightIter.iterator.flatMap { r =>
joinedRow.withRight(r) joinedRow.withRight(r)
var matched = false var matched = false
(if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
matched = true matched = true
joinedRow.copy joinedRow.copy
} else { } else {
Nil Nil
}) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all of the // as we don't know whether we need to append it until finish iterating all of the
// records in left side. // records in left side.
// If we didn't get any proper row, then append a single row with empty left. // If we didn't get any proper row, then append a single row with empty left.
joinedRow.withLeft(leftNullRow).copy joinedRow.withLeft(leftNullRow).copy
...@@ -236,7 +235,7 @@ case class HashOuterJoin( ...@@ -236,7 +235,7 @@ case class HashOuterJoin(
val joinedRow = new JoinedRow() val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length) val leftNullRow = new GenericRow(left.output.length)
val rightNullRow = new GenericRow(right.output.length) val rightNullRow = new GenericRow(right.output.length)
val boundCondition = val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
if (!key.anyNull) { if (!key.anyNull) {
...@@ -246,8 +245,8 @@ case class HashOuterJoin( ...@@ -246,8 +245,8 @@ case class HashOuterJoin(
leftIter.iterator.flatMap[Row] { l => leftIter.iterator.flatMap[Row] { l =>
joinedRow.withLeft(l) joinedRow.withLeft(l)
var matched = false var matched = false
rightIter.zipWithIndex.collect { rightIter.zipWithIndex.collect {
// 1. For those matched (satisfy the join condition) records with both sides filled, // 1. For those matched (satisfy the join condition) records with both sides filled,
// append them directly // append them directly
case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
...@@ -260,7 +259,7 @@ case class HashOuterJoin( ...@@ -260,7 +259,7 @@ case class HashOuterJoin(
// 2. For those unmatched records in left, append additional records with empty right. // 2. For those unmatched records in left, append additional records with empty right.
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
// as we don't know whether we need to append it until finish iterating all // as we don't know whether we need to append it until finish iterating all
// of the records in right side. // of the records in right side.
// If we didn't get any proper row, then append a single row with empty right. // If we didn't get any proper row, then append a single row with empty right.
joinedRow.withRight(rightNullRow).copy joinedRow.withRight(rightNullRow).copy
...@@ -268,8 +267,8 @@ case class HashOuterJoin( ...@@ -268,8 +267,8 @@ case class HashOuterJoin(
} ++ rightIter.zipWithIndex.collect { } ++ rightIter.zipWithIndex.collect {
// 3. For those unmatched records in right, append additional records with empty left. // 3. For those unmatched records in right, append additional records with empty left.
// Re-visiting the records in right, and append additional row with empty left, if its not // Re-visiting the records in right, and append additional row with empty left, if its not
// in the matched set. // in the matched set.
case (r, idx) if (!rightMatchedSet.contains(idx)) => { case (r, idx) if (!rightMatchedSet.contains(idx)) => {
joinedRow(leftNullRow, r).copy joinedRow(leftNullRow, r).copy
} }
...@@ -284,15 +283,15 @@ case class HashOuterJoin( ...@@ -284,15 +283,15 @@ case class HashOuterJoin(
} }
private[this] def buildHashTable( private[this] def buildHashTable(
iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = { iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = {
val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]() val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]()
while (iter.hasNext) { while (iter.hasNext) {
val currentRow = iter.next() val currentRow = iter.next()
val rowKey = keyGenerator(currentRow) val rowKey = keyGenerator(currentRow)
var existingMatchList = hashTable.get(rowKey) var existingMatchList = hashTable.get(rowKey)
if (existingMatchList == null) { if (existingMatchList == null) {
existingMatchList = new ArrayBuffer[Row]() existingMatchList = new CompactBuffer[Row]()
hashTable.put(rowKey, existingMatchList) hashTable.put(rowKey, existingMatchList)
} }
...@@ -311,20 +310,20 @@ case class HashOuterJoin( ...@@ -311,20 +310,20 @@ case class HashOuterJoin(
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
val boundCondition = val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
joinType match { joinType match {
case LeftOuter => leftHashTable.keysIterator.flatMap { key => case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST)) rightHashTable.getOrElse(key, EMPTY_LIST))
} }
case RightOuter => rightHashTable.keysIterator.flatMap { key => case RightOuter => rightHashTable.keysIterator.flatMap { key =>
rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST)) rightHashTable.getOrElse(key, EMPTY_LIST))
} }
case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
fullOuterIterator(key, fullOuterIterator(key,
leftHashTable.getOrElse(key, EMPTY_LIST), leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST)) rightHashTable.getOrElse(key, EMPTY_LIST))
} }
case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
...@@ -550,7 +549,7 @@ case class BroadcastNestedLoopJoin( ...@@ -550,7 +549,7 @@ case class BroadcastNestedLoopJoin(
/** All rows that either match both-way, or rows from streamed joined with nulls. */ /** All rows that either match both-way, or rows from streamed joined with nulls. */
val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new ArrayBuffer[Row] val matchedRows = new CompactBuffer[Row]
// TODO: Use Spark's BitSet. // TODO: Use Spark's BitSet.
val includedBroadcastTuples = val includedBroadcastTuples =
new scala.collection.mutable.BitSet(broadcastedRelation.value.size) new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
...@@ -602,20 +601,20 @@ case class BroadcastNestedLoopJoin( ...@@ -602,20 +601,20 @@ case class BroadcastNestedLoopJoin(
val rightNulls = new GenericMutableRow(right.output.size) val rightNulls = new GenericMutableRow(right.output.size)
/** Rows from broadcasted joined with nulls. */ /** Rows from broadcasted joined with nulls. */
val broadcastRowsWithNulls: Seq[Row] = { val broadcastRowsWithNulls: Seq[Row] = {
val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer() val buf: CompactBuffer[Row] = new CompactBuffer()
var i = 0 var i = 0
val rel = broadcastedRelation.value val rel = broadcastedRelation.value
while (i < rel.length) { while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) { if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match { (joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i)) case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls) case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
case _ => case _ =>
} }
} }
i += 1 i += 1
} }
arrBuf.toSeq buf.toSeq
} }
// TODO: Breaks lineage. // TODO: Breaks lineage.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment