diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala new file mode 100644 index 0000000000000000000000000000000000000000..e6ad4dd28edd10fe4168833c247ebb529783fd87 --- /dev/null +++ b/core/src/main/scala/spark/HttpFileServer.scala @@ -0,0 +1,45 @@ +package spark + +import java.io.{File, PrintWriter} +import java.net.URL +import scala.collection.mutable.HashMap +import org.apache.hadoop.fs.FileUtil + +class HttpFileServer extends Logging { + + var baseDir : File = null + var fileDir : File = null + var jarDir : File = null + var httpServer : HttpServer = null + var serverUri : String = null + + def initialize() { + baseDir = Utils.createTempDir() + fileDir = new File(baseDir, "files") + jarDir = new File(baseDir, "jars") + fileDir.mkdir() + jarDir.mkdir() + logInfo("HTTP File server directory is " + baseDir) + httpServer = new HttpServer(fileDir) + httpServer.start() + serverUri = httpServer.uri + } + + def stop() { + httpServer.stop() + } + + def addFile(file: File) : String = { + return addFileToDir(file, fileDir) + } + + def addJar(file: File) : String = { + return addFileToDir(file, jarDir) + } + + def addFileToDir(file: File, dir: File) : String = { + Utils.copyFile(file, new File(dir, file.getName)) + return dir + "/" + file.getName + } + +} \ No newline at end of file diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 0dec44979f6030c69a2e3b90c4bd0a3e90a031fa..758c42fa61f17d0bdb37c550f475d3f3205939bb 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -2,14 +2,15 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger +import java.net.{URI, URLClassLoader} import akka.actor.Actor import akka.actor.Actor._ -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.generic.Growable -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.SequenceFileInputFormat @@ -77,7 +78,14 @@ class SparkContext( true, isLocal) SparkEnv.set(env) - + + // Used to store a URL for each static file/jar together with the file's local timestamp + val addedFiles = HashMap[String, Long]() + val addedJars = HashMap[String, Long]() + + // Add each JAR given through the constructor + jars.foreach { addJar(_) } + // Create and start the scheduler private var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format @@ -91,13 +99,13 @@ class SparkContext( master match { case "local" => - new LocalScheduler(1, 0) + new LocalScheduler(1, 0, this) case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0) + new LocalScheduler(threads.toInt, 0, this) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt) + new LocalScheduler(threads.toInt, maxFailures.toInt, this) case SPARK_REGEX(sparkUrl) => val scheduler = new ClusterScheduler(this) @@ -132,7 +140,7 @@ class SparkContext( taskScheduler.start() private var dagScheduler = new DAGScheduler(taskScheduler) - + // Methods for creating RDDs def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = { @@ -321,7 +329,44 @@ class SparkContext( // Keep around a weak hash map of values to Cached versions? def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal) + + // Adds a file dependency to all Tasks executed in the future. + def addFile(path: String) { + val uri = new URI(path) + val key = uri.getScheme match { + case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) + case _ => path + } + addedFiles(key) = System.currentTimeMillis + + // Fetch the file locally in case the task is executed locally + val filename = new File(path.split("/").last) + Utils.fetchFile(path, new File(".")) + + logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) + } + def clearFiles() { + addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } + addedFiles.clear() + } + + // Adds a jar dependency to all Tasks executed in the future. + def addJar(path: String) { + val uri = new URI(path) + val key = uri.getScheme match { + case null | "file" => env.httpFileServer.addJar(new File(uri.getPath)) + case _ => path + } + addedJars(key) = System.currentTimeMillis + logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) + } + + def clearJars() { + addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } + addedJars.clear() + } + // Stop the SparkContext def stop() { dagScheduler.stop() @@ -329,6 +374,9 @@ class SparkContext( taskScheduler = null // TODO: Cache.stop()? env.stop() + // Clean up locally linked files + clearFiles() + clearJars() SparkEnv.set(null) ShuffleMapTask.clearCache() logInfo("Successfully stopped SparkContext") diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index add8fcec51e65e174b0dc28b5b3e9357fa85b7d2..a95d1bc8ea8cb9a1f7fd2b6f00d0fe88d0bce81b 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -19,15 +19,17 @@ class SparkEnv ( val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, val blockManager: BlockManager, - val connectionManager: ConnectionManager + val connectionManager: ConnectionManager, + val httpFileServer: HttpFileServer ) { /** No-parameter constructor for unit tests. */ def this() = { - this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null) + this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null, null) } def stop() { + httpFileServer.stop() mapOutputTracker.stop() cacheTracker.stop() shuffleFetcher.stop() @@ -95,7 +97,11 @@ object SparkEnv { System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") val shuffleFetcher = Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher] - + + val httpFileServer = new HttpFileServer() + httpFileServer.initialize() + System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + /* if (System.getProperty("spark.stream.distributed", "false") == "true") { val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]] @@ -126,6 +132,7 @@ object SparkEnv { shuffleManager, broadcastManager, blockManager, - connectionManager) + connectionManager, + httpFileServer) } } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 5eda1011f9f7cb42b093341ca4dd02719fdcfe6c..07aa18e5404cc9411421316bb5ab13b13b324707 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,18 +1,19 @@ package spark import java.io._ -import java.net.InetAddress +import java.net.{InetAddress, URL, URI} +import java.util.{Locale, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} - +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import scala.collection.mutable.ArrayBuffer import scala.util.Random -import java.util.{Locale, UUID} import scala.io.Source /** * Various utility methods used by Spark. */ -object Utils { +object Utils extends Logging { /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -115,6 +116,54 @@ object Utils { val out = new FileOutputStream(dest) copyStream(in, out, true) } + + + + /* Download a file from a given URL to the local filesystem */ + def downloadFile(url: URL, localPath: String) { + val in = url.openStream() + val out = new FileOutputStream(localPath) + Utils.copyStream(in, out, true) + } + + /** + * Download a file requested by the executor. Supports fetching the file in a variety of ways, + * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + */ + def fetchFile(url: String, targetDir: File) { + val filename = url.split("/").last + val targetFile = new File(targetDir, filename) + val uri = new URI(url) + uri.getScheme match { + case "http" | "https" | "ftp" => + logInfo("Fetching " + url + " to " + targetFile) + val in = new URL(url).openStream() + val out = new FileOutputStream(targetFile) + Utils.copyStream(in, out, true) + case "file" | null => + // Remove the file if it already exists + targetFile.delete() + // Symlink the file locally + logInfo("Symlinking " + url + " to " + targetFile) + FileUtil.symLink(url, targetFile.toString) + case _ => + // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others + val uri = new URI(url) + val conf = new Configuration() + val fs = FileSystem.get(uri, conf) + val in = fs.open(new Path(uri)) + val out = new FileOutputStream(targetFile) + Utils.copyStream(in, out, true) + } + // Decompress the file if it's a .tar or .tar.gz + if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xzf", filename), targetDir) + } else if (filename.endsWith(".tar")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xf", filename), targetDir) + } + } /** * Shuffle the elements of a collection into a random order, returning the diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 1740a42a7eff69bffd4ccbe3cfe881d534a8793f..704336102019cf706c011911a508d33865fd70a6 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -65,38 +65,6 @@ class ExecutorRunner( } } - /** - * Download a file requested by the executor. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. - */ - def fetchFile(url: String, targetDir: File) { - val filename = url.split("/").last - val targetFile = new File(targetDir, filename) - if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) { - // Use the java.net library to fetch it - logInfo("Fetching " + url + " to " + targetFile) - val in = new URL(url).openStream() - val out = new FileOutputStream(targetFile) - Utils.copyStream(in, out, true) - } else { - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val uri = new URI(url) - val conf = new Configuration() - val fs = FileSystem.get(uri, conf) - val in = fs.open(new Path(uri)) - val out = new FileOutputStream(targetFile) - Utils.copyStream(in, out, true) - } - // Decompress the file if it's a .tar or .tar.gz - if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xzf", filename), targetDir) - } else if (filename.endsWith(".tar")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xf", filename), targetDir) - } - } - /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{SLAVEID}}" => workerId diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index dba209ac2726febf78e774ef1cc8b7fd6e7903a3..8f975c52d4ccc7416e8334838555e8126b67f0e9 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -1,10 +1,12 @@ package spark.executor import java.io.{File, FileOutputStream} -import java.net.{URL, URLClassLoader} +import java.net.{URI, URL, URLClassLoader} import java.util.concurrent._ -import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.FileUtil + +import scala.collection.mutable.{ArrayBuffer, Map, HashMap} import spark.broadcast._ import spark.scheduler._ @@ -15,9 +17,13 @@ import java.nio.ByteBuffer * The Mesos executor for Spark. */ class Executor extends Logging { - var classLoader: ClassLoader = null + var urlClassLoader : ExecutorURLClassLoader = null var threadPool: ExecutorService = null var env: SparkEnv = null + + val fileSet: HashMap[String, Long] = new HashMap[String, Long]() + val jarSet: HashMap[String, Long] = new HashMap[String, Long]() + val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) @@ -36,13 +42,14 @@ class Executor extends Logging { env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false) SparkEnv.set(env) - // Create our ClassLoader (using spark properties) and set it on this thread - classLoader = createClassLoader() - Thread.currentThread.setContextClassLoader(classLoader) - // Start worker thread pool threadPool = new ThreadPoolExecutor( 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + + // Create our ClassLoader and set it on this thread + urlClassLoader = createClassLoader() + Thread.currentThread.setContextClassLoader(urlClassLoader) + } def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { @@ -54,15 +61,16 @@ class Executor extends Logging { override def run() { SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) + Thread.currentThread.setContextClassLoader(urlClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + taskId) context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) try { SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear() - val task = ser.deserialize[Task[Any]](serializedTask, classLoader) + val task = ser.deserialize[Task[Any]](serializedTask, urlClassLoader) + task.downloadDependencies(fileSet, jarSet) + updateClassLoader() logInfo("Its generation is " + task.generation) env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(taskId.toInt) @@ -96,25 +104,16 @@ class Executor extends Logging { * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path */ - private def createClassLoader(): ClassLoader = { - var loader = this.getClass.getClassLoader - - // If any JAR URIs are given through spark.jar.uris, fetch them to the - // current directory and put them all on the classpath. We assume that - // each URL has a unique file name so that no local filenames will clash - // in this process. This is guaranteed by ClusterScheduler. - val uris = System.getProperty("spark.jar.uris", "") - val localFiles = ArrayBuffer[String]() - for (uri <- uris.split(",").filter(_.size > 0)) { - val url = new URL(uri) - val filename = url.getPath.split("/").last - downloadFile(url, filename) - localFiles += filename - } - if (localFiles.size > 0) { - val urls = localFiles.map(f => new File(f).toURI.toURL).toArray - loader = new URLClassLoader(urls, loader) - } + private def createClassLoader(): ExecutorURLClassLoader = { + + var loader = this.getClass().getClassLoader() + + // For each of the jars in the jarSet, add them to the class loader. + // We assume each of the files has already been fetched. + val urls = jarSet.keySet.map { uri => + new File(uri.split("/").last).toURI.toURL + }.toArray + loader = new URLClassLoader(urls, loader) // If the REPL is in use, add another ClassLoader that will read // new classes defined by the REPL as the user types code @@ -133,13 +132,25 @@ class Executor extends Logging { } } - return loader + return new ExecutorURLClassLoader(Array(), loader) } - // Download a file from a given URL to the local filesystem - private def downloadFile(url: URL, localPath: String) { - val in = url.openStream() - val out = new FileOutputStream(localPath) - Utils.copyStream(in, out, true) + def updateClassLoader() { + val currentURLs = urlClassLoader.getURLs() + val urlSet = jarSet.keySet.map { x => new File(x.split("/").last).toURI.toURL } + urlSet.filterNot(currentURLs.contains(_)).foreach { url => + logInfo("Adding " + url + " to the class loader.") + urlClassLoader.addURL(url) + } + } + + // The addURL method in URLClassLoader is protected. We subclass it to make it accessible. + class ExecutorURLClassLoader(urls : Array[URL], parent : ClassLoader) + extends URLClassLoader(urls, parent) { + override def addURL(url: URL) { + super.addURL(url) + } + } + } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index a281ae94c5e5ad5cfc60ca6eab410524c2ae01c4..b9f0a0d6d0c7f294ea6eb955d8daa74990faa056 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -1,10 +1,10 @@ package spark.scheduler import java.io._ -import java.util.HashMap +import java.util.{HashMap => JHashMap} import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConversions._ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -20,7 +20,9 @@ object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new HashMap[Int, Array[Byte]] + val serializedInfoCache = new JHashMap[Int, Array[Byte]] + val fileSetCache = new JHashMap[Int, Array[Byte]] + val jarSetCache = new JHashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = { synchronized { @@ -40,6 +42,23 @@ object ShuffleMapTask { } } + // Since both the JarSet and FileSet have the same format this is used for both. + def serializeFileSet(set : HashMap[String, Long], stageId: Int, cache : JHashMap[Int, Array[Byte]]) : Array[Byte] = { + val old = cache.get(stageId) + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + objOut.writeObject(set.toArray) + objOut.close() + val bytes = out.toByteArray + cache.put(stageId, bytes) + return bytes + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = { synchronized { val loader = Thread.currentThread.getContextClassLoader @@ -54,9 +73,19 @@ object ShuffleMapTask { } } + // Since both the JarSet and FileSet have the same format this is used for both. + def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = { + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val objIn = new ObjectInputStream(in) + val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap + return (HashMap(set.toSeq: _*)) + } + def clearCache() { synchronized { serializedInfoCache.clear() + fileSetCache.clear() + jarSetCache.clear() } } } @@ -84,6 +113,14 @@ class ShuffleMapTask( val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) out.writeInt(bytes.length) out.write(bytes) + + val fileSetBytes = ShuffleMapTask.serializeFileSet(fileSet, stageId, ShuffleMapTask.fileSetCache) + out.writeInt(fileSetBytes.length) + out.write(fileSetBytes) + val jarSetBytes = ShuffleMapTask.serializeFileSet(jarSet, stageId, ShuffleMapTask.jarSetCache) + out.writeInt(jarSetBytes.length) + out.write(jarSetBytes) + out.writeInt(partition) out.writeLong(generation) out.writeObject(split) @@ -97,6 +134,17 @@ class ShuffleMapTask( val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) rdd = rdd_ dep = dep_ + + val fileSetNumBytes = in.readInt() + val fileSetBytes = new Array[Byte](fileSetNumBytes) + in.readFully(fileSetBytes) + fileSet = ShuffleMapTask.deserializeFileSet(fileSetBytes) + + val jarSetNumBytes = in.readInt() + val jarSetBytes = new Array[Byte](jarSetNumBytes) + in.readFully(jarSetBytes) + jarSet = ShuffleMapTask.deserializeFileSet(jarSetBytes) + partition = in.readInt() generation = in.readLong() split = in.readObject().asInstanceOf[Split] @@ -110,7 +158,7 @@ class ShuffleMapTask( val bucketIterators = if (aggregator.mapSideCombine) { // Apply combiners (map-side aggregation) to the map output. - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) + val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any]) for (elem <- rdd.iterator(split)) { val (k, v) = elem.asInstanceOf[(Any, Any)] val bucketId = partitioner.getPartition(k) diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala index f84d8d9c4f5c904ba959e4b6a5cbcd25fa531a1b..0d5b71b06c1b7383a40055f1d87223200c28df83 100644 --- a/core/src/main/scala/spark/scheduler/Task.scala +++ b/core/src/main/scala/spark/scheduler/Task.scala @@ -1,5 +1,10 @@ package spark.scheduler +import scala.collection.mutable.{HashMap} +import spark.HttpFileServer +import spark.Utils +import java.io.File + /** * A task to execute on a worker node. */ @@ -8,4 +13,30 @@ abstract class Task[T](val stageId: Int) extends Serializable { def preferredLocations: Seq[String] = Nil var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler. + + // Stores jar and file dependencies for this task. + var fileSet : HashMap[String, Long] = new HashMap[String, Long]() + var jarSet : HashMap[String, Long] = new HashMap[String, Long]() + + // Downloads all file dependencies from the Master file server + def downloadDependencies(currentFileSet : HashMap[String, Long], + currentJarSet : HashMap[String, Long]) { + + // Fetch missing file dependencies + fileSet.filter { case(k,v) => + !currentFileSet.contains(k) || currentFileSet(k) <= v + }.foreach { case (k,v) => + Utils.fetchFile(k, new File(System.getProperty("user.dir"))) + currentFileSet(k) = v + } + // Fetch missing jar dependencies + jarSet.filter { case(k,v) => + !currentJarSet.contains(k) || currentJarSet(k) <= v + }.foreach { case (k,v) => + Utils.fetchFile(k, new File(System.getProperty("user.dir"))) + currentJarSet(k) = v + } + + } + } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 5b59479682f2b48cef2d93da39f75021b1552ca9..750231ac31e74c6f4e0a1080191d5f21451992a3 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -60,7 +60,6 @@ class ClusterScheduler(sc: SparkContext) def initialize(context: SchedulerBackend) { backend = context - createJarServer() } def newTaskId(): Long = nextTaskId.getAndIncrement() @@ -88,6 +87,10 @@ class ClusterScheduler(sc: SparkContext) def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks + tasks.foreach { task => + task.fileSet ++= sc.addedFiles + task.jarSet ++= sc.addedJars + } logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = new TaskSetManager(this, taskSet) @@ -235,32 +238,7 @@ class ClusterScheduler(sc: SparkContext) } override def defaultParallelism() = backend.defaultParallelism() - - // Create a server for all the JARs added by the user to SparkContext. - // We first copy the JARs to a temp directory for easier server setup. - private def createJarServer() { - val jarDir = Utils.createTempDir() - logInfo("Temp directory for JARs: " + jarDir) - val filenames = ArrayBuffer[String]() - // Copy each JAR to a unique filename in the jarDir - for ((path, index) <- sc.jars.zipWithIndex) { - val file = new File(path) - if (file.exists) { - val filename = index + "_" + file.getName - Utils.copyFile(file, new File(jarDir, filename)) - filenames += filename - } - } - // Create the server - jarServer = new HttpServer(jarDir) - jarServer.start() - // Build up the jar URI list - val serverUri = jarServer.uri - jarUris = filenames.map(f => serverUri + "/" + f).mkString(",") - System.setProperty("spark.jar.uris", jarUris) - logInfo("JAR server started at " + serverUri) - } - + // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index eb47988f0cdfb000936c17a13985dae317147e4d..65078b026e0c2b90a204284c9db9c9a35bbeece0 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -1,7 +1,10 @@ package spark.scheduler.local +import java.io.File +import java.net.URLClassLoader import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.HashMap import spark._ import spark.scheduler._ @@ -11,15 +14,17 @@ import spark.scheduler._ * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ -class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging { +class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends TaskScheduler with Logging { var attemptId = new AtomicInteger(0) var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) val env = SparkEnv.get var listener: TaskSchedulerListener = null - + val fileSet: HashMap[String, Long] = new HashMap[String, Long]() + val jarSet: HashMap[String, Long] = new HashMap[String, Long]() + // TODO: Need to take into account stage priority in scheduling - override def start() {} + override def start() { } override def setListener(listener: TaskSchedulerListener) { this.listener = listener @@ -30,6 +35,8 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with val failCount = new Array[Int](tasks.size) def submitTask(task: Task[_], idInJob: Int) { + task.fileSet ++= sc.addedFiles + task.jarSet ++= sc.addedJars val myAttemptId = attemptId.getAndIncrement() threadPool.submit(new Runnable { def run() { @@ -42,6 +49,9 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with logInfo("Running task " + idInJob) // Set the Spark execution environment for the worker thread SparkEnv.set(env) + task.downloadDependencies(fileSet, jarSet) + // Create a new classLaoder for the downloaded JARs + Thread.currentThread.setContextClassLoader(createClassLoader()) try { // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. @@ -81,9 +91,19 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with } } + override def stop() { threadPool.shutdownNow() } + private def createClassLoader() : ClassLoader = { + val currentLoader = Thread.currentThread.getContextClassLoader() + val urls = jarSet.keySet.map { uri => + new File(uri.split("/").last).toURI.toURL + }.toArray + logInfo("Creating ClassLoader with jars: " + urls.mkString) + return new URLClassLoader(urls, currentLoader) + } + override def defaultParallelism() = threads } diff --git a/core/src/test/resources/uncommons-maths-1.2.2.jar b/core/src/test/resources/uncommons-maths-1.2.2.jar new file mode 100644 index 0000000000000000000000000000000000000000..e126001c1c270aa1b149970d07c53dbe0d13514e Binary files /dev/null and b/core/src/test/resources/uncommons-maths-1.2.2.jar differ diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..500af1eb902bd61a1d492ff1b412738fe51cd99c --- /dev/null +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -0,0 +1,93 @@ +package spark + +import com.google.common.io.Files +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import java.io.{File, PrintWriter} +import SparkContext._ + +class FileServerSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + var tmpFile : File = _ + var testJarFile : File = _ + + before { + // Create a sample text file + val tmpdir = new File(Files.createTempDir(), "test") + tmpdir.mkdir() + tmpFile = new File(tmpdir, "FileServerSuite.txt") + val pw = new PrintWriter(tmpFile) + pw.println("100") + pw.close() + } + + after { + if (sc != null) { + sc.stop() + sc = null + } + // Clean up downloaded file + if (tmpFile.exists) { + tmpFile.delete() + } + } + + test("Distributing files locally") { + sc = new SparkContext("local[4]", "test") + sc.addFile(tmpFile.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val in = new java.io.BufferedReader(new java.io.FileReader(tmpFile)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect + println(result) + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + + test ("Dynamically adding JARS locally") { + sc = new SparkContext("local[4]", "test") + val sampleJarFile = getClass().getClassLoader().getResource("uncommons-maths-1.2.2.jar").getFile() + sc.addJar(sampleJarFile) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) + val result = sc.parallelize(testData).reduceByKey { (x,y) => + val fac = Thread.currentThread.getContextClassLoader().loadClass("org.uncommons.maths.Maths").getDeclaredMethod("factorial", classOf[Int]) + val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + a + b + }.collect() + assert(result.toSet === Set((1,2), (2,7), (3,121))) + } + + test("Distributing files on a standalone cluster") { + sc = new SparkContext("local-cluster[1,1,512]", "test") + sc.addFile(tmpFile.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val in = new java.io.BufferedReader(new java.io.FileReader(tmpFile)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect + println(result) + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + + + test ("Dynamically adding JARS on a standalone cluster") { + sc = new SparkContext("local-cluster[1,1,512]", "test") + val sampleJarFile = getClass().getClassLoader().getResource("uncommons-maths-1.2.2.jar").getFile() + sc.addJar(sampleJarFile) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) + val result = sc.parallelize(testData).reduceByKey { (x,y) => + val fac = Thread.currentThread.getContextClassLoader().loadClass("org.uncommons.maths.Maths").getDeclaredMethod("factorial", classOf[Int]) + val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + a + b + }.collect() + assert(result.toSet === Set((1,2), (2,7), (3,121))) + } + +} \ No newline at end of file