diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 087270e46dd90e4fe84317789e1da3c774d50854..b3698ffa44d57f192c052e03d5c6551c1b9a2fdf 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -307,16 +307,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
     JavaPairRDD.fromRDD(rdd.keyBy(f))
   }
-  
-  /**
-   * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
-   * (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
-   * This is used to truncate very long lineages. In the current implementation, Spark will save
-   * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
-   * Hence, it is strongly recommended to use checkpoint() on RDDs when
-   * (i) checkpoint() is called before the any job has been executed on this RDD.
-   * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
-   * require recomputation.
+
+  /**
+   * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+   * directory set with SparkContext.setCheckpointDir() and all references to its parent
+   * RDDs will be removed. This function must be called before any job has been
+   * executed on this RDD. It is strongly recommended that this RDD is persisted in
+   * memory, otherwise saving it on a file will require recomputation.
    */
   def checkpoint() = rdd.checkpoint()
 
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index fa2f14113d7c8e638feb2eccfa2ea2af609130f7..14699961adbaa1a3ddb4da92f41a408b56b6b6fc 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -357,20 +357,21 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
   }
 
   /**
-   * Set the directory under which RDDs are going to be checkpointed. This method will
-   * create this directory and will throw an exception of the path already exists (to avoid
-   * overwriting existing files may be overwritten). The directory will be deleted on exit
-   * if indicated.
+   * Set the directory under which RDDs are going to be checkpointed. The directory must
+   * be a HDFS path if running on a cluster. If the directory does not exist, it will
+   * be created. If the directory exists and useExisting is set to true, then the
+   * exisiting directory will be used. Otherwise an exception will be thrown to
+   * prevent accidental overriding of checkpoint files in the existing directory.
    */
   def setCheckpointDir(dir: String, useExisting: Boolean) {
     sc.setCheckpointDir(dir, useExisting)
   }
 
   /**
-   * Set the directory under which RDDs are going to be checkpointed. This method will
-   * create this directory and will throw an exception of the path already exists (to avoid
-   * overwriting existing files may be overwritten). The directory will be deleted on exit
-   * if indicated.
+   * Set the directory under which RDDs are going to be checkpointed. The directory must
+   * be a HDFS path if running on a cluster. If the directory does not exist, it will
+   * be created. If the directory exists, an exception will be thrown to prevent accidental
+   * overriding of checkpoint files.
    */
   def setCheckpointDir(dir: String) {
     sc.setCheckpointDir(dir)
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index e4c0530241556ed4d5584a20e175ec52ecf6d7ae..5526406a20902c2150fca130fc4c76696e17cf07 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest](
     }
   }
 
-  override def checkpoint() { }
-
   val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
 }
 
@@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
       case Seq(a, b) => (a, b)
       case x          => throw new Exception("PairwiseRDD: unexpected value: " + x)
     }
-  override def checkpoint() { }
   val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
 }
 
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 91ac984ba294c99844afaca438b6e5fda3ea6e0e..45102cd9fece4bc801bb8577f30b859becf7df07 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -16,4 +16,4 @@ target: docs/
 private: no
 
 exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
-         pyspark.java_gateway pyspark.examples pyspark.shell
+         pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 1e2f845f9cfdedec261534204d2ef9ba48b0c619..dcbed37270fc319b6d3aa39c2e77b0cff5cc7df8 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -123,6 +123,10 @@ class SparkContext(object):
         jrdd = self._jsc.textFile(name, minSplits)
         return RDD(jrdd, self)
 
+    def _checkpointFile(self, name):
+        jrdd = self._jsc.checkpointFile(name)
+        return RDD(jrdd, self)
+
     def union(self, rdds):
         """
         Build the union of a list of RDDs.
@@ -145,7 +149,7 @@ class SparkContext(object):
     def accumulator(self, value, accum_param=None):
         """
         Create an C{Accumulator} with the given initial value, using a given
-        AccumulatorParam helper object to define how to add values of the data 
+        AccumulatorParam helper object to define how to add values of the data
         type if provided. Default AccumulatorParams are used for integers and
         floating-point numbers if you do not provide one. For other types, the
         AccumulatorParam must implement two methods:
@@ -195,3 +199,15 @@ class SparkContext(object):
         filename = path.split("/")[-1]
         os.environ["PYTHONPATH"] = \
             "%s:%s" % (filename, os.environ["PYTHONPATH"])
