From c7e348faec45ad1d996d16639015c4bc4fc3bc92 Mon Sep 17 00:00:00 2001
From: Andre Schumacher <schumach@icsi.berkeley.edu>
Date: Thu, 15 Aug 2013 16:01:19 -0700
Subject: [PATCH] Implementing SPARK-878 for PySpark: adding zip and egg files
 to context and passing it down to workers which add these to their sys.path

---
 .../main/scala/spark/api/python/PythonRDD.scala  |   9 ++++++++-
 python/pyspark/context.py                        |  14 +++++++++++---
 python/pyspark/rdd.py                            |   4 +++-
 python/pyspark/tests.py                          |  11 +++++++++++
 python/pyspark/worker.py                         |  13 ++++++++++++-
 python/test_support/userlib-0.1-py2.7.egg        | Bin 0 -> 1945 bytes
 6 files changed, 45 insertions(+), 6 deletions(-)
 create mode 100644 python/test_support/userlib-0.1-py2.7.egg

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 2dd79f7100..49671437d0 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -33,6 +33,7 @@ private[spark] class PythonRDD[T: ClassManifest](
     parent: RDD[T],
     command: Seq[String],
     envVars: JMap[String, String],
+    pythonIncludes: JList[String],
     preservePartitoning: Boolean,
     pythonExec: String,
     broadcastVars: JList[Broadcast[Array[Byte]]],
@@ -44,10 +45,11 @@ private[spark] class PythonRDD[T: ClassManifest](
   // Similar to Runtime.exec(), if we are given a single string, split it into words
   // using a standard StringTokenizer (i.e. by spaces)
   def this(parent: RDD[T], command: String, envVars: JMap[String, String],
+      pythonIncludes: JList[String],
       preservePartitoning: Boolean, pythonExec: String,
       broadcastVars: JList[Broadcast[Array[Byte]]],
       accumulator: Accumulator[JList[Array[Byte]]]) =
-    this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
+    this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
       broadcastVars, accumulator)
 
   override def getPartitions = parent.partitions
@@ -79,6 +81,11 @@ private[spark] class PythonRDD[T: ClassManifest](
             dataOut.writeInt(broadcast.value.length)
             dataOut.write(broadcast.value)
           }
+          // Python includes (*.zip and *.egg files)
+          dataOut.writeInt(pythonIncludes.length)
+          for (f <- pythonIncludes) {
+            PythonRDD.writeAsPickle(f, dataOut)
+          }
           dataOut.flush()
           // Serialized user code
           for (elem <- command) {
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index c2b49ff37a..2803ce90f3 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -46,6 +46,7 @@ class SparkContext(object):
     _next_accum_id = 0
     _active_spark_context = None
     _lock = Lock()
+    _python_includes = None # zip and egg files that need to be added to PYTHONPATH
 
     def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
         environment=None, batchSize=1024):
@@ -103,11 +104,14 @@ class SparkContext(object):
         # send.
         self._pickled_broadcast_vars = set()
 
+        SparkFiles._sc = self
+        root_dir = SparkFiles.getRootDirectory()
+        sys.path.append(root_dir)
+
         # Deploy any code dependencies specified in the constructor
+        self._python_includes = list()
         for path in (pyFiles or []):
             self.addPyFile(path)
-        SparkFiles._sc = self
-        sys.path.append(SparkFiles.getRootDirectory())
 
         # Create a temporary directory inside spark.local.dir:
         local_dir = self._jvm.spark.Utils.getLocalDir()
@@ -257,7 +261,11 @@ class SparkContext(object):
         HTTP, HTTPS or FTP URI.
         """
         self.addFile(path)
-        filename = path.split("/")[-1]
+        (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
+
+        if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
+            self._python_includes.append(filename)
+            sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
 
     def setCheckpointDir(self, dirName, useExisting=False):
         """
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 51c2cb9806..99f5967a8e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -758,8 +758,10 @@ class PipelinedRDD(RDD):
         class_manifest = self._prev_jrdd.classManifest()
         env = MapConverter().convert(self.ctx.environment,
                                      self.ctx._gateway._gateway_client)
+        includes = ListConverter().convert(self.ctx._python_includes,
+                                     self.ctx._gateway._gateway_client)
         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
-            pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
+            pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
             broadcast_vars, self.ctx._javaAccumulator, class_manifest)
         self._jrdd_val = python_rdd.asJavaRDD()
         return self._jrdd_val
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f75215a781..29d6a128f6 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -125,6 +125,17 @@ class TestAddFile(PySparkTestCase):
         from userlibrary import UserClass
         self.assertEqual("Hello World!", UserClass().hello())
 
+    def test_add_egg_file_locally(self):
+        # To ensure that we're actually testing addPyFile's effects, check that
+        # this fails due to `userlibrary` not being on the Python path:
+        def func():
+            from userlib import UserClass
+        self.assertRaises(ImportError, func)
+        path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg")
+        self.sc.addPyFile(path)
+        from userlib import UserClass
+        self.assertEqual("Hello World from inside a package!", UserClass().hello())
+
 
 class TestIO(PySparkTestCase):
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 75d692beeb..695f6dfb84 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -49,15 +49,26 @@ def main(infile, outfile):
     split_index = read_int(infile)
     if split_index == -1:  # for unit tests
         return
