From 32b18dd52cf8920903819f23e406271ecd8ac6bb Mon Sep 17 00:00:00 2001
From: Cheng Lian <lian.cs.zju@gmail.com>
Date: Fri, 29 Aug 2014 18:16:47 -0700
Subject: [PATCH] [SPARK-3320][SQL] Made batched in-memory column buffer
 building work for SchemaRDDs with empty partitions

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #2213 from liancheng/spark-3320 and squashes the following commits:

45a0139 [Cheng Lian] Fixed typo in InMemoryColumnarQuerySuite
f67067d [Cheng Lian] Fixed SPARK-3320
---
 .../columnar/InMemoryColumnarTableScan.scala  | 49 +++++++------------
 .../scala/org/apache/spark/sql/TestData.scala |  5 ++
 .../columnar/InMemoryColumnarQuerySuite.scala | 19 +++++--
 3 files changed, 39 insertions(+), 34 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index bc36bacd00..cb055cd74a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -104,40 +104,29 @@ private[sql] case class InMemoryColumnarTableScan(
   override def execute() = {
     relation.cachedColumnBuffers.mapPartitions { iterator =>
       // Find the ordinals of the requested columns.  If none are requested, use the first.
-      val requestedColumns =
-        if (attributes.isEmpty) {
-          Seq(0)
-        } else {
-          attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
-        }
-
-      new Iterator[Row] {
-        private[this] var columnBuffers: Array[ByteBuffer] = null
-        private[this] var columnAccessors: Seq[ColumnAccessor] = null
-        nextBatch()
-
-        private[this] val nextRow = new GenericMutableRow(columnAccessors.length)
-
-        def nextBatch() = {
-          columnBuffers = iterator.next()
-          columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_))
-        }
+      val requestedColumns = if (attributes.isEmpty) {
+        Seq(0)
+      } else {
+        attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
+      }
 
-        override def next() = {
-          if (!columnAccessors.head.hasNext) {
-            nextBatch()
-          }
+      iterator
+        .map(batch => requestedColumns.map(batch(_)).map(ColumnAccessor(_)))
+        .flatMap { columnAccessors =>
+          val nextRow = new GenericMutableRow(columnAccessors.length)
+          new Iterator[Row] {
+            override def next() = {
+              var i = 0
+              while (i < nextRow.length) {
+                columnAccessors(i).extractTo(nextRow, i)
+                i += 1
+              }
+              nextRow
+            }
 
-          var i = 0
-          while (i < nextRow.length) {
-            columnAccessors(i).extractTo(nextRow, i)
-            i += 1
+            override def hasNext = columnAccessors.head.hasNext
           }
-          nextRow
         }
-
-        override def hasNext = columnAccessors.head.hasNext || iterator.hasNext
-      }
     }
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index c3ec82fb69..eb33a61c6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -151,4 +151,9 @@ object TestData {
     TimestampField(new Timestamp(i))
   })
   timestamps.registerTempTable("timestamps")
+
+  case class IntField(i: Int)
+  // An RDD with 4 elements and 8 partitions
+  val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
+  withEmptyParts.registerTempTable("withEmptyParts")
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index fdd2799a53..0e3c67f5ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,14 +17,13 @@
 
 package org.apache.spark.sql.columnar
 
-import org.apache.spark.sql.{QueryTest, TestData}
 import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.execution.SparkLogicalPlan
 import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.{SQLConf, QueryTest, TestData}
 
 class InMemoryColumnarQuerySuite extends QueryTest {
-  import TestData._
-  import TestSQLContext._
+  import org.apache.spark.sql.TestData._
+  import org.apache.spark.sql.test.TestSQLContext._
 
   test("simple columnar query") {
     val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
@@ -93,4 +92,16 @@ class InMemoryColumnarQuerySuite extends QueryTest {
       sql("SELECT time FROM timestamps"),
       timestamps.collect().toSeq)
   }
+
+  test("SPARK-3320 regression: batched column buffer building should work with empty partitions") {
+    checkAnswer(
+      sql("SELECT * FROM withEmptyParts"),
+      withEmptyParts.collect().toSeq)
+
+    TestSQLContext.cacheTable("withEmptyParts")
+
+    checkAnswer(
+      sql("SELECT * FROM withEmptyParts"),
+      withEmptyParts.collect().toSeq)
+  }
 }
-- 
GitLab