Skip to content
Snippets Groups Projects
Commit f8306849 authored by Josh Rosen's avatar Josh Rosen
Browse files

Fix for SPARK-1025: PySpark hang on missing files.

parent 61569906
No related branches found
No related tags found
No related merge requests found
...@@ -52,6 +52,8 @@ private[spark] class PythonRDD[T: ClassTag]( ...@@ -52,6 +52,8 @@ private[spark] class PythonRDD[T: ClassTag](
val env = SparkEnv.get val env = SparkEnv.get
val worker = env.createPythonWorker(pythonExec, envVars.toMap) val worker = env.createPythonWorker(pythonExec, envVars.toMap)
@volatile var readerException: Exception = null
// Start a thread to feed the process input from our parent's iterator // Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) { new Thread("stdin writer for " + pythonExec) {
override def run() { override def run() {
...@@ -82,6 +84,10 @@ private[spark] class PythonRDD[T: ClassTag]( ...@@ -82,6 +84,10 @@ private[spark] class PythonRDD[T: ClassTag](
dataOut.flush() dataOut.flush()
worker.shutdownOutput() worker.shutdownOutput()
} catch { } catch {
case e: java.io.FileNotFoundException =>
readerException = e
// Kill the Python worker process:
worker.shutdownOutput()
case e: IOException => case e: IOException =>
// This can happen for legitimate reasons if the Python code stops returning data before we are done // This can happen for legitimate reasons if the Python code stops returning data before we are done
// passing elements through, e.g., for take(). Just log a message to say it happened. // passing elements through, e.g., for take(). Just log a message to say it happened.
...@@ -106,6 +112,9 @@ private[spark] class PythonRDD[T: ClassTag]( ...@@ -106,6 +112,9 @@ private[spark] class PythonRDD[T: ClassTag](
} }
private def read(): Array[Byte] = { private def read(): Array[Byte] = {
if (readerException != null) {
throw readerException
}
try { try {
stream.readInt() match { stream.readInt() match {
case length if length > 0 => case length if length > 0 =>
......
...@@ -168,6 +168,17 @@ class TestRDDFunctions(PySparkTestCase): ...@@ -168,6 +168,17 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEqual("Hello World!", x.strip()) self.assertEqual("Hello World!", x.strip())
self.assertEqual("Hello World!", y.strip()) self.assertEqual("Hello World!", y.strip())
def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = NamedTemporaryFile(delete=False)
tempFile.write("Hello World!")
tempFile.close()
data = self.sc.textFile(tempFile.name)
filtered_data = data.filter(lambda x: True)
self.assertEqual(1, filtered_data.count())
os.unlink(tempFile.name)
self.assertRaises(Exception, lambda: filtered_data.count())
class TestIO(PySparkTestCase): class TestIO(PySparkTestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment