From e72afdb817bcc8388aeb8b8d31628fd5fd67acf1 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Fri, 6 Jul 2012 15:23:26 -0700
Subject: [PATCH] Some refactoring to make cluster scheduler pluggable.

---
 .../main/scala/spark/PairRDDFunctions.scala   |  11 +-
 core/src/main/scala/spark/SparkContext.scala  |  13 +-
 core/src/main/scala/spark/TaskContext.scala   |   2 +-
 core/src/main/scala/spark/TaskState.scala     |  32 ++
 core/src/main/scala/spark/Utils.scala         |  10 +
 .../scala/spark/deploy/DeployMessage.scala    |   7 +-
 .../scala/spark/deploy/ExecutorState.scala    |   4 +-
 .../scala/spark/deploy/JobDescription.scala   |   1 -
 .../spark/deploy/client/TestClient.scala      |   4 +-
 .../spark/deploy/client/TestExecutor.scala    |   3 +
 .../scala/spark/deploy/master/JobState.scala  |   2 +
 ...utorRunner.scala => ExecutorManager.scala} |   9 +-
 .../scala/spark/deploy/worker/Worker.scala    |   9 +-
 .../spark/deploy/worker/WorkerArguments.scala |   4 +-
 .../scala/spark/{ => executor}/Executor.scala | 125 ++---
 .../spark/executor/ExecutorContext.scala      |  11 +
 .../spark/executor/MesosExecutorRunner.scala  |  68 +++
 .../scala/spark/scheduler/ResultTask.scala    |   2 +-
 .../spark/scheduler/ShuffleMapTask.scala      |   2 +-
 .../src/main/scala/spark/scheduler/Task.scala |   2 +-
 .../scala/spark/scheduler/TaskScheduler.scala |   5 +-
 .../scheduler/cluster/ClusterScheduler.scala  | 294 ++++++++++++
 .../cluster/ClusterSchedulerContext.scala     |  10 +
 .../scheduler/cluster/SlaveResources.scala    |   3 +
 .../scheduler/cluster/TaskDescription.scala   |   5 +
 .../{mesos => cluster}/TaskInfo.scala         |   6 +-
 .../{mesos => cluster}/TaskSetManager.scala   | 112 ++---
 .../spark/scheduler/cluster/WorkerOffer.scala |   7 +
 .../scheduler/local/LocalScheduler.scala      |   2 -
 .../mesos/CoarseMesosScheduler.scala          |  21 +-
 .../scheduler/mesos/MesosScheduler.scala      | 434 +++++-------------
 .../scala/spark/storage/BlockManager.scala    |   4 +-
 .../main/scala/spark/repl/SparkILoop.scala    |   1 -
 spark-executor                                |   2 +-
 34 files changed, 718 insertions(+), 509 deletions(-)
 create mode 100644 core/src/main/scala/spark/TaskState.scala
 rename core/src/main/scala/spark/deploy/worker/{ExecutorRunner.scala => ExecutorManager.scala} (95%)
 rename core/src/main/scala/spark/{ => executor}/Executor.scala (51%)
 create mode 100644 core/src/main/scala/spark/executor/ExecutorContext.scala
 create mode 100644 core/src/main/scala/spark/executor/MesosExecutorRunner.scala
 create mode 100644 core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
 create mode 100644 core/src/main/scala/spark/scheduler/cluster/ClusterSchedulerContext.scala
 create mode 100644 core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
 create mode 100644 core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
 rename core/src/main/scala/spark/scheduler/{mesos => cluster}/TaskInfo.scala (76%)
 rename core/src/main/scala/spark/scheduler/{mesos => cluster}/TaskSetManager.scala (83%)
 create mode 100644 core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala

diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 270447712b..ea24c7897d 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -307,9 +307,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
     val jobtrackerID = formatter.format(new Date())
     val stageId = self.id
     def writeShard(context: spark.TaskContext, iter: Iterator[(K,V)]): Int = {
+      // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+      // around by taking a mod. We expect that no task will be attempted 2 billion times.
+      val attemptNumber = (context.attemptId % Int.MaxValue).toInt
       /* "reduce task" <split #> <attempt # = spark task #> */
       val attemptId = new TaskAttemptID(jobtrackerID,
-        stageId, false, context.splitId, context.attemptId)
+        stageId, false, context.splitId, attemptNumber)
       val hadoopContext = new TaskAttemptContext(wrappedConf.value, attemptId)
       val format = outputFormatClass.newInstance
       val committer = format.getOutputCommitter(hadoopContext)
