diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index ec6b209932b5e17dd85743cec791744d2a5b7104..fabe0bec2d3a696f4385254f7af6a64d5ba3fec2 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -23,7 +23,6 @@ import spark.partial.BoundedDouble
 import spark.partial.PartialResult
 import spark.rdd._
 import spark.SparkContext._
-import java.lang.ref.WeakReference
 
 /**
  * Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -625,20 +624,20 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
 }
 
 private[spark]
-class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U)
-  extends RDD[(K, U)](prev.get) {
+class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U)
+  extends RDD[(K, U)](prev) {
 
-  override def splits = firstParent[(K, V)].splits
+  override def getSplits = firstParent[(K, V)].splits
   override val partitioner = firstParent[(K, V)].partitioner
   override def compute(split: Split, context: TaskContext) =
     firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) }
 }
 
 private[spark]
-class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U])
-  extends RDD[(K, U)](prev.get) {
+class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
+  extends RDD[(K, U)](prev) {
 
-  override def splits = firstParent[(K, V)].splits
+  override def getSplits = firstParent[(K, V)].splits
   override val partitioner = firstParent[(K, V)].partitioner
   override def compute(split: Split, context: TaskContext) = {
     firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) }
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index 68416a78d0a197a80221725bd25459f875ffbf59..ede933c9e9a20febcc82282e2c8fe5c4309118c8 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -30,26 +30,30 @@ private[spark] class ParallelCollection[T: ClassManifest](
   extends RDD[T](sc, Nil) {
   // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
   // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
-  // instead. UPDATE: With the new changes to enable checkpointing, this an be done.
+  // instead.
+  // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
 
   @transient
-  val splits_ = {
+  var splits_ : Array[Split] = {
     val slices = ParallelCollection.slice(data, numSlices).toArray
     slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
   }
 
-  override def splits = splits_.asInstanceOf[Array[Split]]
-
+  override def getSplits = splits_.asInstanceOf[Array[Split]]
 
   override def compute(s: Split, context: TaskContext) =
     s.asInstanceOf[ParallelCollectionSplit[T]].iterator
 
-  override def preferredLocations(s: Split): Seq[String] = {
-    locationPrefs.get(splits_.indexOf(s)) match {
+  override def getPreferredLocations(s: Split): Seq[String] = {
+    locationPrefs.get(s.index) match {
       case Some(s) => s
       case _ => Nil
     }
   }
+
+  override def clearDependencies() {
+    splits_ = null
+  }
 }
 
 
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index cf3ed067735c673381488a1f15fa038069b3bf4d..2c3acc1b69aef095e28c3d4cdf9c9ac96c0e9560 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -81,78 +81,34 @@ abstract class RDD[T: ClassManifest](
   def this(@transient oneParent: RDD[_]) =
     this(oneParent.context , List(new OneToOneDependency(oneParent)))
 
-  // Methods that must be implemented by subclasses:
-
-  /** Set of partitions in this RDD. */
-  def splits: Array[Split]
+  // =======================================================================
+  // Methods that should be implemented by subclasses of RDD
+  // =======================================================================
 
   /** Function for computing a given partition. */
   def compute(split: Split, context: TaskContext): Iterator[T]
 
+  /** Set of partitions in this RDD. */
+  protected def getSplits(): Array[Split]
+
   /** How this RDD depends on any parent RDDs. */
-  def dependencies: List[Dependency[_]] = dependencies_
+  protected def getDependencies(): List[Dependency[_]] = dependencies_
 
-  /** Record user function generating this RDD. */
-  private[spark] val origin = Utils.getSparkCallSite
+  /** Optionally overridden by subclasses to specify placement preferences. */
+  protected def getPreferredLocations(split: Split): Seq[String] = Nil
 
   /** Optionally overridden by subclasses to specify how they are partitioned. */
   val partitioner: Option[Partitioner] = None
 
-  /** Optionally overridden by subclasses to specify placement preferences. */
-  def preferredLocations(split: Split): Seq[String] = {
-    if (isCheckpointed) {
-      checkpointRDD.preferredLocations(split)
-    } else {
-      Nil
-    }
-  }
 
-  /** The [[spark.SparkContext]] that this RDD was created on. */
-  def context = sc
 
-  private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+  // =======================================================================
+  // Methods and fields available on all RDDs
+  // =======================================================================
 
   /** A unique ID for this RDD (within its SparkContext). */
   val id = sc.newRddId()
 
-  // Variables relating to persistence
-  private var storageLevel: StorageLevel = StorageLevel.NONE
-
-  /** Returns the first parent RDD */
-  protected[spark] def firstParent[U: ClassManifest] = dependencies.head.rdd.asInstanceOf[RDD[U]]
-
-  /** Returns the `i` th parent RDD */
-  protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]]
-
-  //////////////////////////////////////////////////////////////////////////////
-  // Checkpointing related variables
-  //////////////////////////////////////////////////////////////////////////////
-
-  // override to set this to false to avoid checkpointing an RDD
-  protected val isCheckpointable = true
-
-  // set to true when an RDD is marked for checkpointing
-  protected var shouldCheckpoint = false
-
-  // set to true when checkpointing is in progress
-  protected var isCheckpointInProgress = false
-
-  // set to true after checkpointing is completed
-  protected[spark] var isCheckpointed = false
-
-  // set to the checkpoint file after checkpointing is completed
-  protected[spark] var checkpointFile: String = null
-
-  // set to the HadoopRDD of the checkpoint file
-  protected var checkpointRDD: RDD[T] = null
-
-  // set to the splits of the Hadoop RDD
-  protected var checkpointRDDSplits: Seq[Split] = null
-
-  //////////////////////////////////////////////////////////////////////////////
-  // Methods available on all RDDs
-  //////////////////////////////////////////////////////////////////////////////
-
   /**
    * Set this RDD's storage level to persist its values across operations after the first time
    * it is computed. Can only be called once on each RDD.
@@ -177,81 +133,39 @@ abstract class RDD[T: ClassManifest](
   def getStorageLevel = storageLevel
 
   /**
-   * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
-   * (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
-   * This is used to truncate very long lineages. In the current implementation, Spark will save
-   * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
-   * Hence, it is strongly recommended to use checkpoint() on RDDs when
-   * (i) Checkpoint() is called before the any job has been executed on this RDD.
-   * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
-   * require recomputation.
+   * Get the preferred location of a split, taking into account whether the
+   * RDD is checkpointed or not.
    */
