From 4f076e105ee30edcb1941216c79d017c5175d9b8 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@cs.berkeley.edu>
Date: Thu, 13 Dec 2012 16:41:15 -0800
Subject: [PATCH] SPARK-635: Pass a TaskContext object to compute() interface
 and use that to close Hadoop input stream. Incorporated Matei's command.

---
 core/src/main/scala/spark/CacheTracker.scala          |  5 ++---
 core/src/main/scala/spark/RDD.scala                   |  8 ++++----
 core/src/main/scala/spark/TaskContext.scala           |  4 ++--
 core/src/main/scala/spark/rdd/BlockRDD.scala          |  2 +-
 core/src/main/scala/spark/rdd/CartesianRDD.scala      |  6 +++---
 core/src/main/scala/spark/rdd/CoGroupedRDD.scala      |  4 ++--
 core/src/main/scala/spark/rdd/CoalescedRDD.scala      |  4 ++--
 core/src/main/scala/spark/rdd/FilteredRDD.scala       |  3 +--
 core/src/main/scala/spark/rdd/FlatMappedRDD.scala     |  4 ++--
 core/src/main/scala/spark/rdd/GlommedRDD.scala        |  4 ++--
 core/src/main/scala/spark/rdd/HadoopRDD.scala         |  4 ++--
 core/src/main/scala/spark/rdd/MapPartitionsRDD.scala  |  3 +--
 .../scala/spark/rdd/MapPartitionsWithSplitRDD.scala   |  4 ++--
 core/src/main/scala/spark/rdd/MappedRDD.scala         |  3 +--
 core/src/main/scala/spark/rdd/NewHadoopRDD.scala      | 11 ++++++-----
 core/src/main/scala/spark/rdd/PipedRDD.scala          |  4 ++--
 core/src/main/scala/spark/rdd/SampledRDD.scala        |  6 +++---
 core/src/main/scala/spark/rdd/ShuffledRDD.scala       |  2 +-
 core/src/main/scala/spark/rdd/UnionRDD.scala          |  6 +++---
 core/src/main/scala/spark/rdd/ZippedRDD.scala         |  8 ++++----
 20 files changed, 46 insertions(+), 49 deletions(-)

diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index e9c545a2cf..3d79078733 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -167,8 +167,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
   }
 
   // Gets or computes an RDD split
-  def getOrCompute[T](
-    rdd: RDD[T], split: Split, taskContext: TaskContext, storageLevel: StorageLevel)
+  def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
   : Iterator[T] = {
     val key = "rdd_%d_%d".format(rdd.id, split.index)
     logInfo("Cache key is " + key)
@@ -211,7 +210,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
         // TODO: also register a listener for when it unloads
         logInfo("Computing partition " + split)
         val elements = new ArrayBuffer[Any]
-        elements ++= rdd.compute(split, taskContext)
+        elements ++= rdd.compute(split, context)
         try {
           // Try to put this block in the blockManager
           blockManager.put(key, elements, storageLevel, true)
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index c53eab67e5..bb4c13c494 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -81,7 +81,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
   def splits: Array[Split]
 
   /** Function for computing a given partition. */
-  def compute(split: Split, taskContext: TaskContext): Iterator[T]
+  def compute(split: Split, context: TaskContext): Iterator[T]
 
   /** How this RDD depends on any parent RDDs. */
   @transient val dependencies: List[Dependency[_]]
@@ -155,11 +155,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
    * This should ''not'' be called by users directly, but is available for implementors of custom
    * subclasses of RDD.
    */
-  final def iterator(split: Split, taskContext: TaskContext): Iterator[T] = {
+  final def iterator(split: Split, context: TaskContext): Iterator[T] = {
     if (storageLevel != StorageLevel.NONE) {
-      SparkEnv.get.cacheTracker.getOrCompute[T](this, split, taskContext, storageLevel)
+      SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
     } else {
-      compute(split, taskContext)
+      compute(split, context)
     }
   }
 
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index b352db8167..d2746b26b3 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -6,11 +6,11 @@ import scala.collection.mutable.ArrayBuffer
 class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
 
   @transient
-  val onCompleteCallbacks = new ArrayBuffer[Unit => Unit]
+  val onCompleteCallbacks = new ArrayBuffer[() => Unit]
 
   // Add a callback function to be executed on task completion. An example use
   // is for HadoopRDD to register a callback to close the input stream.
-  def registerOnCompleteCallback(f: Unit => Unit) {
+  def addOnCompleteCallback(f: () => Unit) {
     onCompleteCallbacks += f
   }
 
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index 8209c36871..f98528a183 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -28,7 +28,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
 
   override def splits = splits_
 
-  override def compute(split: Split, taskContext: TaskContext): Iterator[T] = {
+  override def compute(split: Split, context: TaskContext): Iterator[T] = {
     val blockManager = SparkEnv.get.blockManager
     val blockId = split.asInstanceOf[BlockRDDSplit].blockId
     blockManager.get(blockId) match {
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 6bc0938ce2..4a7e5f3d06 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -36,10 +36,10 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
     rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
   }
 
-  override def compute(split: Split, taskContext: TaskContext) = {
+  override def compute(split: Split, context: TaskContext) = {
     val currSplit = split.asInstanceOf[CartesianSplit]
-    for (x <- rdd1.iterator(currSplit.s1, taskContext);
-      y <- rdd2.iterator(currSplit.s2, taskContext)) yield (x, y)
+    for (x <- rdd1.iterator(currSplit.s1, context);
+      y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
   }
 
   override val dependencies = List(
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 6037681cfd..de0d9fad88 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -68,7 +68,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
 
   override def preferredLocations(s: Split) = Nil
 
-  override def compute(s: Split, taskContext: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+  override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
     val split = s.asInstanceOf[CoGroupSplit]
     val numRdds = split.deps.size
     val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
@@ -78,7 +78,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
     for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
       case NarrowCoGroupSplitDep(rdd, itsSplit) => {
         // Read them from the parent
-        for ((k, v) <- rdd.iterator(itsSplit, taskContext)) {
+        for ((k, v) <- rdd.iterator(itsSplit, context)) {
           getSeq(k.asInstanceOf[K])(depNum) += v
         }
       }
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 06ffc9c42c..1affe0e0ef 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -31,9 +31,9 @@ class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
 
   override def splits = splits_
 
-  override def compute(split: Split, taskContext: TaskContext): Iterator[T] = {
+  override def compute(split: Split, context: TaskContext): Iterator[T] = {
     split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
-      parentSplit => prev.iterator(parentSplit, taskContext)
+      parentSplit => prev.iterator(parentSplit, context)
     }
   }
 
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index 14a80d82c7..b148da28de 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -7,6 +7,5 @@ private[spark]
 class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split, taskContext: TaskContext) =
-    prev.iterator(split, taskContext).filter(f)
+  override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f)
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 64f8c51d6d..785662b2da 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -11,6 +11,6 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
 
-  override def compute(split: Split, taskContext: TaskContext) =
-    prev.iterator(split, taskContext).flatMap(f)
+  override def compute(split: Split, context: TaskContext) =
+    prev.iterator(split, context).flatMap(f)
 }
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index d6b1b27d3e..fac8ffb4cb 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -7,6 +7,6 @@ private[spark]
 class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split, taskContext: TaskContext) =
-    Array(prev.iterator(split, taskContext).toArray).iterator
+  override def compute(split: Split, context: TaskContext) =
+    Array(prev.iterator(split, context).toArray).iterator
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index c6c035a096..ab163f569b 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -66,7 +66,7 @@ class HadoopRDD[K, V](
 
   override def splits = splits_
 
-  override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] {
+  override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
     val split = theSplit.asInstanceOf[HadoopSplit]
     var reader: RecordReader[K, V] = null
 
@@ -75,7 +75,7 @@ class HadoopRDD[K, V](
     reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
 
     // Register an on-task-completion callback to close the input stream.
-    taskContext.registerOnCompleteCallback(Unit => reader.close())
+    context.addOnCompleteCallback(() => reader.close())
 
     val key: K = reader.createKey()
     val value: V = reader.createValue()
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index 715c240060..c764505345 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -14,6 +14,5 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
 
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split, taskContext: TaskContext) =
-    f(prev.iterator(split, taskContext))
+  override def compute(split: Split, context: TaskContext) = f(prev.iterator(split, context))
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index 39f3c7b5f7..3d9888bd34 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -17,6 +17,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
   override val partitioner = if (preservesPartitioning) prev.partitioner else None
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split, taskContext: TaskContext) =
-    f(split.index, prev.iterator(split, taskContext))
+  override def compute(split: Split, context: TaskContext) =
+    f(split.index, prev.iterator(split, context))
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index d82ab3f671..70fa8f4497 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -10,6 +10,5 @@ class MappedRDD[U: ClassManifest, T: ClassManifest](
 
   override def splits = prev.splits
   override val dependencies = List(new OneToOneDependency(prev))
-  override def compute(split: Split, taskContext: TaskContext) =
-    prev.iterator(split, taskContext).map(f)
+  override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).map(f)
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 61f4cbbe94..197ed5ea17 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -53,17 +53,18 @@ class NewHadoopRDD[K, V](
 
   override def splits = splits_
 
-  override def compute(theSplit: Split, taskContext: TaskContext) = new Iterator[(K, V)] {
+  override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
     val split = theSplit.asInstanceOf[NewHadoopSplit]
     val conf = confBroadcast.value.value
     val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
-    val context = newTaskAttemptContext(conf, attemptId)
+    val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
     val format = inputFormatClass.newInstance
-    val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
-    reader.initialize(split.serializableHadoopSplit.value, context)
+    val reader = format.createRecordReader(
+      split.serializableHadoopSplit.value, hadoopAttemptContext)
+    reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
 
     // Register an on-task-completion callback to close the input stream.
-    taskContext.registerOnCompleteCallback(Unit => reader.close())
+    context.addOnCompleteCallback(() => reader.close())
 
     var havePair = false
     var finished = false
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index b34c7ea5b9..336e193217 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -29,7 +29,7 @@ class PipedRDD[T: ClassManifest](
 
   override val dependencies = List(new OneToOneDependency(parent))
 
-  override def compute(split: Split, taskContext: TaskContext): Iterator[String] = {
+  override def compute(split: Split, context: TaskContext): Iterator[String] = {
     val pb = new ProcessBuilder(command)
     // Add the environmental variables to the process.
     val currentEnvVars = pb.environment()
@@ -52,7 +52,7 @@ class PipedRDD[T: ClassManifest](
       override def run() {
         SparkEnv.set(env)
         val out = new PrintWriter(proc.getOutputStream)
-        for (elem <- parent.iterator(split, taskContext)) {
+        for (elem <- parent.iterator(split, context)) {
           out.println(elem)
         }
         out.close()
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 07a1487f3a..6e4797aabb 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -32,13 +32,13 @@ class SampledRDD[T: ClassManifest](
   override def preferredLocations(split: Split) =
     prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
 
-  override def compute(splitIn: Split, taskContext: TaskContext) = {
+  override def compute(splitIn: Split, context: TaskContext) = {
     val split = splitIn.asInstanceOf[SampledRDDSplit]
     if (withReplacement) {
       // For large datasets, the expected number of occurrences of each element in a sample with
       // replacement is Poisson(frac). We use that to get a count for each element.
       val poisson = new Poisson(frac, new DRand(split.seed))
-      prev.iterator(split.prev, taskContext).flatMap { element =>
+      prev.iterator(split.prev, context).flatMap { element =>
         val count = poisson.nextInt()
         if (count == 0) {
           Iterator.empty  // Avoid object allocation when we return 0 items, which is quite often
@@ -48,7 +48,7 @@ class SampledRDD[T: ClassManifest](
       }
     } else { // Sampling without replacement
       val rand = new Random(split.seed)
-      prev.iterator(split.prev, taskContext).filter(x => (rand.nextDouble <= frac))
+      prev.iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
     }
   }
 }
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index c736e92117..f832633646 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -31,7 +31,7 @@ class ShuffledRDD[K, V](
   val dep = new ShuffleDependency(parent, part)
   override val dependencies = List(dep)
 
-  override def compute(split: Split, taskContext: TaskContext): Iterator[(K, V)] = {
+  override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
     SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index)
   }
 }
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 4b9cab8774..a08473f7be 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -12,7 +12,7 @@ private[spark] class UnionSplit[T: ClassManifest](
   extends Split
   with Serializable {
 
-  def iterator(taskContext: TaskContext) = rdd.iterator(split, taskContext)
+  def iterator(context: TaskContext) = rdd.iterator(split, context)
   def preferredLocations() = rdd.preferredLocations(split)
   override val index: Int = idx
 }
@@ -47,8 +47,8 @@ class UnionRDD[T: ClassManifest](
     deps.toList
   }
 
-  override def compute(s: Split, taskContext: TaskContext): Iterator[T] =
-    s.asInstanceOf[UnionSplit[T]].iterator(taskContext)
+  override def compute(s: Split, context: TaskContext): Iterator[T] =
+    s.asInstanceOf[UnionSplit[T]].iterator(context)
 
   override def preferredLocations(s: Split): Seq[String] =
     s.asInstanceOf[UnionSplit[T]].preferredLocations()
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index b987ca5fdf..92d667ff1e 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -12,8 +12,8 @@ private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
   extends Split
   with Serializable {
 
-  def iterator(taskContext: TaskContext): Iterator[(T, U)] =
-    rdd1.iterator(split1, taskContext).zip(rdd2.iterator(split2, taskContext))
+  def iterator(context: TaskContext): Iterator[(T, U)] =
+    rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
 
   def preferredLocations(): Seq[String] =
     rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
@@ -45,8 +45,8 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
   @transient
   override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
 
-  override def compute(s: Split, taskContext: TaskContext): Iterator[(T, U)] =
-    s.asInstanceOf[ZippedSplit[T, U]].iterator(taskContext)
+  override def compute(s: Split, context: TaskContext): Iterator[(T, U)] =
+    s.asInstanceOf[ZippedSplit[T, U]].iterator(context)
 
   override def preferredLocations(s: Split): Seq[String] =
     s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
-- 
GitLab