diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04fba17be4bfab4fa76811e8673c921e82ba3820..e86116680a57a1e51617edd2987007b222e04a4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -111,17 +111,27 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val numRows = ctx.freshName("numRows") + val shouldStop = if (isShouldStopRequired) { + s"if (shouldStop()) { $idx = $rowidx + 1; return; }" + } else { + "// shouldStop check is eliminated" + } s""" |if ($batch == null) { | $nextBatch(); |} |while ($batch != null) { - | int numRows = $batch.numRows(); - | while ($idx < numRows) { - | int $rowidx = $idx++; + | int $numRows = $batch.numRows(); + | int $localEnd = $numRows - $idx; + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | int $rowidx = $idx + $localIdx; | ${consume(ctx, columnsBatchInput).trim} - | if (shouldStop()) return; + | $shouldStop | } + | $idx = $numRows; | $batch = null; | $nextBatch(); |}