From 52989c8a2c8c10d7f5610c033f6782e58fd3abc2 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Fri, 19 Oct 2012 10:24:49 -0700
Subject: [PATCH] Update Python API for v0.6.0 compatibility.

---
 .../scala/spark/api/python/PythonRDD.scala     | 18 +++++++++++-------
 .../main/scala/spark/broadcast/Broadcast.scala |  2 +-
 pyspark/pyspark/broadcast.py                   | 18 +++++++++---------
 pyspark/pyspark/context.py                     |  2 +-
 pyspark/pyspark/java_gateway.py                |  3 ++-
 pyspark/pyspark/serializers.py                 | 18 ++++++++++++++----
 pyspark/pyspark/worker.py                      |  8 ++++----
 7 files changed, 42 insertions(+), 27 deletions(-)

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 4d3bdb3963..528885fe5c 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -5,11 +5,15 @@ import java.io._
 import scala.collection.Map
 import scala.collection.JavaConversions._
 import scala.io.Source
-import spark._
-import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
-import broadcast.Broadcast
-import scala.collection
-import java.nio.charset.Charset
+
+import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import spark.broadcast.Broadcast
+import spark.SparkEnv
+import spark.Split
+import spark.RDD
+import spark.OneToOneDependency
+import spark.rdd.PipedRDD
+
 
 trait PythonRDDBase {
   def compute[T](split: Split, envVars: Map[String, String],
@@ -43,9 +47,9 @@ trait PythonRDDBase {
         SparkEnv.set(env)
         val out = new PrintWriter(proc.getOutputStream)
         val dOut = new DataOutputStream(proc.getOutputStream)
-        out.println(broadcastVars.length)
+        dOut.writeInt(broadcastVars.length)
         for (broadcast <- broadcastVars) {
-          out.print(broadcast.uuid.toString)
+          dOut.writeLong(broadcast.id)
           dOut.writeInt(broadcast.value.length)
           dOut.write(broadcast.value)
           dOut.flush()
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index 6055bfd045..2ffe7f741d 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong
 
 import spark._
 
-abstract class Broadcast[T](id: Long) extends Serializable {
+abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
   def value: T
 
   // We cannot have an abstract readObject here due to some weird issues with
diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py
index 1ea17d59af..4cff02b36d 100644
--- a/pyspark/pyspark/broadcast.py
+++ b/pyspark/pyspark/broadcast.py
@@ -6,7 +6,7 @@
 [1, 2, 3, 4, 5]
 
 >>> from pyspark.broadcast import _broadcastRegistry
->>> _broadcastRegistry[b.uuid] = b
+>>> _broadcastRegistry[b.bid] = b
 >>> from cPickle import dumps, loads
 >>> loads(dumps(b)).value
 [1, 2, 3, 4, 5]
@@ -14,27 +14,27 @@
 >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
 [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
 """
-# Holds broadcasted data received from Java, keyed by UUID.
+# Holds broadcasted data received from Java, keyed by its id.
 _broadcastRegistry = {}
 
 
-def _from_uuid(uuid):
+def _from_id(bid):
     from pyspark.broadcast import _broadcastRegistry
-    if uuid not in _broadcastRegistry:
-        raise Exception("Broadcast variable '%s' not loaded!" % uuid)
-    return _broadcastRegistry[uuid]
+    if bid not in _broadcastRegistry:
+        raise Exception("Broadcast variable '%s' not loaded!" % bid)
+    return _broadcastRegistry[bid]
 
 
 class Broadcast(object):
-    def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None):
+    def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
         self.value = value
-        self.uuid = uuid
+        self.bid = bid
         self._jbroadcast = java_broadcast
         self._pickle_registry = pickle_registry
 
     def __reduce__(self):
         self._pickle_registry.add(self)
-        return (_from_uuid, (self.uuid, ))
+        return (_from_id, (self.bid, ))
 
 
 def _test():
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 04932c93f2..3f4db26644 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -66,5 +66,5 @@ class SparkContext(object):
 
     def broadcast(self, value):
         jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
-        return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
+        return Broadcast(jbroadcast.id(), value, jbroadcast,
                          self._pickled_broadcast_vars)
diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py
index bcb405ba72..3726bcbf17 100644
--- a/pyspark/pyspark/java_gateway.py
+++ b/pyspark/pyspark/java_gateway.py
@@ -7,7 +7,8 @@ SPARK_HOME = os.environ["SPARK_HOME"]
 
 
 assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \
-    "/spark-core-assembly-*-SNAPSHOT.jar")[0]
+    "/spark-core-assembly-*.jar")[0]
+    # TODO: what if multiple assembly jars are found?
 
 
 def launch_gateway():
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index faa1e683c7..21ef8b106c 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -9,16 +9,26 @@ def dump_pickle(obj):
 load_pickle = cPickle.loads
 
 
+def read_long(stream):
+    length = stream.read(8)
+    if length == "":
+        raise EOFError
+    return struct.unpack("!q", length)[0]
+
+
+def read_int(stream):
+    length = stream.read(4)
+    if length == "":
+        raise EOFError
+    return struct.unpack("!i", length)[0]
+
 def write_with_length(obj, stream):
     stream.write(struct.pack("!i", len(obj)))
     stream.write(obj)
 
 
 def read_with_length(stream):
-    length = stream.read(4)
-    if length == "":
-        raise EOFError
-    length = struct.unpack("!i", length)[0]
+    length = read_int(stream)
     obj = stream.read(length)
     if obj == "":
         raise EOFError
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
index a9ed71892f..62824a1c9b 100644
--- a/pyspark/pyspark/worker.py
+++ b/pyspark/pyspark/worker.py
@@ -8,7 +8,7 @@ from base64 import standard_b64decode
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.serializers import write_with_length, read_with_length, \
-    dump_pickle, load_pickle
+    read_long, read_int, dump_pickle, load_pickle
 
 
 # Redirect stdout to stderr so that users must return values from functions.
@@ -29,11 +29,11 @@ def read_input():
 
 
 def main():
-    num_broadcast_variables = int(sys.stdin.readline().strip())
+    num_broadcast_variables = read_int(sys.stdin)
     for _ in range(num_broadcast_variables):
-        uuid = sys.stdin.read(36)
+        bid = read_long(sys.stdin)
         value = read_with_length(sys.stdin)
-        _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value))
+        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
     func = load_obj()
     bypassSerializer = load_obj()
     if bypassSerializer:
-- 
GitLab