diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 6173fd3a69fc7545a700c4048a43d3458bcb02f5..42d58682a1e235e7deb11d4979fb25f9804da51c 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -28,6 +28,7 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
 import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.io.ByteArrayChunkOutputStream
 
 /**
  * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
@@ -201,29 +202,12 @@ private object TorrentBroadcast extends Logging {
   }
 
   def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
-    // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
-    // so we don't need to do the extra memory copy.
-    val bos = new ByteArrayOutputStream()
+    val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE)
     val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
     val ser = SparkEnv.get.serializer.newInstance()
     val serOut = ser.serializeStream(out)
     serOut.writeObject[T](obj).close()
-    val byteArray = bos.toByteArray
-    val bais = new ByteArrayInputStream(byteArray)
-    val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
-    val blocks = new Array[ByteBuffer](numBlocks)
-
-    var blockId = 0
-    for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
-      val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
-      val tempByteArray = new Array[Byte](thisBlockSize)
-      bais.read(tempByteArray, 0, thisBlockSize)
-
-      blocks(blockId) = ByteBuffer.wrap(tempByteArray)
-      blockId += 1
-    }
-    bais.close()
-    blocks
+    bos.toArrays.map(ByteBuffer.wrap)
   }
 
   def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
new file mode 100644
index 0000000000000000000000000000000000000000..daac6f971eb20ee835750c9c39f215e33a6d1ade
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.io
+
+import java.io.OutputStream
+
+import scala.collection.mutable.ArrayBuffer
+
+
+/**
+ * An OutputStream that writes to fixed-size chunks of byte arrays.
+ *
+ * @param chunkSize size of each chunk, in bytes.
+ */
+private[spark]
+class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
+
+  private val chunks = new ArrayBuffer[Array[Byte]]
+
+  /** Index of the last chunk. Starting with -1 when the chunks array is empty. */
+  private var lastChunkIndex = -1
+
+  /**
+   * Next position to write in the last chunk.
+   *
+   * If this equals chunkSize, it means for next write we need to allocate a new chunk.
+   * This can also never be 0.
+   */
+  private var position = chunkSize
+
+  override def write(b: Int): Unit = {
+    allocateNewChunkIfNeeded()
+    chunks(lastChunkIndex)(position) = b.toByte
+    position += 1
+  }
+
+  override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
+    var written = 0
+    while (written < len) {
+      allocateNewChunkIfNeeded()
+      val thisBatch = math.min(chunkSize - position, len - written)
+      System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch)
+      written += thisBatch
+      position += thisBatch
+    }
+  }
+
+  @inline
+  private def allocateNewChunkIfNeeded(): Unit = {
+    if (position == chunkSize) {
+      chunks += new Array[Byte](chunkSize)
+      lastChunkIndex += 1
+      position = 0
+    }
+  }
+
+  def toArrays: Array[Array[Byte]] = {
+    if (lastChunkIndex == -1) {
+      new Array[Array[Byte]](0)
+    } else {
+      // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
+      // An alternative would have been returning an array of ByteBuffers, with the last buffer
+      // bounded to only the last chunk's position. However, given our use case in Spark (to put
+      // the chunks in block manager), only limiting the view bound of the buffer would still
+      // require the block manager to store the whole chunk.
+      val ret = new Array[Array[Byte]](chunks.size)
+      for (i <- 0 until chunks.size - 1) {
+        ret(i) = chunks(i)
+      }
+      if (position == chunkSize) {
+        ret(lastChunkIndex) = chunks(lastChunkIndex)
+      } else {
+        ret(lastChunkIndex) = new Array[Byte](position)
+        System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position)
+      }
+      ret
+    }
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..f855831b8e3673689a9c58802b00c75f9b6aedf3
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.io
+
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+
+class ByteArrayChunkOutputStreamSuite extends FunSuite {
+
+  test("empty output") {
+    val o = new ByteArrayChunkOutputStream(1024)
+    assert(o.toArrays.length === 0)
+  }
+
+  test("write a single byte") {
+    val o = new ByteArrayChunkOutputStream(1024)
+    o.write(10)
+    assert(o.toArrays.length === 1)
+    assert(o.toArrays.head.toSeq === Seq(10.toByte))
+  }
+
+  test("write a single near boundary") {
+    val o = new ByteArrayChunkOutputStream(10)
+    o.write(new Array[Byte](9))
+    o.write(99)
+    assert(o.toArrays.length === 1)
+    assert(o.toArrays.head(9) === 99.toByte)
+  }
+
+  test("write a single at boundary") {
+    val o = new ByteArrayChunkOutputStream(10)
+    o.write(new Array[Byte](10))
+    o.write(99)
+    assert(o.toArrays.length === 2)
+    assert(o.toArrays(1).length === 1)
+    assert(o.toArrays(1)(0) === 99.toByte)
+  }
+
+  test("single chunk output") {
+    val ref = new Array[Byte](8)
+    Random.nextBytes(ref)
+    val o = new ByteArrayChunkOutputStream(10)
+    o.write(ref)
+    val arrays = o.toArrays
+    assert(arrays.length === 1)
+    assert(arrays.head.length === ref.length)
+    assert(arrays.head.toSeq === ref.toSeq)
+  }
+
+  test("single chunk output at boundary size") {
+    val ref = new Array[Byte](10)
+    Random.nextBytes(ref)
+    val o = new ByteArrayChunkOutputStream(10)
+    o.write(ref)
+    val arrays = o.toArrays
+    assert(arrays.length === 1)
+    assert(arrays.head.length === ref.length)
+    assert(arrays.head.toSeq === ref.toSeq)
+  }
+
+  test("multiple chunk output") {
+    val ref = new Array[Byte](26)
+    Random.nextBytes(ref)
+    val o = new ByteArrayChunkOutputStream(10)
+    o.write(ref)
+    val arrays = o.toArrays
+    assert(arrays.length === 3)
+    assert(arrays(0).length === 10)
+    assert(arrays(1).length === 10)
+    assert(arrays(2).length === 6)
+
+    assert(arrays(0).toSeq === ref.slice(0, 10))
+    assert(arrays(1).toSeq === ref.slice(10, 20))
+    assert(arrays(2).toSeq === ref.slice(20, 26))
+  }
+
+  test("multiple chunk output at boundary size") {
+    val ref = new Array[Byte](30)
+    Random.nextBytes(ref)
+    val o = new ByteArrayChunkOutputStream(10)
+    o.write(ref)
+    val arrays = o.toArrays
+    assert(arrays.length === 3)
+    assert(arrays(0).length === 10)
+    assert(arrays(1).length === 10)
+    assert(arrays(2).length === 10)
+
+    assert(arrays(0).toSeq === ref.slice(0, 10))
+    assert(arrays(1).toSeq === ref.slice(10, 20))
+    assert(arrays(2).toSeq === ref.slice(20, 30))
+  }
+}