diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 6173fd3a69fc7545a700c4048a43d3458bcb02f5..42d58682a1e235e7deb11d4979fb25f9804da51c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -28,6 +28,7 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.io.ByteArrayChunkOutputStream /** * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. @@ -201,29 +202,12 @@ private object TorrentBroadcast extends Logging { } def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = { - // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks - // so we don't need to do the extra memory copy. - val bos = new ByteArrayOutputStream() + val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE) val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos val ser = SparkEnv.get.serializer.newInstance() val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() - val byteArray = bos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt - val blocks = new Array[ByteBuffer](numBlocks) - - var blockId = 0 - for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { - val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) - val tempByteArray = new Array[Byte](thisBlockSize) - bais.read(tempByteArray, 0, thisBlockSize) - - blocks(blockId) = ByteBuffer.wrap(tempByteArray) - blockId += 1 - } - bais.close() - blocks + bos.toArrays.map(ByteBuffer.wrap) } def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = { diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala new file mode 100644 index 0000000000000000000000000000000000000000..daac6f971eb20ee835750c9c39f215e33a6d1ade --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import java.io.OutputStream + +import scala.collection.mutable.ArrayBuffer + + +/** + * An OutputStream that writes to fixed-size chunks of byte arrays. + * + * @param chunkSize size of each chunk, in bytes. + */ +private[spark] +class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream { + + private val chunks = new ArrayBuffer[Array[Byte]] + + /** Index of the last chunk. Starting with -1 when the chunks array is empty. */ + private var lastChunkIndex = -1 + + /** + * Next position to write in the last chunk. + * + * If this equals chunkSize, it means for next write we need to allocate a new chunk. + * This can also never be 0. + */ + private var position = chunkSize + + override def write(b: Int): Unit = { + allocateNewChunkIfNeeded() + chunks(lastChunkIndex)(position) = b.toByte + position += 1 + } + + override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + var written = 0 + while (written < len) { + allocateNewChunkIfNeeded() + val thisBatch = math.min(chunkSize - position, len - written) + System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch) + written += thisBatch + position += thisBatch + } + } + + @inline + private def allocateNewChunkIfNeeded(): Unit = { + if (position == chunkSize) { + chunks += new Array[Byte](chunkSize) + lastChunkIndex += 1 + position = 0 + } + } + + def toArrays: Array[Array[Byte]] = { + if (lastChunkIndex == -1) { + new Array[Array[Byte]](0) + } else { + // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk. + // An alternative would have been returning an array of ByteBuffers, with the last buffer + // bounded to only the last chunk's position. However, given our use case in Spark (to put + // the chunks in block manager), only limiting the view bound of the buffer would still + // require the block manager to store the whole chunk. + val ret = new Array[Array[Byte]](chunks.size) + for (i <- 0 until chunks.size - 1) { + ret(i) = chunks(i) + } + if (position == chunkSize) { + ret(lastChunkIndex) = chunks(lastChunkIndex) + } else { + ret(lastChunkIndex) = new Array[Byte](position) + System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position) + } + ret + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..f855831b8e3673689a9c58802b00c75f9b6aedf3 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import scala.util.Random + +import org.scalatest.FunSuite + + +class ByteArrayChunkOutputStreamSuite extends FunSuite { + + test("empty output") { + val o = new ByteArrayChunkOutputStream(1024) + assert(o.toArrays.length === 0) + } + + test("write a single byte") { + val o = new ByteArrayChunkOutputStream(1024) + o.write(10) + assert(o.toArrays.length === 1) + assert(o.toArrays.head.toSeq === Seq(10.toByte)) + } + + test("write a single near boundary") { + val o = new ByteArrayChunkOutputStream(10) + o.write(new Array[Byte](9)) + o.write(99) + assert(o.toArrays.length === 1) + assert(o.toArrays.head(9) === 99.toByte) + } + + test("write a single at boundary") { + val o = new ByteArrayChunkOutputStream(10) + o.write(new Array[Byte](10)) + o.write(99) + assert(o.toArrays.length === 2) + assert(o.toArrays(1).length === 1) + assert(o.toArrays(1)(0) === 99.toByte) + } + + test("single chunk output") { + val ref = new Array[Byte](8) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("single chunk output at boundary size") { + val ref = new Array[Byte](10) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("multiple chunk output") { + val ref = new Array[Byte](26) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 6) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 26)) + } + + test("multiple chunk output at boundary size") { + val ref = new Array[Byte](30) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(10) + o.write(ref) + val arrays = o.toArrays + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 10) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 30)) + } +}