From f24bfd2dd1f5c271b05ac9f166b9d1b6d938a440 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@cs.berkeley.edu>
Date: Tue, 27 Nov 2012 19:20:45 -0800
Subject: [PATCH] For size compression, compress non zero values into non zero
 values.

---
 .../main/scala/spark/MapOutputTracker.scala   | 29 ++++++++++---------
 .../scala/spark/MapOutputTrackerSuite.scala   |  4 +--
 2 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 45441aa5e5..fcf725a255 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -2,6 +2,10 @@ package spark
 
 import java.io._
 import java.util.concurrent.ConcurrentHashMap
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
 
 import akka.actor._
 import akka.dispatch._
@@ -11,16 +15,13 @@ import akka.util.Duration
 import akka.util.Timeout
 import akka.util.duration._
 
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
-import scheduler.MapStatus
+import spark.scheduler.MapStatus
 import spark.storage.BlockManagerId
-import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
 
 private[spark] sealed trait MapOutputTrackerMessage
 private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
-  extends MapOutputTrackerMessage 
+  extends MapOutputTrackerMessage
 private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
 
 private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
@@ -88,14 +89,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
     }
     mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
   }
-  
+
   def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
     var array = mapStatuses.get(shuffleId)
     array.synchronized {
       array(mapId) = status
     }
   }
-  
+
   def registerMapOutputs(
       shuffleId: Int,
       statuses: Array[MapStatus],
@@ -119,10 +120,10 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
       throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
     }
   }
-  
+
   // Remembers which map output locations are currently being fetched on a worker
   val fetching = new HashSet[Int]
-  
+
   // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
   def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
     val statuses = mapStatuses.get(shuffleId)
@@ -149,7 +150,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
       val host = System.getProperty("spark.hostname", Utils.localHostName)
       val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
       val fetchedStatuses = deserializeStatuses(fetchedBytes)
-      
+
       logInfo("Got the output locations")
       mapStatuses.put(shuffleId, fetchedStatuses)
       fetching.synchronized {
@@ -254,8 +255,10 @@ private[spark] object MapOutputTracker {
    * sizes up to 35 GB with at most 10% error.
    */
   def compressSize(size: Long): Byte = {
-    if (size <= 1L) {
+    if (size == 0) {
       0
+    } else if (size <= 1L) {
+      1
     } else {
       math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
     }
@@ -266,7 +269,7 @@ private[spark] object MapOutputTracker {
    */
   def decompressSize(compressedSize: Byte): Long = {
     if (compressedSize == 0) {
-      1
+      0
     } else {
       math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
     }
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 4e9717d871..dee45b6e8f 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -5,7 +5,7 @@ import org.scalatest.FunSuite
 class MapOutputTrackerSuite extends FunSuite {
   test("compressSize") {
     assert(MapOutputTracker.compressSize(0L) === 0)
-    assert(MapOutputTracker.compressSize(1L) === 0)
+    assert(MapOutputTracker.compressSize(1L) === 1)
     assert(MapOutputTracker.compressSize(2L) === 8)
     assert(MapOutputTracker.compressSize(10L) === 25)
     assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145)
@@ -15,7 +15,7 @@ class MapOutputTrackerSuite extends FunSuite {
   }
 
   test("decompressSize") {
-    assert(MapOutputTracker.decompressSize(0) === 1)
+    assert(MapOutputTracker.decompressSize(0) === 0)
     for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) {
       val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size))
       assert(size2 >= 0.99 * size && size2 <= 1.11 * size,
-- 
GitLab