diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index ff2ed4cdfc4195aef3de1668a0f0cb76c125450c..7613b338e64c07bc9a0b7b8ce7798a9ecac27ce1 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -5,45 +5,41 @@ import rdd.CoalescedRDD import scheduler.{ResultTask, ShuffleMapTask} /** - * This class contains all the information of the regarding RDD checkpointing. + * Enumeration to manage state transitions of an RDD through checkpointing + * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] */ +private[spark] object CheckpointState extends Enumeration { + type CheckpointState = Value + val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value +} +/** + * This class contains all the information of the regarding RDD checkpointing. + */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) extends Logging with Serializable { - /** - * This class manages the state transition of an RDD through checkpointing - * [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ] - */ - class CheckpointState extends Serializable { - var state = 0 + import CheckpointState._ - def mark() { if (state == 0) state = 1 } - def start() { assert(state == 1); state = 2 } - def finish() { assert(state == 2); state = 3 } - - def isMarked() = { state == 1 } - def isInProgress = { state == 2 } - def isCheckpointed = { state == 3 } - } - - val cpState = new CheckpointState() + var cpState = Initialized @transient var cpFile: Option[String] = None @transient var cpRDD: Option[RDD[T]] = None @transient var cpRDDSplits: Seq[Split] = Nil // Mark the RDD for checkpointing - def markForCheckpoint() = { - RDDCheckpointData.synchronized { cpState.mark() } + def markForCheckpoint() { + RDDCheckpointData.synchronized { + if (cpState == Initialized) cpState = MarkedForCheckpoint + } } // Is the RDD already checkpointed - def isCheckpointed() = { - RDDCheckpointData.synchronized { cpState.isCheckpointed } + def isCheckpointed(): Boolean = { + RDDCheckpointData.synchronized { cpState == Checkpointed } } - // Get the file to which this RDD was checkpointed to as a Option - def getCheckpointFile() = { + // Get the file to which this RDD was checkpointed to as an Option + def getCheckpointFile(): Option[String] = { RDDCheckpointData.synchronized { cpFile } } @@ -52,8 +48,8 @@ extends Logging with Serializable { // If it is marked for checkpointing AND checkpointing is not already in progress, // then set it to be in progress, else return RDDCheckpointData.synchronized { - if (cpState.isMarked && !cpState.isInProgress) { - cpState.start() + if (cpState == MarkedForCheckpoint) { + cpState = CheckpointingInProgress } else { return } @@ -87,7 +83,7 @@ extends Logging with Serializable { cpRDD = Some(newRDD) cpRDDSplits = newRDD.splits rdd.changeDependencies(newRDD) - cpState.finish() + cpState = Checkpointed RDDCheckpointData.checkpointCompleted() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) }