From 1418d18af43229b442d3ed747fdb8088d4fa5b6f Mon Sep 17 00:00:00 2001 From: Aaron Davidson <aaron@databricks.com> Date: Thu, 5 Sep 2013 15:28:14 -0700 Subject: [PATCH] SPARK-821: Don't cache results when action run locally on driver Caching the results of local actions (e.g., rdd.first()) causes the driver to store entire partitions in its own memory, which may be highly constrained. This patch simply makes the CacheManager avoid caching the result of all locally-run computations. --- core/src/main/scala/org/apache/spark/CacheManager.scala | 4 ++-- core/src/main/scala/org/apache/spark/TaskContext.scala | 1 + .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- core/src/test/scala/org/apache/spark/JavaAPISuite.java | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index e299a106ee..a6f701b880 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -69,8 +69,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val elements = new ArrayBuffer[Any] logInfo("Computing partition " + split) elements ++= rdd.computeOrReadCheckpoint(split, context) - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) + // Persist the result, so long as the task is not running locally + if (!context.runningLocally) blockManager.put(key, elements, storageLevel, true) return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index b2dd668330..c2c358c7ad 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ class TaskContext( val stageId: Int, val splitId: Int, val attemptId: Long, + val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty() ) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 92add5b073..b739118e2f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -478,7 +478,7 @@ class DAGScheduler( SparkEnv.set(env) val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) + val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0, true) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 8a869c9005..591c1d498d 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, null); + TaskContext context = new TaskContext(0, 0, 0, false, null); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); } -- GitLab