-  protected[spark] def checkpoint() {
-    synchronized {
-      if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) {
-        // do nothing
-      } else if (isCheckpointable) {
-        if (sc.checkpointDir == null) {
-          throw new Exception("Checkpoint directory has not been set in the SparkContext.")
-        }
-        shouldCheckpoint = true
-      } else {
-        throw new Exception(this + " cannot be checkpointed")
-      }
-    }
-  }
-
-  def getCheckpointData(): Any = {
-    synchronized {
-      checkpointFile
+  final def preferredLocations(split: Split): Seq[String] = {
+    if (isCheckpointed) {
+      checkpointData.get.getPreferredLocations(split)
+    } else {
+      getPreferredLocations(split)
     }
   }
 
   /**
-   * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after
-   * a job using this RDD has completed (therefore the RDD has been materialized and potentially
-   * stored in memory). In case this RDD is not marked for checkpointing, doCheckpoint() is called
-   * recursively on the parent RDDs.
+   * Get the array of splits of this RDD, taking into account whether the
+   * RDD is checkpointed or not.
    */
-  private[spark] def doCheckpoint() {
-    val startCheckpoint = synchronized {
-      if (isCheckpointable && shouldCheckpoint && !isCheckpointInProgress) {
-        isCheckpointInProgress = true
-        true
-      } else {
-        false
-      }
-    }
-
-    if (startCheckpoint) {
-      val rdd = this
-      rdd.checkpointFile = new Path(context.checkpointDir, "rdd-" + id).toString
-      rdd.saveAsObjectFile(checkpointFile)
-      rdd.synchronized {
-        rdd.checkpointRDD = context.objectFile[T](checkpointFile, rdd.splits.size)
-        rdd.checkpointRDDSplits = rdd.checkpointRDD.splits
-        rdd.changeDependencies(rdd.checkpointRDD)
-        rdd.shouldCheckpoint = false
-        rdd.isCheckpointInProgress = false
-        rdd.isCheckpointed = true
-        logInfo("Done checkpointing RDD " + rdd.id + ", " + rdd + ", created RDD " +
-          rdd.checkpointRDD.id + ", " + rdd.checkpointRDD)
-      }
+  final def splits: Array[Split] = {
+    if (isCheckpointed) {
+      checkpointData.get.getSplits
     } else {
-      // Recursively call doCheckpoint() to perform checkpointing on parent RDD if they are marked
-      dependencies.foreach(_.rdd.doCheckpoint())
+      getSplits
     }
   }
 
   /**
-   * Changes the dependencies of this RDD from its original parents to the new
-   * [[spark.rdd.HadoopRDD]] (`newRDD`) created from the checkpoint file. This method must ensure
-   * that all references to the original parent RDDs must be removed to enable the parent RDDs to
-   * be garbage collected. Subclasses of RDD may override this method for implementing their own
-   * changing logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
+   * Get the list of dependencies of this RDD, taking into account whether the
+   * RDD is checkpointed or not.
    */
-  protected def changeDependencies(newRDD: RDD[_]) {
-    dependencies_ = List(new OneToOneDependency(newRDD))
+  final def dependencies: List[Dependency[_]] = {
+    if (isCheckpointed) {
+      dependencies_
+    } else {
+      getDependencies
+    }
   }
 
   /**
@@ -261,8 +175,7 @@ abstract class RDD[T: ClassManifest](
    */
   final def iterator(split: Split, context: TaskContext): Iterator[T] = {
     if (isCheckpointed) {
-      // ASSUMPTION: Checkpoint Hadoop RDD will have same number of splits as original
-      checkpointRDD.iterator(checkpointRDDSplits(split.index), context)
+      checkpointData.get.iterator(split, context)
     } else if (storageLevel != StorageLevel.NONE) {
       SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
     } else {
@@ -614,18 +527,84 @@ abstract class RDD[T: ClassManifest](
     sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
   }
 
-  @throws(classOf[IOException])
-  private def writeObject(oos: ObjectOutputStream) {
-    synchronized {
-      oos.defaultWriteObject()
+  /**
+   * Mark this RDD for checkpointing. The RDD will be saved to a file inside `checkpointDir`
+   * (set using setCheckpointDir()) and all references to its parent RDDs will be removed.
+   * This is used to truncate very long lineages. In the current implementation, Spark will save
+   * this RDD to a file (using saveAsObjectFile()) after the first job using this RDD is done.
+   * Hence, it is strongly recommended to use checkpoint() on RDDs when
+   * (i) checkpoint() is called before the any job has been executed on this RDD.
+   * (ii) This RDD has been made to persist in memory. Otherwise saving it on a file will
+   * require recomputation.
+   */
+  def checkpoint() {
+    if (checkpointData.isEmpty) {
+      checkpointData = Some(new RDDCheckpointData(this))
+      checkpointData.get.markForCheckpoint()
     }
   }
 
-  @throws(classOf[IOException])
-  private def readObject(ois: ObjectInputStream) {
-    synchronized {
-      ois.defaultReadObject()
-    }
+  /**
+   * Return whether this RDD has been checkpointed or not
+   */
+  def isCheckpointed(): Boolean = {
+    if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false
+  }
+
+  /**
+   * Gets the name of the file to which this RDD was checkpointed
+   */
+  def getCheckpointFile(): Option[String] = {
+    if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None
   }
 
+  // =======================================================================
+  // Other internal methods and fields
+  // =======================================================================
+
+  private var storageLevel: StorageLevel = StorageLevel.NONE
+
+  /** Record user function generating this RDD. */
+  private[spark] val origin = Utils.getSparkCallSite
+
+  private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+
+  private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
+
+  /** Returns the first parent RDD */
+  protected[spark] def firstParent[U: ClassManifest] = {
+    dependencies.head.rdd.asInstanceOf[RDD[U]]
+  }
+
+  /** The [[spark.SparkContext]] that this RDD was created on. */
+  def context = sc
+
+  /**
+   * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler
+   * after a job using this RDD has completed (therefore the RDD has been materialized and
+   * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
+   */
+  protected[spark] def doCheckpoint() {
+    if (checkpointData.isDefined) checkpointData.get.doCheckpoint()
+    dependencies.foreach(_.rdd.doCheckpoint())
+  }
+
+  /**
+   * Changes the dependencies of this RDD from its original parents to the new RDD
+   * (`newRDD`) created from the checkpoint file.
+   */
+  protected[spark] def changeDependencies(newRDD: RDD[_]) {
+    clearDependencies()
+    dependencies_ = List(new OneToOneDependency(newRDD))
+  }
+
+  /**
+   * Clears the dependencies of this RDD. This method must ensure that all references
+   * to the original parent RDDs is removed to enable the parent RDDs to be garbage
+   * collected. Subclasses of RDD may override this method for implementing their own cleaning
+   * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
+   */
+  protected[spark] def clearDependencies() {
+    dependencies_ = null
+  }
 }
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
new file mode 100644
index 0000000000000000000000000000000000000000..7af830940faf2514f4bf4f659ce40aece6a89a7f
--- /dev/null
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -0,0 +1,97 @@
+package spark
+
+import org.apache.hadoop.fs.Path
+import rdd.{CheckpointRDD, CoalescedRDD}
+import scheduler.{ResultTask, ShuffleMapTask}
+
+/**
+ * Enumeration to manage state transitions of an RDD through checkpointing
+ * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
+ */
+private[spark] object CheckpointState extends Enumeration {
+  type CheckpointState = Value
+  val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
+}
+
+/**
+ * This class contains all the information of the regarding RDD checkpointing.
+ */
+private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
+extends Logging with Serializable {
+
+  import CheckpointState._
+
+  var cpState = Initialized
+  @transient var cpFile: Option[String] = None
+  @transient var cpRDD: Option[RDD[T]] = None
+
+  // Mark the RDD for checkpointing
+  def markForCheckpoint() {
+    RDDCheckpointData.synchronized {
+      if (cpState == Initialized) cpState = MarkedForCheckpoint
+    }
+  }
+
+  // Is the RDD already checkpointed
+  def isCheckpointed(): Boolean = {
+    RDDCheckpointData.synchronized { cpState == Checkpointed }
+  }
+
+  // Get the file to which this RDD was checkpointed to as an Option
+  def getCheckpointFile(): Option[String] = {
+    RDDCheckpointData.synchronized { cpFile }
+  }
+
+  // Do the checkpointing of the RDD. Called after the first job using that RDD is over.
+  def doCheckpoint() {
+    // If it is marked for checkpointing AND checkpointing is not already in progress,
+    // then set it to be in progress, else return
+    RDDCheckpointData.synchronized {
+      if (cpState == MarkedForCheckpoint) {
+        cpState = CheckpointingInProgress
+      } else {
+        return
+      }
+    }
+
+    // Save to file, and reload it as an RDD
+    val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString
+    rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
+    val newRDD = new CheckpointRDD[T](rdd.context, path)
+
+    // Change the dependencies and splits of the RDD
+    RDDCheckpointData.synchronized {
+      cpFile = Some(path)
+      cpRDD = Some(newRDD)
+      rdd.changeDependencies(newRDD)
+      cpState = Checkpointed
+      RDDCheckpointData.checkpointCompleted()
+      logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
+    }
+  }
+
+  // Get preferred location of a split after checkpointing
+  def getPreferredLocations(split: Split) = {
+    RDDCheckpointData.synchronized {
+      cpRDD.get.preferredLocations(split)
+    }
+  }
+
+  def getSplits: Array[Split] = {
+    RDDCheckpointData.synchronized {
+      cpRDD.get.splits
+    }
+  }
+
+  // Get iterator. This is called at the worker nodes.
+  def iterator(split: Split, context: TaskContext): Iterator[T] = {
+    rdd.firstParent[T].iterator(split, context)
+  }
+}
+
+private[spark] object RDDCheckpointData {
+  def checkpointCompleted() {
+    ShuffleMapTask.clearCache()
+    ResultTask.clearCache()
+  }
+}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 70257193cf415b7aaa27e07dce8c5cbe0c77dd99..caa9a1794b6c70558a2e393d455435ff6e2f8fa9 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -37,12 +37,8 @@ import spark.broadcast._
 import spark.deploy.LocalSparkCluster
 import spark.partial.ApproximateEvaluator
 import spark.partial.PartialResult
-import spark.rdd.HadoopRDD
-import spark.rdd.NewHadoopRDD
-import spark.rdd.UnionRDD
-import spark.scheduler.ShuffleMapTask
-import spark.scheduler.DAGScheduler
-import spark.scheduler.TaskScheduler
+import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD}
+import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
 import spark.scheduler.local.LocalScheduler
 import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
 import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
@@ -376,6 +372,13 @@ class SparkContext(
       .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes))
   }
 
+
+  protected[spark] def checkpointFile[T: ClassManifest](
+      path: String
+    ): RDD[T] = {
+    new CheckpointRDD[T](this, path)
+  }
+
   /** Build the union of a list of RDDs. */
   def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
 
@@ -494,6 +497,7 @@ class SparkContext(
     clearJars()
     SparkEnv.set(null)
     ShuffleMapTask.clearCache()
+    ResultTask.clearCache()
     logInfo("Successfully stopped SparkContext")
   }
 
@@ -629,10 +633,6 @@ class SparkContext(
  */
 object SparkContext {
 
-  // TODO: temporary hack for using HDFS as input in streaing
-  var inputFile: String = null
-  var idealPartitions: Int = 1
-
   implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
     def addInPlace(t1: Double, t2: Double): Double = t1 + t2
     def zero(initialValue: Double) = 0.0
@@ -728,9 +728,6 @@ object SparkContext {
 
   /** Find the JAR that contains the class of a particular object */
   def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
-
-  implicit def rddToWeakRefRDD[T: ClassManifest](rdd: RDD[T]) = new WeakReference(rdd)
-
 }
 
 
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index 61bc5c90bab5151c8d4fee364b3ddd949df877f4..b1095a52b4b92a82398ac090be45f37492a9c2ed 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,10 +1,8 @@
 package spark.rdd
 
 import scala.collection.mutable.HashMap
-
 import spark.{RDD, SparkContext, SparkEnv, Split, TaskContext}
 
-
 private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
   val index = idx
 }
@@ -14,7 +12,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
   extends RDD[T](sc, Nil) {
 
   @transient
-  val splits_ = (0 until blockIds.size).map(i => {
+  var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
     new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
   }).toArray
 
@@ -26,7 +24,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
     HashMap(blockIds.zip(locations):_*)
   }
 
-  override def splits = splits_
+  override def getSplits = splits_
 
   override def compute(split: Split, context: TaskContext): Iterator[T] = {
     val blockManager = SparkEnv.get.blockManager
@@ -38,12 +36,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
     }
   }
 
-  override def preferredLocations(split: Split) = {
-    if (isCheckpointed) {
-      checkpointRDD.preferredLocations(split)
-    } else {
-      locations_(split.asInstanceOf[BlockRDDSplit].blockId)
-    }
+  override def getPreferredLocations(split: Split) =
+    locations_(split.asInstanceOf[BlockRDDSplit].blockId)
+
+  override def clearDependencies() {
+    splits_ = null
   }
 }
 
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index bc11b60e052ac65a94937f1c6f06aeb3ffbba1d5..79e7c24e7c749531e4f80cac26ff923c43c1f178 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,13 +1,28 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
-
+import java.io.{ObjectOutputStream, IOException}
 import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext}
 
 
 private[spark]
-class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
+class CartesianSplit(
+    idx: Int,
+    @transient rdd1: RDD[_],
+    @transient rdd2: RDD[_],
+    s1Index: Int,
+    s2Index: Int
+  ) extends Split {
+  var s1 = rdd1.splits(s1Index)
+  var s2 = rdd2.splits(s2Index)
   override val index: Int = idx
+
+  @throws(classOf[IOException])
+  private def writeObject(oos: ObjectOutputStream) {
+    // Update the reference to parent split at the time of task serialization
+    s1 = rdd1.splits(s1Index)
+    s2 = rdd2.splits(s2Index)
+    oos.defaultWriteObject()
+  }
 }
 
 private[spark]
@@ -26,20 +41,16 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
     val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
     for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
       val idx = s1.index * numSplitsInRdd2 + s2.index
-      array(idx) = new CartesianSplit(idx, s1, s2)
+      array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index)
     }
     array
   }
 
-  override def splits = splits_
+  override def getSplits = splits_
 
-  override def preferredLocations(split: Split) = {
-    if (isCheckpointed) {
-      checkpointRDD.preferredLocations(split)
-    } else {
-      val currSplit = split.asInstanceOf[CartesianSplit]
-      rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
-    }
+  override def getPreferredLocations(split: Split) = {
+    val currSplit = split.asInstanceOf[CartesianSplit]
+    rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
   }
 
   override def compute(split: Split, context: TaskContext) = {
@@ -57,11 +68,11 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
     }
   )
 
-  override def dependencies = deps_
+  override def getDependencies = deps_
 
-  override protected def changeDependencies(newRDD: RDD[_]) {
-    deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]]))
-    splits_ = newRDD.splits
+  override def clearDependencies() {
+    deps_ = Nil
+    splits_ = null
     rdd1 = null
     rdd2 = null
   }
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
new file mode 100644
index 0000000000000000000000000000000000000000..1a88d402c38e2606a8b6089c9a6328116c08dc3d
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -0,0 +1,124 @@
+package spark.rdd
+
+import spark._
+import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.{NullWritable, BytesWritable}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.fs.Path
+import java.io.{File, IOException, EOFException}
+import java.text.NumberFormat
+
+private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
+  override val index: Int = idx
+}
+
+class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
+  extends RDD[T](sc, Nil) {
+
+  @transient val path = new Path(checkpointPath)
+  @transient val fs = path.getFileSystem(new Configuration())
+
+  @transient val splits_ : Array[Split] = {
+    val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
+    splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
+  }
+
+  checkpointData = Some(new RDDCheckpointData[T](this))
+  checkpointData.get.cpFile = Some(checkpointPath)
+
+  override def getSplits = splits_
+
+  override def getPreferredLocations(split: Split): Seq[String] = {
+    val status = fs.getFileStatus(path)
+    val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+    locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+  }
+
+  override def compute(split: Split, context: TaskContext): Iterator[T] = {
+    CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context)
+  }
+
+  override def checkpoint() {
+    // Do nothing. Hadoop RDD should not be checkpointed.
+  }
+}
+
+private[spark] object CheckpointRDD extends Logging {
+
+  def splitIdToFileName(splitId: Int): String = {
+    val numfmt = NumberFormat.getInstance()
+    numfmt.setMinimumIntegerDigits(5)
+    numfmt.setGroupingUsed(false)
+    "part-"  + numfmt.format(splitId)
+  }
+
+  def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
+    val outputDir = new Path(path)
+    val fs = outputDir.getFileSystem(new Configuration())
+
+    val finalOutputName = splitIdToFileName(context.splitId)
+    val finalOutputPath = new Path(outputDir, finalOutputName)
+    val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
+
+    if (fs.exists(tempOutputPath)) {
+      throw new IOException("Checkpoint failed: temporary path " +
+        tempOutputPath + " already exists")
+    }
+    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
+    val fileOutputStream = if (blockSize < 0) {
+      fs.create(tempOutputPath, false, bufferSize)
+    } else {
+      // This is mainly for testing purpose
+      fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+    }
+    val serializer = SparkEnv.get.serializer.newInstance()
+    val serializeStream = serializer.serializeStream(fileOutputStream)
+    serializeStream.writeAll(iterator)
+    fileOutputStream.close()
+
+    if (!fs.rename(tempOutputPath, finalOutputPath)) {
+      if (!fs.delete(finalOutputPath, true)) {
+        throw new IOException("Checkpoint failed: failed to delete earlier output of task "
+          + context.attemptId);
+      }
+      if (!fs.rename(tempOutputPath, finalOutputPath)) {
+        throw new IOException("Checkpoint failed: failed to save output of task: "
+          + context.attemptId)
+      }
+    }
+  }
+
+  def readFromFile[T](path: String, context: TaskContext): Iterator[T] = {
+    val inputPath = new Path(path)
+    val fs = inputPath.getFileSystem(new Configuration())
+    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+    val fileInputStream = fs.open(inputPath, bufferSize)
+    val serializer = SparkEnv.get.serializer.newInstance()
+    val deserializeStream = serializer.deserializeStream(fileInputStream)
+
+    // Register an on-task-completion callback to close the input stream.
+    context.addOnCompleteCallback(() => deserializeStream.close())
+
+    deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+  }
+
+  // Test whether CheckpointRDD generate expected number of splits despite
+  // each split file having multiple blocks. This needs to be run on a
+  // cluster (mesos or standalone) using HDFS.
+  def main(args: Array[String]) {
+    import spark._
+
+    val Array(cluster, hdfsPath) = args
+    val sc = new SparkContext(cluster, "CheckpointRDD Test")
+    val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
+    val path = new Path(hdfsPath, "temp")
+    val fs = path.getFileSystem(new Configuration())
+    sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _)
+    val cpRDD = new CheckpointRDD[Int](sc, path.toString)
+    assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same")
+    assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same")
+    fs.delete(path)
+  }
+}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index ef8673909bd8a5bc553fe3c7f9f0724cdc40be60..759bea5e9dec822c43a31df6cdb837fd9bde9688 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -11,16 +11,17 @@ import spark.{Dependency, OneToOneDependency, ShuffleDependency}
 
 private[spark] sealed trait CoGroupSplitDep extends Serializable
 
