Skip to content
Snippets Groups Projects
Commit db160676 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-3135] Avoid extra mem copy in TorrentBroadcast via ByteArrayChunkOutputStream

This also enables supporting broadcast variables larger than 2G.

Author: Reynold Xin <rxin@apache.org>

Closes #2054 from rxin/ByteArrayChunkOutputStream and squashes the following commits:

618d9c8 [Reynold Xin] Code review.
93f5a51 [Reynold Xin] Added comments.
ee88e73 [Reynold Xin] to -> until
bbd1cb1 [Reynold Xin] Renamed a variable.
36f4d01 [Reynold Xin] Sort imports.
8f1a8eb [Reynold Xin] [SPARK-3135] Created ByteArrayChunkOutputStream and used it to avoid memory copy in TorrentBroadcast.
parent 1f98add9
No related branches found
No related tags found
No related merge requests found
......@@ -28,6 +28,7 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.io.ByteArrayChunkOutputStream
/**
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
......@@ -201,29 +202,12 @@ private object TorrentBroadcast extends Logging {
}
def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
// TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
// so we don't need to do the extra memory copy.
val bos = new ByteArrayOutputStream()
val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE)
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
val ser = SparkEnv.get.serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
val byteArray = bos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
val blocks = new Array[ByteBuffer](numBlocks)
var blockId = 0
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
val tempByteArray = new Array[Byte](thisBlockSize)
bais.read(tempByteArray, 0, thisBlockSize)
blocks(blockId) = ByteBuffer.wrap(tempByteArray)
blockId += 1
}
bais.close()
blocks
bos.toArrays.map(ByteBuffer.wrap)
}
def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.io
import java.io.OutputStream
import scala.collection.mutable.ArrayBuffer
/**
* An OutputStream that writes to fixed-size chunks of byte arrays.
*
* @param chunkSize size of each chunk, in bytes.
*/
private[spark]
class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
private val chunks = new ArrayBuffer[Array[Byte]]
/** Index of the last chunk. Starting with -1 when the chunks array is empty. */
private var lastChunkIndex = -1
/**
* Next position to write in the last chunk.
*
* If this equals chunkSize, it means for next write we need to allocate a new chunk.
* This can also never be 0.
*/
private var position = chunkSize
override def write(b: Int): Unit = {
allocateNewChunkIfNeeded()
chunks(lastChunkIndex)(position) = b.toByte
position += 1
}
override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
var written = 0
while (written < len) {
allocateNewChunkIfNeeded()
val thisBatch = math.min(chunkSize - position, len - written)
System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch)
written += thisBatch
position += thisBatch
}
}
@inline
private def allocateNewChunkIfNeeded(): Unit = {
if (position == chunkSize) {
chunks += new Array[Byte](chunkSize)
lastChunkIndex += 1
position = 0
}
}
def toArrays: Array[Array[Byte]] = {
if (lastChunkIndex == -1) {
new Array[Array[Byte]](0)
} else {
// Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
// An alternative would have been returning an array of ByteBuffers, with the last buffer
// bounded to only the last chunk's position. However, given our use case in Spark (to put
// the chunks in block manager), only limiting the view bound of the buffer would still
// require the block manager to store the whole chunk.
val ret = new Array[Array[Byte]](chunks.size)
for (i <- 0 until chunks.size - 1) {
ret(i) = chunks(i)
}
if (position == chunkSize) {
ret(lastChunkIndex) = chunks(lastChunkIndex)
} else {
ret(lastChunkIndex) = new Array[Byte](position)
System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position)
}
ret
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util.io
import scala.util.Random
import org.scalatest.FunSuite
class ByteArrayChunkOutputStreamSuite extends FunSuite {
test("empty output") {
val o = new ByteArrayChunkOutputStream(1024)
assert(o.toArrays.length === 0)
}
test("write a single byte") {
val o = new ByteArrayChunkOutputStream(1024)
o.write(10)
assert(o.toArrays.length === 1)
assert(o.toArrays.head.toSeq === Seq(10.toByte))
}
test("write a single near boundary") {
val o = new ByteArrayChunkOutputStream(10)
o.write(new Array[Byte](9))
o.write(99)
assert(o.toArrays.length === 1)
assert(o.toArrays.head(9) === 99.toByte)
}
test("write a single at boundary") {
val o = new ByteArrayChunkOutputStream(10)
o.write(new Array[Byte](10))
o.write(99)
assert(o.toArrays.length === 2)
assert(o.toArrays(1).length === 1)
assert(o.toArrays(1)(0) === 99.toByte)
}
test("single chunk output") {
val ref = new Array[Byte](8)
Random.nextBytes(ref)
val o = new ByteArrayChunkOutputStream(10)
o.write(ref)
val arrays = o.toArrays
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
assert(arrays.head.toSeq === ref.toSeq)
}
test("single chunk output at boundary size") {
val ref = new Array[Byte](10)
Random.nextBytes(ref)
val o = new ByteArrayChunkOutputStream(10)
o.write(ref)
val arrays = o.toArrays
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
assert(arrays.head.toSeq === ref.toSeq)
}
test("multiple chunk output") {
val ref = new Array[Byte](26)
Random.nextBytes(ref)
val o = new ByteArrayChunkOutputStream(10)
o.write(ref)
val arrays = o.toArrays
assert(arrays.length === 3)
assert(arrays(0).length === 10)
assert(arrays(1).length === 10)
assert(arrays(2).length === 6)
assert(arrays(0).toSeq === ref.slice(0, 10))
assert(arrays(1).toSeq === ref.slice(10, 20))
assert(arrays(2).toSeq === ref.slice(20, 26))
}
test("multiple chunk output at boundary size") {
val ref = new Array[Byte](30)
Random.nextBytes(ref)
val o = new ByteArrayChunkOutputStream(10)
o.write(ref)
val arrays = o.toArrays
assert(arrays.length === 3)
assert(arrays(0).length === 10)
assert(arrays(1).length === 10)
assert(arrays(2).length === 10)
assert(arrays(0).toSeq === ref.slice(0, 10))
assert(arrays(1).toSeq === ref.slice(10, 20))
assert(arrays(2).toSeq === ref.slice(20, 30))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment