diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a438b43fdce028e0b4b373291f1518e1f29dea4f..8beb8e2ae91a4ea71b9695799c63d44c770e2386 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -123,6 +123,10 @@ class SparkContext(object):
         jrdd = self._jsc.textFile(name, minSplits)
         return RDD(jrdd, self)
 
+    def _checkpointFile(self, name):
+        jrdd = self._jsc.checkpointFile(name)
+        return RDD(jrdd, self)
+
     def union(self, rdds):
         """
         Build the union of a list of RDDs.
@@ -145,7 +149,7 @@ class SparkContext(object):
     def accumulator(self, value, accum_param=None):
         """
         Create an C{Accumulator} with the given initial value, using a given
-        AccumulatorParam helper object to define how to add values of the data 
+        AccumulatorParam helper object to define how to add values of the data
         type if provided. Default AccumulatorParams are used for integers and
         floating-point numbers if you do not provide one. For other types, the
         AccumulatorParam must implement two methods:
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 9b676cae4ae4cd724353477b835347727abb5a1e..2a2ff9b271d7d7930c97b841cd7acc0c103159d6 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -32,6 +32,7 @@ class RDD(object):
     def __init__(self, jrdd, ctx):
         self._jrdd = jrdd
         self.is_cached = False
+        self.is_checkpointed = False
         self.ctx = ctx
 
     @property
@@ -65,6 +66,7 @@ class RDD(object):
         (ii) This RDD has been made to persist in memory. Otherwise saving it
         on a file will require recomputation.
         """
+        self.is_checkpointed = True
         self._jrdd.rdd().checkpoint()
 
     def isCheckpointed(self):
@@ -696,7 +698,7 @@ class PipelinedRDD(RDD):
     20
     """
     def __init__(self, prev, func, preservesPartitioning=False):
-        if isinstance(prev, PipelinedRDD) and not prev.is_cached:
+        if isinstance(prev, PipelinedRDD) and prev._is_pipelinable:
             prev_func = prev.func
             def pipeline_func(split, iterator):
                 return func(split, prev_func(split, iterator))
@@ -709,6 +711,7 @@ class PipelinedRDD(RDD):
             self.preservesPartitioning = preservesPartitioning
             self._prev_jrdd = prev._jrdd
         self.is_cached = False
+        self.is_checkpointed = False
         self.ctx = prev.ctx
         self.prev = prev
         self._jrdd_val = None
@@ -741,6 +744,10 @@ class PipelinedRDD(RDD):
         self._jrdd_val = python_rdd.asJavaRDD()
         return self._jrdd_val
 
+    @property
+    def _is_pipelinable(self):
+        return not (self.is_cached or self.is_checkpointed)
+
 
 def _test():
     import doctest
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index c959d5dec73df805c1559ba6e6a949dc4e7b3ea4..83283fca4fdfac6b480d1f702983e088f65f2452 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -19,6 +19,9 @@ class TestCheckpoint(unittest.TestCase):
 
     def tearDown(self):
         self.sc.stop()
+        # To avoid Akka rebinding to the same port, since it doesn't unbind
+        # immediately on shutdown
+        self.sc.jvm.System.clearProperty("spark.master.port")
 
     def test_basic_checkpointing(self):
         checkpointDir = NamedTemporaryFile(delete=False)
@@ -41,6 +44,27 @@ class TestCheckpoint(unittest.TestCase):
 
         atexit.register(lambda: shutil.rmtree(checkpointDir.name))
 
+    def test_checkpoint_and_restore(self):
+        checkpointDir = NamedTemporaryFile(delete=False)
+        os.unlink(checkpointDir.name)
+        self.sc.setCheckpointDir(checkpointDir.name)
+
+        parCollection = self.sc.parallelize([1, 2, 3, 4])
+        flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+        self.assertFalse(flatMappedRDD.isCheckpointed())
+        self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+        flatMappedRDD.checkpoint()
+        flatMappedRDD.count()  # forces a checkpoint to be computed
+        time.sleep(1)  # 1 second
+
+        self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+        self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+        atexit.register(lambda: shutil.rmtree(checkpointDir.name))
+
 
 if __name__ == "__main__":
     unittest.main()