-private[spark]
-case class NarrowCoGroupSplitDep(rdd: RDD[_], splitIndex: Int, var split: Split = null)
-  extends CoGroupSplitDep {
+private[spark] case class NarrowCoGroupSplitDep(
+    rdd: RDD[_],
+    splitIndex: Int,
+    var split: Split
+  ) extends CoGroupSplitDep {
+
   @throws(classOf[IOException])
   private def writeObject(oos: ObjectOutputStream) {
-    rdd.synchronized {
-      // Update the reference to parent split at the time of task serialization
-      split = rdd.splits(splitIndex)
-      oos.defaultWriteObject()
-    }
+    // Update the reference to parent split at the time of task serialization
+    split = rdd.splits(splitIndex)
+    oos.defaultWriteObject()
   }
 }
 
@@ -39,7 +40,6 @@ private[spark] class CoGroupAggregator
     { (b1, b2) => b1 ++ b2 })
   with Serializable
 
-
 class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
   extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging {
 
@@ -49,19 +49,19 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
   var deps_ = {
     val deps = new ArrayBuffer[Dependency[_]]
     for ((rdd, index) <- rdds.zipWithIndex) {
-      val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
-      if (mapSideCombinedRDD.partitioner == Some(part)) {
-        logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD)
-        deps += new OneToOneDependency(mapSideCombinedRDD)
+      if (rdd.partitioner == Some(part)) {
+        logInfo("Adding one-to-one dependency with " + rdd)
+        deps += new OneToOneDependency(rdd)
       } else {
         logInfo("Adding shuffle dependency with " + rdd)
+        val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
         deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
       }
     }
     deps.toList
   }
 
-  override def dependencies = deps_
+  override def getDependencies = deps_
 
   @transient
   var splits_ : Array[Split] = {
@@ -72,15 +72,15 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
           case s: ShuffleDependency[_, _] =>
             new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
           case _ =>
-            new NarrowCoGroupSplitDep(r, i): CoGroupSplitDep
+            new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep
         }
       }.toList)
     }
     array
   }
 