+
+    # fetch name of workdir
     spark_files_dir = load_pickle(read_with_length(infile))
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
-    sys.path.append(spark_files_dir)
+
+    # fetch names and values of broadcast variables
     num_broadcast_variables = read_int(infile)
     for _ in range(num_broadcast_variables):
         bid = read_long(infile)
         value = read_with_length(infile)
         _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+
+    # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
+    sys.path.append(spark_files_dir) # *.py files that were added will be copied here
+    num_python_includes =  read_int(infile)
+    for _ in range(num_python_includes):
+        sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+
+    # now load function
     func = load_obj(infile)
     bypassSerializer = load_obj(infile)
     if bypassSerializer:
diff --git a/python/test_support/userlib-0.1-py2.7.egg b/python/test_support/userlib-0.1-py2.7.egg
new file mode 100644
index 0000000000000000000000000000000000000000..1674c9cb2227e160aec55c74e9ba1ddba850f580
GIT binary patch
literal 1945
zcmWIWW@Zs#U|`^2Sdq%-{OPC4zkVQZ3lQ@IacOaCQBG!(er{rBo?bzvZr?^OCPM+2
z-!(^Pyju5>Ifws9kif>ilh5?x&a}Ph`t|+UfpA{q4-(mHcXF|6ID3k0S`rmJ^Pu27
zX2pP|PDj;+_Dt5>5pw-nXZI!d?|Lntzt1q@eSGhRp0S;^U`hJyZ6+W7I>sGuc2o~I
z;$-<*VantEcQTf&4@>aB-66eQMnKwrV%L&?hZb7x-!-%B+8>4hxZ7LO_?&A_9oL=$
zbn|W?2Kfl)_W1bByv&mLc%b`}lmDGLz4eV;d4LhKhymk9W&!bThE5TMq?8kh2`5rh
zI+9Wnd=is5t|hQ}rTk-O;&oG4-6HJKBDP%2^0|tb`0vbyFVi<Lbly2~W{<*zPN@|x
z8|5Z=Xq`RddC}kNq>kRj(<jepuf1^AWV69#e_#LW`mD32PM*0$UVi?}>5)S6>%#*g
z4>4Xznml`c(5%T>?3=boqzIt-qG4rZelQ~gLjn^6g8-5*pfQlVbi#eF!v-Slm#_J*
zt#~NyvVK)UwfP<aL3Sl2=O~vuzMiu$>*@J=O@3Im_WVolA58N~>g61`pSxe0_vE)<
z^ZdsjUjEk5J^74tL*7Q7qYqkq3~ia@R@xk@{p5N=T`@m+{)dK?i}POB9N~!PDabP2
znxJE9d~uu8)sXNlS7bEW#Y9#FH7t}rbS~rCraZfArw&~@*7N!K)x8r0<{n>Yd*t<p
zZFxE~4XyS0U!S|^XTY%6CcWKi#iX=}zhkOq+l41zfBofxU3X24&l$;Q%pNOWu70Po
zJaoCV))c{v6B(yZS+0I)Qncfr)9rrtO;P)c67>waS|`|O?l3T$=A=7qdQyq0*MfDE
zS-ZRgRd|K9^1{m+-Ws>d>^fX$@A>54M%x1?ijPX2D3<;<+wF7fk=$4Lh4pO9-<?S~
zEmJkwv}FJDf>@=?k}n_0dd;3x|M^$py6en$SghCm=SC!>>%c6)9BG#!3k-N^AQnc7
z0HkE(t$pfz=n3C5XM#_h(LJHB7099Mr(1N+Q%CO^6IygS{lCb@1vG^Nh{b@|)!kj!
z)6dOcza+mPJ}0#-HAk<cqC^|0|I|sJN9ejIvt2i02I>N)2L?f8-Bp<dy2XiUsgLpL
z0QmxLUrK60YF<ieUUFr8PG(+qG00jB-8W}?^dtgxtpH+B7J};z_74qmb`1u*)lXMb
z_k#9{Ct46kU+V8af9AAKXTQ7HX<(2zKYG1%+49$E&z3%Y&0ewlb=9w(Ka)?f{BCEn
z+|BlSYRHrXuVqWbmONX!<e6&tvF9nzQpJ<QLnp-x%%2=JZ5Hz+HnaeIaORC_DbOjK
zffywS0=yw1`Po<3)5lZC^9rxGuGYCT=Qjr#Trqy|$@`4&)*uaNxHfr9zK|?po&M~Z
znpkL*X3MY2-#^=b7gj!bzU-Bl%ChB3Ei&HTMxVMb{M@y2XXi{yj==e6o?O|Jq>`(*
z^5{v|nXIY_yAzriEjv{+f4c7E5CpoBkx7IZ6hzp|O`rrAY-t2hu#yOxhmkd7E4Uz9
zfrW769wg03=`R`G1oT1!VL}ry>7ZGUq8nR^N9bk-CO(*MB>T~=M^EGk0|J2tz!MQl
zl1DcKJ*gwi=tnjKmhkau2c>%$*wT0qSv$5|fNm6eCO{bVAK56REP-wUdYVO;Fo^{z
YJYXp}z?+o~q=XX)eSyWk87qhf04R`+K>z>%

literal 0
HcmV?d00001

-- 
GitLab