Skip to content
Snippets Groups Projects
Commit ddcb9d31 authored by Patrick Wendell's avatar Patrick Wendell
Browse files

Merge pull request #895 from ilikerps/821

SPARK-821: Don't cache results when action run locally on driver
parents 699c331f 3a04e76c
No related branches found
No related tags found
No related merge requests found
......@@ -66,10 +66,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
}
try {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
elements ++= rdd.computeOrReadCheckpoint(split, context)
// Try to put this block in the blockManager
val computedValues = rdd.computeOrReadCheckpoint(split, context)
// Persist the result, so long as the task is not running locally
if (context.runningLocally) { return computedValues }
val elements = new ArrayBuffer[Any]
elements ++= computedValues
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
......
......@@ -24,6 +24,7 @@ class TaskContext(
val stageId: Int,
val splitId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
......
......@@ -478,7 +478,8 @@ class DAGScheduler(
SparkEnv.set(env)
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
val taskContext =
new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
......
......@@ -93,7 +93,7 @@ private[spark] class ResultTask[T, U](
}
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
......
......@@ -132,7 +132,7 @@ private[spark] class ShuffleMapTask(
override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
val taskContext = new TaskContext(stageId, partition, attemptId)
val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
metrics = Some(taskContext.taskMetrics)
val blockManager = SparkEnv.get.blockManager
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import scala.collection.mutable.ArrayBuffer
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.mock.EasyMockSugar
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockManager, StorageLevel}
// TODO: Test the CacheManager's thread-safety aspects
class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
var sc : SparkContext = _
var blockManager: BlockManager = _
var cacheManager: CacheManager = _
var split: Partition = _
/** An RDD which returns the values [1, 2, 3, 4]. */
var rdd: RDD[Int] = _
before {
sc = new SparkContext("local", "test")
blockManager = mock[BlockManager]
cacheManager = new CacheManager(blockManager)
split = new Partition { override def index: Int = 0 }
rdd = new RDD(sc, Nil) {
override def getPartitions: Array[Partition] = Array(split)
override val getDependencies = List[Dependency[_]]()
override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator
}
}
after {
sc.stop()
}
test("get uncached rdd") {
expecting {
blockManager.get("rdd_0_0").andReturn(None)
blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true).
andReturn(0)
}
whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, runningLocally = false, null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
}
test("get cached rdd") {
expecting {
blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
}
whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, runningLocally = false, null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
}
test("get uncached local rdd") {
expecting {
// Local computation should not persist the resulting value, so don't expect a put().
blockManager.get("rdd_0_0").andReturn(None)
}
whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, runningLocally = true, null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
}
}
......@@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContext(0, 0, 0, null);
TaskContext context = new TaskContext(0, 0, 0, false, null);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment