diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java index b59d8ce93f23d96ba2bfcd82e5a328c765ea1fb3..566aec622c096076f3d91dbd433297d9f2c53a8e 100644 --- a/core/src/main/scala/spark/SparkFiles.java +++ b/core/src/main/scala/spark/SparkFiles.java @@ -3,23 +3,23 @@ package spark; import java.io.File; /** - * Resolves paths to files added through `addFile(). + * Resolves paths to files added through `SparkContext.addFile()`. */ public class SparkFiles { private SparkFiles() {} /** - * Get the absolute path of a file added through `addFile()`. + * Get the absolute path of a file added through `SparkContext.addFile()`. */ public static String get(String filename) { return new File(getRootDirectory(), filename).getAbsolutePath(); } /** - * Get the root directory that contains files added through `addFile()`. + * Get the root directory that contains files added through `SparkContext.addFile()`. */ public static String getRootDirectory() { return SparkEnv.get().sparkFilesDir(); } -} \ No newline at end of file +} diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b8d7dc05af43d2fdd236d88a10813ebfe1e3aef8..3e33776af0add33499645513229fa6ea5f55811d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,12 +1,15 @@ import os import atexit import shutil +import sys import tempfile +from threading import Lock from tempfile import NamedTemporaryFile from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast +from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD @@ -27,6 +30,8 @@ class SparkContext(object): _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition _next_accum_id = 0 + _active_spark_context = None + _lock = Lock() def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -46,6 +51,11 @@ class SparkContext(object): Java object. Set 1 to disable batching or -1 to use an unlimited batch size. """ + with SparkContext._lock: + if SparkContext._active_spark_context: + raise ValueError("Cannot run multiple SparkContexts at once") + else: + SparkContext._active_spark_context = self self.master = master self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J @@ -75,6 +85,8 @@ class SparkContext(object): # Deploy any code dependencies specified in the constructor for path in (pyFiles or []): self.addPyFile(path) + SparkFiles._sc = self + sys.path.append(SparkFiles.getRootDirectory()) @property def defaultParallelism(self): @@ -85,17 +97,20 @@ class SparkContext(object): return self._jsc.sc().defaultParallelism() def __del__(self): - if self._jsc: - self._jsc.stop() - if self._accumulatorServer: - self._accumulatorServer.shutdown() + self.stop() def stop(self): """ Shut down the SparkContext. """ - self._jsc.stop() - self._jsc = None + if self._jsc: + self._jsc.stop() + self._jsc = None + if self._accumulatorServer: + self._accumulatorServer.shutdown() + self._accumulatorServer = None + with SparkContext._lock: + SparkContext._active_spark_context = None def parallelize(self, c, numSlices=None): """ diff --git a/python/pyspark/files.py b/python/pyspark/files.py index de1334f046c6b96819a071482f1d6b376f7fbb04..98f6a399cc26b53835db185d694b4a0176371e98 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -4,13 +4,15 @@ import os class SparkFiles(object): """ Resolves paths to files added through - L{addFile()<pyspark.context.SparkContext.addFile>}. + L{SparkContext.addFile()<pyspark.context.SparkContext.addFile>}. SparkFiles contains only classmethods; users should not create SparkFiles instances. """ _root_directory = None + _is_running_on_worker = False + _sc = None def __init__(self): raise NotImplementedError("Do not construct SparkFiles objects") @@ -18,7 +20,19 @@ class SparkFiles(object): @classmethod def get(cls, filename): """ - Get the absolute path of a file added through C{addFile()}. + Get the absolute path of a file added through C{SparkContext.addFile()}. """ - path = os.path.join(SparkFiles._root_directory, filename) + path = os.path.join(SparkFiles.getRootDirectory(), filename) return os.path.abspath(path) + + @classmethod + def getRootDirectory(cls): + """ + Get the root directory that contains files added through + C{SparkContext.addFile()}. + """ + if cls._is_running_on_worker: + return cls._root_directory + else: + # This will have to change if we support multiple SparkContexts: + return cls._sc.jvm.spark.SparkFiles.getRootDirectory() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4d70ee4f125006312438a9184339a1dc57c8b5cf..46ab34f063b2b9a739186f2dda91c4b02fbf9dcc 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -4,22 +4,26 @@ individual modules. """ import os import shutil +import sys from tempfile import NamedTemporaryFile import time import unittest from pyspark.context import SparkContext +from pyspark.files import SparkFiles from pyspark.java_gateway import SPARK_HOME class PySparkTestCase(unittest.TestCase): def setUp(self): + self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ self.sc = SparkContext('local[4]', class_name , batchSize=2) def tearDown(self): self.sc.stop() + sys.path = self._old_sys_path # To avoid Akka rebinding to the same port, since it doesn't unbind # immediately on shutdown self.sc.jvm.System.clearProperty("spark.master.port") @@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase): res = self.sc.parallelize(range(2)).map(func).first() self.assertEqual("Hello World!", res) + def test_add_file_locally(self): + path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + self.sc.addFile(path) + download_path = SparkFiles.get("hello.txt") + self.assertNotEqual(path, download_path) + with open(download_path) as test_file: + self.assertEquals("Hello World!\n", test_file.readline()) + + def test_add_py_file_locally(self): + # To ensure that we're actually testing addPyFile's effects, check that + # this fails due to `userlibrary` not being on the Python path: + def func(): + from userlibrary import UserClass + self.assertRaises(ImportError, func) + path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") + self.sc.addFile(path) + from userlibrary import UserClass + self.assertEqual("Hello World!", UserClass().hello()) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4bf643da66d1dea1cdda65b055cd5fbf50145ad1..d33d6dd15f0de69a851fc66a40adfe1ec78d03f4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -26,6 +26,7 @@ def main(): split_index = read_int(sys.stdin) spark_files_dir = load_pickle(read_with_length(sys.stdin)) SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True sys.path.append(spark_files_dir) num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt new file mode 100755 index 0000000000000000000000000000000000000000..980a0d5f19a64b4b30a87d4206aade58726b60e3 --- /dev/null +++ b/python/test_support/hello.txt @@ -0,0 +1 @@ +Hello World!