diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index e299a106ee6599cd610320c5dfe55864009590c6..68b99ca1253d5f3b64395f55bedc322c9dbb81b2 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -66,10 +66,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { } try { // If we got here, we have to load the split - val elements = new ArrayBuffer[Any] logInfo("Computing partition " + split) - elements ++= rdd.computeOrReadCheckpoint(split, context) - // Try to put this block in the blockManager + val computedValues = rdd.computeOrReadCheckpoint(split, context) + // Persist the result, so long as the task is not running locally + if (context.runningLocally) { return computedValues } + val elements = new ArrayBuffer[Any] + elements ++= computedValues blockManager.put(key, elements, storageLevel, true) return elements.iterator.asInstanceOf[Iterator[T]] } finally { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index b2dd668330a3acdc68a0f84160b09a1bcbbcd073..c2c358c7ad6062826d44f7003c928d6168b7f793 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ class TaskContext( val stageId: Int, val splitId: Int, val attemptId: Long, + val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty() ) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cfcabca0b7fc6475e1cdefa47eec69971c7fe7b2..3e3f04f0876150154d99568f8676308699f14b26 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -478,7 +478,8 @@ class DAGScheduler( SparkEnv.set(env) val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) + val taskContext = + new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 2b007cbe824bc1eca40c3cdd3c9e054567d5f1ff..ca44ebb18951f66e9f3983ac59a25f2e0192f0e9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -93,7 +93,7 @@ private[spark] class ResultTask[T, U]( } override def run(attemptId: Long): U = { - val context = new TaskContext(stageId, partition, attemptId) + val context = new TaskContext(stageId, partition, attemptId, runningLocally = false) metrics = Some(context.taskMetrics) try { func(context, rdd.iterator(split, context)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 764775fedea2ecadd5995a45d0604e68d289c0fd..d23df0dd2b0f198680f7ef7c1857845c10c405b5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -132,7 +132,7 @@ private[spark] class ShuffleMapTask( override def run(attemptId: Long): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions - val taskContext = new TaskContext(stageId, partition, attemptId) + val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false) metrics = Some(taskContext.taskMetrics) val blockManager = SparkEnv.get.blockManager diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..214f8244d5091471802476a4d3ceb7ff10fd7848 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.mock.EasyMockSugar + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{BlockManager, StorageLevel} + +// TODO: Test the CacheManager's thread-safety aspects +class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { + var sc : SparkContext = _ + var blockManager: BlockManager = _ + var cacheManager: CacheManager = _ + var split: Partition = _ + /** An RDD which returns the values [1, 2, 3, 4]. */ + var rdd: RDD[Int] = _ + + before { + sc = new SparkContext("local", "test") + blockManager = mock[BlockManager] + cacheManager = new CacheManager(blockManager) + split = new Partition { override def index: Int = 0 } + rdd = new RDD(sc, Nil) { + override def getPartitions: Array[Partition] = Array(split) + override val getDependencies = List[Dependency[_]]() + override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator + } + } + + after { + sc.stop() + } + + test("get uncached rdd") { + expecting { + blockManager.get("rdd_0_0").andReturn(None) + blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true). + andReturn(0) + } + + whenExecuting(blockManager) { + val context = new TaskContext(0, 0, 0, runningLocally = false, null) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(1, 2, 3, 4)) + } + } + + test("get cached rdd") { + expecting { + blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator)) + } + + whenExecuting(blockManager) { + val context = new TaskContext(0, 0, 0, runningLocally = false, null) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(5, 6, 7)) + } + } + + test("get uncached local rdd") { + expecting { + // Local computation should not persist the resulting value, so don't expect a put(). + blockManager.get("rdd_0_0").andReturn(None) + } + + whenExecuting(blockManager) { + val context = new TaskContext(0, 0, 0, runningLocally = true, null) + val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) + assert(value.toList === List(1, 2, 3, 4)) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 8a869c9005fb61280adfec2550bc3f12253e12b4..591c1d498dc0555d66ccded041b7fcc05503403a 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, null); + TaskContext context = new TaskContext(0, 0, 0, false, null); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); }