From e211f405bcb3cf02c3ae589cf81d9c9dfc70bc03 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Fri, 1 Feb 2013 11:48:11 -0800
Subject: [PATCH] Use spark.local.dir for PySpark temp files (SPARK-580).

---
 python/pyspark/context.py | 12 ++++++++----
 python/pyspark/rdd.py     |  7 +------
 2 files changed, 9 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index ba6896dda3..6831f9b7f8 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,8 +1,6 @@
 import os
-import atexit
 import shutil
 import sys
-import tempfile
 from threading import Lock
 from tempfile import NamedTemporaryFile
 
@@ -94,6 +92,11 @@ class SparkContext(object):
         SparkFiles._sc = self
         sys.path.append(SparkFiles.getRootDirectory())
 
+        # Create a temporary directory inside spark.local.dir:
+        local_dir = self._jvm.spark.Utils.getLocalDir()
+        self._temp_dir = \
+            self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
+
     @property
     def defaultParallelism(self):
         """
@@ -126,8 +129,7 @@ class SparkContext(object):
         # Calling the Java parallelize() method with an ArrayList is too slow,
         # because it sends O(n) Py4J commands.  As an alternative, serialized
         # objects are written to a file and loaded through textFile().
-        tempFile = NamedTemporaryFile(delete=False)
-        atexit.register(lambda: os.unlink(tempFile.name))
+        tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
         if self.batchSize != 1:
             c = batched(c, self.batchSize)
         for x in c:
@@ -247,7 +249,9 @@ class SparkContext(object):
 
 
 def _test():
+    import atexit
     import doctest
+    import tempfile
     globs = globals().copy()
     globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
     globs['tempdir'] = tempfile.mkdtemp()
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d7cad2f372..41ea6e6e14 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1,4 +1,3 @@
-import atexit
 from base64 import standard_b64encode as b64enc
 import copy
 from collections import defaultdict
@@ -264,12 +263,8 @@ class RDD(object):
         # Transferring lots of data through Py4J can be slow because
         # socket.readline() is inefficient.  Instead, we'll dump the data to a
         # file and read it back.
-        tempFile = NamedTemporaryFile(delete=False)
+        tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
         tempFile.close()
-        def clean_up_file():
-            try: os.unlink(tempFile.name)
-            except: pass
-        atexit.register(clean_up_file)
         self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
-- 
GitLab