diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index ba6896dda3526bc6df20fd5fdb671369b24022ff..6831f9b7f8b95aac5e82f7d16cb0597289a086a8 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -1,8 +1,6 @@
 import os
-import atexit
 import shutil
 import sys
-import tempfile
 from threading import Lock
 from tempfile import NamedTemporaryFile
 
@@ -94,6 +92,11 @@ class SparkContext(object):
         SparkFiles._sc = self
         sys.path.append(SparkFiles.getRootDirectory())
 
+        # Create a temporary directory inside spark.local.dir:
+        local_dir = self._jvm.spark.Utils.getLocalDir()
+        self._temp_dir = \
+            self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
+
     @property
     def defaultParallelism(self):
         """
@@ -126,8 +129,7 @@ class SparkContext(object):
         # Calling the Java parallelize() method with an ArrayList is too slow,
         # because it sends O(n) Py4J commands.  As an alternative, serialized
         # objects are written to a file and loaded through textFile().
-        tempFile = NamedTemporaryFile(delete=False)
-        atexit.register(lambda: os.unlink(tempFile.name))
+        tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
         if self.batchSize != 1:
             c = batched(c, self.batchSize)
         for x in c:
@@ -247,7 +249,9 @@ class SparkContext(object):
 
 
 def _test():
+    import atexit
     import doctest
+    import tempfile
     globs = globals().copy()
     globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
     globs['tempdir'] = tempfile.mkdtemp()
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d7cad2f3725aa2c8c0a5b53a8061c4c1779139d2..41ea6e6e14c07b9c044f9e54372a80947dd46349 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1,4 +1,3 @@
-import atexit
 from base64 import standard_b64encode as b64enc
 import copy
 from collections import defaultdict
@@ -264,12 +263,8 @@ class RDD(object):
         # Transferring lots of data through Py4J can be slow because
         # socket.readline() is inefficient.  Instead, we'll dump the data to a
         # file and read it back.
-        tempFile = NamedTemporaryFile(delete=False)
+        tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
         tempFile.close()
-        def clean_up_file():
-            try: os.unlink(tempFile.name)
-            except: pass
-        atexit.register(clean_up_file)
         self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile: