From 3446d5c8d6b385106ac85e46320d92faa8efb4e6 Mon Sep 17 00:00:00 2001
From: Patrick Wendell <pwendell@gmail.com>
Date: Thu, 31 Jan 2013 18:02:28 -0800
Subject: [PATCH] SPARK-673: Capture and re-throw Python exceptions

This patch alters the Python <-> executor protocol to pass on
exception data when they occur in user Python code.
---
 .../scala/spark/api/python/PythonRDD.scala    | 40 ++++++++++++-------
 python/pyspark/worker.py                      | 10 ++++-
 2 files changed, 34 insertions(+), 16 deletions(-)

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index f43a152ca7..6b9ef62529 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -103,21 +103,30 @@ private[spark] class PythonRDD[T: ClassManifest](
 
       private def read(): Array[Byte] = {
         try {
-          val length = stream.readInt()
-          if (length != -1) {
-            val obj = new Array[Byte](length)
-            stream.readFully(obj)
-            obj
-          } else {
-            // We've finished the data section of the output, but we can still read some
-            // accumulator updates; let's do that, breaking when we get EOFException
-            while (true) {
-              val len2 = stream.readInt()
-              val update = new Array[Byte](len2)
-              stream.readFully(update)
-              accumulator += Collections.singletonList(update)
+          stream.readInt() match {
+            case length if length > 0 => {
+              val obj = new Array[Byte](length)
+              stream.readFully(obj)
+              obj
             }
-            new Array[Byte](0)
+            case -2 => {
+              // Signals that an exception has been thrown in python
+              val exLength = stream.readInt()
+              val obj = new Array[Byte](exLength)
+              stream.readFully(obj)
+              throw new PythonException(new String(obj))
+            }
+            case -1 => {
+              // We've finished the data section of the output, but we can still read some
+              // accumulator updates; let's do that, breaking when we get EOFException
+              while (true) {
+                val len2 = stream.readInt()
+                val update = new Array[Byte](len2)
+                stream.readFully(update)
+                accumulator += Collections.singletonList(update)
+              }
+              new Array[Byte](0)
+          }
           }
         } catch {
           case eof: EOFException => {
@@ -140,6 +149,9 @@ private[spark] class PythonRDD[T: ClassManifest](
   val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
 }
 
+/** Thrown for exceptions in user Python code. */
+private class PythonException(msg: String) extends Exception(msg)
+
 /**
  * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
  * This is used by PySpark's shuffle operations.
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d33d6dd15f..9622e0cfe4 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -2,6 +2,7 @@
 Worker that receives input from Piped RDD.
 """
 import sys
+import traceback
 from base64 import standard_b64decode
 # CloudPickler needs to be imported so that depicklers are registered using the
 # copy_reg module.
@@ -40,8 +41,13 @@ def main():
     else:
         dumps = dump_pickle
     iterator = read_from_pickle_file(sys.stdin)
-    for obj in func(split_index, iterator):
-        write_with_length(dumps(obj), old_stdout)
+    try:
+        for obj in func(split_index, iterator):
+           write_with_length(dumps(obj), old_stdout)
+    except Exception as e:
+        write_int(-2, old_stdout)
+        write_with_length(traceback.format_exc(), old_stdout)
+        sys.exit(-1)
     # Mark the beginning of the accumulators section of the output
     write_int(-1, old_stdout)
     for aid, accum in _accumulatorRegistry.items():
-- 
GitLab