From 9d8666cac84fc4fc867f6a5e80097dbe5cb65301 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@apache.org> Date: Fri, 25 Jul 2014 18:45:02 -0700 Subject: [PATCH] Part of [SPARK-2456] Removed some HashMaps from DAGScheduler by storing information in Stage. This is part of the scheduler cleanup/refactoring effort to make the scheduler code easier to maintain. @kayousterhout @markhamstra please take a look ... Author: Reynold Xin <rxin@apache.org> Closes #1561 from rxin/dagSchedulerHashMaps and squashes the following commits: 1c44e15 [Reynold Xin] Clear pending tasks in submitMissingTasks. 620a0d1 [Reynold Xin] Use filterKeys. 5b54404 [Reynold Xin] Code review feedback. c1e9a1c [Reynold Xin] Removed some HashMaps from DAGScheduler by storing information in Stage. --- .../apache/spark/scheduler/DAGScheduler.scala | 143 +++++++----------- .../org/apache/spark/scheduler/Stage.scala | 19 ++- .../spark/scheduler/DAGSchedulerSuite.scala | 4 - 3 files changed, 69 insertions(+), 97 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 00b8af27a7..dc6142ab79 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -85,12 +85,9 @@ class DAGScheduler( private val nextStageId = new AtomicInteger(0) private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] - private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]] private[scheduler] val stageIdToStage = new HashMap[Int, Stage] private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] - private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob] - private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo] // Stages we need to run whose parents aren't done private[scheduler] val waitingStages = new HashSet[Stage] @@ -101,9 +98,6 @@ class DAGScheduler( // Stages that must be resubmitted due to fetch failures private[scheduler] val failedStages = new HashSet[Stage] - // Missing tasks from each stage - private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] - private[scheduler] val activeJobs = new HashSet[ActiveJob] // Contains the locations that each RDD's partitions are cached on @@ -223,7 +217,6 @@ class DAGScheduler( new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) - stageToInfos(stage) = StageInfo.fromStage(stage) stage } @@ -315,13 +308,12 @@ class DAGScheduler( */ private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { def updateJobIdStageIdMapsList(stages: List[Stage]) { - if (!stages.isEmpty) { + if (stages.nonEmpty) { val s = stages.head - stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId + s.jobIds += jobId jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id - val parents = getParentStages(s.rdd, jobId) - val parentsWithoutThisJobId = parents.filter(p => - !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + val parents: List[Stage] = getParentStages(s.rdd, jobId) + val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) } updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) } } @@ -333,16 +325,15 @@ class DAGScheduler( * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. * * @param job The job whose state to cleanup. - * @param resultStage Specifies the result stage for the job; if set to None, this method - * searches resultStagesToJob to find and cleanup the appropriate result stage. */ - private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) { + private def cleanupStateForJobAndIndependentStages(job: ActiveJob) { val registeredStages = jobIdToStageIds.get(job.jobId) if (registeredStages.isEmpty || registeredStages.get.isEmpty) { logError("No stages registered for job " + job.jobId) } else { - stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach { - case (stageId, jobSet) => + stageIdToStage.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach { + case (stageId, stage) => + val jobSet = stage.jobIds if (!jobSet.contains(job.jobId)) { logError( "Job %d not registered for stage %d even though that stage was registered for the job" @@ -355,14 +346,9 @@ class DAGScheduler( logDebug("Removing running stage %d".format(stageId)) runningStages -= stage } - stageToInfos -= stage for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) { shuffleToMapStage.remove(k) } - if (pendingTasks.contains(stage) && !pendingTasks(stage).isEmpty) { - logDebug("Removing pending status for stage %d".format(stageId)) - } - pendingTasks -= stage if (waitingStages.contains(stage)) { logDebug("Removing stage %d from waiting set.".format(stageId)) waitingStages -= stage @@ -374,7 +360,6 @@ class DAGScheduler( } // data structures based on StageId stageIdToStage -= stageId - stageIdToJobIds -= stageId ShuffleMapTask.removeStage(stageId) ResultTask.removeStage(stageId) @@ -393,19 +378,7 @@ class DAGScheduler( jobIdToStageIds -= job.jobId jobIdToActiveJob -= job.jobId activeJobs -= job - - if (resultStage.isEmpty) { - // Clean up result stages. - val resultStagesForJob = resultStageToJob.keySet.filter( - stage => resultStageToJob(stage).jobId == job.jobId) - if (resultStagesForJob.size != 1) { - logWarning( - s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)") - } - resultStageToJob --= resultStagesForJob - } else { - resultStageToJob -= resultStage.get - } + job.finalStage.resultOfJob = None } /** @@ -591,9 +564,10 @@ class DAGScheduler( job.listener.jobFailed(exception) } finally { val s = job.finalStage - stageIdToJobIds -= s.id // clean up data structures that were populated for a local job, - stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through - stageToInfos -= s // completion events or stage abort + // clean up data structures that were populated for a local job, + // but that won't get cleaned up via the normal paths through + // completion events or stage abort + stageIdToStage -= s.id jobIdToStageIds -= job.jobId listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult)) } @@ -605,12 +579,8 @@ class DAGScheduler( // That should take care of at least part of the priority inversion problem with // cross-job dependencies. private def activeJobForStage(stage: Stage): Option[Int] = { - if (stageIdToJobIds.contains(stage.id)) { - val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted - jobsThatUseStage.find(jobIdToActiveJob.contains) - } else { - None - } + val jobsThatUseStage: Array[Int] = stage.jobIds.toArray.sorted + jobsThatUseStage.find(jobIdToActiveJob.contains) } private[scheduler] def handleJobGroupCancelled(groupId: String) { @@ -642,9 +612,8 @@ class DAGScheduler( // is in the process of getting stopped. val stageFailedMessage = "Stage cancelled because SparkContext was shut down" runningStages.foreach { stage => - val info = stageToInfos(stage) - info.stageFailed(stageFailedMessage) - listenerBus.post(SparkListenerStageCompleted(info)) + stage.info.stageFailed(stageFailedMessage) + listenerBus.post(SparkListenerStageCompleted(stage.info)) } listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } @@ -690,7 +659,7 @@ class DAGScheduler( } else { jobIdToActiveJob(jobId) = job activeJobs += job - resultStageToJob(finalStage) = job + finalStage.resultOfJob = Some(job) listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray, properties)) submitStage(finalStage) @@ -727,8 +696,7 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) - myPending.clear() + stage.pendingTasks.clear() var tasks = ArrayBuffer[Task[_]]() if (stage.isShuffleMap) { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { @@ -737,7 +705,7 @@ class DAGScheduler( } } else { // This is a final stage; figure out its job's missing partitions - val job = resultStageToJob(stage) + val job = stage.resultOfJob.get for (id <- 0 until job.numPartitions if !job.finished(id)) { val partition = job.partitions(id) val locs = getPreferredLocs(stage.rdd, partition) @@ -758,7 +726,7 @@ class DAGScheduler( // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties)) + listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) // Preemptively serialize a task to make sure it can be serialized. We are catching this // exception here because it would be fairly hard to catch the non-serializable exception @@ -778,11 +746,11 @@ class DAGScheduler( } logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - myPending ++= tasks - logDebug("New pending tasks: " + myPending) + stage.pendingTasks ++= tasks + logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - stageToInfos(stage).submissionTime = Some(clock.getTime()) + stage.info.submissionTime = Some(clock.getTime()) } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -807,13 +775,13 @@ class DAGScheduler( val stage = stageIdToStage(task.stageId) def markStageAsFinished(stage: Stage) = { - val serviceTime = stageToInfos(stage).submissionTime match { + val serviceTime = stage.info.submissionTime match { case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0) case _ => "Unknown" } logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stageToInfos(stage).completionTime = Some(clock.getTime()) - listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + stage.info.completionTime = Some(clock.getTime()) + listenerBus.post(SparkListenerStageCompleted(stage.info)) runningStages -= stage } event.reason match { @@ -822,10 +790,10 @@ class DAGScheduler( // TODO: fail the stage if the accumulator update fails... Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted } - pendingTasks(stage) -= task + stage.pendingTasks -= task task match { case rt: ResultTask[_, _] => - resultStageToJob.get(stage) match { + stage.resultOfJob match { case Some(job) => if (!job.finished(rt.outputId)) { job.finished(rt.outputId) = true @@ -833,7 +801,7 @@ class DAGScheduler( // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { markStageAsFinished(stage) - cleanupStateForJobAndIndependentStages(job, Some(stage)) + cleanupStateForJobAndIndependentStages(job) listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded)) } @@ -860,7 +828,7 @@ class DAGScheduler( } else { stage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(stage) && pendingTasks(stage).isEmpty) { + if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) { markStageAsFinished(stage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -909,7 +877,7 @@ class DAGScheduler( case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - pendingTasks(stage) += task + stage.pendingTasks += task case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => // Mark the stage that the reducer was in as unrunnable @@ -994,13 +962,14 @@ class DAGScheduler( } private[scheduler] def handleStageCancellation(stageId: Int) { - if (stageIdToJobIds.contains(stageId)) { - val jobsThatUseStage: Array[Int] = stageIdToJobIds(stageId).toArray - jobsThatUseStage.foreach(jobId => { - handleJobCancellation(jobId, "because Stage %s was cancelled".format(stageId)) - }) - } else { - logInfo("No active jobs to kill for Stage " + stageId) + stageIdToStage.get(stageId) match { + case Some(stage) => + val jobsThatUseStage: Array[Int] = stage.jobIds.toArray + jobsThatUseStage.foreach { jobId => + handleJobCancellation(jobId, s"because Stage $stageId was cancelled") + } + case None => + logInfo("No active jobs to kill for Stage " + stageId) } submitWaitingStages() } @@ -1009,8 +978,8 @@ class DAGScheduler( if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - failJobAndIndependentStages(jobIdToActiveJob(jobId), - "Job %d cancelled %s".format(jobId, reason), None) + failJobAndIndependentStages( + jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason)) } submitWaitingStages() } @@ -1024,26 +993,21 @@ class DAGScheduler( // Skip all the actions if the stage has been removed. return } - val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq - stageToInfos(failedStage).completionTime = Some(clock.getTime()) - for (resultStage <- dependentStages) { - val job = resultStageToJob(resultStage) - failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", - Some(resultStage)) + val dependentJobs: Seq[ActiveJob] = + activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq + failedStage.info.completionTime = Some(clock.getTime()) + for (job <- dependentJobs) { + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") } - if (dependentStages.isEmpty) { + if (dependentJobs.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") } } /** * Fails a job and all stages that are only used by that job, and cleans up relevant state. - * - * @param resultStage The result stage for the job, if known. Used to cleanup state for the job - * slightly more efficiently than when not specified. */ - private def failJobAndIndependentStages(job: ActiveJob, failureReason: String, - resultStage: Option[Stage]) { + private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { val error = new SparkException(failureReason) var ableToCancelStages = true @@ -1057,7 +1021,7 @@ class DAGScheduler( logError("No stages registered for job " + job.jobId) } stages.foreach { stageId => - val jobsForStage = stageIdToJobIds.get(stageId) + val jobsForStage: Option[HashSet[Int]] = stageIdToStage.get(stageId).map(_.jobIds) if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) { logError( "Job %d not registered for stage %d even though that stage was registered for the job" @@ -1071,9 +1035,8 @@ class DAGScheduler( if (runningStages.contains(stage)) { try { // cancelTasks will fail if a SchedulerBackend does not implement killTask taskScheduler.cancelTasks(stageId, shouldInterruptThread) - val stageInfo = stageToInfos(stage) - stageInfo.stageFailed(failureReason) - listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + stage.info.stageFailed(failureReason) + listenerBus.post(SparkListenerStageCompleted(stage.info)) } catch { case e: UnsupportedOperationException => logInfo(s"Could not cancel tasks for stage $stageId", e) @@ -1086,7 +1049,7 @@ class DAGScheduler( if (ableToCancelStages) { job.listener.jobFailed(error) - cleanupStateForJobAndIndependentStages(job, resultStage) + cleanupStateForJobAndIndependentStages(job) listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 798cbc598d..800905413d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashSet + import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.BlockManagerId @@ -56,8 +58,22 @@ private[spark] class Stage( val numPartitions = rdd.partitions.size val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) var numAvailableOutputs = 0 + + /** Set of jobs that this stage belongs to. */ + val jobIds = new HashSet[Int] + + /** For stages that are the final (consists of only ResultTasks), link to the ActiveJob. */ + var resultOfJob: Option[ActiveJob] = None + var pendingTasks = new HashSet[Task[_]] + private var nextAttemptId = 0 + val name = callSite.shortForm + val details = callSite.longForm + + /** Pointer to the [StageInfo] object, set by DAGScheduler. */ + var info: StageInfo = StageInfo.fromStage(this) + def isAvailable: Boolean = { if (!isShuffleMap) { true @@ -108,9 +124,6 @@ private[spark] class Stage( def attemptId: Int = nextAttemptId - val name = callSite.shortForm - val details = callSite.longForm - override def toString = "Stage " + id override def hashCode(): Int = id diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 44dd1e092a..9021662bcf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -686,15 +686,11 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F BlockManagerId("exec-" + host, host, 12345, 0) private def assertDataStructuresEmpty = { - assert(scheduler.pendingTasks.isEmpty) assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) assert(scheduler.jobIdToStageIds.isEmpty) - assert(scheduler.stageIdToJobIds.isEmpty) assert(scheduler.stageIdToStage.isEmpty) - assert(scheduler.stageToInfos.isEmpty) - assert(scheduler.resultStageToJob.isEmpty) assert(scheduler.runningStages.isEmpty) assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) -- GitLab