diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index e299a106ee6599cd610320c5dfe55864009590c6..a6f701b8803609062e00c33a2a889bc874bd5410 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -69,8 +69,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val elements = new ArrayBuffer[Any] logInfo("Computing partition " + split) elements ++= rdd.computeOrReadCheckpoint(split, context) - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) + // Persist the result, so long as the task is not running locally + if (!context.runningLocally) blockManager.put(key, elements, storageLevel, true) return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index b2dd668330a3acdc68a0f84160b09a1bcbbcd073..c2c358c7ad6062826d44f7003c928d6168b7f793 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ class TaskContext( val stageId: Int, val splitId: Int, val attemptId: Long, + val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty() ) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 92add5b073ab0ccf99a8928d3774705d41307ed6..b739118e2f6e043768e67daa2237bb11a7fc2078 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -478,7 +478,7 @@ class DAGScheduler( SparkEnv.set(env) val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) + val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0, true) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 8a869c9005fb61280adfec2550bc3f12253e12b4..591c1d498dc0555d66ccded041b7fcc05503403a 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, null); + TaskContext context = new TaskContext(0, 0, 0, false, null); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); }