diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index e299a106ee6599cd610320c5dfe55864009590c6..a6f701b8803609062e00c33a2a889bc874bd5410 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -69,8 +69,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
           val elements = new ArrayBuffer[Any]
           logInfo("Computing partition " + split)
           elements ++= rdd.computeOrReadCheckpoint(split, context)
-          // Try to put this block in the blockManager
-          blockManager.put(key, elements, storageLevel, true)
+          // Persist the result, so long as the task is not running locally
+          if (!context.runningLocally) blockManager.put(key, elements, storageLevel, true)
           return elements.iterator.asInstanceOf[Iterator[T]]
         } finally {
           loading.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index b2dd668330a3acdc68a0f84160b09a1bcbbcd073..c2c358c7ad6062826d44f7003c928d6168b7f793 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -24,6 +24,7 @@ class TaskContext(
   val stageId: Int,
   val splitId: Int,
   val attemptId: Long,
+  val runningLocally: Boolean = false,
   val taskMetrics: TaskMetrics = TaskMetrics.empty()
 ) extends Serializable {
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 92add5b073ab0ccf99a8928d3774705d41307ed6..b739118e2f6e043768e67daa2237bb11a7fc2078 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -478,7 +478,7 @@ class DAGScheduler(
       SparkEnv.set(env)
       val rdd = job.finalStage.rdd
       val split = rdd.partitions(job.partitions(0))
-      val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
+      val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
       try {
         val result = job.func(taskContext, rdd.iterator(split, taskContext))
         job.listener.taskSucceeded(0, result)
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 8a869c9005fb61280adfec2550bc3f12253e12b4..591c1d498dc0555d66ccded041b7fcc05503403a 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
   @Test
   public void iterator() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
-    TaskContext context = new TaskContext(0, 0, 0, null);
+    TaskContext context = new TaskContext(0, 0, 0, false, null);
     Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
   }