From 750ed64cd9db4f81a53caaf1fd6c8a6a0c07887d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh <simonh@tw.ibm.com> Date: Thu, 17 Mar 2016 23:24:44 -0700 Subject: [PATCH] [SPARK-13930] [SQL] Apply fast serialization on collect limit operator ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-13930 Recently the fast serialization has been introduced to collecting DataFrame/Dataset (#11664). The same technology can be used on collect limit operator too. ## How was this patch tested? Add a benchmark for collect limit to `BenchmarkWholeStageCodegen`. Without this patch: model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- collect limit 1 million 3413 / 3768 0.3 3255.0 1.0X collect limit 2 millions 9728 / 10440 0.1 9277.3 0.4X With this patch: model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- collect limit 1 million 833 / 1284 1.3 794.4 1.0X collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #11759 from viirya/execute-take. --- .../spark/sql/execution/SparkPlan.scala | 78 ++++++++++++------- .../BenchmarkWholeStageCodegen.scala | 21 +++++ 2 files changed, 71 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a392b53412..010ed7f500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -219,48 +219,62 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** - * Runs this query returning the result as an array. + * Packing the UnsafeRows into byte array for faster serialization. + * The byte arrays are in the following format: + * [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] + * + * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also + * compressed. */ - def executeCollect(): Array[InternalRow] = { - // Packing the UnsafeRows into byte array for faster serialization. - // The byte arrays are in the following format: - // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1] - // - // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also - // compressed. - val byteArrayRdd = execute().mapPartitionsInternal { iter => + private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + execute().mapPartitionsInternal { iter => + var count = 0 val buffer = new Array[Byte](4 << 10) // 4K val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bos = new ByteArrayOutputStream() val out = new DataOutputStream(codec.compressedOutputStream(bos)) - while (iter.hasNext) { + while (iter.hasNext && (n < 0 || count < n)) { val row = iter.next().asInstanceOf[UnsafeRow] out.writeInt(row.getSizeInBytes) row.writeToStream(out, buffer) + count += 1 } out.writeInt(-1) out.flush() out.close() Iterator(bos.toByteArray) } + } - // Collect the byte arrays back to driver, then decode them as UnsafeRows. + /** + * Decode the byte arrays back to UnsafeRows and put them into buffer. + */ + private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = { val nFields = schema.length - val results = ArrayBuffer[InternalRow]() + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(codec.compressedInputStream(bis)) + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + buffer += row + sizeOfNextRow = ins.readInt() + } + } + + /** + * Runs this query returning the result as an array. + */ + def executeCollect(): Array[InternalRow] = { + val byteArrayRdd = getByteArrayRdd() + + val results = ArrayBuffer[InternalRow]() byteArrayRdd.collect().foreach { bytes => - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val bis = new ByteArrayInputStream(bytes) - val ins = new DataInputStream(codec.compressedInputStream(bis)) - var sizeOfNextRow = ins.readInt() - while (sizeOfNextRow >= 0) { - val bs = new Array[Byte](sizeOfNextRow) - ins.readFully(bs) - val row = new UnsafeRow(nFields) - row.pointTo(bs, sizeOfNextRow) - results += row - sizeOfNextRow = ins.readInt() - } + decodeUnsafeRows(bytes, results) } results.toArray } @@ -283,7 +297,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = execute().map(_.copy()) + val childRDD = getByteArrayRdd(n) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length @@ -307,13 +321,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val left = n - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext - val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) + val res = sc.runJob(childRDD, + (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) + + res.foreach { r => + decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf) + } - res.foreach(buf ++= _.take(n - buf.size)) partsScanned += p.size } - buf.toArray + if (buf.size > n) { + buf.take(n).toArray + } else { + buf.toArray + } } private[this] def isTesting: Boolean = sys.props.contains("spark.testing") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index b6051b07c8..d293ff66fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -465,4 +465,25 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { collect 4 millions 3193 / 3895 0.3 3044.7 0.1X */ } + + ignore("collect limit") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect limit", N) + benchmark.addCase("collect limit 1 million") { iter => + sqlContext.range(N * 4).limit(N).collect() + } + benchmark.addCase("collect limit 2 millions") { iter => + sqlContext.range(N * 4).limit(N * 2).collect() + } + benchmark.run() + + /** + model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) + collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect limit 1 million 833 / 1284 1.3 794.4 1.0X + collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X + */ + } } -- GitLab