diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index bc36bacd00b13226065f3dbd12002123b7593b73..cb055cd74a5e5d638fd00485d203f2e83f4c2156 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -104,40 +104,29 @@ private[sql] case class InMemoryColumnarTableScan(
   override def execute() = {
     relation.cachedColumnBuffers.mapPartitions { iterator =>
       // Find the ordinals of the requested columns.  If none are requested, use the first.
-      val requestedColumns =
-        if (attributes.isEmpty) {
-          Seq(0)
-        } else {
-          attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
-        }
-
-      new Iterator[Row] {
-        private[this] var columnBuffers: Array[ByteBuffer] = null
-        private[this] var columnAccessors: Seq[ColumnAccessor] = null
-        nextBatch()
-
-        private[this] val nextRow = new GenericMutableRow(columnAccessors.length)
-
-        def nextBatch() = {
-          columnBuffers = iterator.next()
-          columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_))
-        }
+      val requestedColumns = if (attributes.isEmpty) {
+        Seq(0)
+      } else {
+        attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
+      }
 
-        override def next() = {
-          if (!columnAccessors.head.hasNext) {
-            nextBatch()
-          }
+      iterator
+        .map(batch => requestedColumns.map(batch(_)).map(ColumnAccessor(_)))
+        .flatMap { columnAccessors =>
+          val nextRow = new GenericMutableRow(columnAccessors.length)
+          new Iterator[Row] {
+            override def next() = {
+              var i = 0
+              while (i < nextRow.length) {
+                columnAccessors(i).extractTo(nextRow, i)
+                i += 1
+              }
+              nextRow
+            }
 
-          var i = 0
-          while (i < nextRow.length) {
-            columnAccessors(i).extractTo(nextRow, i)
-            i += 1
+            override def hasNext = columnAccessors.head.hasNext
           }
-          nextRow
         }
-
-        override def hasNext = columnAccessors.head.hasNext || iterator.hasNext
-      }
     }
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index c3ec82fb6977839fd329bca167da0ec34d4b60c8..eb33a61c6e811ab5006e12675d98e4c5c371ad80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -151,4 +151,9 @@ object TestData {
     TimestampField(new Timestamp(i))
   })
   timestamps.registerTempTable("timestamps")
+
+  case class IntField(i: Int)
+  // An RDD with 4 elements and 8 partitions
+  val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
+  withEmptyParts.registerTempTable("withEmptyParts")
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index fdd2799a53268cd08249ad92950bd40fa5b7dbb6..0e3c67f5eed2973595b5f5d0f8495750f94f2b93 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,14 +17,13 @@
 
 package org.apache.spark.sql.columnar
 
-import org.apache.spark.sql.{QueryTest, TestData}
 import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.execution.SparkLogicalPlan
 import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.{SQLConf, QueryTest, TestData}
 
 class InMemoryColumnarQuerySuite extends QueryTest {
-  import TestData._
-  import TestSQLContext._
+  import org.apache.spark.sql.TestData._
+  import org.apache.spark.sql.test.TestSQLContext._
 
   test("simple columnar query") {
     val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan
@@ -93,4 +92,16 @@ class InMemoryColumnarQuerySuite extends QueryTest {
       sql("SELECT time FROM timestamps"),
       timestamps.collect().toSeq)
   }
+
+  test("SPARK-3320 regression: batched column buffer building should work with empty partitions") {
+    checkAnswer(
+      sql("SELECT * FROM withEmptyParts"),
+      withEmptyParts.collect().toSeq)
+
+    TestSQLContext.cacheTable("withEmptyParts")
+
+    checkAnswer(
+      sql("SELECT * FROM withEmptyParts"),
+      withEmptyParts.collect().toSeq)
+  }
 }