-  override def splits = splits_
-
+  override def getSplits = splits_
+  
   override val partitioner = Some(part)
 
   override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
@@ -111,9 +111,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
     map.iterator
   }
 
-  override protected def changeDependencies(newRDD: RDD[_]) {
-    deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]]))
-    splits_ = newRDD.splits
+  override def clearDependencies() {
+    deps_ = null
+    splits_ = null
     rdds = null
   }
 }
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index c5e2300d2638c2cbdbdecd7531e6b356036fb57b..167755bbba2674d73e91bd7802a2d89bbd3e8e8c 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,11 +1,22 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
-
 import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
 
+private[spark] case class CoalescedRDDSplit(
+    index: Int,
+    @transient rdd: RDD[_],
+    parentsIndices: Array[Int]
+  ) extends Split {
+  var parents: Seq[Split] = parentsIndices.map(rdd.splits(_))
 
-private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
+  @throws(classOf[IOException])
+  private def writeObject(oos: ObjectOutputStream) {
+    // Update the reference to parent split at the time of task serialization
+    parents = parentsIndices.map(rdd.splits(_))
+    oos.defaultWriteObject()
+  }
+}
 
 /**
  * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
@@ -23,17 +34,17 @@ class CoalescedRDD[T: ClassManifest](
   @transient var splits_ : Array[Split] = {
     val prevSplits = prev.splits
     if (prevSplits.length < maxPartitions) {
-      prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) }
+      prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
     } else {
       (0 until maxPartitions).map { i =>
         val rangeStart = (i * prevSplits.length) / maxPartitions
         val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
-        new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd))
+        new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray)
       }.toArray
     }
   }
 
-  override def splits = splits_
+  override def getSplits = splits_
 
   override def compute(split: Split, context: TaskContext): Iterator[T] = {
     split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
@@ -44,15 +55,15 @@ class CoalescedRDD[T: ClassManifest](
   var deps_ : List[Dependency[_]] = List(
     new NarrowDependency(prev) {
       def getParents(id: Int): Seq[Int] =
-        splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index)
+        splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
     }
   )
 
-  override def dependencies = deps_
+  override def getDependencies() = deps_
 
-  override protected def changeDependencies(newRDD: RDD[_]) {
-    deps_ = List(new OneToOneDependency(newRDD))
-    splits_ = newRDD.splits
+  override def clearDependencies() {
+    deps_ = Nil
+    splits_ = null
     prev = null
   }
 }
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index 70c4be7903a0aabc4adf01a4d6396decfc5cf9a7..b80e9bc07b45867eb3b03598144cd8892c706f5d 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -1,14 +1,14 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
-
 import spark.{OneToOneDependency, RDD, Split, TaskContext}
 
-private[spark]
-class FilteredRDD[T: ClassManifest](prev: WeakReference[RDD[T]], f: T => Boolean)
-  extends RDD[T](prev.get) {
+private[spark] class FilteredRDD[T: ClassManifest](
+    prev: RDD[T],
+    f: T => Boolean)
+  extends RDD[T](prev) {
+
+  override def getSplits = firstParent[T].splits
 
-  override def splits = firstParent[T].splits
   override def compute(split: Split, context: TaskContext) =
     firstParent[T].iterator(split, context).filter(f)
-}
\ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 1ebbb4c9bd70a7ba26eb4b2d7c266a4093f0a6ee..1b604c66e2fa52c849c38b1f16738b0949a7b405 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -1,17 +1,16 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
-
 import spark.{RDD, Split, TaskContext}
 
 
 private[spark]
 class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
-    prev: WeakReference[RDD[T]],
+    prev: RDD[T],
     f: T => TraversableOnce[U])
-  extends RDD[U](prev.get) {
+  extends RDD[U](prev) {
+
+  override def getSplits = firstParent[T].splits
 
-  override def splits = firstParent[T].splits
   override def compute(split: Split, context: TaskContext) =
     firstParent[T].iterator(split, context).flatMap(f)
 }
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index 43661ae3f8c99b96bec0e5e1def3efa728a78e87..051bffed192bc39a9ac3084ffbec49acdb524340 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -1,13 +1,12 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
-
 import spark.{RDD, Split, TaskContext}
 
-private[spark]
-class GlommedRDD[T: ClassManifest](prev: WeakReference[RDD[T]])
-  extends RDD[Array[T]](prev.get) {
-  override def splits = firstParent[T].splits
+private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
+  extends RDD[Array[T]](prev) {
+
+  override def getSplits = firstParent[T].splits
+
   override def compute(split: Split, context: TaskContext) =
     Array(firstParent[T].iterator(split, context).toArray).iterator
 }
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index 7b5f8ac3e981aa13f9efeec3843cd3e052d4653c..f547f53812661da253353384e43dbe2702a3cb68 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -22,9 +22,8 @@ import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskCo
  * A Spark split class that wraps around a Hadoop InputSplit.
  */
 private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
-  extends Split
-  with Serializable {
-
+  extends Split {
+  
   val inputSplit = new SerializableWritable[InputSplit](s)
 
   override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -64,7 +63,7 @@ class HadoopRDD[K, V](
       .asInstanceOf[InputFormat[K, V]]
   }
 
-  override def splits = splits_
+  override def getSplits = splits_
 
   override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
     val split = theSplit.asInstanceOf[HadoopSplit]
@@ -110,11 +109,13 @@ class HadoopRDD[K, V](
     }
   }
 
-  override def preferredLocations(split: Split) = {
+  override def getPreferredLocations(split: Split) = {
     // TODO: Filtering out "localhost" in case of file:// URLs
     val hadoopSplit = split.asInstanceOf[HadoopSplit]
     hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
   }
 
-  override val isCheckpointable = false
+  override def checkpoint() {
+    // Do nothing. Hadoop RDD should not be checkpointed.
+  }
 }
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index 991f4be73f347ecebac41baabe410ff306d220e8..073f7d7d2aad251c4240bd665b9fc02e90eec8a8 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -1,24 +1,20 @@
 package spark.rdd
 
-
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
-import java.lang.ref.WeakReference
-
 import spark.{RDD, Split, TaskContext}
 
 
 private[spark]
 class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
-    prev: WeakReference[RDD[T]],
+    prev: RDD[T],
     f: Iterator[T] => Iterator[U],
     preservesPartitioning: Boolean = false)
