From 82329b0b2856fbe9c257dd615d4bbcf51f0bbace Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Thu, 19 May 2011 12:47:09 -0700
Subject: [PATCH] Updated scheduler to support running on just some partitions
 of final RDD

---
 core/src/main/scala/spark/DAGScheduler.scala | 24 ++++++++++++++------
 core/src/main/scala/spark/RDD.scala          | 17 ++++++++------
 core/src/main/scala/spark/ResultTask.scala   |  5 ++--
 core/src/main/scala/spark/Scheduler.scala    |  9 ++++++--
 core/src/main/scala/spark/SparkContext.scala | 24 ++++++++------------
 5 files changed, 46 insertions(+), 33 deletions(-)

diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala
index 048a0faf2f..6c00e96d46 100644
--- a/core/src/main/scala/spark/DAGScheduler.scala
+++ b/core/src/main/scala/spark/DAGScheduler.scala
@@ -116,9 +116,11 @@ private trait DAGScheduler extends Scheduler with Logging {
     missing.toList
   }
 
-  override def runJob[T, U](finalRdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
+  override def runJob[T, U](finalRdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int])
+                           (implicit m: ClassManifest[U])
       : Array[U] = {
-    val numOutputParts: Int = finalRdd.splits.size
+    val outputParts = partitions.toArray
+    val numOutputParts: Int = partitions.size
     val finalStage = newStage(finalRdd, None)
     val results = new Array[U](numOutputParts)
     val finished = new Array[Boolean](numOutputParts)
@@ -134,6 +136,13 @@ private trait DAGScheduler extends Scheduler with Logging {
     logInfo("Parents of final stage: " + finalStage.parents)
     logInfo("Missing parents: " + getMissingParentStages(finalStage))
 
+    // Optimization for first() and take() if the RDD has no shuffle dependencies
+    if (finalStage.parents.size == 0 && numOutputParts == 1) {
+      logInfo("Computing the requested partition locally")
+      val split = finalRdd.splits(outputParts(0))
+      return Array(func(finalRdd.iterator(split)))
+    }
+
     def submitStage(stage: Stage) {
       if (!waiting(stage) && !running(stage)) {
         val missing = getMissingParentStages(stage)
@@ -154,9 +163,10 @@ private trait DAGScheduler extends Scheduler with Logging {
       val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
       var tasks = ArrayBuffer[Task[_]]()
       if (stage == finalStage) {
-        for (p <- 0 until numOutputParts if (!finished(p))) {
-          val locs = getPreferredLocs(finalRdd, p)
-          tasks += new ResultTask(finalStage.id, finalRdd, func, p, locs)
+        for (id <- 0 until numOutputParts if (!finished(id))) {
+          val part = outputParts(id)
+          val locs = getPreferredLocs(finalRdd, part)
+          tasks += new ResultTask(finalStage.id, finalRdd, func, part, locs, id)
         }
       } else {
         for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
@@ -177,8 +187,8 @@ private trait DAGScheduler extends Scheduler with Logging {
         Accumulators.add(currentThread, evt.accumUpdates)
         evt.task match {
           case rt: ResultTask[_, _] =>
-            results(rt.partition) = evt.result.asInstanceOf[U]
-            finished(rt.partition) = true
+            results(rt.outputId) = evt.result.asInstanceOf[U]
+            finished(rt.outputId) = true
             numFinished += 1
             pendingTasks(finalStage) -= rt
           case smt: ShuffleMapTask =>
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 6accd5e356..45dcad54b4 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -123,17 +123,21 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) {
     "%s(%d)".format(getClass.getSimpleName, id)
   }
 
-  // TODO: Reimplement these to properly build any shuffle dependencies on
-  // the cluster rather than attempting to compute a partiton on the master
-  /*
+  // Take the first num elements of the RDD. This currently scans the partitions
+  // *one by one*, so it will be slow if a lot of partitions are required. In that
+  // case, use collect() to get the whole RDD instead.
   def take(num: Int): Array[T] = {
     if (num == 0)
       return new Array[T](0)
     val buf = new ArrayBuffer[T]
-    for (split <- splits; elem <- iterator(split)) {
-      buf += elem
-      if (buf.length == num)
+    var p = 0
+    while (buf.size < num && p < splits.size) {
+      val left = num - buf.size
+      val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p))
+      buf ++= res(0)
+      if (buf.size == num)
         return buf.toArray
+      p += 1
     }
     return buf.toArray
   }
@@ -142,7 +146,6 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) {
     case Array(t) => t
     case _ => throw new UnsupportedOperationException("empty collection")
   }
-  */
 }
 
 class MappedRDD[U: ClassManifest, T: ClassManifest](
diff --git a/core/src/main/scala/spark/ResultTask.scala b/core/src/main/scala/spark/ResultTask.scala
index 3b63896175..986e99b81e 100644
--- a/core/src/main/scala/spark/ResultTask.scala
+++ b/core/src/main/scala/spark/ResultTask.scala
@@ -1,6 +1,7 @@
 package spark
 
-class ResultTask[T, U](val stageId: Int, rdd: RDD[T], func: Iterator[T] => U, val partition: Int, locs: Seq[String])
+class ResultTask[T, U](val stageId: Int, rdd: RDD[T], func: Iterator[T] => U,
+                       val partition: Int, locs: Seq[String], val outputId: Int)
 extends Task[U] {
   val split = rdd.splits(partition)
 
@@ -11,4 +12,4 @@ extends Task[U] {
   override def preferredLocations: Seq[String] = locs
 
   override def toString = "ResultTask(" + stageId + ", " + partition + ")"
-}
\ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/Scheduler.scala b/core/src/main/scala/spark/Scheduler.scala
index fbcbb3e935..59c719938f 100644
--- a/core/src/main/scala/spark/Scheduler.scala
+++ b/core/src/main/scala/spark/Scheduler.scala
@@ -3,9 +3,14 @@ package spark
 // Scheduler trait, implemented by both NexusScheduler and LocalScheduler.
 private trait Scheduler {
   def start()
+
   def waitForRegister()
-  //def runTasks[T](tasks: Array[Task[T]])(implicit m: ClassManifest[T]): Array[T]
-  def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U]): Array[U]
+
+  // Run a function on some partitions of an RDD, returning an array of results.
+  def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int])
+                  (implicit m: ClassManifest[U]): Array[U]
+
   def stop()
+
   def numCores(): Int
 }
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index c1807de0ef..7aa1eb0a71 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -145,28 +145,22 @@ extends Logging {
       None
   }
 
-  // Run an array of spark.Task objects
-  private[spark] def runTaskObjects[T: ClassManifest](tasks: Seq[Task[T]])
-      : Array[T] = {
-    return null;
-    /*
-    logInfo("Running " + tasks.length + " tasks in parallel")
-    val start = System.nanoTime
-    val result = scheduler.runTasks(tasks.toArray)
-    logInfo("Tasks finished in " + (System.nanoTime - start) / 1e9 + " s")
-    return result
-    */
-  }
-
-  private[spark] def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)(implicit m: ClassManifest[U])
+  private[spark] def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int])
+                                 (implicit m: ClassManifest[U])
       : Array[U] = {
     logInfo("Starting job...")
     val start = System.nanoTime
-    val result = scheduler.runJob(rdd, func)
+    val result = scheduler.runJob(rdd, func, partitions)
     logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
     result
   }
 
+  private[spark] def runJob[T, U](rdd: RDD[T], func: Iterator[T] => U)
+                                 (implicit m: ClassManifest[U])
+      : Array[U] = {
+    runJob(rdd, func, 0 until rdd.splits.size)
+  }
+
   // Clean a closure to make it ready to serialized and send to tasks
   // (removes unreferenced variables in $outer's, updates REPL variables)
   private[spark] def clean[F <: AnyRef](f: F): F = {
-- 
GitLab