From d4dfab503a9222b5acf5c4bf69b91c16f298e4aa Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Tue, 24 Dec 2013 14:01:13 -0800
Subject: [PATCH] Fixed Python API for sc.setCheckpointDir. Also other fixes
 based on Reynold's comments on PR 289.

---
 .../main/scala/org/apache/spark/SparkContext.scala    |  4 ++--
 .../scala/org/apache/spark/rdd/CheckpointRDD.scala    |  2 --
 .../org/apache/spark/rdd/RDDCheckpointData.scala      |  2 +-
 python/pyspark/context.py                             |  9 ++-------
 python/pyspark/tests.py                               |  4 ++--
 .../spark/streaming/dstream/FileInputDStream.scala    |  6 +++---
 .../spark/streaming/scheduler/JobGenerator.scala      | 11 ++++++-----
 7 files changed, 16 insertions(+), 22 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index c30f896cf1..cc87febf33 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -861,12 +861,12 @@ class SparkContext(
    * be a HDFS path if running on a cluster.
    */
   def setCheckpointDir(directory: String) {
-    checkpointDir = Option(directory).map(dir => {
+    checkpointDir = Option(directory).map { dir =>
       val path = new Path(dir, UUID.randomUUID().toString)
       val fs = path.getFileSystem(hadoopConfiguration)
       fs.mkdirs(path)
       fs.getFileStatus(path).getPath().toString
-    })
+    }
   }
 
   /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 80385fce57..293a7d1f68 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -18,9 +18,7 @@
 package org.apache.spark.rdd
 
 import java.io.IOException
-
 import scala.reflect.ClassTag
-import java.io.{IOException}
 import org.apache.spark._
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.deploy.SparkHadoopUtil
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 5a565d7e78..091a6fdb54 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -95,7 +95,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T])
     rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _)
     val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
     if (newRDD.partitions.size != rdd.partitions.size) {
-      throw new Exception(
+      throw new SparkException(
         "Checkpoint RDD " + newRDD + "("+ newRDD.partitions.size + ") has different " +
           "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")")
     }
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 0604f6836c..108f36576a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -320,17 +320,12 @@ class SparkContext(object):
             self._python_includes.append(filename)
             sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
 
-    def setCheckpointDir(self, dirName, useExisting=False):
+    def setCheckpointDir(self, dirName):
         """
         Set the directory under which RDDs are going to be checkpointed. The
         directory must be a HDFS path if running on a cluster.
-
-        If the directory does not exist, it will be created. If the directory
-        exists and C{useExisting} is set to true, then the exisiting directory
-        will be used.  Otherwise an exception will be thrown to prevent
-        accidental overriding of checkpoint files in the existing directory.
         """
-        self._jsc.sc().setCheckpointDir(dirName, useExisting)
+        self._jsc.sc().setCheckpointDir(dirName)
 
     def _getJavaStorageLevel(self, storageLevel):
         """
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 3987642bf4..7acb6eaf10 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -73,8 +73,8 @@ class TestCheckpoint(PySparkTestCase):
         time.sleep(1)  # 1 second
         self.assertTrue(flatMappedRDD.isCheckpointed())
         self.assertEqual(flatMappedRDD.collect(), result)
-        self.assertEqual(self.checkpointDir.name,
-                         os.path.dirname(flatMappedRDD.getCheckpointFile()))
+        self.assertEqual("file:" + self.checkpointDir.name,
+                         os.path.dirname(os.path.dirname(flatMappedRDD.getCheckpointFile())))
 
     def test_checkpoint_and_restore(self):
         parCollection = self.sc.parallelize([1, 2, 3, 4])
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index 4a7c5cf29c..d6514a1fb1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -123,7 +123,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
           reset()
       }
     }
-    (Seq(), -1, Seq())
+    (Seq.empty, -1, Seq.empty)
   }
 
   /** Generate one RDD from an array of files */
@@ -193,7 +193,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
    * been seen before (i.e. the file should not be in lastModTimeFiles)
    */
   private[streaming]
-  class CustomPathFilter(currentTime: Long) extends PathFilter() {
+  class CustomPathFilter(currentTime: Long) extends PathFilter {
     // Latest file mod time seen in this round of fetching files and its corresponding files
     var latestModTime = 0L
     val latestModTimeFiles = new HashSet[String]()
@@ -209,7 +209,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
         logDebug("Rejected by filter " + path)
         return false
       } else {              // Accept file only if
-      val modTime = fs.getFileStatus(path).getModificationTime()
+        val modTime = fs.getFileStatus(path).getModificationTime()
         logDebug("Mod time for " + path + " is " + modTime)
         if (modTime < prevModTime) {
           logDebug("Mod time less than last mod time")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 2552d51654..921a33a4cb 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -17,16 +17,17 @@
 
 package org.apache.spark.streaming.scheduler
 
+import akka.actor.{Props, Actor}
 import org.apache.spark.SparkEnv
 import org.apache.spark.Logging
 import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
 import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
-import akka.actor.{Props, Actor}
 
-sealed trait JobGeneratorEvent
-case class GenerateJobs(time: Time) extends JobGeneratorEvent
-case class ClearOldMetadata(time: Time) extends JobGeneratorEvent
-case class DoCheckpoint(time: Time) extends JobGeneratorEvent
+/** Event classes for JobGenerator */
+private[scheduler] sealed trait JobGeneratorEvent
+private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent
+private[scheduler] case class ClearOldMetadata(time: Time) extends JobGeneratorEvent
+private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent
 
 /**
  * This class generates jobs from DStreams as well as drives checkpointing and cleaning
-- 
GitLab