diff --git a/core/src/main/scala/spark/KryoSerialization.scala b/core/src/main/scala/spark/KryoSerializer.scala
similarity index 92%
rename from core/src/main/scala/spark/KryoSerialization.scala
rename to core/src/main/scala/spark/KryoSerializer.scala
index ba34a5452adc453c846da134edd8030681083113..658b7b7d0f4938a59ac8d7882c022a2a64ff7a1f 100644
--- a/core/src/main/scala/spark/KryoSerialization.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -10,6 +10,10 @@ import scala.collection.mutable
 import com.esotericsoftware.kryo._
 import com.esotericsoftware.kryo.{Serializer => KSerializer}
 
+/**
+ * Zig-zag encoder used to write object sizes to serialization streams.
+ * Based on Kryo's integer encoder.
+ */
 object ZigZag {
   def writeInt(n: Int, out: OutputStream) {
     var value = n
@@ -110,12 +114,15 @@ trait KryoRegistrator {
 class KryoSerializer extends Serializer with Logging {
   val kryo = createKryo()
 
+  val bufferSize = 
+    System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 
+
   val threadBuf = new ThreadLocal[ObjectBuffer] {
-    override def initialValue = new ObjectBuffer(kryo, 257*1024*1024)
+    override def initialValue = new ObjectBuffer(kryo, bufferSize)
   }
 
   val threadByteBuf = new ThreadLocal[ByteBuffer] {
-    override def initialValue = ByteBuffer.allocate(257*1024*1024)
+    override def initialValue = ByteBuffer.allocate(bufferSize)
   }
 
   def createKryo(): Kryo = {
diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala
index 9c00ab5d0834bb2184070f545763f304a03a3b2c..974346e3670d6ceaf4b54816f9d0019b884182db 100644
--- a/core/src/main/scala/spark/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/ShuffleMapTask.scala
@@ -6,7 +6,7 @@ import scala.collection.mutable.HashMap
 
 
 class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String])
-extends DAGTask[String](stageId) {
+extends DAGTask[String](stageId) with Logging {
   val split = rdd.splits(partition)
 
   override def run: String = {
@@ -23,11 +23,12 @@ extends DAGTask[String](stageId) {
         case None => aggregator.createCombiner(v)
       }
     }
+    val ser = SparkEnv.get.serializer.newInstance()
     for (i <- 0 until numOutputSplits) {
       val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i)
       // TODO: use Serializer instead of ObjectInputStream
       // TODO: have some kind of EOF marker
-      val out = new ObjectOutputStream(new FileOutputStream(file))
+      val out = ser.outputStream(new FileOutputStream(file))
       buckets(i).foreach(pair => out.writeObject(pair))
       out.close()
     }
diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
index ac5b979753a54f2bce70cedbc775d1293ef28a00..d9f1351f2203fcebc23e755b9e226c28dbe6598e 100644
--- a/core/src/main/scala/spark/SimpleShuffleFetcher.scala
+++ b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
@@ -11,6 +11,7 @@ import scala.collection.mutable.HashMap
 class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
   def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
     logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
+    val ser = SparkEnv.get.serializer.newInstance()
     val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
     val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
     for ((serverUri, index) <- serverUris.zipWithIndex) {
@@ -20,10 +21,9 @@ class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
       for (i <- inputIds) {
         try {
           val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
-          // TODO: use Serializer instead of ObjectInputStream
           // TODO: multithreaded fetch
           // TODO: would be nice to retry multiple times
-          val inputStream = new ObjectInputStream(new URL(url).openStream())
+          val inputStream = ser.inputStream(new URL(url).openStream())
           try {
             while (true) {
               val pair = inputStream.readObject().asInstanceOf[(K, V)]
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 1bfd0172d7438b79acd5707272d89ed6674317f5..ec48564d438d03ac594f159df96440ceeff7fc93 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -19,12 +19,10 @@ object SparkEnv {
   }
 
   def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
-    val cacheClass = System.getProperty("spark.cache.class",
-      "spark.SoftReferenceCache")
+    val cacheClass = System.getProperty("spark.cache.class", "spark.SoftReferenceCache")
     val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
     
-    val serClass = System.getProperty("spark.serializer",
-      "spark.JavaSerializer")
+    val serClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
     val ser = Class.forName(serClass).newInstance().asInstanceOf[Serializer]
 
     val cacheTracker = new CacheTracker(isMaster, cache)