-  extends RDD[U](prev.get) {
+  extends RDD[U](prev) {
+
+  override val partitioner =
+    if (preservesPartitioning) firstParent[T].partitioner else None
 
-  override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
+  override def getSplits = firstParent[T].splits
 
-  override def splits = firstParent[T].splits
   override def compute(split: Split, context: TaskContext) =
     f(firstParent[T].iterator(split, context))
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index e2e7753cded75a78f4a35fc92f3776621a0e0aa0..2ddc3d01b647a573d85aa0c3622341fdd3ed1adb 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -1,7 +1,5 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
-
 import spark.{RDD, Split, TaskContext}
 
 
@@ -12,13 +10,15 @@ import spark.{RDD, Split, TaskContext}
  */
 private[spark]
 class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
-    prev: WeakReference[RDD[T]],
+    prev: RDD[T],
     f: (Int, Iterator[T]) => Iterator[U],
-    preservesPartitioning: Boolean)
-  extends RDD[U](prev.get) {
+    preservesPartitioning: Boolean
+  ) extends RDD[U](prev) {
+
+  override def getSplits = firstParent[T].splits
+
+  override val partitioner = if (preservesPartitioning) prev.partitioner else None
 
-  override val partitioner = if (preservesPartitioning) prev.get.partitioner else None
-  override def splits = firstParent[T].splits
   override def compute(split: Split, context: TaskContext) =
     f(split.index, firstParent[T].iterator(split, context))
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 986cf35291ad35ea2e0a0605d238486a953e30cf..c6ceb272cdc7bfd06478e4b278534937665321d1 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -1,17 +1,15 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
-
 import spark.{RDD, Split, TaskContext}
 
-
 private[spark]
 class MappedRDD[U: ClassManifest, T: ClassManifest](
-    prev: WeakReference[RDD[T]],
+    prev: RDD[T],
     f: T => U)
-  extends RDD[U](prev.get) {
+  extends RDD[U](prev) {
+
+  override def getSplits = firstParent[T].splits
 
-  override def splits = firstParent[T].splits
   override def compute(split: Split, context: TaskContext) =
     firstParent[T].iterator(split, context).map(f)
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index c7cc8d468533e5f39c5ba7bb34625992649047d3..bb22db073c50f736c4637e34da2bee0c340a4376 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -52,7 +52,7 @@ class NewHadoopRDD[K, V](
     result
   }
 
-  override def splits = splits_
+  override def getSplits = splits_
 
   override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
     val split = theSplit.asInstanceOf[NewHadoopSplit]
@@ -87,10 +87,8 @@ class NewHadoopRDD[K, V](
     }
   }
 
-  override def preferredLocations(split: Split) = {
+  override def getPreferredLocations(split: Split) = {
     val theSplit = split.asInstanceOf[NewHadoopSplit]
     theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
   }
-
-  override val isCheckpointable = false
 }
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 076c6a64a0531de760da1fdef034d097d3c77a12..6631f83510cb6851a9a6002792415d6ca556f1c8 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -1,7 +1,6 @@
 package spark.rdd
 
 import java.io.PrintWriter
-import java.lang.ref.WeakReference
 import java.util.StringTokenizer
 
 import scala.collection.Map
@@ -17,18 +16,18 @@ import spark.{RDD, SparkEnv, Split, TaskContext}
  * (printing them one per line) and returns the output as a collection of strings.
  */
 class PipedRDD[T: ClassManifest](
-    prev: WeakReference[RDD[T]],
+    prev: RDD[T],
     command: Seq[String],
     envVars: Map[String, String])
-  extends RDD[String](prev.get) {
+  extends RDD[String](prev) {
 
-  def this(prev: WeakReference[RDD[T]], command: Seq[String]) = this(prev, command, Map())
+  def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
 
   // Similar to Runtime.exec(), if we are given a single string, split it into words
   // using a standard StringTokenizer (i.e. by spaces)
-  def this(prev: WeakReference[RDD[T]], command: String) = this(prev, PipedRDD.tokenize(command))
+  def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
 
-  override def splits = firstParent[T].splits
+  override def getSplits = firstParent[T].splits
 
   override def compute(split: Split, context: TaskContext): Iterator[String] = {
     val pb = new ProcessBuilder(command)
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 0dc83c127fb26a8048bbc20fbb531fec8301e660..1bc9c96112dfad644aa6538b351005518adb12f5 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -1,6 +1,5 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
 import java.util.Random
 
 import cern.jet.random.Poisson
@@ -14,21 +13,21 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali
 }
 
 class SampledRDD[T: ClassManifest](
-    prev: WeakReference[RDD[T]],
-    withReplacement: Boolean,
+    prev: RDD[T],
+    withReplacement: Boolean, 
     frac: Double,
     seed: Int)
-  extends RDD[T](prev.get) {
+  extends RDD[T](prev) {
 
   @transient
-  val splits_ = {
+  var splits_ : Array[Split] = {
     val rg = new Random(seed)
     firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
   }
 
-  override def splits = splits_.asInstanceOf[Array[Split]]
+  override def getSplits = splits_.asInstanceOf[Array[Split]]
 
-  override def preferredLocations(split: Split) =
+  override def getPreferredLocations(split: Split) =
     firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
 
   override def compute(splitIn: Split, context: TaskContext) = {
@@ -50,4 +49,8 @@ class SampledRDD[T: ClassManifest](
       firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
     }
   }
+
+  override def clearDependencies() {
+    splits_ = null
+  }
 }
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 7d592ffc5eddeb9b213c844c7507c56b5f0a8d82..f40b56be64ca4f6cd9f2d20fe0f136cbf20944ac 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,10 +1,7 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
-
 import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
 
-
 private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
   override val index = idx
   override def hashCode(): Int = idx
@@ -12,30 +9,29 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
 
 /**
  * The resulting RDD from a shuffle (e.g. repartitioning of data).
- * @param parent the parent RDD.
+ * @param prev the parent RDD.
  * @param part the partitioner used to partition the RDD
  * @tparam K the key class.
  * @tparam V the value class.
  */
 class ShuffledRDD[K, V](
-    @transient prev: WeakReference[RDD[(K, V)]],
+    prev: RDD[(K, V)],
     part: Partitioner)
-  extends RDD[(K, V)](prev.get.context, List(new ShuffleDependency(prev.get, part))) {
+  extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
 
   override val partitioner = Some(part)
 
   @transient
   var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
 
-  override def splits = splits_
+  override def getSplits = splits_
 
   override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
     val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
     SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
   }
 
-  override def changeDependencies(newRDD: RDD[_]) {
-    dependencies_ = Nil
-    splits_ = newRDD.splits
+  override def clearDependencies() {
+    splits_ = null
   }
 }
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 21965d3f5d282b53d8ff24e73f9218c02502b7d4..24a085df02aaef9fa0c6cdc7fdb4bd211ccd5c16 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -1,20 +1,26 @@
 package spark.rdd
 
-import java.lang.ref.WeakReference
 import scala.collection.mutable.ArrayBuffer
-import spark.{Dependency, OneToOneDependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
 
+private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+  extends Split {
 
-private[spark] class UnionSplit[T: ClassManifest](
-    idx: Int,
-    rdd: RDD[T],
-    split: Split)
-  extends Split
-  with Serializable {
+  var split: Split = rdd.splits(splitIndex)
 
   def iterator(context: TaskContext) = rdd.iterator(split, context)
+
   def preferredLocations() = rdd.preferredLocations(split)
+
   override val index: Int = idx
+
+  @throws(classOf[IOException])
+  private def writeObject(oos: ObjectOutputStream) {
+    // Update the reference to parent split at the time of task serialization
+    split = rdd.splits(splitIndex)
+    oos.defaultWriteObject()
+  }
 }
 
 class UnionRDD[T: ClassManifest](
@@ -27,13 +33,13 @@ class UnionRDD[T: ClassManifest](
     val array = new Array[Split](rdds.map(_.splits.size).sum)
     var pos = 0
     for (rdd <- rdds; split <- rdd.splits) {
-      array(pos) = new UnionSplit(pos, rdd, split)
+      array(pos) = new UnionSplit(pos, rdd, split.index)
       pos += 1
     }
     array
   }
 
-  override def splits = splits_
+  override def getSplits = splits_
 
   @transient var deps_ = {
     val deps = new ArrayBuffer[Dependency[_]]
@@ -45,22 +51,17 @@ class UnionRDD[T: ClassManifest](
     deps.toList
   }
 
-  override def dependencies = deps_
+  override def getDependencies = deps_
 
   override def compute(s: Split, context: TaskContext): Iterator[T] =
     s.asInstanceOf[UnionSplit[T]].iterator(context)
 
-  override def preferredLocations(s: Split): Seq[String] = {
-    if (isCheckpointed) {
-      checkpointRDD.preferredLocations(s)
-    } else {
-      s.asInstanceOf[UnionSplit[T]].preferredLocations()
-    }
-  }
+  override def getPreferredLocations(s: Split): Seq[String] =
+    s.asInstanceOf[UnionSplit[T]].preferredLocations()
 
-  override protected def changeDependencies(newRDD: RDD[_]) {
-    deps_ = List(new OneToOneDependency(newRDD))
-    splits_ = newRDD.splits
+  override def clearDependencies() {
+    deps_ = null
+    splits_ = null
     rdds = null
   }
 }
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 33b64e2d249a5a04d77bd37baea478095d28da6e..16e6cc0f1ba93858645c2f90416ad443ced7b472 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,55 +1,66 @@
 package spark.rdd
 
 import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
 
 
 private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
     idx: Int,
-    rdd1: RDD[T],
-    rdd2: RDD[U],
-    split1: Split,
-    split2: Split)
-  extends Split
-  with Serializable {
+    @transient rdd1: RDD[T],
+    @transient rdd2: RDD[U]
+  ) extends Split {
 
-  def iterator(context: TaskContext): Iterator[(T, U)] =
-    rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+  var split1 = rdd1.splits(idx)
+  var split2 = rdd1.splits(idx)
+  override val index: Int = idx
 
-  def preferredLocations(): Seq[String] =
-    rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+  def splits = (split1, split2)
 
-  override val index: Int = idx
+  @throws(classOf[IOException])
+  private def writeObject(oos: ObjectOutputStream) {
+    // Update the reference to parent split at the time of task serialization
+    split1 = rdd1.splits(idx)
+    split2 = rdd2.splits(idx)
+    oos.defaultWriteObject()
+  }
 }
 
 class ZippedRDD[T: ClassManifest, U: ClassManifest](
     sc: SparkContext,
-    @transient rdd1: RDD[T],
-    @transient rdd2: RDD[U])
-  extends RDD[(T, U)](sc, Nil)
+    var rdd1: RDD[T],
+    var rdd2: RDD[U])
+  extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
   with Serializable {
 
   // TODO: FIX THIS.
 
   @transient
-  val splits_ : Array[Split] = {
+  var splits_ : Array[Split] = {
     if (rdd1.splits.size != rdd2.splits.size) {
       throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
     }
     val array = new Array[Split](rdd1.splits.size)
     for (i <- 0 until rdd1.splits.size) {
-      array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i))
+      array(i) = new ZippedSplit(i, rdd1, rdd2)
     }
     array
   }
 
-  override def splits = splits_
+  override def getSplits = splits_
 
-  @transient
-  override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
+  override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
+    val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+    rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+  }
 
-  override def compute(s: Split, context: TaskContext): Iterator[(T, U)] =
-    s.asInstanceOf[ZippedSplit[T, U]].iterator(context)
+  override def getPreferredLocations(s: Split): Seq[String] = {
+    val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+    rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+  }
 
-  override def preferredLocations(s: Split): Seq[String] =
-    s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
+  override def clearDependencies() {
+    splits_ = null
+    rdd1 = null
+    rdd2 = null
+  }
 }
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index e492279b4ec6444dfc793f0a0a20309c60b8c399..7ec6564105632b077648f95dec3008d6cb041874 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -1,17 +1,74 @@
 package spark.scheduler
 
 import spark._
+import java.io._
+import util.{MetadataCleaner, TimeStampedHashMap}
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+private[spark] object ResultTask {
+
+  // A simple map between the stage id to the serialized byte array of a task.
+  // Served as a cache for task serialization because serialization can be
+  // expensive on the master node if it needs to launch thousands of tasks.
+  val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+  val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.cleanup)
+
+  def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
+    synchronized {
+      val old = serializedInfoCache.get(stageId).orNull
+      if (old != null) {
+        return old
+      } else {
+        val out = new ByteArrayOutputStream
+        val ser = SparkEnv.get.closureSerializer.newInstance
+        val objOut = ser.serializeStream(new GZIPOutputStream(out))
+        objOut.writeObject(rdd)
+        objOut.writeObject(func)
+        objOut.close()
+        val bytes = out.toByteArray
+        serializedInfoCache.put(stageId, bytes)
+        return bytes
+      }
+    }
+  }
+
+  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
+    synchronized {
+      val loader = Thread.currentThread.getContextClassLoader
+      val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+      val ser = SparkEnv.get.closureSerializer.newInstance
+      val objIn = ser.deserializeStream(in)
+      val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+      val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
+      return (rdd, func)
+    }
+  }
+
+  def clearCache() {
+    synchronized {
+      serializedInfoCache.clear()
+    }
+  }
+}
+
 
 private[spark] class ResultTask[T, U](
     stageId: Int,
-    rdd: RDD[T],
-    func: (TaskContext, Iterator[T]) => U,
-    val partition: Int,
+    var rdd: RDD[T],
+    var func: (TaskContext, Iterator[T]) => U,
+    var partition: Int,
     @transient locs: Seq[String],
     val outputId: Int)
-  extends Task[U](stageId) {
+  extends Task[U](stageId) with Externalizable {
 
-  val split = rdd.splits(partition)
+  def this() = this(0, null, null, 0, null, 0)
+
+  var split = if (rdd == null) {
+    null
+  } else {
+    rdd.splits(partition)
+  }
 
   override def run(attemptId: Long): U = {
     val context = new TaskContext(stageId, partition, attemptId)
@@ -23,4 +80,31 @@ private[spark] class ResultTask[T, U](
   override def preferredLocations: Seq[String] = locs
 
   override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+
+  override def writeExternal(out: ObjectOutput) {
+    RDDCheckpointData.synchronized {
+      split = rdd.splits(partition)
+      out.writeInt(stageId)
+      val bytes = ResultTask.serializeInfo(
+        stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
+      out.writeInt(bytes.length)
+      out.write(bytes)
+      out.writeInt(partition)
+      out.writeInt(outputId)
+      out.writeObject(split)
+    }
+  }
+
+  override def readExternal(in: ObjectInput) {
+    val stageId = in.readInt()
+    val numBytes = in.readInt()
+    val bytes = new Array[Byte](numBytes)
+    in.readFully(bytes)
+    val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
+    rdd = rdd_.asInstanceOf[RDD[T]]
+    func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
+    partition = in.readInt()
+    val outputId = in.readInt()
+    split = in.readObject().asInstanceOf[Split]
+  }
 }
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 7fdc178d4b6c66dd0897273d8cd72337cdc76cfa..feb63abb618e4999f20fef6a182f63525cabbfd8 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -90,13 +90,16 @@ private[spark] class ShuffleMapTask(
   }
 
   override def writeExternal(out: ObjectOutput) {
-    out.writeInt(stageId)
-    val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
-    out.writeInt(bytes.length)
-    out.write(bytes)
-    out.writeInt(partition)
-    out.writeLong(generation)
-    out.writeObject(split)
+    RDDCheckpointData.synchronized {
+      split = rdd.splits(partition)
+      out.writeInt(stageId)
+      val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+      out.writeInt(bytes.length)
+      out.write(bytes)
+      out.writeInt(partition)
+      out.writeLong(generation)
+      out.writeObject(split)
+    }
   }
 
   override def readExternal(in: ObjectInput) {
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 302a95db664fbbe8aa7ead5fa91f6ebf68e6a5f5..51573254cac7db9d3a1885f2ec51b84b437abe4b 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -2,17 +2,16 @@ package spark
 
 import org.scalatest.{BeforeAndAfter, FunSuite}
 import java.io.File
-import rdd.{BlockRDD, CoalescedRDD, MapPartitionsWithSplitRDD}
+import spark.rdd._
 import spark.SparkContext._
 import storage.StorageLevel
-import java.util.concurrent.Semaphore
-import collection.mutable.ArrayBuffer
 
 class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
   initLogging()
 
   var sc: SparkContext = _
   var checkpointDir: File = _
+  val partitioner = new HashPartitioner(2)
 
   before {
     checkpointDir = File.createTempFile("temp", "")
@@ -35,14 +34,30 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
     }
   }
 
+  test("RDDs with one-to-one dependencies") {
+    testCheckpointing(_.map(x => x.toString))
+    testCheckpointing(_.flatMap(x => 1 to x))
+    testCheckpointing(_.filter(_ % 2 == 0))
+    testCheckpointing(_.sample(false, 0.5, 0))
+    testCheckpointing(_.glom())
+    testCheckpointing(_.mapPartitions(_.map(_.toString)))
+    testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
+      (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+    testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
+    testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
+    testCheckpointing(_.pipe(Seq("cat")))
+  }
+
   test("ParallelCollection") {
-    val parCollection = sc.makeRDD(1 to 4)
+    val parCollection = sc.makeRDD(1 to 4, 2)
+    val numSplits = parCollection.splits.size
     parCollection.checkpoint()
     assert(parCollection.dependencies === Nil)
     val result = parCollection.collect()
-    sleep(parCollection) // slightly extra time as loading classes for the first can take some time
-    assert(sc.objectFile[Int](parCollection.checkpointFile).collect() === result)
+    assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
     assert(parCollection.dependencies != Nil)
+    assert(parCollection.splits.length === numSplits)
+    assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList)
     assert(parCollection.collect() === result)
   }
 
@@ -51,163 +66,292 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
     val blockManager = SparkEnv.get.blockManager
     blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
     val blockRDD = new BlockRDD[String](sc, Array(blockId))
+    val numSplits = blockRDD.splits.size
     blockRDD.checkpoint()
     val result = blockRDD.collect()
-    sleep(blockRDD)
-    assert(sc.objectFile[String](blockRDD.checkpointFile).collect() === result)
+    assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
     assert(blockRDD.dependencies != Nil)
+    assert(blockRDD.splits.length === numSplits)
+    assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList)
     assert(blockRDD.collect() === result)
   }
 
-  test("RDDs with one-to-one dependencies") {
-    testCheckpointing(_.map(x => x.toString))
-    testCheckpointing(_.flatMap(x => 1 to x))
-    testCheckpointing(_.filter(_ % 2 == 0))
-    testCheckpointing(_.sample(false, 0.5, 0))
-    testCheckpointing(_.glom())
-    testCheckpointing(_.mapPartitions(_.map(_.toString)))
-    testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
-      (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false))
-    testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString), 1000)
-    testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x), 1000)
-    testCheckpointing(_.pipe(Seq("cat")))
-  }
-
   test("ShuffledRDD") {
-    testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _))
+    testCheckpointing(rdd => {
+      // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
+      new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner)
+    })
   }
 
   test("UnionRDD") {
-    testCheckpointing(_.union(sc.makeRDD(5 to 6, 4)))
+    def otherRDD = sc.makeRDD(1 to 10, 1)
+
+    // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed.
+    // Current implementation of UnionRDD has transient reference to parent RDDs,
+    // so only the splits will reduce in serialized size, not the RDD.
+    testCheckpointing(_.union(otherRDD), false, true)
+    testParentCheckpointing(_.union(otherRDD), false, true)
   }
 
   test("CartesianRDD") {
-    testCheckpointing(_.cartesian(sc.makeRDD(5 to 6, 4)), 1000)
+    def otherRDD = sc.makeRDD(1 to 10, 1)
+    testCheckpointing(new CartesianRDD(sc, _, otherRDD))
+
+    // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+    // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+    // so only the RDD will reduce in serialized size, not the splits.
+    testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
+
+    // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after
+    // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+    // Note that this test is very specific to the current implementation of CartesianRDD.
+    val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+    ones.checkpoint // checkpoint that MappedRDD
+    val cartesian = new CartesianRDD(sc, ones, ones)
+    val splitBeforeCheckpoint =
+      serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+    cartesian.count() // do the checkpointing
+    val splitAfterCheckpoint =
+      serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+    assert(
+      (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
+        (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
+      "CartesianRDD.parents not updated after parent RDD checkpointed"
+    )
   }
 
   test("CoalescedRDD") {
     testCheckpointing(new CoalescedRDD(_, 2))
+
+    // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+    // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+    // so only the RDD will reduce in serialized size, not the splits.
+    testParentCheckpointing(new CoalescedRDD(_, 2), true, false)
+
+    // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after
+    // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+    // Note that this test is very specific to the current implementation of CoalescedRDDSplits
+    val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+    ones.checkpoint // checkpoint that MappedRDD
+    val coalesced = new CoalescedRDD(ones, 2)
+    val splitBeforeCheckpoint =
+      serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+    coalesced.count() // do the checkpointing
+    val splitAfterCheckpoint =
+      serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+    assert(
+      splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
+      "CoalescedRDDSplit.parents not updated after parent RDD checkpointed"
+    )
   }
 
   test("CoGroupedRDD") {
-    val rdd2 = sc.makeRDD(5 to 6, 4).map(x => (x % 2, 1))
-    testCheckpointing(rdd1 => rdd1.map(x => (x % 2, 1)).cogroup(rdd2))
-    testCheckpointing(rdd1 => rdd1.map(x => (x % 2, x)).join(rdd2))
+    val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD()
+    testCheckpointing(rdd => {
+      CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
+    }, false, true)
 
-    // Special test to make sure that the CoGroupSplit of CoGroupedRDD do not
-    // hold on to the splits of its parent RDDs, as the splits of parent RDDs
-    // may change while checkpointing. Rather the splits of parent RDDs must
-    // be fetched at the time of serialization to ensure the latest splits to
-    // be sent along with the task.
+    val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD()
+    testParentCheckpointing(rdd => {
+      CheckpointSuite.cogroup(
+        longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
+    }, false, true)
+  }
 
-    val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
+  test("ZippedRDD") {
+    testCheckpointing(
+      rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+    // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
+    // Current implementation of ZippedRDDSplit has transient references to parent RDDs,
+    // so only the RDD will reduce in serialized size, not the splits.
+    testParentCheckpointing(
+      rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+  }
 
-    val ones = sc.parallelize(1 to 100, 1).map(x => (x,1))
-    val reduced = ones.reduceByKey(_ + _)
-    val seqOfCogrouped = new ArrayBuffer[RDD[(Int, Int)]]()
-    seqOfCogrouped += reduced.cogroup(ones).mapValues[Int](add)
-    for(i <- 1 to 10) {
-      seqOfCogrouped += seqOfCogrouped.last.cogroup(ones).mapValues(add)
-    }
-    val finalCogrouped = seqOfCogrouped.last
-    val intermediateCogrouped = seqOfCogrouped(5)
-
-    val bytesBeforeCheckpoint = Utils.serialize(finalCogrouped.splits)
-    intermediateCogrouped.checkpoint()
-    finalCogrouped.count()
-    sleep(intermediateCogrouped)
-    val bytesAfterCheckpoint = Utils.serialize(finalCogrouped.splits)
-    println("Before = " + bytesBeforeCheckpoint.size + ", after = " + bytesAfterCheckpoint.size)
-    assert(bytesAfterCheckpoint.size < bytesBeforeCheckpoint.size,
-      "CoGroupedSplits still holds on to the splits of its parent RDDs")
-  }
-  /*
   /**
-   * This test forces two ResultTasks of the same job to be launched before and after
-   * the checkpointing of job's RDD is completed.
+   * Test checkpointing of the final RDD generated by the given operation. By default,
+   * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
+   * It can also test whether the size of serialized RDD splits has reduced after checkpointing or
+   * not, but this is not done by default as usually the splits do not refer to any RDD and
+   * therefore never store the lineage.
    */
-  test("Threading - ResultTasks") {
-    val op1 = (parCollection: RDD[Int]) => {
-      parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) })
+  def testCheckpointing[U: ClassManifest](
+      op: (RDD[Int]) => RDD[U],
+      testRDDSize: Boolean = true,
+      testRDDSplitSize: Boolean = false
+    ) {
+    // Generate the final RDD using given RDD operation
+    val baseRDD = generateLongLineageRDD
+    val operatedRDD = op(baseRDD)
+    val parentRDD = operatedRDD.dependencies.headOption.orNull
+    val rddType = operatedRDD.getClass.getSimpleName
+    val numSplits = operatedRDD.splits.length
+
+    // Find serialized sizes before and after the checkpoint
+    val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+    operatedRDD.checkpoint()
+    val result = operatedRDD.collect()
+    val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+    // Test whether the checkpoint file has been created
+    assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result)
+
+    // Test whether dependencies have been changed from its earlier parent RDD
+    assert(operatedRDD.dependencies.head.rdd != parentRDD)
+
+    // Test whether the splits have been changed to the new Hadoop splits
+    assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList)
+
+    // Test whether the number of splits is same as before
+    assert(operatedRDD.splits.length === numSplits)
+
+    // Test whether the data in the checkpointed RDD is same as original
+    assert(operatedRDD.collect() === result)
+
+    // Test whether serialized size of the RDD has reduced. If the RDD
+    // does not have any dependency to another RDD (e.g., ParallelCollection,
+    // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing.
+    if (testRDDSize) {
+      logInfo("Size of " + rddType +
+        "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+      assert(
+        rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+        "Size of " + rddType + " did not reduce after checkpointing " +
+          "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+      )
     }
-    val op2 = (firstRDD: RDD[(Int, Int)]) => {
-      firstRDD.map(x => { println("2nd map running on " + x); Thread.sleep(500); x })
+
+    // Test whether serialized size of the splits has reduced. If the splits
+    // do not have any non-transient reference to another RDD or another RDD's splits, it
+    // does not refer to a lineage and therefore may not reduce in size after checkpointing.
+    // However, if the original splits before checkpointing do refer to a parent RDD, the splits
+    // must be forgotten after checkpointing (to remove all reference to parent RDDs) and
+    // replaced with the HadoopSplits of the checkpointed RDD.
+    if (testRDDSplitSize) {
+      logInfo("Size of " + rddType + " splits "
+        + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
+      assert(
+        splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+        "Size of " + rddType + " splits did not reduce after checkpointing " +
+          "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+      )
     }
-    testThreading(op1, op2)
   }
 
   /**
-   * This test forces two ShuffleMapTasks of the same job to be launched before and after
-   * the checkpointing of job's RDD is completed.
+   * Test whether checkpointing of the parent of the generated RDD also
+   * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
+   * RDDs splits. So even if the parent RDD is checkpointed and its splits changed,
+   * this RDD will remember the splits and therefore potentially the whole lineage.
    */
-  test("Threading - ShuffleMapTasks") {
-    val op1 = (parCollection: RDD[Int]) => {
-      parCollection.map(x => { println("1st map running on " + x); Thread.sleep(500); (x % 2, x) })
+  def testParentCheckpointing[U: ClassManifest](
+      op: (RDD[Int]) => RDD[U],
+      testRDDSize: Boolean,
+      testRDDSplitSize: Boolean
+    ) {
+    // Generate the final RDD using given RDD operation
+    val baseRDD = generateLongLineageRDD
+    val operatedRDD = op(baseRDD)
+    val parentRDD = operatedRDD.dependencies.head.rdd
+    val rddType = operatedRDD.getClass.getSimpleName
+    val parentRDDType = parentRDD.getClass.getSimpleName
+
+    // Find serialized sizes before and after the checkpoint
+    val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+    parentRDD.checkpoint()  // checkpoint the parent RDD, not the generated one
+    val result = operatedRDD.collect()
+    val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+    // Test whether the data in the checkpointed RDD is same as original
+    assert(operatedRDD.collect() === result)
+
+    // Test whether serialized size of the RDD has reduced because of its parent being
+    // checkpointed. If this RDD or its parent RDD do not have any dependency
+    // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may
+    // not reduce in size after checkpointing.
+    if (testRDDSize) {
+      assert(
+        rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+        "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType +
+          "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+      )
     }
-    val op2 = (firstRDD: RDD[(Int, Int)]) => {
-      firstRDD.groupByKey(2).map(x => { println("2nd map running on " + x); Thread.sleep(500); x })
+
+    // Test whether serialized size of the splits has reduced because of its parent being
+    // checkpointed. If the splits do not have any non-transient reference to another RDD
+    // or another RDD's splits, it does not refer to a lineage and therefore may not reduce
+    // in size after checkpointing. However, if the splits do refer to the *splits* of a parent
+    // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's
+    // splits must have changed after checkpointing.
+    if (testRDDSplitSize) {
+      assert(
+        splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+        "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType +
+          "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+      )
     }
-    testThreading(op1, op2)
+
   }
-  */
 
-  def testCheckpointing[U: ClassManifest](op: (RDD[Int]) => RDD[U], sleepTime: Long = 500) {
-    val parCollection = sc.makeRDD(1 to 4, 4)
-    val operatedRDD = op(parCollection)
-    operatedRDD.checkpoint()
-    val parentRDD = operatedRDD.dependencies.head.rdd
-    val result = operatedRDD.collect()
-    sleep(operatedRDD)
-    //println(parentRDD + ", " + operatedRDD.dependencies.head.rdd )
-    assert(sc.objectFile[U](operatedRDD.checkpointFile).collect() === result)
-    assert(operatedRDD.dependencies.head.rdd != parentRDD)
-    assert(operatedRDD.collect() === result)
+  /**
+   * Generate an RDD with a long lineage of one-to-one dependencies.
+   */
+  def generateLongLineageRDD(): RDD[Int] = {
+    var rdd = sc.makeRDD(1 to 100, 4)
+    for (i <- 1 to 50) {
+      rdd = rdd.map(x => x + 1)
+    }
+    rdd
   }
-  /*
-  def testThreading[U: ClassManifest, V: ClassManifest](op1: (RDD[Int]) => RDD[U], op2: (RDD[U]) => RDD[V]) {
-
-    val parCollection = sc.makeRDD(1 to 2, 2)
-
-    // This is the RDD that is to be checkpointed
-    val firstRDD = op1(parCollection)
-    val parentRDD = firstRDD.dependencies.head.rdd
-    firstRDD.checkpoint()
-
-    // This the RDD that uses firstRDD. This is designed to launch a
-    // ShuffleMapTask that uses firstRDD.
-    val secondRDD = op2(firstRDD)
-
-    // Starting first job, to initiate the checkpointing
-    logInfo("\nLaunching 1st job to initiate checkpointing\n")
-    firstRDD.collect()
-
-    // Checkpointing has started but not completed yet
-    Thread.sleep(100)
-    assert(firstRDD.dependencies.head.rdd === parentRDD)
-
-    // Starting second job; first task of this job will be
-    // launched _before_ firstRDD is marked as checkpointed
-    // and the second task will be launched _after_ firstRDD
-    // is marked as checkpointed
-    logInfo("\nLaunching 2nd job that is designed to launch tasks " +
-      "before and after checkpointing is complete\n")
-    val result = secondRDD.collect()
-
-    // Check whether firstRDD has been successfully checkpointed
-    assert(firstRDD.dependencies.head.rdd != parentRDD)
-
-    logInfo("\nRecomputing 2nd job to verify the results of the previous computation\n")
-    // Check whether the result in the previous job was correct or not
-    val correctResult = secondRDD.collect()
-    assert(result === correctResult)
-  }
-  */
-  def sleep(rdd: RDD[_]) {
-    val startTime = System.currentTimeMillis()
-    val maxWaitTime = 5000
-    while(rdd.isCheckpointed == false && System.currentTimeMillis() < startTime + maxWaitTime) {
-      Thread.sleep(50)
+
+  /**
+   * Generate an RDD with a long lineage specifically for CoGroupedRDD.
+   * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage
+   * and narrow dependency with this RDD. This method generate such an RDD by a sequence
+   * of cogroups and mapValues which creates a long lineage of narrow dependencies.
+   */
+  def generateLongLineageRDDForCoGroupedRDD() = {
+    val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
+
+    def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+
+    var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones)
+    for(i <- 1 to 10) {
+      cogrouped = cogrouped.mapValues(add).cogroup(ones)
     }
-    assert(rdd.isCheckpointed === true, "Waiting for checkpoint to complete took more than " + maxWaitTime + " ms")
+    cogrouped.mapValues(add)
+  }
+
+  /**
+   * Get serialized sizes of the RDD and its splits
+   */
+  def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+    (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size)
+  }
+
+  /**
+   * Serialize and deserialize an object. This is useful to verify the objects
+   * contents after deserialization (e.g., the contents of an RDD split after
+   * it is sent to a slave along with a task)
+   */
+  def serializeDeserialize[T](obj: T): T = {
+    val bytes = Utils.serialize(obj)
+    Utils.deserialize[T](bytes)
   }
 }
+
+
+object CheckpointSuite {
+  // This is a custom cogroup function that does not use mapValues like
+  // the PairRDDFunctions.cogroup()
+  def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
+    //println("First = " + first + ", second = " + second)
+    new CoGroupedRDD[K](
+      Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]),
+      part
+    ).asInstanceOf[RDD[(K, Seq[Seq[V]])]]
+  }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 792c129be8d3faec1bc9c1d506ef887c87790a6e..d5048aeed7b06e5589fb8c91028ba90d1f4a1769 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -334,20 +334,22 @@ extends Serializable with Logging {
    * this method to save custom checkpoint data.
    */
   protected[streaming] def updateCheckpointData(currentTime: Time) {
-
     logInfo("Updating checkpoint data for time " + currentTime)
 
     // Get the checkpointed RDDs from the generated RDDs
-    val newRdds = generatedRDDs.filter(_._2.getCheckpointData() != null)
-                                         .map(x => (x._1, x._2.getCheckpointData()))
-    // Make a copy of the existing checkpoint data
+    val newRdds = generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
+                               .map(x => (x._1, x._2.getCheckpointFile.get))
+
+    // Make a copy of the existing checkpoint data (checkpointed RDDs)
     val oldRdds = checkpointData.rdds.clone()
-    // If the new checkpoint has checkpoints then replace existing with the new one
+
+    // If the new checkpoint data has checkpoints then replace existing with the new one
     if (newRdds.size > 0) {
       checkpointData.rdds.clear()
       checkpointData.rdds ++= newRdds
     }
-    // Make dependencies update their checkpoint data
+
+    // Make parent DStreams update their checkpoint data
     dependencies.foreach(_.updateCheckpointData(currentTime))
 
     // TODO: remove this, this is just for debugging
@@ -381,9 +383,7 @@ extends Serializable with Logging {
     checkpointData.rdds.foreach {
       case(time, data) => {
         logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'")
-        val rdd = ssc.sc.objectFile[T](data.toString)
-        // Set the checkpoint file name to identify this RDD as a checkpointed RDD by updateCheckpointData()
-        rdd.checkpointFile = data.toString
+        val rdd = ssc.sc.checkpointFile[T](data.toString)
         generatedRDDs += ((time, rdd))
       }
     }