Skip to content
Snippets Groups Projects
Commit 1418d18a authored by Aaron Davidson's avatar Aaron Davidson
Browse files

SPARK-821: Don't cache results when action run locally on driver

Caching the results of local actions (e.g., rdd.first()) causes the driver to
store entire partitions in its own memory, which may be highly constrained.
This patch simply makes the CacheManager avoid caching the result of all locally-run computations.
parent 714e7f9e
No related branches found
No related tags found
No related merge requests found
...@@ -69,8 +69,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { ...@@ -69,8 +69,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val elements = new ArrayBuffer[Any] val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split) logInfo("Computing partition " + split)
elements ++= rdd.computeOrReadCheckpoint(split, context) elements ++= rdd.computeOrReadCheckpoint(split, context)
// Try to put this block in the blockManager // Persist the result, so long as the task is not running locally
blockManager.put(key, elements, storageLevel, true) if (!context.runningLocally) blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]] return elements.iterator.asInstanceOf[Iterator[T]]
} finally { } finally {
loading.synchronized { loading.synchronized {
......
...@@ -24,6 +24,7 @@ class TaskContext( ...@@ -24,6 +24,7 @@ class TaskContext(
val stageId: Int, val stageId: Int,
val splitId: Int, val splitId: Int,
val attemptId: Long, val attemptId: Long,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty() val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable { ) extends Serializable {
......
...@@ -478,7 +478,7 @@ class DAGScheduler( ...@@ -478,7 +478,7 @@ class DAGScheduler(
SparkEnv.set(env) SparkEnv.set(env)
val rdd = job.finalStage.rdd val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0)) val split = rdd.partitions(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
try { try {
val result = job.func(taskContext, rdd.iterator(split, taskContext)) val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result) job.listener.taskSucceeded(0, result)
......
...@@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable { ...@@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
@Test @Test
public void iterator() { public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContext(0, 0, 0, null); TaskContext context = new TaskContext(0, 0, 0, false, null);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment