From eacb98e90075ca3082ad7c832b24719f322d9eb2 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@cs.berkeley.edu>
Date: Thu, 13 Dec 2012 15:41:53 -0800
Subject: [PATCH] SPARK-635: Pass a TaskContext object to compute() interface
 and use that to close Hadoop input stream.

---
 core/src/main/scala/spark/CacheTracker.scala  | 30 ++++++++++---------
 .../main/scala/spark/PairRDDFunctions.scala   | 15 +++++-----
 .../main/scala/spark/ParallelCollection.scala | 17 ++++++-----
 core/src/main/scala/spark/RDD.scala           |  8 ++---
 core/src/main/scala/spark/TaskContext.scala   | 19 +++++++++++-
 .../scala/spark/api/java/JavaRDDLike.scala    | 25 ++++++++--------
 core/src/main/scala/spark/rdd/BlockRDD.scala  | 25 +++++++---------
 .../main/scala/spark/rdd/CartesianRDD.scala   | 17 +++++------
 .../main/scala/spark/rdd/CoGroupedRDD.scala   | 30 ++++++++-----------
 .../main/scala/spark/rdd/CoalescedRDD.scala   |  9 +++---
 .../main/scala/spark/rdd/FilteredRDD.scala    |  8 ++---
 .../main/scala/spark/rdd/FlatMappedRDD.scala  | 10 +++----
 .../src/main/scala/spark/rdd/GlommedRDD.scala |  8 ++---
 core/src/main/scala/spark/rdd/HadoopRDD.scala | 22 +++++++-------
 .../scala/spark/rdd/MapPartitionsRDD.scala    | 10 +++----
 .../spark/rdd/MapPartitionsWithSplitRDD.scala |  7 ++---
 core/src/main/scala/spark/rdd/MappedRDD.scala |  9 +++---
 .../main/scala/spark/rdd/NewHadoopRDD.scala   | 27 ++++++++---------
 core/src/main/scala/spark/rdd/PipedRDD.scala  | 11 +++----
 .../src/main/scala/spark/rdd/SampledRDD.scala | 15 +++++-----
 .../main/scala/spark/rdd/ShuffledRDD.scala    |  9 ++----
 core/src/main/scala/spark/rdd/UnionRDD.scala  | 22 +++++++-------
 core/src/main/scala/spark/rdd/ZippedRDD.scala | 19 ++++++------
 .../scala/spark/scheduler/DAGScheduler.scala  | 17 ++++++-----
 .../scala/spark/scheduler/ResultTask.scala    |  6 ++--
 .../spark/scheduler/ShuffleMapTask.scala      | 17 +++++++----
 26 files changed, 207 insertions(+), 205 deletions(-)

diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index c5db6ce63a..e9c545a2cf 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -1,5 +1,9 @@
 package spark
 
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
 import akka.actor._
 import akka.dispatch._
 import akka.pattern.ask
@@ -8,10 +12,6 @@ import akka.util.Duration
 import akka.util.Timeout
 import akka.util.duration._
 
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
 import spark.storage.BlockManager
 import spark.storage.StorageLevel
 
