diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index bf0837c0660cb2090a86a4aa75ec34bcd8b41b32..9e7791fbb46b053273e001f0d9c5bcc680d8c9a1 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -379,29 +379,34 @@ class DAGScheduler( * We run the operation in a separate thread just in case it takes a bunch of time, so that we * don't block the DAGScheduler event loop or other concurrent jobs. */ - private def runLocally(job: ActiveJob) { + protected def runLocally(job: ActiveJob) { logInfo("Computing the requested partition locally") new Thread("Local computation of job " + job.runId) { override def run() { - try { - SparkEnv.set(env) - val rdd = job.finalStage.rdd - val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) - try { - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - job.listener.taskSucceeded(0, result) - } finally { - taskContext.executeOnCompleteCallbacks() - } - } catch { - case e: Exception => - job.listener.jobFailed(e) - } + runLocallyWithinThread(job) } }.start() } + // Broken out for easier testing in DAGSchedulerSuite. + protected def runLocallyWithinThread(job: ActiveJob) { + try { + SparkEnv.set(env) + val rdd = job.finalStage.rdd + val split = rdd.partitions(job.partitions(0)) + val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) + try { + val result = job.func(taskContext, rdd.iterator(split, taskContext)) + job.listener.taskSucceeded(0, result) + } finally { + taskContext.executeOnCompleteCallbacks() + } + } catch { + case e: Exception => + job.listener.jobFailed(e) + } + } + /** Submits stage, but first recursively submits any missing parents. */ private def submitStage(stage: Stage) { logDebug("submitStage(" + stage + ")") diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 07cccc7ce087418ce90c4b1b9b5a9cb9256b4bf5..29b565ecad4afa4843f7368724869f14cbb4e529 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -90,7 +90,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter { cacheLocations.clear() results.clear() mapOutputTracker = new MapOutputTracker() - scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) + scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) { + override def runLocally(job: ActiveJob) { + // don't bother with the thread while unit testing + runLocallyWithinThread(job) + } + } } after { @@ -203,8 +208,6 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter { override def toString = "DAGSchedulerSuite Local RDD" } runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) - // this shouldn't be needed, but i haven't stubbed out runLocally yet - Thread.sleep(500) assert(results === Map(0 -> 42)) }