diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index e299a106ee6599cd610320c5dfe55864009590c6..68b99ca1253d5f3b64395f55bedc322c9dbb81b2 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -66,10 +66,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
         }
         try {
           // If we got here, we have to load the split
-          val elements = new ArrayBuffer[Any]
           logInfo("Computing partition " + split)
-          elements ++= rdd.computeOrReadCheckpoint(split, context)
-          // Try to put this block in the blockManager
+          val computedValues = rdd.computeOrReadCheckpoint(split, context)
+          // Persist the result, so long as the task is not running locally
+          if (context.runningLocally) { return computedValues }
+          val elements = new ArrayBuffer[Any]
+          elements ++= computedValues
           blockManager.put(key, elements, storageLevel, true)
           return elements.iterator.asInstanceOf[Iterator[T]]
         } finally {
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index b2dd668330a3acdc68a0f84160b09a1bcbbcd073..c2c358c7ad6062826d44f7003c928d6168b7f793 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -24,6 +24,7 @@ class TaskContext(
   val stageId: Int,
   val splitId: Int,
   val attemptId: Long,
+  val runningLocally: Boolean = false,
   val taskMetrics: TaskMetrics = TaskMetrics.empty()
 ) extends Serializable {
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index cfcabca0b7fc6475e1cdefa47eec69971c7fe7b2..3e3f04f0876150154d99568f8676308699f14b26 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -478,7 +478,8 @@ class DAGScheduler(
       SparkEnv.set(env)
       val rdd = job.finalStage.rdd
       val split = rdd.partitions(job.partitions(0))
-      val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
+      val taskContext =
+        new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
       try {
         val result = job.func(taskContext, rdd.iterator(split, taskContext))
         job.listener.taskSucceeded(0, result)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 2b007cbe824bc1eca40c3cdd3c9e054567d5f1ff..ca44ebb18951f66e9f3983ac59a25f2e0192f0e9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -93,7 +93,7 @@ private[spark] class ResultTask[T, U](
   }
 
   override def run(attemptId: Long): U = {
-    val context = new TaskContext(stageId, partition, attemptId)
+    val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
     metrics = Some(context.taskMetrics)
     try {
       func(context, rdd.iterator(split, context))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 764775fedea2ecadd5995a45d0604e68d289c0fd..d23df0dd2b0f198680f7ef7c1857845c10c405b5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -132,7 +132,7 @@ private[spark] class ShuffleMapTask(
   override def run(attemptId: Long): MapStatus = {
     val numOutputSplits = dep.partitioner.numPartitions
 
-    val taskContext = new TaskContext(stageId, partition, attemptId)
+    val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
     metrics = Some(taskContext.taskMetrics)
 
     val blockManager = SparkEnv.get.blockManager
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..214f8244d5091471802476a4d3ceb7ff10fd7848
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.mock.EasyMockSugar
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{BlockManager, StorageLevel}
+
+// TODO: Test the CacheManager's thread-safety aspects
+class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
+  var sc : SparkContext = _
+  var blockManager: BlockManager = _
+  var cacheManager: CacheManager = _
+  var split: Partition = _
+  /** An RDD which returns the values [1, 2, 3, 4]. */
+  var rdd: RDD[Int] = _
+
+  before {
+    sc = new SparkContext("local", "test")
+    blockManager = mock[BlockManager]
+    cacheManager = new CacheManager(blockManager)
+    split = new Partition { override def index: Int = 0 }
+    rdd = new RDD(sc, Nil) {
+      override def getPartitions: Array[Partition] = Array(split)
+      override val getDependencies = List[Dependency[_]]()
+      override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator
+    }
+  }
+
+  after {
+    sc.stop()
+  }
+
+  test("get uncached rdd") {
+    expecting {
+      blockManager.get("rdd_0_0").andReturn(None)
+      blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true).
+        andReturn(0)
+    }
+
+    whenExecuting(blockManager) {
+      val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+      val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
+      assert(value.toList === List(1, 2, 3, 4))
+    }
+  }
+
+  test("get cached rdd") {
+    expecting {
+      blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
+    }
+
+    whenExecuting(blockManager) {
+      val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+      val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
+      assert(value.toList === List(5, 6, 7))
+    }
+  }
+
+  test("get uncached local rdd") {
+    expecting {
+      // Local computation should not persist the resulting value, so don't expect a put().
+      blockManager.get("rdd_0_0").andReturn(None)
+    }
+
+    whenExecuting(blockManager) {
+      val context = new TaskContext(0, 0, 0, runningLocally = true, null)
+      val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
+      assert(value.toList === List(1, 2, 3, 4))
+    }
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 8a869c9005fb61280adfec2550bc3f12253e12b4..591c1d498dc0555d66ccded041b7fcc05503403a 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
   @Test
   public void iterator() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
-    TaskContext context = new TaskContext(0, 0, 0, null);
+    TaskContext context = new TaskContext(0, 0, 0, false, null);
     Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
   }