From 3c434cbfd0d6821e5bcf572be792b787a514018b Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Wed, 20 May 2015 16:21:23 -0700
Subject: [PATCH] [SPARK-7767] [STREAMING] Added test for checkpoint
 serialization in StreamingContext.start()

Currently, the background checkpointing thread fails silently if the checkpoint is not serializable. It is hard to debug and therefore its best to fail fast at `start()` when checkpointing is enabled and the checkpoint is not serializable.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #6292 from tdas/SPARK-7767 and squashes the following commits:

51304e6 [Tathagata Das] Addressed comments.
c35237b [Tathagata Das] Added test for checkpoint serialization in StreamingContext.start()
---
 .../serializer/SerializationDebugger.scala    |  2 +-
 .../apache/spark/streaming/Checkpoint.scala   | 70 +++++++++++--------
 .../spark/streaming/StreamingContext.scala    | 26 ++++++-
 .../streaming/StreamingContextSuite.scala     | 27 +++++--
 4 files changed, 89 insertions(+), 36 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index 5abfa467c0..bb5db54553 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -27,7 +27,7 @@ import scala.util.control.NonFatal
 
 import org.apache.spark.Logging
 
-private[serializer] object SerializationDebugger extends Logging {
+private[spark] object SerializationDebugger extends Logging {
 
   /**
    * Improve the given NotSerializableException with the serialization path leading from the given
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 7bfae253c3..d8dc4e4101 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -102,6 +102,44 @@ object Checkpoint extends Logging {
       Seq.empty
     }
   }
+
+  /** Serialize the checkpoint, or throw any exception that occurs */
+  def serialize(checkpoint: Checkpoint, conf: SparkConf): Array[Byte] = {
+    val compressionCodec = CompressionCodec.createCodec(conf)
+    val bos = new ByteArrayOutputStream()
+    val zos = compressionCodec.compressedOutputStream(bos)
+    val oos = new ObjectOutputStream(zos)
+    Utils.tryWithSafeFinally {
+      oos.writeObject(checkpoint)
+    } {
+      oos.close()
+    }
+    bos.toByteArray
+  }
+
+  /** Deserialize a checkpoint from the input stream, or throw any exception that occurs */
+  def deserialize(inputStream: InputStream, conf: SparkConf): Checkpoint = {
+    val compressionCodec = CompressionCodec.createCodec(conf)
+    var ois: ObjectInputStreamWithLoader = null
+    Utils.tryWithSafeFinally {
+
+      // ObjectInputStream uses the last defined user-defined class loader in the stack
+      // to find classes, which maybe the wrong class loader. Hence, a inherited version
+      // of ObjectInputStream is used to explicitly use the current thread's default class
+      // loader to find and load classes. This is a well know Java issue and has popped up
+      // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+      val zis = compressionCodec.compressedInputStream(inputStream)
+      ois = new ObjectInputStreamWithLoader(zis,
+        Thread.currentThread().getContextClassLoader)
+      val cp = ois.readObject.asInstanceOf[Checkpoint]
+      cp.validate()
+      cp
+    } {
+      if (ois != null) {
+        ois.close()
+      }
+    }
+  }
 }
 
 
@@ -189,17 +227,10 @@ class CheckpointWriter(
   }
 
   def write(checkpoint: Checkpoint, clearCheckpointDataLater: Boolean) {
-    val bos = new ByteArrayOutputStream()
-    val zos = compressionCodec.compressedOutputStream(bos)
-    val oos = new ObjectOutputStream(zos)
-    Utils.tryWithSafeFinally {
-      oos.writeObject(checkpoint)
-    } {
-      oos.close()
-    }
     try {
+      val bytes = Checkpoint.serialize(checkpoint, conf)
       executor.execute(new CheckpointWriteHandler(
-        checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater))
+        checkpoint.checkpointTime, bytes, clearCheckpointDataLater))
       logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
     } catch {
       case rej: RejectedExecutionException =>
@@ -264,25 +295,8 @@ object CheckpointReader extends Logging {
     checkpointFiles.foreach(file => {
       logInfo("Attempting to load checkpoint from file " + file)
       try {
-        var ois: ObjectInputStreamWithLoader = null
-        var cp: Checkpoint = null
-        Utils.tryWithSafeFinally {
-          val fis = fs.open(file)
-          // ObjectInputStream uses the last defined user-defined class loader in the stack
-          // to find classes, which maybe the wrong class loader. Hence, a inherited version
-          // of ObjectInputStream is used to explicitly use the current thread's default class
-          // loader to find and load classes. This is a well know Java issue and has popped up
-          // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
-          val zis = compressionCodec.compressedInputStream(fis)
-          ois = new ObjectInputStreamWithLoader(zis,
-            Thread.currentThread().getContextClassLoader)
-          cp = ois.readObject.asInstanceOf[Checkpoint]
-        } {
-          if (ois != null) {
-            ois.close()
-          }
-        }
-        cp.validate()
+        val fis = fs.open(file)
+        val cp = Checkpoint.deserialize(fis, conf)
         logInfo("Checkpoint successfully loaded from file " + file)
         logInfo("Checkpoint was generated at time " + cp.checkpointTime)
         return Some(cp)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index fe614c4be5..95063692e1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.streaming
 
-import java.io.InputStream
+import java.io.{InputStream, NotSerializableException}
 import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
 
 import scala.collection.Map
@@ -35,6 +35,7 @@ import org.apache.spark._
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.input.FixedLengthBinaryInputFormat
 import org.apache.spark.rdd.{RDD, RDDOperationScope}
+import org.apache.spark.serializer.SerializationDebugger
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming.StreamingContextState._
 import org.apache.spark.streaming.dstream._
@@ -235,6 +236,10 @@ class StreamingContext private[streaming] (
     }
   }
 
+  private[streaming] def isCheckpointingEnabled: Boolean = {
+    checkpointDir != null
+  }
+
   private[streaming] def initialCheckpoint: Checkpoint = {
     if (isCheckpointPresent) cp_ else null
   }
@@ -523,11 +528,26 @@ class StreamingContext private[streaming] (
     assert(graph != null, "Graph is null")
     graph.validate()
 
-    assert(
-      checkpointDir == null || checkpointDuration != null,
+    require(
+      !isCheckpointingEnabled || checkpointDuration != null,
       "Checkpoint directory has been set, but the graph checkpointing interval has " +
         "not been set. Please use StreamingContext.checkpoint() to set the interval."
     )
+
+    // Verify whether the DStream checkpoint is serializable
+    if (isCheckpointingEnabled) {
+      val checkpoint = new Checkpoint(this, Time.apply(0))
+      try {
+        Checkpoint.serialize(checkpoint, conf)
+      } catch {
+        case e: NotSerializableException =>
+          throw new NotSerializableException(
+            "DStream checkpointing has been enabled but the DStreams with their functions " +
+              "are not serializable\nSerialization stack:\n" +
+              SerializationDebugger.find(checkpoint).map("\t- " + _).mkString("\n")
+          )
+      }
+    }
   }
 
   /**
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 4b12affbb0..3a958bf3a3 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -17,21 +17,21 @@
 
 package org.apache.spark.streaming
 
-import java.io.File
+import java.io.{File, NotSerializableException}
 import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.commons.io.FileUtils
-import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
-import org.scalatest.concurrent.Timeouts
 import org.scalatest.concurrent.Eventually._
+import org.scalatest.concurrent.Timeouts
 import org.scalatest.exceptions.TestFailedDueToTimeoutException
 import org.scalatest.time.SpanSugar._
+import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
 
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming.dstream.DStream
 import org.apache.spark.streaming.receiver.Receiver
 import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
 
 
 class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging {
@@ -132,6 +132,25 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
     }
   }
 
+  test("start with non-seriazable DStream checkpoints") {
+    val checkpointDir = Utils.createTempDir()
+    ssc = new StreamingContext(conf, batchDuration)
+    ssc.checkpoint(checkpointDir.getAbsolutePath)
+    addInputStream(ssc).foreachRDD { rdd =>
+      // Refer to this.appName from inside closure so that this closure refers to
+      // the instance of StreamingContextSuite, and is therefore not serializable
+      rdd.count() + appName
+    }
+
+    // Test whether start() fails early when checkpointing is enabled
+    val exception = intercept[NotSerializableException] {
+      ssc.start()
+    }
+    assert(exception.getMessage().contains("DStreams with their functions are not serializable"))
+    assert(ssc.getState() !== StreamingContextState.ACTIVE)
+    assert(StreamingContext.getActive().isEmpty)
+  }
+
   test("start multiple times") {
     ssc = new StreamingContext(master, appName, batchDuration)
     addInputStream(ssc).register()
-- 
GitLab