@@ -41,7 +41,7 @@ private[spark] class CacheTrackerActor extends Actor with Logging {
   private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
   private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
   private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
-  
+
   def receive = {
     case SlaveCacheStarted(host: String, size: Long) =>
       slaveCapacity.put(host, size)
@@ -92,14 +92,14 @@ private[spark] class CacheTrackerActor extends Actor with Logging {
 
 private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
   extends Logging {
- 
+
   // Tracker actor on the master, or remote reference to it on workers
   val ip: String = System.getProperty("spark.master.host", "localhost")
   val port: Int = System.getProperty("spark.master.port", "7077").toInt
   val actorName: String = "CacheTracker"
 
   val timeout = 10.seconds
-  
+
   var trackerActor: ActorRef = if (isMaster) {
     val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
     logInfo("Registered CacheTrackerActor actor")
@@ -132,7 +132,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
       throw new SparkException("Error reply received from CacheTracker")
     }
   }
-  
+
   // Registers an RDD (on master only)
   def registerRDD(rddId: Int, numPartitions: Int) {
     registeredRddIds.synchronized {
@@ -143,7 +143,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
       }
     }
   }
-  
+
   // For BlockManager.scala only
   def cacheLost(host: String) {
     communicate(MemoryCacheLost(host))
@@ -155,19 +155,21 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
   def getCacheStatus(): Seq[(String, Long, Long)] = {
     askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
   }
-  
+
   // For BlockManager.scala only
   def notifyFromBlockManager(t: AddedToCache) {
     communicate(t)
   }
-  
+
   // Get a snapshot of the currently known locations
   def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
     askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
   }
-  
+
   // Gets or computes an RDD split
-  def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
+  def getOrCompute[T](
+    rdd: RDD[T], split: Split, taskContext: TaskContext, storageLevel: StorageLevel)
+  : Iterator[T] = {
     val key = "rdd_%d_%d".format(rdd.id, split.index)
     logInfo("Cache key is " + key)
     blockManager.get(key) match {
@@ -209,7 +211,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
         // TODO: also register a listener for when it unloads
         logInfo("Computing partition " + split)
         val elements = new ArrayBuffer[Any]
-        elements ++= rdd.compute(split)
+        elements ++= rdd.compute(split, taskContext)
         try {
           // Try to put this block in the blockManager
           blockManager.put(key, elements, storageLevel, true)
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index e5bb639cfd..08ae06e865 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -35,11 +35,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
   with Serializable {
 
   /**
-   * Generic function to combine the elements for each key using a custom set of aggregation 
+   * Generic function to combine the elements for each key using a custom set of aggregation
    * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C
    * Note that V and C can be different -- for example, one might group an RDD of type
    * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions:
-   * 
+   *
    * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
    * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
    * - `mergeCombiners`, to combine two C's into a single one.
@@ -118,7 +118,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
   /** Count the number of elements for each key, and return the result to the master as a Map. */
   def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
 
-  /** 
+  /**
    * (Experimental) Approximate version of countByKey that can return a partial result if it does
    * not finish within a timeout.
    */
@@ -224,7 +224,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
     }
   }
 
-  /** 
+  /**
    * Simplified version of combineByKey that hash-partitions the resulting RDD using the default
    * parallelism level.
    */
@@ -628,7 +628,8 @@ class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)]
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
   override val partitioner = prev.partitioner
-  override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))}
+  override def compute(split: Split, taskContext: TaskContext) =
+    prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))}
 }
 
 private[spark]
@@ -639,8 +640,8 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]
   override val dependencies = List(new OneToOneDependency(prev))
   override val partitioner = prev.partitioner
 
