Skip to content
Snippets Groups Projects
Commit 35168d9c authored by Josh Rosen's avatar Josh Rosen
Browse files

Fix sys.path bug in PySpark SparkContext.addPyFile

parent 7b9e96c9
No related branches found
No related tags found
No related merge requests found
...@@ -215,8 +215,6 @@ class SparkContext(object): ...@@ -215,8 +215,6 @@ class SparkContext(object):
""" """
self.addFile(path) self.addFile(path)
filename = path.split("/")[-1] filename = path.split("/")[-1]
os.environ["PYTHONPATH"] = \
"%s:%s" % (filename, os.environ["PYTHONPATH"])
def setCheckpointDir(self, dirName, useExisting=False): def setCheckpointDir(self, dirName, useExisting=False):
""" """
......
...@@ -9,21 +9,32 @@ import time ...@@ -9,21 +9,32 @@ import time
import unittest import unittest
from pyspark.context import SparkContext from pyspark.context import SparkContext
from pyspark.java_gateway import SPARK_HOME
class TestCheckpoint(unittest.TestCase): class PySparkTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) class_name = self.__class__.__name__
self.checkpointDir = NamedTemporaryFile(delete=False) self.sc = SparkContext('local[4]', class_name , batchSize=2)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self): def tearDown(self):
self.sc.stop() self.sc.stop()
# To avoid Akka rebinding to the same port, since it doesn't unbind # To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown # immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port") self.sc.jvm.System.clearProperty("spark.master.port")
class TestCheckpoint(PySparkTestCase):
def setUp(self):
PySparkTestCase.setUp(self)
self.checkpointDir = NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
PySparkTestCase.tearDown(self)
shutil.rmtree(self.checkpointDir.name) shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self): def test_basic_checkpointing(self):
...@@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase): ...@@ -57,5 +68,22 @@ class TestCheckpoint(unittest.TestCase):
self.assertEquals([1, 2, 3, 4], recovered.collect()) self.assertEquals([1, 2, 3, 4], recovered.collect())
class TestAddFile(PySparkTestCase):
def test_add_py_file(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this job fails due to `userlibrary` not being on the Python path:
def func(x):
from userlibrary import UserClass
return UserClass().hello()
self.assertRaises(Exception,
self.sc.parallelize(range(2)).map(func).first)
# Add the file, so the job should now succeed:
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
self.sc.addPyFile(path)
res = self.sc.parallelize(range(2)).map(func).first()
self.assertEqual("Hello World!", res)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -26,6 +26,7 @@ def main(): ...@@ -26,6 +26,7 @@ def main():
split_index = read_int(sys.stdin) split_index = read_int(sys.stdin)
spark_files_dir = load_pickle(read_with_length(sys.stdin)) spark_files_dir = load_pickle(read_with_length(sys.stdin))
SparkFiles._root_directory = spark_files_dir SparkFiles._root_directory = spark_files_dir
sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin) num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables): for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin) bid = read_long(sys.stdin)
......
"""
Used to test shipping of code depenencies with SparkContext.addPyFile().
"""
class UserClass(object):
def hello(self):
return "Hello World!"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment