diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ba6896dda3526bc6df20fd5fdb671369b24022ff..6831f9b7f8b95aac5e82f7d16cb0597289a086a8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1,8 +1,6 @@ import os -import atexit import shutil import sys -import tempfile from threading import Lock from tempfile import NamedTemporaryFile @@ -94,6 +92,11 @@ class SparkContext(object): SparkFiles._sc = self sys.path.append(SparkFiles.getRootDirectory()) + # Create a temporary directory inside spark.local.dir: + local_dir = self._jvm.spark.Utils.getLocalDir() + self._temp_dir = \ + self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath() + @property def defaultParallelism(self): """ @@ -126,8 +129,7 @@ class SparkContext(object): # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False) - atexit.register(lambda: os.unlink(tempFile.name)) + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) if self.batchSize != 1: c = batched(c, self.batchSize) for x in c: @@ -247,7 +249,9 @@ class SparkContext(object): def _test(): + import atexit import doctest + import tempfile globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['tempdir'] = tempfile.mkdtemp() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d7cad2f3725aa2c8c0a5b53a8061c4c1779139d2..41ea6e6e14c07b9c044f9e54372a80947dd46349 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1,4 +1,3 @@ -import atexit from base64 import standard_b64encode as b64enc import copy from collections import defaultdict @@ -264,12 +263,8 @@ class RDD(object): # Transferring lots of data through Py4J can be slow because # socket.readline() is inefficient. Instead, we'll dump the data to a # file and read it back. - tempFile = NamedTemporaryFile(delete=False) + tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - def clean_up_file(): - try: os.unlink(tempFile.name) - except: pass - atexit.register(clean_up_file) self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: