From 35168d9c89904f0dc0bb470c1799f5ca3b04221f Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Tue, 22 Jan 2013 17:54:11 -0800
Subject: [PATCH] Fix sys.path bug in PySpark SparkContext.addPyFile

---
 python/pyspark/context.py          |  2 --
 python/pyspark/tests.py            | 38 ++++++++++++++++++++++++++----
 python/pyspark/worker.py           |  1 +
 python/test_support/userlibrary.py |  7 ++++++
 4 files changed, 41 insertions(+), 7 deletions(-)
 create mode 100755 python/test_support/userlibrary.py

diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index ec0cc7c2f9..b8d7dc05af 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -215,8 +215,6 @@ class SparkContext(object):
         """
         self.addFile(path)
         filename = path.split("/")[-1]
-        os.environ["PYTHONPATH"] = \
-            "%s:%s" % (filename, os.environ["PYTHONPATH"])
 
     def setCheckpointDir(self, dirName, useExisting=False):
         """
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b0a403b580..4d70ee4f12 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -9,21 +9,32 @@ import time
 import unittest
 
 from pyspark.context import SparkContext
+from pyspark.java_gateway import SPARK_HOME
 
 
-class TestCheckpoint(unittest.TestCase):
+class PySparkTestCase(unittest.TestCase):
 
     def setUp(self):
-        self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
-        self.checkpointDir = NamedTemporaryFile(delete=False)
-        os.unlink(self.checkpointDir.name)
-        self.sc.setCheckpointDir(self.checkpointDir.name)
+        class_name = self.__class__.__name__
+        self.sc = SparkContext('local[4]', class_name , batchSize=2)
 
     def tearDown(self):
         self.sc.stop()
         # To avoid Akka rebinding to the same port, since it doesn't unbind
         # immediately on shutdown
         self.sc.jvm.System.clearProperty("spark.master.port")
+
+
+class TestCheckpoint(PySparkTestCase):
+
+    def setUp(self):
+        PySparkTestCase.setUp(self)
+        self.checkpointDir = NamedTemporaryFile(delete=False)
+        os.unlink(self.checkpointDir.name)
+        self.sc.setCheckpointDir(self.checkpointDir.name)
+
+    def tearDown(self):
+        PySparkTestCase.tearDown(self)
         shutil.rmtree(self.checkpointDir.name)
 
     def test_basic_checkpointing(self):
@@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase):
         self.assertEquals([1, 2, 3, 4], recovered.collect())
 
 
+class TestAddFile(PySparkTestCase):
+
+    def test_add_py_file(self):
+        # To ensure that we're actually testing addPyFile's effects, check that
+        # this job fails due to `userlibrary` not being on the Python path:
+        def func(x):
+            from userlibrary import UserClass
+            return UserClass().hello()
+        self.assertRaises(Exception,
+                          self.sc.parallelize(range(2)).map(func).first)
+        # Add the file, so the job should now succeed:
+        path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+        self.sc.addPyFile(path)
+        res = self.sc.parallelize(range(2)).map(func).first()
+        self.assertEqual("Hello World!", res)
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index e7bdb7682b..4bf643da66 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -26,6 +26,7 @@ def main():
     split_index = read_int(sys.stdin)
     spark_files_dir = load_pickle(read_with_length(sys.stdin))
     SparkFiles._root_directory = spark_files_dir
+    sys.path.append(spark_files_dir)
     num_broadcast_variables = read_int(sys.stdin)
     for _ in range(num_broadcast_variables):
         bid = read_long(sys.stdin)
diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py
new file mode 100755
index 0000000000..5bb6f5009f
--- /dev/null
+++ b/python/test_support/userlibrary.py
@@ -0,0 +1,7 @@
+"""
+Used to test shipping of code depenencies with SparkContext.addPyFile().
+"""
+
+class UserClass(object):
+    def hello(self):
+        return "Hello World!"
-- 
GitLab