-  override def compute(split: Split) = {
-    prev.iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) }
+  override def compute(split: Split, taskContext: TaskContext) = {
+    prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) }
   }
 }
 
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index 9b57ae3b4f..a27f766e31 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -8,8 +8,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
     val slice: Int,
     values: Seq[T])
   extends Split with Serializable {
-  
-  def iterator(): Iterator[T] = values.iterator
+
+  def iterator: Iterator[T] = values.iterator
 
   override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt
 
@@ -22,7 +22,7 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
 }
 
 private[spark] class ParallelCollection[T: ClassManifest](
-    sc: SparkContext, 
+    sc: SparkContext,
     @transient data: Seq[T],
     numSlices: Int)
   extends RDD[T](sc) {
@@ -38,17 +38,18 @@ private[spark] class ParallelCollection[T: ClassManifest](
 
   override def splits = splits_.asInstanceOf[Array[Split]]
 
-  override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator
-  
+  override def compute(s: Split, taskContext: TaskContext) =
+    s.asInstanceOf[ParallelCollectionSplit[T]].iterator
+
   override def preferredLocations(s: Split): Seq[String] = Nil
-  
+
   override val dependencies: List[Dependency[_]] = Nil
 }
 
 private object ParallelCollection {
   /**
    * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
-   * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes 
+   * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
    * it efficient to run Spark over RDDs representing large sets of numbers.
    */
   def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
@@ -58,7 +59,7 @@ private object ParallelCollection {
     seq match {
       case r: Range.Inclusive => {
         val sign = if (r.step < 0) {
-          -1 
+          -1
         } else {
           1
         }
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 6270e018b3..c53eab67e5 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -81,7 +81,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
   def splits: Array[Split]
 
   /** Function for computing a given partition. */
-  def compute(split: Split): Iterator[T]
+  def compute(split: Split, taskContext: TaskContext): Iterator[T]
 
   /** How this RDD depends on any parent RDDs. */
   @transient val dependencies: List[Dependency[_]]
@@ -155,11 +155,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
    * This should ''not'' be called by users directly, but is available for implementors of custom
    * subclasses of RDD.
    */
-  final def iterator(split: Split): Iterator[T] = {
+  final def iterator(split: Split, taskContext: TaskContext): Iterator[T] = {
     if (storageLevel != StorageLevel.NONE) {
-      SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
+      SparkEnv.get.cacheTracker.getOrCompute[T](this, split, taskContext, storageLevel)
     } else {
-      compute(split)
+      compute(split, taskContext)
     }
   }
 
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index c14377d17b..b352db8167 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -1,3 +1,20 @@
 package spark
 
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable
+import scala.collection.mutable.ArrayBuffer
+
+
+class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
+
+  @transient
+  val onCompleteCallbacks = new ArrayBuffer[Unit => Unit]
+
+  // Add a callback function to be executed on task completion. An example use
+  // is for HadoopRDD to register a callback to close the input stream.
+  def registerOnCompleteCallback(f: Unit => Unit) {
+    onCompleteCallbacks += f
+  }
+
+  def executeOnCompleteCallbacks() {
+    onCompleteCallbacks.foreach{_()}
+  }
+}
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 482eb9281a..81d3a94466 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -1,16 +1,15 @@
 package spark.api.java
 
-import spark.{SparkContext, Split, RDD}
+import java.util.{List => JList}
+import scala.Tuple2
+import scala.collection.JavaConversions._
+
+import spark.{SparkContext, Split, RDD, TaskContext}
 import spark.api.java.JavaPairRDD._
 import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
 import spark.partial.{PartialResult, BoundedDouble}
 import spark.storage.StorageLevel
 
-import java.util.{List => JList}
-
-import scala.collection.JavaConversions._
-import java.{util, lang}
-import scala.Tuple2
 
 trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   def wrapRDD(rdd: RDD[T]): This
@@ -24,7 +23,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
 
   /** The [[spark.SparkContext]] that this RDD was created on. */
   def context: SparkContext = rdd.context
-  
+
   /** A unique ID for this RDD (within its SparkContext). */
   def id: Int = rdd.id
 
@@ -36,7 +35,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    * This should ''not'' be called by users directly, but is available for implementors of custom
    * subclasses of RDD.
    */
-  def iterator(split: Split): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split))
+  def iterator(split: Split, taskContext: TaskContext): java.util.Iterator[T] =
+    asJavaIterator(rdd.iterator(split, taskContext))
 
   // Transformations (return a new RDD)
 
@@ -99,7 +99,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
   }
 
-  
   /**
    * Return a new RDD by applying a function to each partition of this RDD.
    */
@@ -183,7 +182,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   }
 
   // Actions (launch a job to return a value to the user program)
-    
+
   /**
    * Applies a function f to all elements of this RDD.
    */
@@ -200,7 +199,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     val arr: java.util.Collection[T] = rdd.collect().toSeq
     new java.util.ArrayList(arr)
   }
-  
+
   /**
    * Reduces the elements of this RDD using the specified associative binary operator.
    */
@@ -208,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
 
   /**
    * Aggregate the elements of each partition, and then the results for all the partitions, using a
-   * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to 
+   * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
    * modify t1 and return it as its result value to avoid object allocation; however, it should not
    * modify t2.
    */
@@ -251,7 +250,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    * combine step happens locally on the master, equivalent to running a single reduce task.
    */
   def countByValue(): java.util.Map[T, java.lang.Long] =
-    mapAsJavaMap(rdd.countByValue().map((x => (x._1, new lang.Long(x._2)))))
+    mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
 
   /**
    * (Experimental) Approximate version of countByValue().
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index cb73976aed..8209c36871 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -2,11 +2,8 @@ package spark.rdd
 
 import scala.collection.mutable.HashMap
 
-import spark.Dependency
-import spark.RDD
-import spark.SparkContext
-import spark.SparkEnv
-import spark.Split
+import spark.{Dependency, RDD, SparkContext, SparkEnv, Split, TaskContext}
+
 
 private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
   val index = idx
@@ -19,29 +16,29 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
   @transient
   val splits_ = (0 until blockIds.size).map(i => {
     new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
-  }).toArray 
-  
-  @transient 
+  }).toArray
+
+  @transient
   lazy val locations_  = {
-    val blockManager = SparkEnv.get.blockManager 
+    val blockManager = SparkEnv.get.blockManager
     /*val locations = blockIds.map(id => blockManager.getLocations(id))*/
