From 2d25e34814f81f288587f3277324cb655a5fb38d Mon Sep 17 00:00:00 2001
From: Ankur Dave <ankurdave@gmail.com>
Date: Wed, 23 Jul 2014 20:11:28 -0700
Subject: [PATCH] Replace RoutingTableMessage with pair

RoutingTableMessage was used to construct routing tables to enable
joining VertexRDDs with partitioned edges. It stored three elements: the
destination vertex ID, the source edge partition, and a byte specifying
the position in which the edge partition referenced the vertex to enable
join elimination.

However, this was incompatible with sort-based shuffle (SPARK-2045). It
was also slightly wasteful, because partition IDs are usually much
smaller than 2^32, though this was mitigated by a custom serializer that
used variable-length encoding.

This commit replaces RoutingTableMessage with a pair of (VertexId, Int)
where the Int encodes both the source partition ID (in the lower 30
bits) and the position (in the top 2 bits).

Author: Ankur Dave <ankurdave@gmail.com>

Closes #1553 from ankurdave/remove-RoutingTableMessage and squashes the following commits:

697e17b [Ankur Dave] Replace RoutingTableMessage with pair
---
 .../spark/graphx/GraphKryoRegistrator.scala   |  1 -
 .../graphx/impl/RoutingTablePartition.scala   | 47 +++++++++++--------
 .../spark/graphx/impl/Serializers.scala       | 16 +++----
 .../org/apache/spark/graphx/package.scala     |  2 +-
 4 files changed, 36 insertions(+), 30 deletions(-)

diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
index eea9fe9520..1948c978c3 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -35,7 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator {
 
   def registerClasses(kryo: Kryo) {
     kryo.register(classOf[Edge[Object]])
-    kryo.register(classOf[RoutingTableMessage])
     kryo.register(classOf[(VertexId, Object)])
     kryo.register(classOf[EdgePartition[Object, Object]])
     kryo.register(classOf[BitSet])
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index 502b112d31..a565d3b28b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector}
 import org.apache.spark.graphx._
 import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
-/**
- * A message from the edge partition `pid` to the vertex partition containing `vid` specifying that
- * the edge partition references `vid` in the specified `position` (src, dst, or both).
-*/
-private[graphx]
-class RoutingTableMessage(
-    var vid: VertexId,
-    var pid: PartitionID,
-    var position: Byte)
-  extends Product2[VertexId, (PartitionID, Byte)] with Serializable {
-  override def _1 = vid
-  override def _2 = (pid, position)
-  override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage]
-}
+import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
 
 private[graphx]
 class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
   /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
   def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
-    new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage](
+    new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
       self, partitioner).setSerializer(new RoutingTableMessageSerializer)
   }
 }
@@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions {
 
 private[graphx]
 object RoutingTablePartition {
+  /**
+   * A message from an edge partition to a vertex specifying the position in which the edge
+   * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower
+   * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.
+   */
+  type RoutingTableMessage = (VertexId, Int)
+
+  private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = {
+    val positionUpper2 = position << 30
+    val pidLower30 = pid & 0x3FFFFFFF
+    (vid, positionUpper2 | pidLower30)
+  }
+
+  private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1
+  private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF
+  private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte
+
   val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty)
 
   /** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */
@@ -77,7 +81,9 @@ object RoutingTablePartition {
       map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
     }
     map.iterator.map { vidAndPosition =>
-      new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2)
+      val vid = vidAndPosition._1
+      val position = vidAndPosition._2
+      toMessage(vid, pid, position)
     }
   }
 
@@ -88,9 +94,12 @@ object RoutingTablePartition {
     val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
     val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
     for (msg <- iter) {
-      pid2vid(msg.pid) += msg.vid
-      srcFlags(msg.pid) += (msg.position & 0x1) != 0
-      dstFlags(msg.pid) += (msg.position & 0x2) != 0
+      val vid = vidFromMessage(msg)
+      val pid = pidFromMessage(msg)
+      val position = positionFromMessage(msg)
+      pid2vid(pid) += vid
+      srcFlags(pid) += (position & 0x1) != 0
+      dstFlags(pid) += (position & 0x2) != 0
     }
 
     new RoutingTablePartition(pid2vid.zipWithIndex.map {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
index 2d98c24d69..3909efcdfc 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
@@ -24,9 +24,11 @@ import java.nio.ByteBuffer
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.graphx._
 import org.apache.spark.serializer._
 
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
+
 private[graphx]
 class RoutingTableMessageSerializer extends Serializer with Serializable {
   override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
@@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
       new ShuffleSerializationStream(s) {
         def writeObject[T: ClassTag](t: T): SerializationStream = {
           val msg = t.asInstanceOf[RoutingTableMessage]
-          writeVarLong(msg.vid, optimizePositive = false)
-          writeUnsignedVarInt(msg.pid)
-          // TODO: Write only the bottom two bits of msg.position
-          s.write(msg.position)
+          writeVarLong(msg._1, optimizePositive = false)
+          writeInt(msg._2)
           this
         }
       }
@@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
       new ShuffleDeserializationStream(s) {
         override def readObject[T: ClassTag](): T = {
           val a = readVarLong(optimizePositive = false)
-          val b = readUnsignedVarInt()
-          val c = s.read()
-          if (c == -1) throw new EOFException
-          new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T]
+          val b = readInt()
+          (a, b).asInstanceOf[T]
         }
       }
   }
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
index ff17edeaf8..6aab28ff05 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/package.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala
@@ -30,7 +30,7 @@ package object graphx {
    */
   type VertexId = Long
 
-  /** Integer identifer of a graph partition. */
+  /** Integer identifer of a graph partition. Must be less than 2^30. */
   // TODO: Consider using Char.
   type PartitionID = Int
 
-- 
GitLab