diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index bc36bacd00b13226065f3dbd12002123b7593b73..cb055cd74a5e5d638fd00485d203f2e83f4c2156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -104,40 +104,29 @@ private[sql] case class InMemoryColumnarTableScan( override def execute() = { relation.cachedColumnBuffers.mapPartitions { iterator => // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = - if (attributes.isEmpty) { - Seq(0) - } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) - } - - new Iterator[Row] { - private[this] var columnBuffers: Array[ByteBuffer] = null - private[this] var columnAccessors: Seq[ColumnAccessor] = null - nextBatch() - - private[this] val nextRow = new GenericMutableRow(columnAccessors.length) - - def nextBatch() = { - columnBuffers = iterator.next() - columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) - } + val requestedColumns = if (attributes.isEmpty) { + Seq(0) + } else { + attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + } - override def next() = { - if (!columnAccessors.head.hasNext) { - nextBatch() - } + iterator + .map(batch => requestedColumns.map(batch(_)).map(ColumnAccessor(_))) + .flatMap { columnAccessors => + val nextRow = new GenericMutableRow(columnAccessors.length) + new Iterator[Row] { + override def next() = { + var i = 0 + while (i < nextRow.length) { + columnAccessors(i).extractTo(nextRow, i) + i += 1 + } + nextRow + } - var i = 0 - while (i < nextRow.length) { - columnAccessors(i).extractTo(nextRow, i) - i += 1 + override def hasNext = columnAccessors.head.hasNext } - nextRow } - - override def hasNext = columnAccessors.head.hasNext || iterator.hasNext - } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index c3ec82fb6977839fd329bca167da0ec34d4b60c8..eb33a61c6e811ab5006e12675d98e4c5c371ad80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -151,4 +151,9 @@ object TestData { TimestampField(new Timestamp(i)) }) timestamps.registerTempTable("timestamps") + + case class IntField(i: Int) + // An RDD with 4 elements and 8 partitions + val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + withEmptyParts.registerTempTable("withEmptyParts") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index fdd2799a53268cd08249ad92950bd40fa5b7dbb6..0e3c67f5eed2973595b5f5d0f8495750f94f2b93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar -import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.execution.SparkLogicalPlan import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{SQLConf, QueryTest, TestData} class InMemoryColumnarQuerySuite extends QueryTest { - import TestData._ - import TestSQLContext._ + import org.apache.spark.sql.TestData._ + import org.apache.spark.sql.test.TestSQLContext._ test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan @@ -93,4 +92,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) } + + test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { + checkAnswer( + sql("SELECT * FROM withEmptyParts"), + withEmptyParts.collect().toSeq) + + TestSQLContext.cacheTable("withEmptyParts") + + checkAnswer( + sql("SELECT * FROM withEmptyParts"), + withEmptyParts.collect().toSeq) + } }