From 72eed2b95edb3b0b213517c815e09c3886b11669 Mon Sep 17 00:00:00 2001 From: Tathagata Das <tathagata.das1565@gmail.com> Date: Mon, 17 Dec 2012 18:52:43 -0800 Subject: [PATCH] Converted CheckpointState in RDDCheckpointData to use scala Enumeration. --- .../main/scala/spark/RDDCheckpointData.scala | 48 +++++++++---------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index ff2ed4cdfc..7613b338e6 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -5,45 +5,41 @@ import rdd.CoalescedRDD import scheduler.{ResultTask, ShuffleMapTask} /** - * This class contains all the information of the regarding RDD checkpointing. + * Enumeration to manage state transitions of an RDD through checkpointing + * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] */ +private[spark] object CheckpointState extends Enumeration { + type CheckpointState = Value + val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value +} +/** + * This class contains all the information of the regarding RDD checkpointing. + */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) extends Logging with Serializable { - /** - * This class manages the state transition of an RDD through checkpointing - * [ Not checkpointed --> marked for checkpointing --> checkpointing in progress --> checkpointed ] - */ - class CheckpointState extends Serializable { - var state = 0 + import CheckpointState._ - def mark() { if (state == 0) state = 1 } - def start() { assert(state == 1); state = 2 } - def finish() { assert(state == 2); state = 3 } - - def isMarked() = { state == 1 } - def isInProgress = { state == 2 } - def isCheckpointed = { state == 3 } - } - - val cpState = new CheckpointState() + var cpState = Initialized @transient var cpFile: Option[String] = None @transient var cpRDD: Option[RDD[T]] = None @transient var cpRDDSplits: Seq[Split] = Nil // Mark the RDD for checkpointing - def markForCheckpoint() = { - RDDCheckpointData.synchronized { cpState.mark() } + def markForCheckpoint() { + RDDCheckpointData.synchronized { + if (cpState == Initialized) cpState = MarkedForCheckpoint + } } // Is the RDD already checkpointed - def isCheckpointed() = { - RDDCheckpointData.synchronized { cpState.isCheckpointed } + def isCheckpointed(): Boolean = { + RDDCheckpointData.synchronized { cpState == Checkpointed } } - // Get the file to which this RDD was checkpointed to as a Option - def getCheckpointFile() = { + // Get the file to which this RDD was checkpointed to as an Option + def getCheckpointFile(): Option[String] = { RDDCheckpointData.synchronized { cpFile } } @@ -52,8 +48,8 @@ extends Logging with Serializable { // If it is marked for checkpointing AND checkpointing is not already in progress, // then set it to be in progress, else return RDDCheckpointData.synchronized { - if (cpState.isMarked && !cpState.isInProgress) { - cpState.start() + if (cpState == MarkedForCheckpoint) { + cpState = CheckpointingInProgress } else { return } @@ -87,7 +83,7 @@ extends Logging with Serializable { cpRDD = Some(newRDD) cpRDDSplits = newRDD.splits rdd.changeDependencies(newRDD) - cpState.finish() + cpState = Checkpointed RDDCheckpointData.checkpointCompleted() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } -- GitLab