-    val locations = blockManager.getLocations(blockIds) 
+    val locations = blockManager.getLocations(blockIds)
     HashMap(blockIds.zip(locations):_*)
   }
 
   override def splits = splits_
 
-  override def compute(split: Split): Iterator[T] = {
-    val blockManager = SparkEnv.get.blockManager 
+  override def compute(split: Split, taskContext: TaskContext): Iterator[T] = {
+    val blockManager = SparkEnv.get.blockManager
     val blockId = split.asInstanceOf[BlockRDDSplit].blockId
     blockManager.get(blockId) match {
       case Some(block) => block.asInstanceOf[Iterator[T]]
-      case None => 
+      case None =>
         throw new Exception("Could not compute split, block " + blockId + " not found")
     }
   }
 
-  override def preferredLocations(split: Split) = 
+  override def preferredLocations(split: Split) =
     locations_(split.asInstanceOf[BlockRDDSplit].blockId)
 
   override val dependencies: List[Dependency[_]] = Nil
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 7c354b6b2e..6bc0938ce2 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,9 +1,7 @@
 package spark.rdd
 
-import spark.NarrowDependency
-import spark.RDD
-import spark.SparkContext
-import spark.Split
+import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext}
+
 
 private[spark]
 class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
@@ -17,9 +15,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
     rdd2: RDD[U])
   extends RDD[Pair[T, U]](sc)
   with Serializable {
-  
+
   val numSplitsInRdd2 = rdd2.splits.size
-  
+
   @transient
   val splits_ = {
     // create the cross product split
@@ -38,11 +36,12 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
     rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
   }
 
-  override def compute(split: Split) = {
+  override def compute(split: Split, taskContext: TaskContext) = {
     val currSplit = split.asInstanceOf[CartesianSplit]
-    for (x <- rdd1.iterator(currSplit.s1); y <- rdd2.iterator(currSplit.s2)) yield (x, y)
+    for (x <- rdd1.iterator(currSplit.s1, taskContext);
+      y <- rdd2.iterator(currSplit.s2, taskContext)) yield (x, y)
   }
-  
+
   override val dependencies = List(
     new NarrowDependency(rdd1) {
       def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 50bec9e63b..6037681cfd 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -3,21 +3,15 @@ package spark.rdd
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 
-import spark.Aggregator
-import spark.Dependency
-import spark.Logging
-import spark.OneToOneDependency
-import spark.Partitioner
-import spark.RDD
-import spark.ShuffleDependency
-import spark.SparkEnv
-import spark.Split
+import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
+import spark.{Dependency, OneToOneDependency, ShuffleDependency}
+
 
 private[spark] sealed trait CoGroupSplitDep extends Serializable
 private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
 private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
 
-private[spark] 
+private[spark]
 class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
   override val index: Int = idx
   override def hashCode(): Int = idx
@@ -32,9 +26,9 @@ private[spark] class CoGroupAggregator
 
 class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
   extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
-  
+
   val aggr = new CoGroupAggregator
-  
+
   @transient
   override val dependencies = {
     val deps = new ArrayBuffer[Dependency[_]]
@@ -50,7 +44,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
     }
     deps.toList
   }
-  
+
   @transient
   val splits_ : Array[Split] = {
     val firstRdd = rdds.head
@@ -69,12 +63,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
   }
 
   override def splits = splits_
-  
+
   override val partitioner = Some(part)
-  
+
   override def preferredLocations(s: Split) = Nil
-  
-  override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
+
+  override def compute(s: Split, taskContext: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
     val split = s.asInstanceOf[CoGroupSplit]
     val numRdds = split.deps.size
     val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
@@ -84,7 +78,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
     for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
       case NarrowCoGroupSplitDep(rdd, itsSplit) => {
         // Read them from the parent
-        for ((k, v) <- rdd.iterator(itsSplit)) {
+        for ((k, v) <- rdd.iterator(itsSplit, taskContext)) {
           getSeq(k.asInstanceOf[K])(depNum) += v
         }
       }
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 0967f4f5df..06ffc9c42c 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,8 +1,7 @@
 package spark.rdd
 
-import spark.NarrowDependency
-import spark.RDD
-import spark.Split
+import spark.{NarrowDependency, RDD, Split, TaskContext}
+
 
 private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
 
@@ -32,9 +31,9 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
 
   override def splits = splits_
 
-  override def compute(split: Split): Iterator[T] = {
+  override def compute(split: Split, taskContext: TaskContext): Iterator[T] = {
     split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
-      parentSplit => prev.iterator(parentSplit)
+      parentSplit => prev.iterator(parentSplit, taskContext)
     }
   }
 
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index dfe9dc73f3..14a80d82c7 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -1,12 +1,12 @@
 package spark.rdd
 
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
+
 
 private[spark]
 class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split) = prev.iterator(split).filter(f)
+  override def compute(split: Split, taskContext: TaskContext) =
+    prev.iterator(split, taskContext).filter(f)
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 3534dc8057..64f8c51d6d 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -1,16 +1,16 @@
 package spark.rdd
 
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
 
 private[spark]
 class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
     prev: RDD[T],
     f: T => TraversableOnce[U])
   extends RDD[U](prev.context) {
-  
+
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split) = prev.iterator(split).flatMap(f)
+
+  override def compute(split: Split, taskContext: TaskContext) =
+    prev.iterator(split, taskContext).flatMap(f)
 }
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index e30564f2da..d6b1b27d3e 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -1,12 +1,12 @@
 package spark.rdd
 
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
+
 
 private[spark]
 class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator
+  override def compute(split: Split, taskContext: TaskContext) =
+    Array(prev.iterator(split, taskContext).toArray).iterator
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index bf29a1f075..c6c035a096 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -15,19 +15,16 @@ import org.apache.hadoop.mapred.RecordReader
 import org.apache.hadoop.mapred.Reporter
 import org.apache.hadoop.util.ReflectionUtils
 
-import spark.Dependency
-import spark.RDD
-import spark.SerializableWritable
-import spark.SparkContext
-import spark.Split
+import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
 
-/** 
+
+/**
  * A Spark split class that wraps around a Hadoop InputSplit.
  */
 private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
   extends Split
   with Serializable {
-  
+
   val inputSplit = new SerializableWritable[InputSplit](s)
 
   override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -47,10 +44,10 @@ class HadoopRDD[K, V](
     valueClass: Class[V],
     minSplits: Int)
   extends RDD[(K, V)](sc) {
-  
+
   // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
   val confBroadcast = sc.broadcast(new SerializableWritable(conf))
-  
+
   @transient
   val splits_ : Array[Split] = {
     val inputFormat = createInputFormat(conf)
@@ -69,7 +66,7 @@ class HadoopRDD[K, V](
 
   override def splits = splits_
 
-  override def compute(theSplit: Split) = new Iterator[(K, V)] {
+  override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] {
     val split = theSplit.asInstanceOf[HadoopSplit]
     var reader: RecordReader[K, V] = null
 
@@ -77,6 +74,9 @@ class HadoopRDD[K, V](
     val fmt = createInputFormat(conf)
     reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
 
+    // Register an on-task-completion callback to close the input stream.
+    taskContext.registerOnCompleteCallback(Unit => reader.close())
+
     val key: K = reader.createKey()
     val value: V = reader.createValue()
     var gotNext = false
@@ -115,6 +115,6 @@ class HadoopRDD[K, V](
     val hadoopSplit = split.asInstanceOf[HadoopSplit]
     hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
   }
-  
+
   override val dependencies: List[Dependency[_]] = Nil
 }
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index a904ef62c3..715c240060 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -1,8 +1,7 @@
 package spark.rdd
 
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
+
 
 private[spark]
 class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
@@ -12,8 +11,9 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
   extends RDD[U](prev.context) {
 
   override val partitioner = if (preservesPartitioning) prev.partitioner else None
-  
+
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split) = f(prev.iterator(split))
+  override def compute(split: Split, taskContext: TaskContext) =
+    f(prev.iterator(split, taskContext))
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index 14e390c43b..39f3c7b5f7 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -1,8 +1,6 @@
 package spark.rdd
 
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
 
 /**
  * A variant of the MapPartitionsRDD that passes the split index into the
@@ -19,5 +17,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
   override val partitioner = if (preservesPartitioning) prev.partitioner else None
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split) = f(split.index, prev.iterator(split))
+  override def compute(split: Split, taskContext: TaskContext) =
+    f(split.index, prev.iterator(split, taskContext))
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 59bedad8ef..d82ab3f671 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -1,16 +1,15 @@
 package spark.rdd
 
-import spark.OneToOneDependency
-import spark.RDD
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
 
 private[spark]
 class MappedRDD[U: ClassManifest, T: ClassManifest](
     prev: RDD[T],
     f: T => U)
   extends RDD[U](prev.context) {
-  
+
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split) = prev.iterator(split).map(f)
+  override def compute(split: Split, taskContext: TaskContext) =
+    prev.iterator(split, taskContext).map(f)
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 7a1a0fb87d..61f4cbbe94 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -1,22 +1,19 @@
 package spark.rdd
 
+import java.text.SimpleDateFormat
+import java.util.Date
+
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.io.Writable
 import org.apache.hadoop.mapreduce._
 
-import java.util.Date
-import java.text.SimpleDateFormat
+import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
 
-import spark.Dependency
-import spark.RDD
-import spark.SerializableWritable
-import spark.SparkContext
-import spark.Split
 
-private[spark] 
+private[spark]
 class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
   extends Split {
-  
+
   val serializableHadoopSplit = new SerializableWritable(rawSplit)
 
   override def hashCode(): Int = (41 * (41 + rddId) + index)
@@ -29,7 +26,7 @@ class NewHadoopRDD[K, V](
     @transient conf: Configuration)
   extends RDD[(K, V)](sc)
   with HadoopMapReduceUtil {
-  
+
   // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
   val confBroadcast = sc.broadcast(new SerializableWritable(conf))
   // private val serializableConf = new SerializableWritable(conf)
@@ -56,7 +53,7 @@ class NewHadoopRDD[K, V](
 
   override def splits = splits_
 
-  override def compute(theSplit: Split) = new Iterator[(K, V)] {
+  override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] {
     val split = theSplit.asInstanceOf[NewHadoopSplit]
     val conf = confBroadcast.value.value
     val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
@@ -64,7 +61,10 @@ class NewHadoopRDD[K, V](
     val format = inputFormatClass.newInstance
     val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
     reader.initialize(split.serializableHadoopSplit.value, context)
-   
+
+    // Register an on-task-completion callback to close the input stream.
+    taskContext.registerOnCompleteCallback(Unit => reader.close())
+
     var havePair = false
     var finished = false
 
@@ -72,9 +72,6 @@ class NewHadoopRDD[K, V](
       if (!finished && !havePair) {
         finished = !reader.nextKeyValue
         havePair = !finished
-        if (finished) {
-          reader.close()
-        }
       }
       !finished
     }
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 98ea0c92d6..b34c7ea5b9 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -8,10 +8,7 @@ import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 
-import spark.OneToOneDependency
-import spark.RDD
-import spark.SparkEnv
-import spark.Split
+import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
 
 
 /**
@@ -32,12 +29,12 @@ class PipedRDD[T: ClassManifest](
 
   override val dependencies = List(new OneToOneDependency(parent))
 
-  override def compute(split: Split): Iterator[String] = {
+  override def compute(split: Split, taskContext: TaskContext): Iterator[String] = {
     val pb = new ProcessBuilder(command)
     // Add the environmental variables to the process.
     val currentEnvVars = pb.environment()
     envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) }
-    
+
     val proc = pb.start()
     val env = SparkEnv.get
 
@@ -55,7 +52,7 @@ class PipedRDD[T: ClassManifest](
       override def run() {
         SparkEnv.set(env)
         val out = new PrintWriter(proc.getOutputStream)
-        for (elem <- parent.iterator(split)) {
+        for (elem <- parent.iterator(split, taskContext)) {
           out.println(elem)
         }
         out.close()
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 87a5268f27..07a1487f3a 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -4,9 +4,8 @@ import java.util.Random
 import cern.jet.random.Poisson
 import cern.jet.random.engine.DRand
 
-import spark.RDD
-import spark.OneToOneDependency
-import spark.Split
+import spark.{OneToOneDependency, RDD, Split, TaskContext}
+
 
 private[spark]
 class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
@@ -15,7 +14,7 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali
 
 class SampledRDD[T: ClassManifest](
     prev: RDD[T],
-    withReplacement: Boolean, 
+    withReplacement: Boolean,
     frac: Double,
     seed: Int)
   extends RDD[T](prev.context) {
@@ -29,17 +28,17 @@ class SampledRDD[T: ClassManifest](
   override def splits = splits_.asInstanceOf[Array[Split]]
 
   override val dependencies = List(new OneToOneDependency(prev))
-  
+
   override def preferredLocations(split: Split) =
     prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
 
-  override def compute(splitIn: Split) = {
+  override def compute(splitIn: Split, taskContext: TaskContext) = {
     val split = splitIn.asInstanceOf[SampledRDDSplit]
     if (withReplacement) {
       // For large datasets, the expected number of occurrences of each element in a sample with
       // replacement is Poisson(frac). We use that to get a count for each element.
       val poisson = new Poisson(frac, new DRand(split.seed))
-      prev.iterator(split.prev).flatMap { element =>
+      prev.iterator(split.prev, taskContext).flatMap { element =>
         val count = poisson.nextInt()
         if (count == 0) {
           Iterator.empty  // Avoid object allocation when we return 0 items, which is quite often
@@ -49,7 +48,7 @@ class SampledRDD[T: ClassManifest](
       }
     } else { // Sampling without replacement
       val rand = new Random(split.seed)
-      prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac))
+      prev.iterator(split.prev, taskContext).filter(x => (rand.nextDouble <= frac))
     }
   }
 }
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index 145e419c53..c736e92117 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,10 +1,7 @@
 package spark.rdd
 
-import spark.Partitioner
-import spark.RDD
-import spark.ShuffleDependency
-import spark.SparkEnv
-import spark.Split
+import spark.{OneToOneDependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
+
 
 private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
   override val index = idx
@@ -34,7 +31,7 @@ class ShuffledRDD[K, V](
   val dep = new ShuffleDependency(parent, part)
   override val dependencies = List(dep)
 
-  override def compute(split: Split): Iterator[(K, V)] = {
+  override def compute(split: Split, taskContext: TaskContext): Iterator[(K, V)] = {
     SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index)
   }
 }
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index f0b9225f7c..4b9cab8774 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -2,20 +2,17 @@ package spark.rdd
 
 import scala.collection.mutable.ArrayBuffer
 
-import spark.Dependency
-import spark.RangeDependency
-import spark.RDD
-import spark.SparkContext
-import spark.Split
+import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+
 
 private[spark] class UnionSplit[T: ClassManifest](
-    idx: Int, 
+    idx: Int,
     rdd: RDD[T],
     split: Split)
   extends Split
   with Serializable {
-  
-  def iterator() = rdd.iterator(split)
+
+  def iterator(taskContext: TaskContext) = rdd.iterator(split, taskContext)
   def preferredLocations() = rdd.preferredLocations(split)
   override val index: Int = idx
 }
@@ -25,7 +22,7 @@ class UnionRDD[T: ClassManifest](
     @transient rdds: Seq[RDD[T]])
   extends RDD[T](sc)
   with Serializable {
-  
+
   @transient
   val splits_ : Array[Split] = {
     val array = new Array[Split](rdds.map(_.splits.size).sum)
@@ -44,13 +41,14 @@ class UnionRDD[T: ClassManifest](
     val deps = new ArrayBuffer[Dependency[_]]
     var pos = 0
     for (rdd <- rdds) {
-      deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) 
+      deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
       pos += rdd.splits.size
     }
     deps.toList
   }
-  
-  override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator()
+
+  override def compute(s: Split, taskContext: TaskContext): Iterator[T] =
+    s.asInstanceOf[UnionSplit[T]].iterator(taskContext)
 
   override def preferredLocations(s: Split): Seq[String] =
     s.asInstanceOf[UnionSplit[T]].preferredLocations()
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 80f0150c45..b987ca5fdf 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,21 +1,19 @@
 package spark.rdd
 
-import spark.Dependency
-import spark.OneToOneDependency
-import spark.RDD
-import spark.SparkContext
-import spark.Split
+import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
+
 
 private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
-    idx: Int, 
+    idx: Int,
     rdd1: RDD[T],
     rdd2: RDD[U],
     split1: Split,
     split2: Split)
   extends Split
   with Serializable {
-  
-  def iterator(): Iterator[(T, U)] = rdd1.iterator(split1).zip(rdd2.iterator(split2))
+
+  def iterator(taskContext: TaskContext): Iterator[(T, U)] =
+    rdd1.iterator(split1, taskContext).zip(rdd2.iterator(split2, taskContext))
 
   def preferredLocations(): Seq[String] =
     rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
@@ -46,8 +44,9 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
 
   @transient
   override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
-  
-  override def compute(s: Split): Iterator[(T, U)] = s.asInstanceOf[ZippedSplit[T, U]].iterator()
+
+  override def compute(s: Split, taskContext: TaskContext): Iterator[(T, U)] =
+    s.asInstanceOf[ZippedSplit[T, U]].iterator(taskContext)
 
   override def preferredLocations(s: Split): Seq[String] =
     s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 5c71207d43..29757b1178 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -16,8 +16,8 @@ import spark.storage.BlockManagerMaster
 import spark.storage.BlockManagerId
 
 /**
- * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for 
- * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal 
+ * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
+ * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
  * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
  * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
  */
@@ -73,7 +73,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
 
   val deadHosts = new HashSet[String]  // TODO: The code currently assumes these can't come back;
                                        // that's not going to be a realistic assumption in general
-  
+
   val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
   val running = new HashSet[Stage] // Stages we are running right now
   val failed = new HashSet[Stage]  // Stages that must be resubmitted due to fetch failures
@@ -94,7 +94,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
   def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
     cacheLocs(rdd.id)
   }
-  
+
   def updateCacheLocs() {
     cacheLocs = cacheTracker.getLocationsSnapshot()
   }
@@ -326,7 +326,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
           val rdd = job.finalStage.rdd
           val split = rdd.splits(job.partitions(0))
           val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
-          val result = job.func(taskContext, rdd.iterator(split))
+          val result = job.func(taskContext, rdd.iterator(split, taskContext))
+          taskContext.executeOnCompleteCallbacks()
           job.listener.taskSucceeded(0, result)
         } catch {
           case e: Exception =>
@@ -353,7 +354,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
       }
     }
   }
-  
+
   def submitMissingTasks(stage: Stage) {
     logDebug("submitMissingTasks(" + stage + ")")
     // Get our pending tasks and remember them in our pendingTasks entry
@@ -395,7 +396,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
     val task = event.task
     val stage = idToStage(task.stageId)
     event.reason match {
-      case Success =>  
+      case Success =>
         logInfo("Completed " + task)
         if (event.accumUpdates != null) {
           Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
@@ -519,7 +520,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
       updateCacheLocs()
     }
   }
-  
+
   /**
    * Aborts all jobs depending on a particular Stage. This is called in response to a task set
    * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 2ebd4075a2..e492279b4e 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -10,12 +10,14 @@ private[spark] class ResultTask[T, U](
     @transient locs: Seq[String],
     val outputId: Int)
   extends Task[U](stageId) {
-  
+
   val split = rdd.splits(partition)
 
   override def run(attemptId: Long): U = {
     val context = new TaskContext(stageId, partition, attemptId)
-    func(context, rdd.iterator(split))
+    val result = func(context, rdd.iterator(split, context))
+    context.executeOnCompleteCallbacks()
+    result
   }
 
   override def preferredLocations: Seq[String] = locs
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 60105c42b6..bd1911fce2 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -70,19 +70,19 @@ private[spark] object ShuffleMapTask {
 
 private[spark] class ShuffleMapTask(
     stageId: Int,
-    var rdd: RDD[_], 
+    var rdd: RDD[_],
     var dep: ShuffleDependency[_,_],
-    var partition: Int, 
+    var partition: Int,
     @transient var locs: Seq[String])
   extends Task[MapStatus](stageId)
   with Externalizable
   with Logging {
 
   def this() = this(0, null, null, 0, null)
-  
+
   var split = if (rdd == null) {
-    null 
-  } else { 
+    null
+  } else {
     rdd.splits(partition)
   }
 
@@ -113,9 +113,11 @@ private[spark] class ShuffleMapTask(
     val numOutputSplits = dep.partitioner.numPartitions
     val partitioner = dep.partitioner
 
+    val taskContext = new TaskContext(stageId, partition, attemptId)
+
     // Partition the map output.
     val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
-    for (elem <- rdd.iterator(split)) {
+    for (elem <- rdd.iterator(split, taskContext)) {
       val pair = elem.asInstanceOf[(Any, Any)]
       val bucketId = partitioner.getPartition(pair._1)
       buckets(bucketId) += pair
@@ -133,6 +135,9 @@ private[spark] class ShuffleMapTask(
       compressedSizes(i) = MapOutputTracker.compressSize(size)
     }
 
+    // Execute the callbacks on task completion.
+    taskContext.executeOnCompleteCallbacks()
+
     return new MapStatus(blockManager.blockManagerId, compressedSizes)
   }
 
-- 
GitLab