diff --git a/core/src/main/scala/spark/KryoSerialization.scala b/core/src/main/scala/spark/KryoSerializer.scala similarity index 92% rename from core/src/main/scala/spark/KryoSerialization.scala rename to core/src/main/scala/spark/KryoSerializer.scala index ba34a5452adc453c846da134edd8030681083113..658b7b7d0f4938a59ac8d7882c022a2a64ff7a1f 100644 --- a/core/src/main/scala/spark/KryoSerialization.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -10,6 +10,10 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} +/** + * Zig-zag encoder used to write object sizes to serialization streams. + * Based on Kryo's integer encoder. + */ object ZigZag { def writeInt(n: Int, out: OutputStream) { var value = n @@ -110,12 +114,15 @@ trait KryoRegistrator { class KryoSerializer extends Serializer with Logging { val kryo = createKryo() + val bufferSize = + System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + val threadBuf = new ThreadLocal[ObjectBuffer] { - override def initialValue = new ObjectBuffer(kryo, 257*1024*1024) + override def initialValue = new ObjectBuffer(kryo, bufferSize) } val threadByteBuf = new ThreadLocal[ByteBuffer] { - override def initialValue = ByteBuffer.allocate(257*1024*1024) + override def initialValue = ByteBuffer.allocate(bufferSize) } def createKryo(): Kryo = { diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala index 9c00ab5d0834bb2184070f545763f304a03a3b2c..974346e3670d6ceaf4b54816f9d0019b884182db 100644 --- a/core/src/main/scala/spark/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/ShuffleMapTask.scala @@ -6,7 +6,7 @@ import scala.collection.mutable.HashMap class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String]) -extends DAGTask[String](stageId) { +extends DAGTask[String](stageId) with Logging { val split = rdd.splits(partition) override def run: String = { @@ -23,11 +23,12 @@ extends DAGTask[String](stageId) { case None => aggregator.createCombiner(v) } } + val ser = SparkEnv.get.serializer.newInstance() for (i <- 0 until numOutputSplits) { val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i) // TODO: use Serializer instead of ObjectInputStream // TODO: have some kind of EOF marker - val out = new ObjectOutputStream(new FileOutputStream(file)) + val out = ser.outputStream(new FileOutputStream(file)) buckets(i).foreach(pair => out.writeObject(pair)) out.close() } diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala index ac5b979753a54f2bce70cedbc775d1293ef28a00..d9f1351f2203fcebc23e755b9e226c28dbe6598e 100644 --- a/core/src/main/scala/spark/SimpleShuffleFetcher.scala +++ b/core/src/main/scala/spark/SimpleShuffleFetcher.scala @@ -11,6 +11,7 @@ import scala.collection.mutable.HashMap class SimpleShuffleFetcher extends ShuffleFetcher with Logging { def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) + val ser = SparkEnv.get.serializer.newInstance() val splitsByUri = new HashMap[String, ArrayBuffer[Int]] val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) for ((serverUri, index) <- serverUris.zipWithIndex) { @@ -20,10 +21,9 @@ class SimpleShuffleFetcher extends ShuffleFetcher with Logging { for (i <- inputIds) { try { val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId) - // TODO: use Serializer instead of ObjectInputStream // TODO: multithreaded fetch // TODO: would be nice to retry multiple times - val inputStream = new ObjectInputStream(new URL(url).openStream()) + val inputStream = ser.inputStream(new URL(url).openStream()) try { while (true) { val pair = inputStream.readObject().asInstanceOf[(K, V)] diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 1bfd0172d7438b79acd5707272d89ed6674317f5..ec48564d438d03ac594f159df96440ceeff7fc93 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -19,12 +19,10 @@ object SparkEnv { } def createFromSystemProperties(isMaster: Boolean): SparkEnv = { - val cacheClass = System.getProperty("spark.cache.class", - "spark.SoftReferenceCache") + val cacheClass = System.getProperty("spark.cache.class", "spark.SoftReferenceCache") val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] - val serClass = System.getProperty("spark.serializer", - "spark.JavaSerializer") + val serClass = System.getProperty("spark.serializer", "spark.JavaSerializer") val ser = Class.forName(serClass).newInstance().asInstanceOf[Serializer] val cacheTracker = new CacheTracker(isMaster, cache)