diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 04fba17be4bfab4fa76811e8673c921e82ba3820..e86116680a57a1e51617edd2987007b222e04a4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -111,17 +111,27 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
     val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
       genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
     }
+    val localIdx = ctx.freshName("localIdx")
+    val localEnd = ctx.freshName("localEnd")
+    val numRows = ctx.freshName("numRows")
+    val shouldStop = if (isShouldStopRequired) {
+      s"if (shouldStop()) { $idx = $rowidx + 1; return; }"
+    } else {
+      "// shouldStop check is eliminated"
+    }
     s"""
        |if ($batch == null) {
        |  $nextBatch();
        |}
        |while ($batch != null) {
-       |  int numRows = $batch.numRows();
-       |  while ($idx < numRows) {
-       |    int $rowidx = $idx++;
+       |  int $numRows = $batch.numRows();
+       |  int $localEnd = $numRows - $idx;
+       |  for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
+       |    int $rowidx = $idx + $localIdx;
        |    ${consume(ctx, columnsBatchInput).trim}
-       |    if (shouldStop()) return;
+       |    $shouldStop
        |  }
+       |  $idx = $numRows;
        |  $batch = null;
        |  $nextBatch();
        |}