From 91c07a33d90ab0357e8713507134ecef5c14e28a Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Sat, 21 May 2011 22:50:08 -0700
Subject: [PATCH] Various fixes to serialization

---
 .../{KryoSerialization.scala => KryoSerializer.scala} | 11 +++++++++--
 core/src/main/scala/spark/ShuffleMapTask.scala        |  5 +++--
 core/src/main/scala/spark/SimpleShuffleFetcher.scala  |  4 ++--
 core/src/main/scala/spark/SparkEnv.scala              |  6 ++----
 4 files changed, 16 insertions(+), 10 deletions(-)
 rename core/src/main/scala/spark/{KryoSerialization.scala => KryoSerializer.scala} (92%)

diff --git a/core/src/main/scala/spark/KryoSerialization.scala b/core/src/main/scala/spark/KryoSerializer.scala
similarity index 92%
rename from core/src/main/scala/spark/KryoSerialization.scala
rename to core/src/main/scala/spark/KryoSerializer.scala
index ba34a5452a..658b7b7d0f 100644
--- a/core/src/main/scala/spark/KryoSerialization.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -10,6 +10,10 @@ import scala.collection.mutable
 import com.esotericsoftware.kryo._
 import com.esotericsoftware.kryo.{Serializer => KSerializer}
 
+/**
+ * Zig-zag encoder used to write object sizes to serialization streams.
+ * Based on Kryo's integer encoder.
+ */
 object ZigZag {
   def writeInt(n: Int, out: OutputStream) {
     var value = n
@@ -110,12 +114,15 @@ trait KryoRegistrator {
 class KryoSerializer extends Serializer with Logging {
   val kryo = createKryo()
 
+  val bufferSize = 
+    System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 
+
   val threadBuf = new ThreadLocal[ObjectBuffer] {
-    override def initialValue = new ObjectBuffer(kryo, 257*1024*1024)
+    override def initialValue = new ObjectBuffer(kryo, bufferSize)
   }
 
   val threadByteBuf = new ThreadLocal[ByteBuffer] {
-    override def initialValue = ByteBuffer.allocate(257*1024*1024)
+    override def initialValue = ByteBuffer.allocate(bufferSize)
   }
 
   def createKryo(): Kryo = {
diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala
index 9c00ab5d08..974346e367 100644
--- a/core/src/main/scala/spark/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/ShuffleMapTask.scala
@@ -6,7 +6,7 @@ import scala.collection.mutable.HashMap
 
 
 class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String])
-extends DAGTask[String](stageId) {
+extends DAGTask[String](stageId) with Logging {
   val split = rdd.splits(partition)
 
   override def run: String = {
@@ -23,11 +23,12 @@ extends DAGTask[String](stageId) {
         case None => aggregator.createCombiner(v)
       }
     }
+    val ser = SparkEnv.get.serializer.newInstance()
     for (i <- 0 until numOutputSplits) {
       val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i)
       // TODO: use Serializer instead of ObjectInputStream
       // TODO: have some kind of EOF marker
-      val out = new ObjectOutputStream(new FileOutputStream(file))
+      val out = ser.outputStream(new FileOutputStream(file))
       buckets(i).foreach(pair => out.writeObject(pair))
       out.close()
     }
diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
index ac5b979753..d9f1351f22 100644
--- a/core/src/main/scala/spark/SimpleShuffleFetcher.scala
+++ b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
@@ -11,6 +11,7 @@ import scala.collection.mutable.HashMap
 class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
   def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
     logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
+    val ser = SparkEnv.get.serializer.newInstance()
     val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
     val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
     for ((serverUri, index) <- serverUris.zipWithIndex) {
@@ -20,10 +21,9 @@ class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
       for (i <- inputIds) {
         try {
           val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
-          // TODO: use Serializer instead of ObjectInputStream
           // TODO: multithreaded fetch
           // TODO: would be nice to retry multiple times
-          val inputStream = new ObjectInputStream(new URL(url).openStream())
+          val inputStream = ser.inputStream(new URL(url).openStream())
           try {
             while (true) {
               val pair = inputStream.readObject().asInstanceOf[(K, V)]
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 1bfd0172d7..ec48564d43 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -19,12 +19,10 @@ object SparkEnv {
   }
 
   def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
-    val cacheClass = System.getProperty("spark.cache.class",
-      "spark.SoftReferenceCache")
+    val cacheClass = System.getProperty("spark.cache.class", "spark.SoftReferenceCache")
     val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
     
-    val serClass = System.getProperty("spark.serializer",
-      "spark.JavaSerializer")
+    val serClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
     val ser = Class.forName(serClass).newInstance().asInstanceOf[Serializer]
 
     val cacheTracker = new CacheTracker(isMaster, cache)
-- 
GitLab