diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a392b5341244fce7a8cfcae811dbd5902aebfb74..010ed7f5008ebaffb1105addb0be9efb55230c3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -219,48 +219,62 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** - * Runs this query returning the result as an array. + * Packing the UnsafeRows into byte array for faster serialization. + * The byte arrays are in the following format: + * [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + * + * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also + * compressed. */ - def executeCollect(): Array[InternalRow] = { - // Packing the UnsafeRows into byte array for faster serialization. - // The byte arrays are in the following format: - // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] - // - // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also - // compressed. - val byteArrayRdd = execute().mapPartitionsInternal { iter => + private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + execute().mapPartitionsInternal { iter => + var count = 0 val buffer = new Array[Byte](4 << 10) // 4K val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bos = new ByteArrayOutputStream() val out = new DataOutputStream(codec.compressedOutputStream(bos)) - while (iter.hasNext) { + while (iter.hasNext && (n < 0 || count < n)) { val row = iter.next().asInstanceOf[UnsafeRow] out.writeInt(row.getSizeInBytes) row.writeToStream(out, buffer) + count += 1 } out.writeInt(-1) out.flush() out.close() Iterator(bos.toByteArray) } + } - // Collect the byte arrays back to driver, then decode them as UnsafeRows. + /** + * Decode the byte arrays back to UnsafeRows and put them into buffer. + */ + private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = { val nFields = schema.length - val results = ArrayBuffer[InternalRow]() + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(codec.compressedInputStream(bis)) + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + buffer += row + sizeOfNextRow = ins.readInt() + } + } + + /** + * Runs this query returning the result as an array. + */ + def executeCollect(): Array[InternalRow] = { + val byteArrayRdd = getByteArrayRdd() + + val results = ArrayBuffer[InternalRow]() byteArrayRdd.collect().foreach { bytes => - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val bis = new ByteArrayInputStream(bytes) - val ins = new DataInputStream(codec.compressedInputStream(bis)) - var sizeOfNextRow = ins.readInt() - while (sizeOfNextRow >= 0) { - val bs = new Array[Byte](sizeOfNextRow) - ins.readFully(bs) - val row = new UnsafeRow(nFields) - row.pointTo(bs, sizeOfNextRow) - results += row - sizeOfNextRow = ins.readInt() - } + decodeUnsafeRows(bytes, results) } results.toArray } @@ -283,7 +297,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = execute().map(_.copy()) + val childRDD = getByteArrayRdd(n) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length @@ -307,13 +321,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val left = n - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext - val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) + val res = sc.runJob(childRDD, + (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) + + res.foreach { r => + decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf) + } - res.foreach(buf ++= _.take(n - buf.size)) partsScanned += p.size } - buf.toArray + if (buf.size > n) { + buf.take(n).toArray + } else { + buf.toArray + } } private[this] def isTesting: Boolean = sys.props.contains("spark.testing") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index b6051b07c8093ac94a422b852ac3ef1d9c1c44bb..d293ff66fbc6bb2adfdd1973d472041503217c59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -465,4 +465,25 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { collect 4 millions 3193 / 3895 0.3 3044.7 0.1X */ } + + ignore("collect limit") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect limit", N) + benchmark.addCase("collect limit 1 million") { iter => + sqlContext.range(N * 4).limit(N).collect() + } + benchmark.addCase("collect limit 2 millions") { iter => + sqlContext.range(N * 4).limit(N * 2).collect() + } + benchmark.run() + + /** + model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) + collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect limit 1 million 833 / 1284 1.3 794.4 1.0X + collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X + */ + } }