diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
index ff2ed4cdfc4195aef3de1668a0f0cb76c125450c..7613b338e64c07bc9a0b7b8ce7798a9ecac27ce1 100644
--- a/core/src/main/scala/spark/RDDCheckpointData.scala
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -5,45 +5,41 @@ import rdd.CoalescedRDD
 import scheduler.{ResultTask, ShuffleMapTask}
 
 /**
- * This class contains all the information of the regarding RDD checkpointing.
+ * Enumeration to manage state transitions of an RDD through checkpointing
+ * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
  */
+private[spark] object CheckpointState extends Enumeration {
+  type CheckpointState = Value
+  val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
+}
 
+/**
+ * This class contains all the information of the regarding RDD checkpointing.
+ */
 private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
 extends Logging with Serializable {
 
-  /**
-   * This class manages the state transition of an RDD through checkpointing
-   * [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
-   */
-  class CheckpointState extends Serializable {
-    var state = 0
+  import CheckpointState._
 
-    def mark()    { if (state == 0) state = 1 }
-    def start()   { assert(state == 1); state = 2 }
-    def finish()  { assert(state == 2); state = 3 }
-
-    def isMarked() =     { state == 1 }
-    def isInProgress =   { state == 2 }
-    def isCheckpointed = { state == 3 }
-  }
-
-  val cpState = new CheckpointState()
+  var cpState = Initialized
   @transient var cpFile: Option[String] = None
   @transient var cpRDD: Option[RDD[T]] = None
   @transient var cpRDDSplits: Seq[Split] = Nil
 
   // Mark the RDD for checkpointing
-  def markForCheckpoint() = {
-    RDDCheckpointData.synchronized { cpState.mark() }
+  def markForCheckpoint() {
+    RDDCheckpointData.synchronized {
+      if (cpState == Initialized) cpState = MarkedForCheckpoint
+    }
   }
 
   // Is the RDD already checkpointed
-  def isCheckpointed() = {
-    RDDCheckpointData.synchronized { cpState.isCheckpointed }
+  def isCheckpointed(): Boolean = {
+    RDDCheckpointData.synchronized { cpState == Checkpointed }
   }
 
-  // Get the file to which this RDD was checkpointed to as a Option
-  def getCheckpointFile() = {
+  // Get the file to which this RDD was checkpointed to as an Option
+  def getCheckpointFile(): Option[String] = {
     RDDCheckpointData.synchronized { cpFile }
   }
 
@@ -52,8 +48,8 @@ extends Logging with Serializable {
     // If it is marked for checkpointing AND checkpointing is not already in progress,
     // then set it to be in progress, else return
     RDDCheckpointData.synchronized {
-      if (cpState.isMarked && !cpState.isInProgress) {
-        cpState.start()
+      if (cpState == MarkedForCheckpoint) {
+        cpState = CheckpointingInProgress
       } else {
         return
       }
@@ -87,7 +83,7 @@ extends Logging with Serializable {
       cpRDD = Some(newRDD)
       cpRDDSplits = newRDD.splits
       rdd.changeDependencies(newRDD)
-      cpState.finish()
+      cpState = Checkpointed
       RDDCheckpointData.checkpointCompleted()
       logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
     }