From 7fa6978a1e8822cf377fbb1e8a8d23adc4ebe12e Mon Sep 17 00:00:00 2001
From: Mridul Muralidharan <mridul@gmail.com>
Date: Sun, 28 Apr 2013 23:08:10 +0530
Subject: [PATCH] Allow CheckpointWriter pending tasks to finish

---
 .../src/main/scala/spark/streaming/Checkpoint.scala | 13 +++++++------
 .../main/scala/spark/streaming/DStreamGraph.scala   |  2 +-
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index 7bd104b8d5..4bbad908d0 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -42,7 +42,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
   private val writeFile = new Path(file.getParent, file.getName + ".next")
   private val bakFile = new Path(file.getParent, file.getName + ".bk")
 
-  @volatile private var stopped = false
+  private var stopped = false
 
   val conf = new Configuration()
   var fs = file.getFileSystem(conf)
@@ -57,10 +57,6 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
       var attempts = 0
       val startTime = System.currentTimeMillis()
       while (attempts < maxAttempts) {
-        if (stopped) {
-          logInfo("Already stopped, ignore checkpoint attempt for " + file)
-          return
-        }
         attempts += 1
         try {
           logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
@@ -99,8 +95,13 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
   }
 
   def stop() {
-    stopped = true
+    synchronized {
+      if (stopped) return ;
+      stopped = true
+    }
     executor.shutdown()
+    val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS)
+    logInfo("CheckpointWriter executor terminated ? " + terminated)
   }
 }
 
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index adb7f3a24d..3b331956f5 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -54,8 +54,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
         throw new Exception("Batch duration already set as " + batchDuration +
           ". cannot set it again.")
       }
+      batchDuration = duration
     }
-    batchDuration = duration
   }
 
   def remember(duration: Duration) {
-- 
GitLab