@@ -371,7 +374,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
     writer.preSetup()
 
     def writeToFile(context: TaskContext, iter: Iterator[(K,V)]) {
-      writer.setup(context.stageId, context.splitId, context.attemptId)
+      // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+      // around by taking a mod. We expect that no task will be attempted 2 billion times.
+      val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+
+      writer.setup(context.stageId, context.splitId, attemptNumber)
       writer.open()
       
       var count = 0
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index cba70794e7..d35b2b1cac 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -41,8 +41,8 @@ import spark.scheduler.ShuffleMapTask
 import spark.scheduler.DAGScheduler
 import spark.scheduler.TaskScheduler
 import spark.scheduler.local.LocalScheduler
+import spark.scheduler.cluster.ClusterScheduler
 import spark.scheduler.mesos.MesosScheduler
-import spark.scheduler.mesos.CoarseMesosScheduler
 import spark.storage.BlockManagerMaster
 
 class SparkContext(
@@ -89,11 +89,17 @@ class SparkContext(
         new LocalScheduler(threads.toInt, maxFailures.toInt)
       case _ =>
         MesosNativeLibrary.load()
+        val sched = new ClusterScheduler(this)
+        val schedContext = new MesosScheduler(sched, this, master, frameworkName)
+        sched.initialize(schedContext)
+        sched
+        /*
         if (System.getProperty("spark.mesos.coarse", "false") == "true") {
           new CoarseMesosScheduler(this, master, frameworkName)
         } else {
           new MesosScheduler(this, master, frameworkName)
         }
+        */
     }
   }
   taskScheduler.start()
@@ -272,11 +278,6 @@ class SparkContext(
     logInfo("Successfully stopped SparkContext")
   }
 
-  // Wait for the scheduler to be registered with the cluster manager
-  def waitForRegister() {
-    taskScheduler.waitForRegister()
-  }
-
   // Get Spark's home location from either a value set through the constructor,
   // or the spark.home Java property, or the SPARK_HOME environment variable
   // (in that order of preference). If neither of these is set, return None.
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index 7a6214aab6..c14377d17b 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -1,3 +1,3 @@
 package spark
 
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
+class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable
diff --git a/core/src/main/scala/spark/TaskState.scala b/core/src/main/scala/spark/TaskState.scala
new file mode 100644
index 0000000000..9566b52432
--- /dev/null
+++ b/core/src/main/scala/spark/TaskState.scala
@@ -0,0 +1,32 @@
+package spark
+
+import org.apache.mesos.Protos.{TaskState => MesosTaskState}
+
+object TaskState
+  extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
+
+  val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value
+
+  type TaskState = Value
+
+  def isFinished(state: TaskState) = Seq(FINISHED, FAILED, LOST).contains(state)
+
+  def toMesos(state: TaskState): MesosTaskState = state match {
+    case LAUNCHING => MesosTaskState.TASK_STARTING
+    case RUNNING => MesosTaskState.TASK_RUNNING
+    case FINISHED => MesosTaskState.TASK_FINISHED
+    case FAILED => MesosTaskState.TASK_FAILED
+    case KILLED => MesosTaskState.TASK_KILLED
+    case LOST => MesosTaskState.TASK_LOST
+  }
+
+  def fromMesos(mesosState: MesosTaskState): TaskState = mesosState match {
+    case MesosTaskState.TASK_STAGING => LAUNCHING
+    case MesosTaskState.TASK_STARTING => LAUNCHING
+    case MesosTaskState.TASK_RUNNING => RUNNING
+    case MesosTaskState.TASK_FINISHED => FINISHED
+    case MesosTaskState.TASK_FAILED => FAILED
+    case MesosTaskState.TASK_KILLED => KILLED
+    case MesosTaskState.TASK_LOST => LOST
+  }
+}
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 674ff9e298..5eda1011f9 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -13,6 +13,7 @@ import scala.io.Source
  * Various utility methods used by Spark.
  */
 object Utils {
+  /** Serialize an object using Java serialization */
   def serialize[T](o: T): Array[Byte] = {
     val bos = new ByteArrayOutputStream()
     val oos = new ObjectOutputStream(bos)
@@ -21,12 +22,14 @@ object Utils {
     return bos.toByteArray
   }
 
+  /** Deserialize an object using Java serialization */
   def deserialize[T](bytes: Array[Byte]): T = {
     val bis = new ByteArrayInputStream(bytes)
     val ois = new ObjectInputStream(bis)
     return ois.readObject.asInstanceOf[T]
   }
 
+  /** Deserialize an object using Java serialization and the given ClassLoader */
   def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
     val bis = new ByteArrayInputStream(bytes)
     val ois = new ObjectInputStream(bis) {
@@ -106,6 +109,13 @@ object Utils {
     }
   }
 
+  /** Copy a file on the local file system */
+  def copyFile(source: File, dest: File) {
+    val in = new FileInputStream(source)
+    val out = new FileOutputStream(dest)
+    copyStream(in, out, true)
+  }
+
   /**
    * Shuffle the elements of a collection into a random order, returning the
    * result in a new collection. Unlike scala.util.Random.shuffle, this method
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 14492ed552..cf5e42797b 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -1,5 +1,7 @@
 package spark.deploy
 
+import spark.deploy.ExecutorState.ExecutorState
+
 sealed trait DeployMessage extends Serializable
 
 // Worker to Master
@@ -10,8 +12,7 @@ case class RegisterWorker(id: String, host: String, port: Int, cores: Int, memor
 case class ExecutorStateChanged(
     jobId: String,
     execId: Int,
-    state:
-    ExecutorState.Value,
+    state: ExecutorState,
     message: Option[String])
   extends DeployMessage
 
@@ -38,7 +39,7 @@ case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
 
 case class RegisteredJob(jobId: String) extends DeployMessage
 case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
-case class ExecutorUpdated(id: Int, state: ExecutorState.Value, message: Option[String])
+case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String])
 case class JobKilled(message: String)
 
 // Internal message in Client
diff --git a/core/src/main/scala/spark/deploy/ExecutorState.scala b/core/src/main/scala/spark/deploy/ExecutorState.scala
index ea73f7be29..d6ff1c54ca 100644
--- a/core/src/main/scala/spark/deploy/ExecutorState.scala
+++ b/core/src/main/scala/spark/deploy/ExecutorState.scala
@@ -5,5 +5,7 @@ object ExecutorState
 
   val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value
 
-  def isFinished(state: Value): Boolean = (state == KILLED || state == FAILED || state == LOST)
+  type ExecutorState = Value
+
+  def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST).contains(state)
 }
diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala
index 3f91402a31..8ae77b1038 100644
--- a/core/src/main/scala/spark/deploy/JobDescription.scala
+++ b/core/src/main/scala/spark/deploy/JobDescription.scala
@@ -4,7 +4,6 @@ class JobDescription(
     val name: String,
     val cores: Int,
     val memoryPerSlave: Int,
-    val fileUrls: Seq[String],
     val command: Command)
   extends Serializable {
 
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index b04f362997..df9a36c7fe 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -24,8 +24,8 @@ object TestClient {
   def main(args: Array[String]) {
     val url = args(0)
     val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress(), 0)
-    val desc = new JobDescription("TestClient", 1, 512, Seq(),
-      Command("spark.deploy.client.TestExecutor", Seq(), Map()))
+    val desc = new JobDescription(
+      "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()))
     val listener = new TestListener
     val client = new Client(actorSystem, url, desc, listener)
     client.start()
diff --git a/core/src/main/scala/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/spark/deploy/client/TestExecutor.scala
index 1a74cc03cf..2e40e10d18 100644
--- a/core/src/main/scala/spark/deploy/client/TestExecutor.scala
+++ b/core/src/main/scala/spark/deploy/client/TestExecutor.scala
@@ -3,5 +3,8 @@ package spark.deploy.client
 object TestExecutor {
   def main(args: Array[String]) {
     println("Hello world!")
+    while (true) {
+      Thread.sleep(1000)
+    }
   }
 }
diff --git a/core/src/main/scala/spark/deploy/master/JobState.scala b/core/src/main/scala/spark/deploy/master/JobState.scala
index 3a69a37aca..50b0c6f95b 100644
--- a/core/src/main/scala/spark/deploy/master/JobState.scala
+++ b/core/src/main/scala/spark/deploy/master/JobState.scala
@@ -1,5 +1,7 @@
 package spark.deploy.master
 
 object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
+  type JobState = Value
+
   val WAITING, RUNNING, FINISHED, FAILED = Value
 }
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorManager.scala
similarity index 95%
rename from core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
rename to core/src/main/scala/spark/deploy/worker/ExecutorManager.scala
index ec58f576e7..ce17799648 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorManager.scala
@@ -10,7 +10,10 @@ import org.apache.hadoop.conf.Configuration
 import scala.Some
 import spark.deploy.ExecutorStateChanged
 
-class ExecutorRunner(
+/**
+ * Manages the execution of one executor process.
+ */
+class ExecutorManager(
     jobId: String,
     execId: Int,
     jobDesc: JobDescription,
@@ -26,13 +29,13 @@ class ExecutorRunner(
   var process: Process = null
 
   def start() {
-    workerThread = new Thread("ExecutorRunner for " + fullId) {
+    workerThread = new Thread("ExecutorManager for " + fullId) {
       override def run() { fetchAndRunExecutor() }
     }
     workerThread.start()
   }
 
-  /** Stop this executor runner, including killing the process it launched */
+  /** Stop this executor manager, including killing the process it launched */
   def kill() {
     if (workerThread != null) {
       workerThread.interrupt()
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index e5da181e9a..fba44ca9b5 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -26,7 +26,7 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
   val workerId = generateWorkerId()
   var sparkHome: File = null
   var workDir: File = null
-  val executors = new HashMap[String, ExecutorRunner]
+  val executors = new HashMap[String, ExecutorManager]
   val finishedExecutors = new ArrayBuffer[String]
 
   var coresUsed = 0
@@ -104,9 +104,10 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
 
     case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) =>
       logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name))
-      val er = new ExecutorRunner(jobId, execId, jobDesc, cores_, memory_, self, sparkHome, workDir)
-      executors(jobId + "/" + execId) = er
-      er.start()
+      val manager = new ExecutorManager(
+        jobId, execId, jobDesc, cores_, memory_, self, sparkHome, workDir)
+      executors(jobId + "/" + execId) = manager
+      manager.start()
       master ! ExecutorStateChanged(jobId, execId, ExecutorState.LOADING, None)
 
     case ExecutorStateChanged(jobId, execId, state, message) =>
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
index ab764aa877..3248d03697 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
@@ -83,6 +83,8 @@ class WorkerArguments(args: Array[String]) {
   def inferDefaultMemory(): Int = {
     val bean = ManagementFactory.getOperatingSystemMXBean
                                 .asInstanceOf[com.sun.management.OperatingSystemMXBean]
-    (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt
+    val totalMb = (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt
+    // Leave out 1 GB for the operating system, but don't return a negative memory size
+    math.max(totalMb - 1024, 512)
   }
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
similarity index 51%
rename from core/src/main/scala/spark/Executor.scala
rename to core/src/main/scala/spark/executor/Executor.scala
index 9ead0d2870..ad02b85254 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -1,118 +1,90 @@
-package spark
+package spark.executor
 
 import java.io.{File, FileOutputStream}
-import java.net.{URI, URL, URLClassLoader}
+import java.net.{URL, URLClassLoader}
 import java.util.concurrent._
 
-import scala.actors.remote.RemoteActor
 import scala.collection.mutable.ArrayBuffer
 
-import com.google.protobuf.ByteString
-
-import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
-
 import spark.broadcast._
 import spark.scheduler._
+import spark._
+import java.nio.ByteBuffer
 
 /**
  * The Mesos executor for Spark.
  */
-class Executor extends org.apache.mesos.Executor with Logging {
+class Executor extends Logging {
   var classLoader: ClassLoader = null
   var threadPool: ExecutorService = null
   var env: SparkEnv = null
 
+  val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
+
   initLogging()
 
-  override def registered(
-      driver: ExecutorDriver,
-      executorInfo: ExecutorInfo,
-      frameworkInfo: FrameworkInfo,
-      slaveInfo: SlaveInfo) {
-    // Make sure the local hostname we report matches Mesos's name for this host
-    Utils.setCustomHostname(slaveInfo.getHostname())
-
-    // Read spark.* system properties from executor arg
-    val props = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
-    for ((key, value) <- props) {
+  def initialize(slaveHostname: String, properties: Seq[(String, String)]) {
+    // Make sure the local hostname we report matches the cluster scheduler's name for this host
+    Utils.setCustomHostname(slaveHostname)
+
+    // Set spark.* system properties from executor arg
+    for ((key, value) <- properties) {
       System.setProperty(key, value)
     }
 
-    // Make sure an appropriate class loader is set for remote actors
-    RemoteActor.classLoader = getClass.getClassLoader
-
     // Initialize Spark environment (using system properties read above)
-    env = SparkEnv.createFromSystemProperties(slaveInfo.getHostname(), 0, false, false)
+    env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
     SparkEnv.set(env)
     // Old stuff that isn't yet using env
     Broadcast.initialize(false)
-    
+
     // Create our ClassLoader (using spark properties) and set it on this thread
     classLoader = createClassLoader()
     Thread.currentThread.setContextClassLoader(classLoader)
-    
+
     // Start worker thread pool
     threadPool = new ThreadPoolExecutor(
-        1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+      1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
   }
 
-  override def disconnected(d: ExecutorDriver) {}
-
-  override def reregistered(d: ExecutorDriver, s: SlaveInfo) {}
-  
-  override def launchTask(d: ExecutorDriver, task: MTaskInfo) {
-    threadPool.execute(new TaskRunner(task, d))
+  def launchTask(context: ExecutorContext, taskId: Long, serializedTask: ByteBuffer) {
+    threadPool.execute(new TaskRunner(context, taskId, serializedTask))
   }
 
-  class TaskRunner(info: MTaskInfo, d: ExecutorDriver)
-  extends Runnable {
+  class TaskRunner(context: ExecutorContext, taskId: Long, serializedTask: ByteBuffer)
+    extends Runnable {
+
     override def run() {
-      val tid = info.getTaskId.getValue
       SparkEnv.set(env)
       Thread.currentThread.setContextClassLoader(classLoader)
       val ser = SparkEnv.get.closureSerializer.newInstance()
-      logInfo("Running task ID " + tid)
-      d.sendStatusUpdate(TaskStatus.newBuilder()
-          .setTaskId(info.getTaskId)
-          .setState(TaskState.TASK_RUNNING)
-          .build())
+      logInfo("Running task ID " + taskId)
+      context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
       try {
         SparkEnv.set(env)
         Thread.currentThread.setContextClassLoader(classLoader)
-        Accumulators.clear
-        val task = ser.deserialize[Task[Any]](info.getData.asReadOnlyByteBuffer, classLoader)
+        Accumulators.clear()
+        val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
         env.mapOutputTracker.updateGeneration(task.generation)
-        val value = task.run(tid.toInt)
+        val value = task.run(taskId.toInt)
         val accumUpdates = Accumulators.values
         val result = new TaskResult(value, accumUpdates)
-        d.sendStatusUpdate(TaskStatus.newBuilder()
-            .setTaskId(info.getTaskId)
-            .setState(TaskState.TASK_FINISHED)
-            .setData(ByteString.copyFrom(ser.serialize(result)))
-            .build())
-        logInfo("Finished task ID " + tid)
+        context.statusUpdate(taskId, TaskState.FINISHED, ser.serialize(result))
+        logInfo("Finished task ID " + taskId)
       } catch {
         case ffe: FetchFailedException => {
           val reason = ffe.toTaskEndReason
-          d.sendStatusUpdate(TaskStatus.newBuilder()
-              .setTaskId(info.getTaskId)
-              .setState(TaskState.TASK_FAILED)
-              .setData(ByteString.copyFrom(ser.serialize(reason)))
-              .build())
+          context.statusUpdate(taskId, TaskState.FINISHED, ser.serialize(reason))
         }
+
         case t: Throwable => {
           val reason = ExceptionFailure(t)
-          d.sendStatusUpdate(TaskStatus.newBuilder()
-              .setTaskId(info.getTaskId)
-              .setState(TaskState.TASK_FAILED)
-              .setData(ByteString.copyFrom(ser.serialize(reason)))
-              .build())
+          context.statusUpdate(taskId, TaskState.FINISHED, ser.serialize(reason))
 
           // TODO: Should we exit the whole executor here? On the one hand, the failed task may
           // have left some weird state around depending on when the exception was thrown, but on
           // the other hand, maybe we could detect that when future tasks fail and exit then.
-          logError("Exception in task ID " + tid, t)
+          logError("Exception in task ID " + taskId, t)
           //System.exit(1)
         }
       }
@@ -120,7 +92,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
   }
 
   /**
-   * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes 
+   * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
    * created by the interpreter to the search path
    */
   private def createClassLoader(): ClassLoader = {
@@ -129,7 +101,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
     // If any JAR URIs are given through spark.jar.uris, fetch them to the
     // current directory and put them all on the classpath. We assume that
     // each URL has a unique file name so that no local filenames will clash
-    // in this process. This is guaranteed by MesosScheduler.
+    // in this process. This is guaranteed by ClusterScheduler.
     val uris = System.getProperty("spark.jar.uris", "")
     val localFiles = ArrayBuffer[String]()
     for (uri <- uris.split(",").filter(_.size > 0)) {
@@ -150,7 +122,8 @@ class Executor extends org.apache.mesos.Executor with Logging {
       logInfo("Using REPL class URI: " + classUri)
       loader = {
         try {
-          val klass = Class.forName("spark.repl.ExecutorClassLoader").asInstanceOf[Class[_ <: ClassLoader]]
+          val klass = Class.forName("spark.repl.ExecutorClassLoader")
+            .asInstanceOf[Class[_ <: ClassLoader]]
           val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
           constructor.newInstance(classUri, loader)
         } catch {
@@ -168,28 +141,4 @@ class Executor extends org.apache.mesos.Executor with Logging {
     val out = new FileOutputStream(localPath)
     Utils.copyStream(in, out, true)
   }
-
-  override def error(d: ExecutorDriver, message: String) {
-    logError("Error from Mesos: " + message)
-  }
-
-  override def killTask(d: ExecutorDriver, t: TaskID) {
-    logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
-  }
-
-  override def shutdown(d: ExecutorDriver) {}
-
-  override def frameworkMessage(d: ExecutorDriver, data: Array[Byte]) {}
-}
-
-/**
- * Executor entry point.
- */
-object Executor extends Logging {
-  def main(args: Array[String]) {
-    MesosNativeLibrary.load()
-    // Create a new Executor and start it running
-    val exec = new Executor
-    new MesosExecutorDriver(exec).run()
-  }
 }
diff --git a/core/src/main/scala/spark/executor/ExecutorContext.scala b/core/src/main/scala/spark/executor/ExecutorContext.scala
new file mode 100644
index 0000000000..6b86d8d18a
--- /dev/null
+++ b/core/src/main/scala/spark/executor/ExecutorContext.scala
@@ -0,0 +1,11 @@
+package spark.executor
+
+import java.nio.ByteBuffer
+import spark.TaskState.TaskState
+
+/**
+ * Interface used by Executor to send back updates to the cluster scheduler.
+ */
+trait ExecutorContext {
+  def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)
+}
diff --git a/core/src/main/scala/spark/executor/MesosExecutorRunner.scala b/core/src/main/scala/spark/executor/MesosExecutorRunner.scala
new file mode 100644
index 0000000000..7695cbdfd7
--- /dev/null
+++ b/core/src/main/scala/spark/executor/MesosExecutorRunner.scala
@@ -0,0 +1,68 @@
+package spark.executor
+
+import java.nio.ByteBuffer
+import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
+import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _}
+import spark.TaskState.TaskState
+import com.google.protobuf.ByteString
+import spark.{Utils, Logging}
+
+class MesosExecutorRunner(executor: Executor)
+  extends MesosExecutor
+  with ExecutorContext
+  with Logging {
+
+  var driver: ExecutorDriver = null
+
+  override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
+    val mesosTaskId = TaskID.newBuilder().setValue(taskId.toString).build()
+    driver.sendStatusUpdate(MesosTaskStatus.newBuilder()
+      .setTaskId(mesosTaskId)
+      .setState(MesosTaskState.TASK_FINISHED)
+      .setData(ByteString.copyFrom(data))
+      .build())
+  }
+
+  override def registered(
+      driver: ExecutorDriver,
+      executorInfo: ExecutorInfo,
+      frameworkInfo: FrameworkInfo,
+      slaveInfo: SlaveInfo) {
+    this.driver = driver
+    val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
+    executor.initialize(slaveInfo.getHostname, properties)
+  }
+
+  override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
+    val taskId = taskInfo.getTaskId.getValue.toLong
+    executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer)
+  }
+
+  override def error(d: ExecutorDriver, message: String) {
+    logError("Error from Mesos: " + message)
+  }
+
+  override def killTask(d: ExecutorDriver, t: TaskID) {
+    logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
+  }
+
+  override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}
+
+  override def disconnected(d: ExecutorDriver) {}
+
+  override def frameworkMessage(d: ExecutorDriver, data: Array[Byte]) {}
+
+  override def shutdown(d: ExecutorDriver) {}
+}
+
+/**
+ * Entry point for Mesos executor.
+ */
+object MesosExecutorRunner {
+  def main(args: Array[String]) {
+    MesosNativeLibrary.load()
+    // Create a new Executor and start it running
+    val runner = new MesosExecutorRunner(new Executor)
+    new MesosExecutorDriver(runner).run()
+  }
+}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index d2fab55b5e..090ced9d76 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -13,7 +13,7 @@ class ResultTask[T, U](
   
   val split = rdd.splits(partition)
 
-  override def run(attemptId: Int): U = {
+  override def run(attemptId: Long): U = {
     val context = new TaskContext(stageId, partition, attemptId)
     func(context, rdd.iterator(split))
   }
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 8c0e06f020..db89db903e 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -104,7 +104,7 @@ class ShuffleMapTask(
     split = in.readObject().asInstanceOf[Split]
   }
 
-  override def run(attemptId: Int): BlockManagerId = {
+  override def run(attemptId: Long): BlockManagerId = {
     val numOutputSplits = dep.partitioner.numPartitions
     val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
     val partitioner = dep.partitioner
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index 42325956ba..f84d8d9c4f 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -4,7 +4,7 @@ package spark.scheduler
  * A task to execute on a worker node.
  */
 abstract class Task[T](val stageId: Int) extends Serializable {
-  def run(attemptId: Int): T
+  def run(attemptId: Long): T
   def preferredLocations: Seq[String] = Nil
 
   var generation: Long = -1   // Map output tracker generation. Will be set by TaskScheduler.
diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
index cb7c375d97..c35633d53c 100644
--- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
@@ -1,7 +1,7 @@
 package spark.scheduler
 
 /**
- * Low-level task scheduler interface, implemented by both MesosScheduler and LocalScheduler.
+ * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
  * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
  * and are responsible for sending the tasks to the cluster, running them, retrying if there
  * are failures, and mitigating stragglers. They return events to the DAGScheduler through
@@ -10,9 +10,6 @@ package spark.scheduler
 trait TaskScheduler {
   def start(): Unit
 
-  // Wait for registration with Mesos.
-  def waitForRegister(): Unit
-
   // Disconnect from the cluster.
   def stop(): Unit
 
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
new file mode 100644
index 0000000000..c9b0c4e9b6
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -0,0 +1,294 @@
+package spark.scheduler.cluster
+
+import java.io.{File, FileInputStream, FileOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
+import spark._
+import spark.TaskState.TaskState
+import spark.scheduler._
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicLong
+
+/**
+ * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
+ * start(), then submit task sets through the runTasks method.
+ */
+class ClusterScheduler(sc: SparkContext)
+  extends TaskScheduler
+  with Logging {
+
+  // How often to check for speculative tasks
+  val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
+
+  val activeTaskSets = new HashMap[String, TaskSetManager]
+  var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
+
+  val taskIdToTaskSetId = new HashMap[Long, String]
+  val taskIdToSlaveId = new HashMap[Long, String]
+  val taskSetTaskIds = new HashMap[String, HashSet[Long]]
+
+  // Incrementing Mesos task IDs
+  val nextTaskId = new AtomicLong(0)
+
+  // Which hosts in the cluster are alive (contains hostnames)
+  val hostsAlive = new HashSet[String]
+
+  // Which slave IDs we have executors on
+  val slaveIdsWithExecutors = new HashSet[String]
+
+  val slaveIdToHost = new HashMap[String, String]
+
+  // JAR server, if any JARs were added by the user to the SparkContext
+  var jarServer: HttpServer = null
+
+  // URIs of JARs to pass to executor
+  var jarUris: String = ""
+
+  // Listener object to pass upcalls into
+  var listener: TaskSchedulerListener = null
+
+  var schedContext: ClusterSchedulerContext = null
+
+  val mapOutputTracker = SparkEnv.get.mapOutputTracker
+
+  override def setListener(listener: TaskSchedulerListener) {
+    this.listener = listener
+  }
+
+  def initialize(context: ClusterSchedulerContext) {
+    schedContext = context
+    createJarServer()
+  }
+
+  def newTaskId(): Long = nextTaskId.getAndIncrement()
+
+  override def start() {
+    schedContext.start()
+
+    if (System.getProperty("spark.speculation", "false") == "true") {
+      new Thread("ClusterScheduler speculation check") {
+        setDaemon(true)
+
+        override def run() {
+          while (true) {
+            try {
+              Thread.sleep(SPECULATION_INTERVAL)
+            } catch {
+              case e: InterruptedException => {}
+            }
+            checkSpeculatableTasks()
+          }
+        }
+      }.start()
+    }
+  }
+
+  def submitTasks(taskSet: TaskSet) {
+    val tasks = taskSet.tasks
+    logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
+    this.synchronized {
+      val manager = new TaskSetManager(this, taskSet)
+      activeTaskSets(taskSet.id) = manager
+      activeTaskSetsQueue += manager
+      taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+    }
+    schedContext.reviveOffers()
+  }
+
+  def taskSetFinished(manager: TaskSetManager) {
+    this.synchronized {
+      activeTaskSets -= manager.taskSet.id
+      activeTaskSetsQueue -= manager
+      taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+      taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
+      taskSetTaskIds.remove(manager.taskSet.id)
+    }
+  }
+
+
+  /**
+   * Called by cluster manager to offer resources on slaves. We respond by asking our active task
+   * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
+   * that tasks are balanced across the cluster.
+   */
+  def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = {
+    synchronized {
+      // Mark each slave as alive and remember its hostname
+      for (o <- offers) {
+        slaveIdToHost(o.slaveId) = o.hostname
+        hostsAlive += o.hostname
+      }
+      // Build a list of tasks to assign to each slave
+      val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
+      val availableCpus = offers.map(o => o.cores).toArray
+      var launchedTask = false
+      for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
+        do {
+          launchedTask = false
+          for (i <- 0 until offers.size) {
+            val sid = offers(i).slaveId
+            val host = offers(i).hostname
+            manager.slaveOffer(sid, host, availableCpus(i)) match {
+              case Some(task) =>
+                tasks(i) += task
+                val tid = task.taskId
+                taskIdToTaskSetId(tid) = manager.taskSet.id
+                taskSetTaskIds(manager.taskSet.id) += tid
+                taskIdToSlaveId(tid) = sid
+                slaveIdsWithExecutors += sid
+                availableCpus(i) -= 1
+                launchedTask = true
+
+              case None => {}
+            }
+          }
+        } while (launchedTask)
+      }
+      return tasks
+    }
+  }
+
+  def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+    var taskSetToUpdate: Option[TaskSetManager] = None
+    var failedHost: Option[String] = None
+    var taskFailed = false
+    synchronized {
+      try {
+        if (state == TaskState.LOST && taskIdToSlaveId.contains(tid)) {
+          // We lost the executor on this slave, so remember that it's gone
+          val slaveId = taskIdToSlaveId(tid)
+          val host = slaveIdToHost(slaveId)
+          if (hostsAlive.contains(host)) {
+            slaveIdsWithExecutors -= slaveId
+            hostsAlive -= host
+            activeTaskSetsQueue.foreach(_.hostLost(host))
+            failedHost = Some(host)
+          }
+        }
+        taskIdToTaskSetId.get(tid) match {
+          case Some(taskSetId) =>
+            if (activeTaskSets.contains(taskSetId)) {
+              //activeTaskSets(taskSetId).statusUpdate(status)
+              taskSetToUpdate = Some(activeTaskSets(taskSetId))
+            }
+            if (TaskState.isFinished(state)) {
+              taskIdToTaskSetId.remove(tid)
+              if (taskSetTaskIds.contains(taskSetId)) {
+                taskSetTaskIds(taskSetId) -= tid
+              }
+              taskIdToSlaveId.remove(tid)
+            }
+            if (state == TaskState.FAILED) {
+              taskFailed = true
+            }
+          case None =>
+            logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+        }
+      } catch {
+        case e: Exception => logError("Exception in statusUpdate", e)
+      }
+    }
+    // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
+    if (taskSetToUpdate != None) {
+      taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
+    }
+    if (failedHost != None) {
+      listener.hostLost(failedHost.get)
+      schedContext.reviveOffers()
+    }
+    if (taskFailed) {
+      // Also revive offers if a task had failed for some reason other than host lost
+      schedContext.reviveOffers()
+    }
+  }
+
+  def error(message: String) {
+    synchronized {
+      if (activeTaskSets.size > 0) {
+        // Have each task set throw a SparkException with the error
+        for ((taskSetId, manager) <- activeTaskSets) {
+          try {
+            manager.error(message)
+          } catch {
+            case e: Exception => logError("Exception in error callback", e)
+          }
+        }
+      } else {
+        // No task sets are active but we still got an error. Just exit since this
+        // must mean the error is during registration.
+        // It might be good to do something smarter here in the future.
+        logError("Exiting due to error from cluster scheduler: " + message)
+        System.exit(1)
+      }
+    }
+  }
+
+  override def stop() {
+    if (schedContext != null) {
+      schedContext.stop()
+    }
+    if (jarServer != null) {
+      jarServer.stop()
+    }
+  }
+
+  override def defaultParallelism() = schedContext.defaultParallelism()
+
+  // Create a server for all the JARs added by the user to SparkContext.
+  // We first copy the JARs to a temp directory for easier server setup.
+  private def createJarServer() {
+    val jarDir = Utils.createTempDir()
+    logInfo("Temp directory for JARs: " + jarDir)
+    val filenames = ArrayBuffer[String]()
+    // Copy each JAR to a unique filename in the jarDir
+    for ((path, index) <- sc.jars.zipWithIndex) {
+      val file = new File(path)
+      if (file.exists) {
+        val filename = index + "_" + file.getName
+        Utils.copyFile(file, new File(jarDir, filename))
+        filenames += filename
+      }
+    }
+    // Create the server
+    jarServer = new HttpServer(jarDir)
+    jarServer.start()
+    // Build up the jar URI list
+    val serverUri = jarServer.uri
+    jarUris = filenames.map(f => serverUri + "/" + f).mkString(",")
+    System.setProperty("spark.jar.uris", jarUris)
+    logInfo("JAR server started at " + serverUri)
+  }
+
+  // Check for speculatable tasks in all our active jobs.
+  def checkSpeculatableTasks() {
+    var shouldRevive = false
+    synchronized {
+      for (ts <- activeTaskSetsQueue) {
+        shouldRevive |= ts.checkSpeculatableTasks()
+      }
+    }
+    if (shouldRevive) {
+      schedContext.reviveOffers()
+    }
+  }
+
+  def slaveLost(slaveId: String) {
+    var failedHost: Option[String] = None
+    synchronized {
+      val host = slaveIdToHost(slaveId)
+      if (hostsAlive.contains(host)) {
+        slaveIdsWithExecutors -= slaveId
+        hostsAlive -= host
+        activeTaskSetsQueue.foreach(_.hostLost(host))
+        failedHost = Some(host)
+      }
+    }
+    if (failedHost != None) {
+      listener.hostLost(failedHost.get)
+      schedContext.reviveOffers()
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterSchedulerContext.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterSchedulerContext.scala
new file mode 100644
index 0000000000..6b9687ac25
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterSchedulerContext.scala
@@ -0,0 +1,10 @@
+package spark.scheduler.cluster
+
+trait ClusterSchedulerContext {
+  def start(): Unit
+  def stop(): Unit
+  def reviveOffers(): Unit
+  def defaultParallelism(): Int
+
+  // TODO: Probably want to add a killTask too
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
new file mode 100644
index 0000000000..e15d577a8b
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala
@@ -0,0 +1,3 @@
+package spark.scheduler.cluster
+
+class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {}
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
new file mode 100644
index 0000000000..fad62f96aa
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
@@ -0,0 +1,5 @@
+package spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+
+class TaskDescription(val taskId: Long, val name: String, val serializedTask: ByteBuffer) {}
diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
similarity index 76%
rename from core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
rename to core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index af2f80ea66..0fc1d8ed30 100644
--- a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -1,9 +1,9 @@
-package spark.scheduler.mesos
+package spark.scheduler.cluster
 
 /**
- * Information about a running task attempt.
+ * Information about a running task attempt inside a TaskSet.
  */
-class TaskInfo(val taskId: String, val index: Int, val launchTime: Long, val host: String) {
+class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) {
   var finishTime: Long = 0
   var failed = false
 
diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
similarity index 83%
rename from core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
rename to core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index a8bebf8e50..75b67a0eb4 100644
--- a/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -1,4 +1,4 @@
-package spark.scheduler.mesos
+package spark.scheduler.cluster
 
 import java.util.Arrays
 import java.util.{HashMap => JHashMap}
@@ -9,22 +9,19 @@ import scala.collection.mutable.HashSet
 import scala.math.max
 import scala.math.min
 
-import com.google.protobuf.ByteString
-
-import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
-
 import spark._
 import spark.scheduler._
+import spark.TaskState.TaskState
+import java.nio.ByteBuffer
 
 /**
- * Schedules the tasks within a single TaskSet in the MesosScheduler.
+ * Schedules the tasks within a single TaskSet in the ClusterScheduler.
  */
 class TaskSetManager(
-    sched: MesosScheduler, 
-    val taskSet: TaskSet)
+  sched: ClusterScheduler,
+  val taskSet: TaskSet)
   extends Logging {
-  
+
   // Maximum time to wait to run a task in a preferred location (in ms)
   val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
 
@@ -68,12 +65,12 @@ class TaskSetManager(
   // List containing all pending tasks (also used as a stack, as above)
   val allPendingTasks = new ArrayBuffer[Int]
 
-  // Tasks that can be specualted. Since these will be a small fraction of total
-  // tasks, we'll just hold them in a HaskSet.
+  // Tasks that can be speculated. Since these will be a small fraction of total
+  // tasks, we'll just hold them in a HashSet.
   val speculatableTasks = new HashSet[Int]
 
   // Task index, start and finish time for each task attempt (indexed by task ID)
-  val taskInfos = new HashMap[String, TaskInfo]
+  val taskInfos = new HashMap[Long, TaskInfo]
 
   // Did the job fail?
   var failed = false
@@ -140,12 +137,13 @@ class TaskSetManager(
   // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
   // task must have a preference for this host (or no preferred locations at all).
   def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
-    speculatableTasks.retain(index => !finished(index))  // Remove finished tasks from set
-    val localTask = speculatableTasks.find { index =>
-      val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
-      val attemptLocs = taskAttempts(index).map(_.host)
-      (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
-    }
+    speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+    val localTask = speculatableTasks.find {
+        index =>
+          val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
+          val attemptLocs = taskAttempts(index).map(_.host)
+          (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
+      }
     if (localTask != None) {
       speculatableTasks -= localTask.get
       return localTask
@@ -190,11 +188,11 @@ class TaskSetManager(
   }
 
   // Respond to an offer of a single slave from the scheduler by finding a task
-  def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[MTaskInfo] = {
+  def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
     if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
       val time = System.currentTimeMillis
       var localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
-      
+
       findTask(host, localOnly) match {
         case Some(index) => {
           // Found a task; do some bookkeeping and return a Mesos task for it
@@ -204,38 +202,23 @@ class TaskSetManager(
           val preferred = isPreferredLocation(task, host)
           val prefStr = if (preferred) "preferred" else "non-preferred"
           logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
-              taskSet.id, index, taskId.getValue, slaveId, host, prefStr))
+            taskSet.id, index, taskId, slaveId, host, prefStr))
           // Do various bookkeeping
           copiesRunning(index) += 1
-          val info = new TaskInfo(taskId.getValue, index, time, host)
-          taskInfos(taskId.getValue) = info
+          val info = new TaskInfo(taskId, index, time, host)
+          taskInfos(taskId) = info
           taskAttempts(index) = info :: taskAttempts(index)
           if (preferred) {
             lastPreferredLaunchTime = time
           }
-          // Create and return the Mesos task object
-          val cpuRes = Resource.newBuilder()
-            .setName("cpus")
-            .setType(Value.Type.SCALAR)
-            .setScalar(Value.Scalar.newBuilder().setValue(CPUS_PER_TASK).build())
-            .build()
-
+          // Serialize and return the task
           val startTime = System.currentTimeMillis
           val serializedTask = ser.serialize(task)
           val timeTaken = System.currentTimeMillis - startTime
-
           logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
             taskSet.id, index, serializedTask.limit, timeTaken))
-
           val taskName = "task %s:%d".format(taskSet.id, index)
-          return Some(MTaskInfo.newBuilder()
-              .setTaskId(taskId)
-              .setSlaveId(SlaveID.newBuilder().setValue(slaveId))
-              .setExecutor(sched.executorInfo)
-              .setName(taskName)
-              .addResources(cpuRes)
-              .setData(ByteString.copyFrom(serializedTask))
-              .build())
+          return Some(new TaskDescription(taskId, taskName, serializedTask))
         }
         case _ =>
       }
@@ -243,32 +226,30 @@ class TaskSetManager(
     return None
   }
 
-  def statusUpdate(status: TaskStatus) {
-    status.getState match {
-      case TaskState.TASK_FINISHED =>
-        taskFinished(status)
-      case TaskState.TASK_LOST =>
-        taskLost(status)
-      case TaskState.TASK_FAILED =>
-        taskLost(status)
-      case TaskState.TASK_KILLED =>
-        taskLost(status)
+  def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+    state match {
+      case TaskState.FINISHED =>
+        taskFinished(tid, state, serializedData)
+      case TaskState.LOST =>
+        taskLost(tid, state, serializedData)
+      case TaskState.FAILED =>
+        taskLost(tid, state, serializedData)
+      case TaskState.KILLED =>
+        taskLost(tid, state, serializedData)
       case _ =>
     }
   }
 
-  def taskFinished(status: TaskStatus) {
-    val tid = status.getTaskId.getValue
+  def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
     val info = taskInfos(tid)
     val index = info.index
     info.markSuccessful()
     if (!finished(index)) {
       tasksFinished += 1
       logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
-          tid, info.duration, tasksFinished, numTasks))
+        tid, info.duration, tasksFinished, numTasks))
       // Deserialize task result and pass it to the scheduler
-      val result = ser.deserialize[TaskResult[_]](
-        status.getData.asReadOnlyByteBuffer, getClass.getClassLoader)
+      val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
       sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
       // Mark finished and stop if we've finished all the tasks
       finished(index) = true
@@ -281,8 +262,7 @@ class TaskSetManager(
     }
   }
 
-  def taskLost(status: TaskStatus) {
-    val tid = status.getTaskId.getValue
+  def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
     val info = taskInfos(tid)
     val index = info.index
     info.markFailed()
@@ -291,9 +271,8 @@ class TaskSetManager(
       copiesRunning(index) -= 1
       // Check if the problem is a map output fetch failure. In that case, this
       // task will never succeed on any node, so tell the scheduler about it.
-      if (status.getData != null && status.getData.size > 0) {
-        val reason = ser.deserialize[TaskEndReason](
-          status.getData.asReadOnlyByteBuffer, getClass.getClassLoader)
+      if (serializedData != null && serializedData.limit() > 0) {
+        val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
         reason match {
           case fetchFailed: FetchFailed =>
             logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
@@ -334,11 +313,11 @@ class TaskSetManager(
       // On non-fetch failures, re-enqueue the task as pending for a max number of retries
       addPendingTask(index)
       // Count failed attempts only on FAILED and LOST state (not on KILLED)
-      if (status.getState == TaskState.TASK_FAILED || status.getState == TaskState.TASK_LOST) {
+      if (state == TaskState.FAILED || state == TaskState.LOST) {
         numFailures(index) += 1
         if (numFailures(index) > MAX_TASK_FAILURES) {
           logError("Task %s:%d failed more than %d times; aborting job".format(
-              taskSet.id, index, MAX_TASK_FAILURES))
+            taskSet.id, index, MAX_TASK_FAILURES))
           abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
         }
       }
@@ -389,7 +368,7 @@ class TaskSetManager(
 
   /**
    * Check for tasks to be speculated and return true if there are any. This is called periodically
-   * by the MesosScheduler.
+   * by the ClusterScheduler.
    *
    * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
    * we don't scan the whole task set. It might also help to make this sorted by launch time.
@@ -414,8 +393,9 @@ class TaskSetManager(
       for ((tid, info) <- taskInfos) {
         val index = info.index
         if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
-            !speculatableTasks.contains(index)) {
-          logInfo("Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+          !speculatableTasks.contains(index)) {
+          logInfo(
+            "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
               taskSet.id, index, info.host, threshold))
           speculatableTasks += index
           foundTasks = true
diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
new file mode 100644
index 0000000000..1e83f103e7
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
@@ -0,0 +1,7 @@
+package spark.scheduler.cluster
+
+/**
+ * Represents free resources available on a worker node.
+ */
+class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) {
+}
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 1a47f3fddf..eb47988f0c 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -20,8 +20,6 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
   // TODO: Need to take into account stage priority in scheduling
 
   override def start() {}
-  
-  override def waitForRegister() {}
 
   override def setListener(listener: TaskSchedulerListener) { 
     this.listener = listener
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
index 2eee36264a..51fb3dc72f 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
@@ -26,10 +26,13 @@ import com.google.protobuf.ByteString
 
 import org.apache.mesos.{Scheduler => MScheduler}
 import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, TaskState => MesosTaskState, _}
 
 import spark._
 import spark.scheduler._
+import spark.scheduler.cluster.{TaskSetManager, ClusterScheduler}
+
+/*
 
 sealed trait CoarseMesosSchedulerMessage
 case class RegisterSlave(slaveId: String, host: String) extends CoarseMesosSchedulerMessage
@@ -50,7 +53,7 @@ class CoarseMesosScheduler(
     sc: SparkContext,
     master: String,
     frameworkName: String)
-  extends MesosScheduler(sc, master, frameworkName) {
+  extends ClusterScheduler(sc, master, frameworkName) {
 
   val actorSystem = sc.env.actorSystem
   val actorName = "CoarseMesosScheduler"
@@ -161,7 +164,7 @@ class CoarseMesosScheduler(
               taskIdToSlaveId -= tid
               taskIdsOnSlave(slaveId) -= tid
             }
-            if (status.getState == TaskState.TASK_FAILED) {
+            if (status.getState == MesosTaskState.TASK_FAILED) {
               taskFailed = true
             }
           case None =>
@@ -205,7 +208,7 @@ class CoarseMesosScheduler(
         // TODO: Maybe call our statusUpdate() instead to clean our internal data structures
         activeTaskSets(taskSetId).statusUpdate(TaskStatus.newBuilder()
           .setTaskId(TaskID.newBuilder().setValue(tid).build())
-          .setState(TaskState.TASK_LOST)
+          .setState(MesosTaskState.TASK_LOST)
           .build())
       }
       // Also report the loss to the DAGScheduler
@@ -283,7 +286,7 @@ class CoarseMesosScheduler(
 class WorkerTask(slaveId: String, host: String) extends Task[Unit](-1) {
   generation = 0
 
-  def run(id: Int) {
+  def run(id: Long) {
     val env = SparkEnv.get
     val classLoader = Thread.currentThread.getContextClassLoader
     val actor = env.actorSystem.actorOf(
@@ -323,7 +326,7 @@ class WorkerActor(slaveId: String, host: String, env: SparkEnv, classLoader: Cla
         val result = new TaskResult(value, accumUpdates)
         masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
             .setTaskId(desc.getTaskId)
-            .setState(TaskState.TASK_FINISHED)
+            .setState(MesosTaskState.TASK_FINISHED)
             .setData(ByteString.copyFrom(Utils.serialize(result)))
             .build())
         logInfo("Finished task ID " + tid)
@@ -332,7 +335,7 @@ class WorkerActor(slaveId: String, host: String, env: SparkEnv, classLoader: Cla
           val reason = ffe.toTaskEndReason
           masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
               .setTaskId(desc.getTaskId)
-              .setState(TaskState.TASK_FAILED)
+              .setState(MesosTaskState.TASK_FAILED)
               .setData(ByteString.copyFrom(Utils.serialize(reason)))
               .build())
         }
@@ -340,7 +343,7 @@ class WorkerActor(slaveId: String, host: String, env: SparkEnv, classLoader: Cla
           val reason = ExceptionFailure(t)
           masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
               .setTaskId(desc.getTaskId)
-              .setState(TaskState.TASK_FAILED)
+              .setState(MesosTaskState.TASK_FAILED)
               .setData(ByteString.copyFrom(Utils.serialize(reason)))
               .build())
 
@@ -364,3 +367,5 @@ class WorkerActor(slaveId: String, host: String, env: SparkEnv, classLoader: Cla
       threadPool.execute(new TaskRunner(task))    
   }
 }
+
+*/
\ No newline at end of file
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
index 9113348976..8131d84fdf 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
@@ -1,36 +1,26 @@
 package spark.scheduler.mesos
 
-import java.io.{File, FileInputStream, FileOutputStream}
-import java.util.{ArrayList => JArrayList}
-import java.util.{List => JList}
-import java.util.{HashMap => JHashMap}
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Map
-import scala.collection.mutable.PriorityQueue
-import scala.collection.JavaConversions._
-import scala.math.Ordering
-
 import com.google.protobuf.ByteString
 
 import org.apache.mesos.{Scheduler => MScheduler}
 import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
 
-import spark._
-import spark.scheduler._
+import spark.{SparkException, Utils, Logging, SparkContext}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+import java.io.File
+import spark.scheduler.cluster._
+import java.util.{ArrayList => JArrayList, List => JList}
+import java.util.Collections
+import spark.TaskState
 
-/**
- * The main TaskScheduler implementation, which runs tasks on Mesos. Clients should first call
- * start(), then submit task sets through the runTasks method.
- */
 class MesosScheduler(
+    scheduler: ClusterScheduler,
     sc: SparkContext,
     master: String,
     frameworkName: String)
-  extends TaskScheduler
+  extends ClusterSchedulerContext
   with MScheduler
   with Logging {
 
@@ -52,86 +42,40 @@ class MesosScheduler(
     }
   }
 
-  // How often to check for speculative tasks
-  val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
-
   // Lock used to wait for scheduler to be registered
   var isRegistered = false
   val registeredLock = new Object()
 
-  val activeTaskSets = new HashMap[String, TaskSetManager]
-  var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
-
-  val taskIdToTaskSetId = new HashMap[String, String]
-  val taskIdToSlaveId = new HashMap[String, String]
-  val taskSetTaskIds = new HashMap[String, HashSet[String]]
-
-  // Incrementing Mesos task IDs
-  var nextTaskId = 0
-
   // Driver for talking to Mesos
   var driver: SchedulerDriver = null
 
-  // Which hosts in the cluster are alive (contains hostnames)
-  val hostsAlive = new HashSet[String]
-
   // Which slave IDs we have executors on
   val slaveIdsWithExecutors = new HashSet[String]
+  val taskIdToSlaveId = new HashMap[Long, String]
 
-  val slaveIdToHost = new HashMap[String, String]
-
-  // JAR server, if any JARs were added by the user to the SparkContext
-  var jarServer: HttpServer = null
-
-  // URIs of JARs to pass to executor
-  var jarUris: String = ""
-  
-  // Create an ExecutorInfo for our tasks
-  val executorInfo = createExecutorInfo()
+  // An ExecutorInfo for our tasks
+  var executorInfo: ExecutorInfo = null
 
-  // Listener object to pass upcalls into
-  var listener: TaskSchedulerListener = null
-
-  val mapOutputTracker = SparkEnv.get.mapOutputTracker
-
-  override def setListener(listener: TaskSchedulerListener) { 
-    this.listener = listener
-  }
-
-  def newTaskId(): TaskID = {
-    val id = TaskID.newBuilder().setValue("" + nextTaskId).build()
-    nextTaskId += 1
-    return id
-  }
-  
   override def start() {
-    new Thread("MesosScheduler driver") {
-      setDaemon(true)
-      override def run() {
-        val sched = MesosScheduler.this
-        val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(frameworkName).build()
-        driver = new MesosSchedulerDriver(sched, fwInfo, master)
-        try {
-          val ret = driver.run()
-          logInfo("driver.run() returned with code " + ret)
-        } catch {
-          case e: Exception => logError("driver.run() failed", e)
-        }
-      }
-    }.start()
-    if (System.getProperty("spark.speculation", "false") == "true") {
-      new Thread("MesosScheduler speculation check") {
+    synchronized {
+      new Thread("MesosScheduler driver") {
         setDaemon(true)
+
         override def run() {
-          waitForRegister()
-          while (true) {
-            try {
-              Thread.sleep(SPECULATION_INTERVAL)
-            } catch { case e: InterruptedException => {} }
-            checkSpeculatableTasks()
+          val sched = MesosScheduler.this
+          val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(frameworkName).build()
+          driver = new MesosSchedulerDriver(sched, fwInfo, master)
+          try {
+            val ret = driver.run()
+            logInfo("driver.run() returned with code " + ret)
+          } catch {
+            case e: Exception => logError("driver.run() failed", e)
           }
         }
       }.start()
+
+      executorInfo = createExecutorInfo()
+      waitForRegister()
     }
   }
 
@@ -141,11 +85,7 @@ class MesosScheduler(
         path
       case None =>
         throw new SparkException("Spark home is not set; set it through the spark.home system " +
-            "property, the SPARK_HOME environment variable or the SparkContext constructor")
-    }
-    // If the user added JARs to the SparkContext, create an HTTP server to ship them to executors
-    if (sc.jars.size > 0) {
-      createJarServer()
+          "property, the SPARK_HOME environment variable or the SparkContext constructor")
     }
     val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
     val environment = Environment.newBuilder()
@@ -173,30 +113,27 @@ class MesosScheduler(
       .addResources(memory)
       .build()
   }
-  
-  def submitTasks(taskSet: TaskSet) {
-    val tasks = taskSet.tasks
-    logInfo("Adding task set " + taskSet.id + " with " + tasks.size + " tasks")
-    waitForRegister()
-    this.synchronized {
-      val manager = new TaskSetManager(this, taskSet)
-      activeTaskSets(taskSet.id) = manager
-      activeTaskSetsQueue += manager
-      taskSetTaskIds(taskSet.id) = new HashSet()
-    }
-    reviveOffers()
-  }
-  
-  def taskSetFinished(manager: TaskSetManager) {
-    this.synchronized {
-      activeTaskSets -= manager.taskSet.id
-      activeTaskSetsQueue -= manager
-      taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
-      taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
-      taskSetTaskIds.remove(manager.taskSet.id)
+
+  /**
+   * Create and serialize the executor argument to pass to Mesos. Our executor arg is an array
+   * containing all the spark.* system properties in the form of (String, String) pairs.
+   */
+  private def createExecArg(): Array[Byte] = {
+    val props = new HashMap[String, String]
+    val iterator = System.getProperties.entrySet.iterator
+    while (iterator.hasNext) {
+      val entry = iterator.next
+      val (key, value) = (entry.getKey.toString, entry.getValue.toString)
+      if (key.startsWith("spark.")) {
+        props(key) = value
+      }
     }
+    // Serialize the map as an array of (String, String) pairs
+    return Utils.serialize(props.toArray)
   }
 
+  override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+
   override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
     logInfo("Registered as framework ID " + frameworkId.getValue)
     registeredLock.synchronized {
@@ -204,8 +141,8 @@ class MesosScheduler(
       registeredLock.notifyAll()
     }
   }
-  
-  override def waitForRegister() {
+
+  def waitForRegister() {
     registeredLock.synchronized {
       while (!isRegistered) {
         registeredLock.wait()
@@ -218,229 +155,128 @@ class MesosScheduler(
   override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
 
   /**
-   * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets 
+   * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
    * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
    * tasks are balanced across the cluster.
    */
   override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
     synchronized {
-      // Mark each slave as alive and remember its hostname
-      for (o <- offers) {
-        slaveIdToHost(o.getSlaveId.getValue) = o.getHostname
-        hostsAlive += o.getHostname
-      }
-      // Build a list of tasks to assign to each slave
-      val tasks = offers.map(o => new JArrayList[MTaskInfo])
-      val availableCpus = offers.map(o => getResource(o.getResourcesList(), "cpus"))
-      val enoughMem = offers.map(o => {
-        val mem = getResource(o.getResourcesList(), "mem")
+      // Build a big list of the offerable workers, and remember their indices so that we can
+      // figure out which Offer to reply to for each worker
+      val offerableIndices = new ArrayBuffer[Int]
+      val offerableWorkers = new ArrayBuffer[WorkerOffer]
+
+      def enoughMemory(o: Offer) = {
+        val mem = getResource(o.getResourcesList, "mem")
         val slaveId = o.getSlaveId.getValue
         mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
-      })
-      var launchedTask = false
-      for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
-        do {
-          launchedTask = false
-          for (i <- 0 until offers.size if enoughMem(i)) {
-            val sid = offers(i).getSlaveId.getValue
-            val host = offers(i).getHostname
-            manager.slaveOffer(sid, host, availableCpus(i)) match {
-              case Some(task) => 
-                tasks(i).add(task)
-                val tid = task.getTaskId.getValue
-                taskIdToTaskSetId(tid) = manager.taskSet.id
-                taskSetTaskIds(manager.taskSet.id) += tid
-                taskIdToSlaveId(tid) = sid
-                slaveIdsWithExecutors += sid
-                availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
-                launchedTask = true
-                
-              case None => {}
-            }
+      }
+
+      for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
+        offerableIndices += index
+        offerableWorkers += new WorkerOffer(
+          offer.getSlaveId.getValue,
+          offer.getHostname,
+          getResource(offer.getResourcesList, "cpus").toInt)
+      }
+
+      // Call into the ClusterScheduler
+      val taskLists = scheduler.resourceOffers(offerableWorkers)
+
+      // Build a list of Mesos tasks for each slave
+      val mesosTasks = offers.map(o => Collections.emptyList[MesosTaskInfo]())
+      for ((taskList, index) <- taskLists.zipWithIndex) {
+        if (!taskList.isEmpty) {
+          val offerNum = offerableIndices(index)
+          mesosTasks(offerNum) = new JArrayList[MesosTaskInfo](taskList.size)
+          for (taskDesc <- taskList) {
+            taskIdToSlaveId(taskDesc.taskId) = offers(offerNum).getSlaveId.getValue
+            mesosTasks(offerNum).add(createMesosTask(taskDesc, offers(offerNum).getSlaveId))
           }
-        } while (launchedTask)
+        }
       }
+
+      // Reply to the offers
       val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
       for (i <- 0 until offers.size) {
-        d.launchTasks(offers(i).getId(), tasks(i), filters)
+        d.launchTasks(offers(i).getId, mesosTasks(i), filters)
       }
     }
   }
 
-  // Helper function to pull out a resource from a Mesos Resources protobuf
+  /** Helper function to pull out a resource from a Mesos Resources protobuf */
   def getResource(res: JList[Resource], name: String): Double = {
     for (r <- res if r.getName == name) {
       return r.getScalar.getValue
     }
-    
+    // If we reached here, no resource with the required name was present
     throw new IllegalArgumentException("No resource called " + name + " in " + res)
   }
 
-  // Check whether a Mesos task state represents a finished task
-  def isFinished(state: TaskState) = {
-    state == TaskState.TASK_FINISHED ||
-    state == TaskState.TASK_FAILED ||
-    state == TaskState.TASK_KILLED ||
-    state == TaskState.TASK_LOST
+  /** Turn a Spark TaskDescription into a Mesos task */
+  def createMesosTask(task: TaskDescription, slaveId: SlaveID): MesosTaskInfo = {
+    val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build()
+    val cpuResource = Resource.newBuilder()
+      .setName("cpus")
+      .setType(Value.Type.SCALAR)
+      .setScalar(Value.Scalar.newBuilder().setValue(1).build())
+      .build()
+    return MesosTaskInfo.newBuilder()
+      .setTaskId(taskId)
+      .setSlaveId(slaveId)
+      .setExecutor(executorInfo)
+      .setName(task.name)
+      .addResources(cpuResource)
+      .setData(ByteString.copyFrom(task.serializedTask))
+      .build()
+  }
+
+  /** Check whether a Mesos task state represents a finished task */
+  def isFinished(state: MesosTaskState) = {
+    state == MesosTaskState.TASK_FINISHED ||
+      state == MesosTaskState.TASK_FAILED ||
+      state == MesosTaskState.TASK_KILLED ||
+      state == MesosTaskState.TASK_LOST
   }
 
   override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
-    val tid = status.getTaskId.getValue
-    var taskSetToUpdate: Option[TaskSetManager] = None
-    var failedHost: Option[String] = None
-    var taskFailed = false
+    val tid = status.getTaskId.getValue.toLong
+    val state = TaskState.fromMesos(status.getState)
     synchronized {
-      try {
-        if (status.getState == TaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
-          // We lost the executor on this slave, so remember that it's gone
-          val slaveId = taskIdToSlaveId(tid)
-          val host = slaveIdToHost(slaveId)
-          if (hostsAlive.contains(host)) {
-            slaveIdsWithExecutors -= slaveId
-            hostsAlive -= host
-            activeTaskSetsQueue.foreach(_.hostLost(host))
-            failedHost = Some(host)
-          }
-        }
-        taskIdToTaskSetId.get(tid) match {
-          case Some(taskSetId) =>
-            if (activeTaskSets.contains(taskSetId)) {
-              //activeTaskSets(taskSetId).statusUpdate(status)
-              taskSetToUpdate = Some(activeTaskSets(taskSetId))
-            }
-            if (isFinished(status.getState)) {
-              taskIdToTaskSetId.remove(tid)
-              if (taskSetTaskIds.contains(taskSetId)) {
-                taskSetTaskIds(taskSetId) -= tid
-              }
-              taskIdToSlaveId.remove(tid)
-            }
-            if (status.getState == TaskState.TASK_FAILED) {
-              taskFailed = true
-            }
-          case None =>
-            logInfo("Ignoring update from TID " + tid + " because its task set is gone")
-        }
-      } catch {
-        case e: Exception => logError("Exception in statusUpdate", e)
+      if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
+        // We lost the executor on this slave, so remember that it's gone
+        slaveIdsWithExecutors -= taskIdToSlaveId(tid)
+      }
+      if (isFinished(status.getState)) {
+        taskIdToSlaveId.remove(tid)
       }
     }
-    // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
-    if (taskSetToUpdate != None) {
-      taskSetToUpdate.get.statusUpdate(status)
-    }
-    if (failedHost != None) {
-      listener.hostLost(failedHost.get)
-      reviveOffers()
-    }
-    if (taskFailed) {
-      // Also revive offers if a task had failed for some reason other than host lost
-      reviveOffers()
-    }
+    scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer)
   }
 
   override def error(d: SchedulerDriver, message: String) {
     logError("Mesos error: " + message)
-    synchronized {
-      if (activeTaskSets.size > 0) {
-        // Have each task set throw a SparkException with the error
-        for ((taskSetId, manager) <- activeTaskSets) {
-          try {
-            manager.error(message)
-          } catch {
-            case e: Exception => logError("Exception in error callback", e)
-          }
-        }
-      } else {
-        // No task sets are active but we still got an error. Just exit since this
-        // must mean the error is during registration.
-        // It might be good to do something smarter here in the future.
-        System.exit(1)
-      }
-    }
+    scheduler.error(message)
   }
 
   override def stop() {
     if (driver != null) {
       driver.stop()
     }
-    if (jarServer != null) {
-      jarServer.stop()
-    }
-  }
-
-  // TODO: query Mesos for number of cores
-  override def defaultParallelism() =
-    System.getProperty("spark.default.parallelism", "8").toInt
-
-  // Create a server for all the JARs added by the user to SparkContext.
-  // We first copy the JARs to a temp directory for easier server setup.
-  private def createJarServer() {
-    val jarDir = Utils.createTempDir()
-    logInfo("Temp directory for JARs: " + jarDir)
-    val filenames = ArrayBuffer[String]()
-    // Copy each JAR to a unique filename in the jarDir
-    for ((path, index) <- sc.jars.zipWithIndex) {
-      val file = new File(path)
-      if (file.exists) {
-        val filename = index + "_" + file.getName
-        copyFile(file, new File(jarDir, filename))
-        filenames += filename
-      }
-    }
-    // Create the server
-    jarServer = new HttpServer(jarDir)
-    jarServer.start()
-    // Build up the jar URI list
-    val serverUri = jarServer.uri
-    jarUris = filenames.map(f => serverUri + "/" + f).mkString(",")
-    logInfo("JAR server started at " + serverUri)
-  }
-
-  // Copy a file on the local file system
-  private def copyFile(source: File, dest: File) {
-    val in = new FileInputStream(source)
-    val out = new FileOutputStream(dest)
-    Utils.copyStream(in, out, true)
   }
 
-  // Create and serialize the executor argument to pass to Mesos.
-  // Our executor arg is an array containing all the spark.* system properties
-  // in the form of (String, String) pairs.
-  private def createExecArg(): Array[Byte] = {
-    val props = new HashMap[String, String]
-    val iter = System.getProperties.entrySet.iterator
-    while (iter.hasNext) {
-      val entry = iter.next
-      val (key, value) = (entry.getKey.toString, entry.getValue.toString)
-      if (key.startsWith("spark.")) {
-        props(key) = value
-      }
-    }
-    // Set spark.jar.uris to our JAR URIs, regardless of system property
-    props("spark.jar.uris") = jarUris
-    // Serialize the map as an array of (String, String) pairs
-    return Utils.serialize(props.toArray)
+  override def reviveOffers() {
+    driver.reviveOffers()
   }
 
   override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
 
-  override def slaveLost(d: SchedulerDriver, s: SlaveID) {
-    var failedHost: Option[String] = None
+  override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
+    logInfo("Mesos slave lost: " + slaveId.getValue)
     synchronized {
-      val slaveId = s.getValue
-      val host = slaveIdToHost(slaveId)
-      if (hostsAlive.contains(host)) {
-        slaveIdsWithExecutors -= slaveId
-        hostsAlive -= host
-        activeTaskSetsQueue.foreach(_.hostLost(host))
-        failedHost = Some(host)
-      }
-    }
-    if (failedHost != None) {
-      listener.hostLost(failedHost.get)
-      reviveOffers()
+      slaveIdsWithExecutors -= slaveId.getValue
     }
+    scheduler.slaveLost(slaveId.toString)
   }
 
   override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) {
@@ -448,22 +284,6 @@ class MesosScheduler(
     slaveLost(d, s)
   }
 
-  override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
-
-  // Check for speculatable tasks in all our active jobs.
-  def checkSpeculatableTasks() {
-    var shouldRevive = false
-    synchronized {
-      for (ts <- activeTaskSetsQueue) {
-        shouldRevive |= ts.checkSpeculatableTasks()
-      }
-    }
-    if (shouldRevive) {
-      reviveOffers()
-    }
-  }
-
-  def reviveOffers() {
-    driver.reviveOffers()
-  }
+  // TODO: query Mesos for number of cores
+  override def defaultParallelism() = System.getProperty("spark.default.parallelism", "8").toInt
 }
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 0a807f0582..15131960d6 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -61,7 +61,7 @@ class BlockLocker(numLockers: Int) {
   private val hashLocker = Array.fill(numLockers)(new Object())
   
   def getLock(blockId: String): Object = {
-    return hashLocker(Math.abs(blockId.hashCode % numLockers))
+    return hashLocker(math.abs(blockId.hashCode % numLockers))
   }
 }
 
@@ -312,7 +312,7 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging
     // wait for and gather all the remote blocks
     for ((cmId, future) <- remoteBlockFutures) {
       var count = 0
-      val oneBlockId = remoteBlockIdsPerLocation(new BlockManagerId(cmId.host, cmId.port)).first
+      val oneBlockId = remoteBlockIdsPerLocation(new BlockManagerId(cmId.host, cmId.port)).head
       future() match {
         case Some(message) => {
           val bufferMessage = message.asInstanceOf[BufferMessage]
diff --git a/repl/src/main/scala/spark/repl/SparkILoop.scala b/repl/src/main/scala/spark/repl/SparkILoop.scala
index 935790a091..faece8baa4 100644
--- a/repl/src/main/scala/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/spark/repl/SparkILoop.scala
@@ -819,7 +819,6 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
         spark.repl.Main.interp.out.println("Creating SparkContext...");
         spark.repl.Main.interp.out.flush();
         @transient val sc = spark.repl.Main.interp.createSparkContext();
-        sc.waitForRegister();
         spark.repl.Main.interp.out.println("Spark context available as sc.");
         spark.repl.Main.interp.out.flush();
         """)
diff --git a/spark-executor b/spark-executor
index 0f9b9b1ece..2d6934f7da 100755
--- a/spark-executor
+++ b/spark-executor
@@ -1,4 +1,4 @@
 #!/bin/sh
 FWDIR="`dirname $0`"
 echo "Running spark-executor with framework dir = $FWDIR"
-exec $FWDIR/run spark.Executor
+exec $FWDIR/run spark.executor.MesosExecutorRunner
-- 
GitLab