+
+    def setCheckpointDir(self, dirName, useExisting=False):
+        """
+        Set the directory under which RDDs are going to be checkpointed. The
+        directory must be a HDFS path if running on a cluster.
+
+        If the directory does not exist, it will be created. If the directory
+        exists and C{useExisting} is set to true, then the exisiting directory
+        will be used.  Otherwise an exception will be thrown to prevent
+        accidental overriding of checkpoint files in the existing directory.
+        """
+        self._jsc.sc().setCheckpointDir(dirName, useExisting)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index b58bf24e3e46c8c2f44b6b7c0282cb2c26696ede..d53355a8f1db49c710f4fff094cf2f481cac8695 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -32,6 +32,7 @@ class RDD(object):
     def __init__(self, jrdd, ctx):
         self._jrdd = jrdd
         self.is_cached = False
+        self.is_checkpointed = False
         self.ctx = ctx
         self._partitionFunc = None
 
@@ -50,6 +51,34 @@ class RDD(object):
         self._jrdd.cache()
         return self
 
+    def checkpoint(self):
+        """
+        Mark this RDD for checkpointing. It will be saved to a file inside the
+        checkpoint directory set with L{SparkContext.setCheckpointDir()} and
+        all references to its parent RDDs will be removed. This function must
+        be called before any job has been executed on this RDD. It is strongly
+        recommended that this RDD is persisted in memory, otherwise saving it
+        on a file will require recomputation.
+        """
+        self.is_checkpointed = True
+        self._jrdd.rdd().checkpoint()
+
+    def isCheckpointed(self):
+        """
+        Return whether this RDD has been checkpointed or not
+        """
+        return self._jrdd.rdd().isCheckpointed()
+
+    def getCheckpointFile(self):
+        """
+        Gets the name of the file to which this RDD was checkpointed
+        """
+        checkpointFile = self._jrdd.rdd().getCheckpointFile()
+        if checkpointFile.isDefined():
+            return checkpointFile.get()
+        else:
+            return None
+
     # TODO persist(self, storageLevel)
 
     def map(self, f, preservesPartitioning=False):
@@ -667,7 +696,7 @@ class PipelinedRDD(RDD):
     20
     """
     def __init__(self, prev, func, preservesPartitioning=False):
-        if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+        if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
             prev_func = prev.func
             def pipeline_func(split, iterator):
                 return func(split, prev_func(split, iterator))
@@ -680,6 +709,7 @@ class PipelinedRDD(RDD):
             self.preservesPartitioning = preservesPartitioning
             self._prev_jrdd = prev._jrdd
         self.is_cached = False
+        self.is_checkpointed = False
         self.ctx = prev.ctx
         self.prev = prev
         self._jrdd_val = None
@@ -712,6 +742,9 @@ class PipelinedRDD(RDD):
         self._jrdd_val = python_rdd.asJavaRDD()
         return self._jrdd_val
 
+    def _is_pipelinable(self):
+        return not (self.is_cached or self.is_checkpointed)
+
 
 def _test():
     import doctest
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0a403b58054aaa1f993df58169e844e24d6cac7
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,61 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import shutil
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+
+
+class TestCheckpoint(unittest.TestCase):
+
+    def setUp(self):
+        self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
+        self.checkpointDir = NamedTemporaryFile(delete=False)
+        os.unlink(self.checkpointDir.name)
+        self.sc.setCheckpointDir(self.checkpointDir.name)
+
+    def tearDown(self):
+        self.sc.stop()
+        # To avoid Akka rebinding to the same port, since it doesn't unbind
+        # immediately on shutdown
+        self.sc.jvm.System.clearProperty("spark.master.port")
+        shutil.rmtree(self.checkpointDir.name)
+
+    def test_basic_checkpointing(self):
+        parCollection = self.sc.parallelize([1, 2, 3, 4])
+        flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+        self.assertFalse(flatMappedRDD.isCheckpointed())
+        self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+        flatMappedRDD.checkpoint()
+        result = flatMappedRDD.collect()
+        time.sleep(1)  # 1 second
+        self.assertTrue(flatMappedRDD.isCheckpointed())
+        self.assertEqual(flatMappedRDD.collect(), result)
+        self.assertEqual(self.checkpointDir.name,
+                         os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+    def test_checkpoint_and_restore(self):
+        parCollection = self.sc.parallelize([1, 2, 3, 4])
+        flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+        self.assertFalse(flatMappedRDD.isCheckpointed())
+        self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+        flatMappedRDD.checkpoint()
+        flatMappedRDD.count()  # forces a checkpoint to be computed
+        time.sleep(1)  # 1 second
+
+        self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+        self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/run-tests b/python/run-tests
index 32470911f934b20ce602cbf5f7782280e38d874f..ce214e98a8f713b0d1ab4238414084e9b0198220 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED))
 $FWDIR/pyspark -m doctest pyspark/accumulators.py
 FAILED=$(($?||$FAILED))
 
+$FWDIR/pyspark -m unittest pyspark.tests
+FAILED=$(($?||$FAILED))
+
 if [[ $FAILED != 0 ]]; then
     echo -en "\033[31m"  # Red
     echo "Had test failures; see logs."