diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 89f7c316dc7798b754fc1a821a2529a8c908375b..8c38262dd8b3267aa44e1301931a8ed224d6c0ba 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest]( } } - override def checkpoint() { } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) } - override def checkpoint() { } val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) } diff --git a/python/epydoc.conf b/python/epydoc.conf index 91ac984ba294c99844afaca438b6e5fda3ea6e0e..45102cd9fece4bc801bb8577f30b859becf7df07 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -16,4 +16,4 @@ target: docs/ private: no exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers - pyspark.java_gateway pyspark.examples pyspark.shell + pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1e2f845f9cfdedec261534204d2ef9ba48b0c619..a438b43fdce028e0b4b373291f1518e1f29dea4f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -195,3 +195,12 @@ class SparkContext(object): filename = path.split("/")[-1] os.environ["PYTHONPATH"] = \ "%s:%s" % (filename, os.environ["PYTHONPATH"]) + + def setCheckpointDir(self, dirName, useExisting=False): + """ + Set the directory under which RDDs are going to be checkpointed. This + method will create this directory and will throw an exception of the + path already exists (to avoid overwriting existing files may be + overwritten). The directory will be deleted on exit if indicated. + """ + self._jsc.sc().setCheckpointDir(dirName, useExisting) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d705f0f9e102909c48e53e8e08134b6db6757002..9b676cae4ae4cd724353477b835347727abb5a1e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -49,6 +49,40 @@ class RDD(object): self._jrdd.cache() return self + def checkpoint(self): + """ + Mark this RDD for checkpointing. The RDD will be saved to a file inside + `checkpointDir` (set using setCheckpointDir()) and all references to + its parent RDDs will be removed. This is used to truncate very long + lineages. In the current implementation, Spark will save this RDD to + a file (using saveAsObjectFile()) after the first job using this RDD is + done. Hence, it is strongly recommended to use checkpoint() on RDDs + when + + (i) checkpoint() is called before the any job has been executed on this + RDD. + + (ii) This RDD has been made to persist in memory. Otherwise saving it + on a file will require recomputation. + """ + self._jrdd.rdd().checkpoint() + + def isCheckpointed(self): + """ + Return whether this RDD has been checkpointed or not + """ + return self._jrdd.rdd().isCheckpointed() + + def getCheckpointFile(self): + """ + Gets the name of the file to which this RDD was checkpointed + """ + checkpointFile = self._jrdd.rdd().getCheckpointFile() + if checkpointFile.isDefined(): + return checkpointFile.get() + else: + return None + # TODO persist(self, storageLevel) def map(self, f, preservesPartitioning=False): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py new file mode 100644 index 0000000000000000000000000000000000000000..c959d5dec73df805c1559ba6e6a949dc4e7b3ea4 --- /dev/null +++ b/python/pyspark/tests.py @@ -0,0 +1,46 @@ +""" +Unit tests for PySpark; additional tests are implemented as doctests in +individual modules. +""" +import atexit +import os +import shutil +from tempfile import NamedTemporaryFile +import time +import unittest + +from pyspark.context import SparkContext + + +class TestCheckpoint(unittest.TestCase): + + def setUp(self): + self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2) + + def tearDown(self): + self.sc.stop() + + def test_basic_checkpointing(self): + checkpointDir = NamedTemporaryFile(delete=False) + os.unlink(checkpointDir.name) + self.sc.setCheckpointDir(checkpointDir.name) + + parCollection = self.sc.parallelize([1, 2, 3, 4]) + flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1)) + + self.assertFalse(flatMappedRDD.isCheckpointed()) + self.assertIsNone(flatMappedRDD.getCheckpointFile()) + + flatMappedRDD.checkpoint() + result = flatMappedRDD.collect() + time.sleep(1) # 1 second + self.assertTrue(flatMappedRDD.isCheckpointed()) + self.assertEqual(flatMappedRDD.collect(), result) + self.assertEqual(checkpointDir.name, + os.path.dirname(flatMappedRDD.getCheckpointFile())) + + atexit.register(lambda: shutil.rmtree(checkpointDir.name)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/run-tests b/python/run-tests index 32470911f934b20ce602cbf5f7782280e38d874f..ce214e98a8f713b0d1ab4238414084e9b0198220 100755 --- a/python/run-tests +++ b/python/run-tests @@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED)) $FWDIR/pyspark -m doctest pyspark/accumulators.py FAILED=$(($?||$FAILED)) +$FWDIR/pyspark -m unittest pyspark.tests +FAILED=$(($?||$FAILED)) + if [[ $FAILED != 0 ]]; then echo -en "\033[31m" # Red echo "Had test failures; see logs."