diff --git a/README.md b/README.md index 28ad1e4604d1497d8302b48e31bb57bfca2a9797..456b8060ef3280ab077cbcb3f7dcc17a8f67a6bf 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ described below. When developing a Spark application, specify the Hadoop version by adding the "hadoop-client" artifact to your project's dependencies. For example, if you're -using Hadoop 1.0.1 and build your application using SBT, add this entry to +using Hadoop 1.2.1 and build your application using SBT, add this entry to `libraryDependencies`: "org.apache.hadoop" % "hadoop-client" % "1.2.1" diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e8ff4da475a9822bd96c13876a3b70384b4ddc0a..10db2fa7e7c3b34b5c8882ed61ada1bb33e15049 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.generic.Growable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -144,6 +144,14 @@ class SparkContext( executorEnvs ++= environment } + // Set SPARK_USER for user who is running SparkContext. + val sparkUser = Option { + Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER")) + }.getOrElse { + SparkContext.SPARK_UNKNOWN_USER + } + executorEnvs("SPARK_USER") = sparkUser + // Create and start the scheduler private[spark] var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format @@ -255,8 +263,10 @@ class SparkContext( conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { - conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) + Utils.getSystemProperties.foreach { case (key, value) => + if (key.startsWith("spark.hadoop.")) { + conf.set(key.substring("spark.hadoop.".length), value) + } } val bufferSize = System.getProperty("spark.buffer.size", "65536") conf.set("io.file.buffer.size", bufferSize) @@ -270,6 +280,12 @@ class SparkContext( override protected def childValue(parent: Properties): Properties = new Properties(parent) } + private[spark] def getLocalProperties(): Properties = localProperties.get() + + private[spark] def setLocalProperties(props: Properties) { + localProperties.set(props) + } + def initLocalProperties() { localProperties.set(new Properties()) } @@ -291,7 +307,7 @@ class SparkContext( /** Set a human readable description of the current job. */ @deprecated("use setJobGroup", "0.8.1") def setJobDescription(value: String) { - setJobGroup("", value) + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) } /** @@ -589,7 +605,8 @@ class SparkContext( val uri = new URI(path) val key = uri.getScheme match { case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case _ => path + case "local" => "file:" + uri.getPath + case _ => path } addedFiles(key) = System.currentTimeMillis @@ -791,11 +808,10 @@ class SparkContext( val cleanedFunc = clean(func) logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, + dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() - result } /** @@ -977,6 +993,8 @@ object SparkContext { private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" + private[spark] val SPARK_UNKNOWN_USER = "<unknown>" + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 6bc846aa92eb4440bd0c61a20d9f7a2aaf2467b7..fc1537f7963c44b4adb8e2e72c8786b92e6aaf70 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,16 +17,39 @@ package org.apache.spark.deploy +import java.security.PrivilegedExceptionAction + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.UserGroupInformation -import org.apache.spark.SparkException +import org.apache.spark.{SparkContext, SparkException} /** * Contains util methods to interact with Hadoop from Spark. */ private[spark] class SparkHadoopUtil { + val conf = newConfiguration() + UserGroupInformation.setConfiguration(conf) + + def runAsUser(user: String)(func: () => Unit) { + // if we are already running as the user intended there is no reason to do the doAs. It + // will actually break secure HDFS access as it doesn't fill in the credentials. Also if + // the user is UNKNOWN then we shouldn't be creating a remote unknown user + // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only + // in SparkContext. + val currentUser = Option(System.getProperty("user.name")). + getOrElse(SparkContext.SPARK_UNKNOWN_USER) + if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) { + val ugi = UserGroupInformation.createRemoteUser(user) + ugi.doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) + } else { + func() + } + } /** * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop @@ -42,9 +65,9 @@ class SparkHadoopUtil { def isYarnMode(): Boolean = { false } } - + object SparkHadoopUtil { - private val hadoop = { + private val hadoop = { val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) if (yarnMode) { try { @@ -56,7 +79,7 @@ object SparkHadoopUtil { new SparkHadoopUtil } } - + def get: SparkHadoopUtil = { hadoop } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 8fabc956659015c4df1796a918779a954b81705d..fff9cb60c78498b2643af10a311c63b3b85607bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -104,7 +104,7 @@ private[spark] class ExecutorRunner( // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ - command.arguments.map(substituteVariables) + (command.arguments ++ Seq(appId)).map(substituteVariables) } /** diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 80ff4c59cb484f33e3148d95cc0c9796d71fdaab..caee6b01ab1fa57ae5edb33ab5716cc88fbacf29 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -111,7 +111,7 @@ private[spark] object CoarseGrainedExecutorBackend { def main(args: Array[String]) { if (args.length < 4) { - //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors + //the reason we allow the last appid argument is to make it easy to kill rogue executors System.err.println( "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " + "[<appid>]") diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index b773346df305d598a54ebe17be18b7626f4d3025..5c9bb9db1ce9e9269f394f045a67980826db3c69 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -25,8 +25,9 @@ import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap -import org.apache.spark.scheduler._ import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util.Utils @@ -129,6 +130,8 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER) + def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { val tr = new TaskRunner(context, taskId, serializedTask) runningTasks.put(taskId, tr) @@ -176,7 +179,7 @@ private[spark] class Executor( } } - override def run() { + override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () => val startTime = System.currentTimeMillis() SparkEnv.set(env) Thread.currentThread.setContextClassLoader(replClassLoader) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 32901a508f53b34d44c01fc631e9927680bd312a..47e958b5e6f4bfdb4380ce81d6fb206bb9d04f87 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -132,6 +132,8 @@ class HadoopRDD[K, V]( override def getPartitions: Array[Partition] = { val jobConf = getJobConf() + // add the credentials here as this can be called before SparkContext initialized + SparkHadoopUtil.get.addCredentials(jobConf) val inputFormat = getInputFormat(jobConf) if (inputFormat.isInstanceOf[Configurable]) { inputFormat.asInstanceOf[Configurable].setConf(jobConf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala index c7d12952153644a401daaa154afeb0fa1cdd42fc..c5d7ca0481c8071c9b6aca74d9ad32e5019922b9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala @@ -25,6 +25,8 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet +import akka.util.duration._ + import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode @@ -127,21 +129,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = f backend.start() if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) { - new Thread("TaskScheduler speculation check") { - setDaemon(true) - - override def run() { - logInfo("Starting speculative execution thread") - while (true) { - try { - Thread.sleep(SPECULATION_INTERVAL) - } catch { - case e: InterruptedException => {} - } - checkSpeculatableTasks() - } - } - }.start() + logInfo("Starting speculative execution thread") + + sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, + SPECULATION_INTERVAL milliseconds) { + checkSpeculatableTasks() + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4cef0825dd6c0aab711df8a58700bd37fb91c0e0..d0b21e896e812c483ecc4bb17a70f1b000f75a83 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -417,15 +417,14 @@ class DAGScheduler( case ExecutorLost(execId) => handleExecutorLost(execId) - case begin: BeginEvent => - listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo)) + case BeginEvent(task, taskInfo) => + listenerBus.post(SparkListenerTaskStart(task, taskInfo)) - case gettingResult: GettingResultEvent => - listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo)) + case GettingResultEvent(task, taskInfo) => + listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo)) - case completion: CompletionEvent => - listenerBus.post(SparkListenerTaskEnd( - completion.task, completion.reason, completion.taskInfo, completion.taskMetrics)) + case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => + listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics)) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 12b0d74fb5346867bdafd27bb3dd877bf277b489..60927831a159a7d4e2b092b97775ac3005ed1c1e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -1,280 +1,384 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import java.io.PrintWriter -import java.io.File -import java.io.FileNotFoundException -import java.text.SimpleDateFormat -import java.util.{Date, Properties} -import java.util.concurrent.LinkedBlockingQueue - -import scala.collection.mutable.{HashMap, ListBuffer} - -import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.executor.TaskMetrics - -/** - * A logger class to record runtime information for jobs in Spark. This class outputs one log file - * per Spark job with information such as RDD graph, tasks start/stop, shuffle information. - * - * @param logDirName The base directory for the log files. - */ -class JobLogger(val logDirName: String) extends SparkListener with Logging { - - private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark") - - private val jobIDToPrintWriter = new HashMap[Int, PrintWriter] - private val stageIDToJobID = new HashMap[Int, Int] - private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]] - private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents] - - createLogDir() - def this() = this(String.valueOf(System.currentTimeMillis())) - - // The following 5 functions are used only in testing. - private[scheduler] def getLogDir = logDir - private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter - private[scheduler] def getStageIDToJobID = stageIDToJobID - private[scheduler] def getJobIDToStages = jobIDToStages - private[scheduler] def getEventQueue = eventQueue - - // Create a folder for log files, the folder's name is the creation time of the jobLogger - protected def createLogDir() { - val dir = new File(logDir + "/" + logDirName + "/") - if (!dir.exists() && !dir.mkdirs()) { - logError("Error creating log directory: " + logDir + "/" + logDirName + "/") - } - } - - // Create a log file for one job, the file name is the jobID - protected def createLogWriter(jobID: Int) { - try { - val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID) - jobIDToPrintWriter += (jobID -> fileWriter) - } catch { - case e: FileNotFoundException => e.printStackTrace() - } - } - - // Close log file, and clean the stage relationship in stageIDToJobID - protected def closeLogWriter(jobID: Int) = - jobIDToPrintWriter.get(jobID).foreach { fileWriter => - fileWriter.close() - jobIDToStages.get(jobID).foreach(_.foreach{ stage => - stageIDToJobID -= stage.id - }) - jobIDToPrintWriter -= jobID - jobIDToStages -= jobID - } - - // Write log information to log file, withTime parameter controls whether to recored - // time stamp for the information - protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) { - var writeInfo = info - if (withTime) { - val date = new Date(System.currentTimeMillis()) - writeInfo = DATE_FORMAT.format(date) + ": " +info - } - jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo)) - } - - protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) = - stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime)) - - protected def buildJobDep(jobID: Int, stage: Stage) { - if (stage.jobId == jobID) { - jobIDToStages.get(jobID) match { - case Some(stageList) => stageList += stage - case None => val stageList = new ListBuffer[Stage] - stageList += stage - jobIDToStages += (jobID -> stageList) - } - stageIDToJobID += (stage.id -> jobID) - stage.parents.foreach(buildJobDep(jobID, _)) - } - } - - protected def recordStageDep(jobID: Int) { - def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = { - var rddList = new ListBuffer[RDD[_]] - rddList += rdd - rdd.dependencies.foreach { - case shufDep: ShuffleDependency[_, _] => - case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd) - } - rddList - } - jobIDToStages.get(jobID).foreach {_.foreach { stage => - var depRddDesc: String = "" - getRddsInStage(stage.rdd).foreach { rdd => - depRddDesc += rdd.id + "," - } - var depStageDesc: String = "" - stage.parents.foreach { stage => - depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")" - } - jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" + - depRddDesc.substring(0, depRddDesc.length - 1) + ")" + - " STAGE_DEP=" + depStageDesc, false) - } - } - } - - // Generate indents and convert to String - protected def indentString(indent: Int) = { - val sb = new StringBuilder() - for (i <- 1 to indent) { - sb.append(" ") - } - sb.toString() - } - - protected def getRddName(rdd: RDD[_]) = { - var rddName = rdd.getClass.getName - if (rdd.name != null) { - rddName = rdd.name - } - rddName - } - - protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) { - val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")" - jobLogInfo(jobID, indentString(indent) + rddInfo, false) - rdd.dependencies.foreach { - case shufDep: ShuffleDependency[_, _] => - val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId - jobLogInfo(jobID, indentString(indent + 1) + depInfo, false) - case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1) - } - } - - protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) { - val stageInfo = if (stage.isShuffleMap) { - "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId - } else { - "STAGE_ID=" + stage.id + " RESULT_STAGE" - } - if (stage.jobId == jobID) { - jobLogInfo(jobID, indentString(indent) + stageInfo, false) - recordRddInStageGraph(jobID, stage.rdd, indent) - stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2)) - } else { - jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false) - } - } - - // Record task metrics into job log files - protected def recordTaskMetrics(stageID: Int, status: String, - taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID + - " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + - " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname - val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime - val readMetrics = taskMetrics.shuffleReadMetrics match { - case Some(metrics) => - " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime + - " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + - " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + - " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + - " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead - case None => "" - } - val writeMetrics = taskMetrics.shuffleWriteMetrics match { - case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten - case None => "" - } - stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics) - } - - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { - stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( - stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks)) - } - - override def onStageCompleted(stageCompleted: StageCompleted) { - stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format( - stageCompleted.stage.stageId)) - } - - override def onTaskStart(taskStart: SparkListenerTaskStart) { } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val task = taskEnd.task - val taskInfo = taskEnd.taskInfo - var taskStatus = "" - task match { - case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK" - case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK" - } - taskEnd.reason match { - case Success => taskStatus += " STATUS=SUCCESS" - recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics) - case Resubmitted => - taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + - " STAGE_ID=" + task.stageId - stageLogInfo(task.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => - taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + - task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + - mapId + " REDUCE_ID=" + reduceId - stageLogInfo(task.stageId, taskStatus) - case OtherFailure(message) => - taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + - " STAGE_ID=" + task.stageId + " INFO=" + message - stageLogInfo(task.stageId, taskStatus) - case _ => - } - } - - override def onJobEnd(jobEnd: SparkListenerJobEnd) { - val job = jobEnd.job - var info = "JOB_ID=" + job.jobId - jobEnd.jobResult match { - case JobSucceeded => info += " STATUS=SUCCESS" - case JobFailed(exception, _) => - info += " STATUS=FAILED REASON=" - exception.getMessage.split("\\s+").foreach(info += _ + "_") - case _ => - } - jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase) - closeLogWriter(job.jobId) - } - - protected def recordJobProperties(jobID: Int, properties: Properties) { - if(properties != null) { - val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") - jobLogInfo(jobID, description, false) - } - } - - override def onJobStart(jobStart: SparkListenerJobStart) { - val job = jobStart.job - val properties = jobStart.properties - createLogWriter(job.jobId) - recordJobProperties(job.jobId, properties) - buildJobDep(job.jobId, job.finalStage) - recordStageDep(job.jobId) - recordStageDepGraph(job.jobId, job.finalStage) - jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED") - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.{IOException, File, FileNotFoundException, PrintWriter} +import java.text.SimpleDateFormat +import java.util.{Date, Properties} +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.StorageLevel + +/** + * A logger class to record runtime information for jobs in Spark. This class outputs one log file + * for each Spark job, containing RDD graph, tasks start/stop, shuffle information. + * JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext + * after the SparkContext is created. + * Note that each JobLogger only works for one SparkContext + * @param logDirName The base directory for the log files. + */ + +class JobLogger(val user: String, val logDirName: String) + extends SparkListener with Logging { + + def this() = this(System.getProperty("user.name", "<unknown>"), + String.valueOf(System.currentTimeMillis())) + + private val logDir = + if (System.getenv("SPARK_LOG_DIR") != null) + System.getenv("SPARK_LOG_DIR") + else + "/tmp/spark-%s".format(user) + + private val jobIDToPrintWriter = new HashMap[Int, PrintWriter] + private val stageIDToJobID = new HashMap[Int, Int] + private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]] + private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents] + + createLogDir() + + // The following 5 functions are used only in testing. + private[scheduler] def getLogDir = logDir + private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter + private[scheduler] def getStageIDToJobID = stageIDToJobID + private[scheduler] def getJobIDToStages = jobIDToStages + private[scheduler] def getEventQueue = eventQueue + + /** Create a folder for log files, the folder's name is the creation time of jobLogger */ + protected def createLogDir() { + val dir = new File(logDir + "/" + logDirName + "/") + if (dir.exists()) { + return + } + if (dir.mkdirs() == false) { + // JobLogger should throw a exception rather than continue to construct this object. + throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/") + } + } + + /** + * Create a log file for one job + * @param jobID ID of the job + * @exception FileNotFoundException Fail to create log file + */ + protected def createLogWriter(jobID: Int) { + try { + val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID) + jobIDToPrintWriter += (jobID -> fileWriter) + } catch { + case e: FileNotFoundException => e.printStackTrace() + } + } + + /** + * Close log file, and clean the stage relationship in stageIDToJobID + * @param jobID ID of the job + */ + protected def closeLogWriter(jobID: Int) { + jobIDToPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + jobIDToStages.get(jobID).foreach(_.foreach{ stage => + stageIDToJobID -= stage.id + }) + jobIDToPrintWriter -= jobID + jobIDToStages -= jobID + } + } + + /** + * Write info into log file + * @param jobID ID of the job + * @param info Info to be recorded + * @param withTime Controls whether to record time stamp before the info, default is true + */ + protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) { + var writeInfo = info + if (withTime) { + val date = new Date(System.currentTimeMillis()) + writeInfo = DATE_FORMAT.format(date) + ": " +info + } + jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo)) + } + + /** + * Write info into log file + * @param stageID ID of the stage + * @param info Info to be recorded + * @param withTime Controls whether to record time stamp before the info, default is true + */ + protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) { + stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime)) + } + + /** + * Build stage dependency for a job + * @param jobID ID of the job + * @param stage Root stage of the job + */ + protected def buildJobDep(jobID: Int, stage: Stage) { + if (stage.jobId == jobID) { + jobIDToStages.get(jobID) match { + case Some(stageList) => stageList += stage + case None => val stageList = new ListBuffer[Stage] + stageList += stage + jobIDToStages += (jobID -> stageList) + } + stageIDToJobID += (stage.id -> jobID) + stage.parents.foreach(buildJobDep(jobID, _)) + } + } + + /** + * Record stage dependency and RDD dependency for a stage + * @param jobID Job ID of the stage + */ + protected def recordStageDep(jobID: Int) { + def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = { + var rddList = new ListBuffer[RDD[_]] + rddList += rdd + rdd.dependencies.foreach { + case shufDep: ShuffleDependency[_, _] => + case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd) + } + rddList + } + jobIDToStages.get(jobID).foreach {_.foreach { stage => + var depRddDesc: String = "" + getRddsInStage(stage.rdd).foreach { rdd => + depRddDesc += rdd.id + "," + } + var depStageDesc: String = "" + stage.parents.foreach { stage => + depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")" + } + jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" + + depRddDesc.substring(0, depRddDesc.length - 1) + ")" + + " STAGE_DEP=" + depStageDesc, false) + } + } + } + + /** + * Generate indents and convert to String + * @param indent Number of indents + * @return string of indents + */ + protected def indentString(indent: Int): String = { + val sb = new StringBuilder() + for (i <- 1 to indent) { + sb.append(" ") + } + sb.toString() + } + + /** + * Get RDD's name + * @param rdd Input RDD + * @return String of RDD's name + */ + protected def getRddName(rdd: RDD[_]): String = { + var rddName = rdd.getClass.getSimpleName + if (rdd.name != null) { + rddName = rdd.name + } + rddName + } + + /** + * Record RDD dependency graph in a stage + * @param jobID Job ID of the stage + * @param rdd Root RDD of the stage + * @param indent Indent number before info + */ + protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) { + val rddInfo = + if (rdd.getStorageLevel != StorageLevel.NONE) { + "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " + + rdd.origin + " " + rdd.generator + } else { + "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " + + rdd.origin + " " + rdd.generator + } + jobLogInfo(jobID, indentString(indent) + rddInfo, false) + rdd.dependencies.foreach { + case shufDep: ShuffleDependency[_, _] => + val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId + jobLogInfo(jobID, indentString(indent + 1) + depInfo, false) + case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1) + } + } + + /** + * Record stage dependency graph of a job + * @param jobID Job ID of the stage + * @param stage Root stage of the job + * @param indent Indent number before info, default is 0 + */ + protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) { + val stageInfo = if (stage.isShuffleMap) { + "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId + } else { + "STAGE_ID=" + stage.id + " RESULT_STAGE" + } + if (stage.jobId == jobID) { + jobLogInfo(jobID, indentString(indent) + stageInfo, false) + if (!idSet.contains(stage.id)) { + idSet += stage.id + recordRddInStageGraph(jobID, stage.rdd, indent) + stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2)) + } + } else { + jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false) + } + } + + /** + * Record task metrics into job log files, including execution info and shuffle metrics + * @param stageID Stage ID of the task + * @param status Status info of the task + * @param taskInfo Task description info + * @param taskMetrics Task running metrics + */ + protected def recordTaskMetrics(stageID: Int, status: String, + taskInfo: TaskInfo, taskMetrics: TaskMetrics) { + val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID + + " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + + " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname + val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime + val readMetrics = taskMetrics.shuffleReadMetrics match { + case Some(metrics) => + " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime + + " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + + " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + + " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime + + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + case None => "" + } + val writeMetrics = taskMetrics.shuffleWriteMetrics match { + case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + case None => "" + } + stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics) + } + + /** + * When stage is submitted, record stage submit info + * @param stageSubmitted Stage submitted event + */ + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { + stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( + stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks)) + } + + /** + * When stage is completed, record stage completion status + * @param stageCompleted Stage completed event + */ + override def onStageCompleted(stageCompleted: StageCompleted) { + stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format( + stageCompleted.stage.stageId)) + } + + override def onTaskStart(taskStart: SparkListenerTaskStart) { } + + /** + * When task ends, record task completion status and metrics + * @param taskEnd Task end event + */ + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val task = taskEnd.task + val taskInfo = taskEnd.taskInfo + var taskStatus = "" + task match { + case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK" + case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK" + } + taskEnd.reason match { + case Success => taskStatus += " STATUS=SUCCESS" + recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics) + case Resubmitted => + taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + + " STAGE_ID=" + task.stageId + stageLogInfo(task.stageId, taskStatus) + case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + + task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + + mapId + " REDUCE_ID=" + reduceId + stageLogInfo(task.stageId, taskStatus) + case OtherFailure(message) => + taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + + " STAGE_ID=" + task.stageId + " INFO=" + message + stageLogInfo(task.stageId, taskStatus) + case _ => + } + } + + /** + * When job ends, recording job completion status and close log file + * @param jobEnd Job end event + */ + override def onJobEnd(jobEnd: SparkListenerJobEnd) { + val job = jobEnd.job + var info = "JOB_ID=" + job.jobId + jobEnd.jobResult match { + case JobSucceeded => info += " STATUS=SUCCESS" + case JobFailed(exception, _) => + info += " STATUS=FAILED REASON=" + exception.getMessage.split("\\s+").foreach(info += _ + "_") + case _ => + } + jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase) + closeLogWriter(job.jobId) + } + + /** + * Record job properties into job log file + * @param jobID ID of the job + * @param properties Properties of the job + */ + protected def recordJobProperties(jobID: Int, properties: Properties) { + if(properties != null) { + val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") + jobLogInfo(jobID, description, false) + } + } + + /** + * When job starts, record job property and stage graph + * @param jobStart Job start event + */ + override def onJobStart(jobStart: SparkListenerJobStart) { + val job = jobStart.job + val properties = jobStart.properties + createLogWriter(job.jobId) + recordJobProperties(job.jobId, properties) + buildJobDep(job.jobId, job.finalStage) + recordStageDep(job.jobId) + recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int]) + jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED") + } +} + diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 24d97da6eb6e0fdef25c0b78f5d6c427fc28cb76..1dc71a04282e52940e9297f284b480a9d046dad7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -146,26 +146,26 @@ private[spark] class ShuffleMapTask( metrics = Some(context.taskMetrics) val blockManager = SparkEnv.get.blockManager - var shuffle: ShuffleBlocks = null - var buckets: ShuffleWriterGroup = null + val shuffleBlockManager = blockManager.shuffleBlockManager + var shuffle: ShuffleWriterGroup = null + var success = false try { // Obtain all the block writers for shuffle blocks. val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) - shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) - buckets = shuffle.acquireWriters(partitionId) + shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) // Write the map output to its associated buckets. for (elem <- rdd.iterator(split, context)) { val pair = elem.asInstanceOf[Product2[Any, Any]] val bucketId = dep.partitioner.getPartition(pair._1) - buckets.writers(bucketId).write(pair) + shuffle.writers(bucketId).write(pair) } // Commit the writes. Get the size of each bucket block (total block size). var totalBytes = 0L var totalTime = 0L - val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => + val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => writer.commit() val size = writer.fileSegment().length totalBytes += size @@ -179,19 +179,20 @@ private[spark] class ShuffleMapTask( shuffleMetrics.shuffleWriteTime = totalTime metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) + success = true new MapStatus(blockManager.blockManagerId, compressedSizes) } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. - if (buckets != null) { - buckets.writers.foreach(_.revertPartialWrites()) + if (shuffle != null) { + shuffle.writers.foreach(_.revertPartialWrites()) } throw e } finally { // Release the writers back to the shuffle block manager. - if (shuffle != null && buckets != null) { - buckets.writers.foreach(_.close()) - shuffle.releaseWriters(buckets) + if (shuffle != null && shuffle.writers != null) { + shuffle.writers.foreach(_.close()) + shuffle.releaseWriters(success) } // Execute the callbacks on task completion. context.executeOnCompleteCallbacks() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f5548fc2dab128d7715d4a781890c7e52f0b7059..3bb715e7d0bfb26a80af873df9f8c2e5d1d651c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -88,8 +88,14 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - freeCores(executorId) += 1 - makeOffers(executorId) + if (executorActor.contains(executorId)) { + freeCores(executorId) += 1 + makeOffers(executorId) + } else { + // Ignoring the update since we don't know about the executor. + val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s" + logWarning(msg.format(taskId, state, sender, executorId)) + } } case ReviveOffers => @@ -176,7 +182,9 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) } - private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + private val timeout = { + Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + } def stopExecutors() { try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 40fdfcddb13acd0498e82e8a4c115986488140c2..cec02e945c31f310a01f09fd0608a18eeeaa27c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -33,6 +33,10 @@ private[spark] class SimrSchedulerBackend( val tmpPath = new Path(driverFilePath + "_tmp") val filePath = new Path(driverFilePath) + val uiFilePath = driverFilePath + "_ui" + val tmpUiPath = new Path(uiFilePath + "_tmp") + val uiPath = new Path(uiFilePath) + val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt override def start() { @@ -47,6 +51,8 @@ private[spark] class SimrSchedulerBackend( logInfo("Writing to HDFS file: " + driverFilePath) logInfo("Writing Akka address: " + driverUrl) + logInfo("Writing to HDFS file: " + uiFilePath) + logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress) // Create temporary file to prevent race condition where executors get empty driverUrl file val temp = fs.create(tmpPath, true) @@ -56,6 +62,12 @@ private[spark] class SimrSchedulerBackend( // "Atomic" rename fs.rename(tmpPath, filePath) + + // Write Spark UI Address to file + val uiTemp = fs.create(tmpUiPath, true) + uiTemp.writeUTF(sc.ui.appUIAddress) + uiTemp.close() + fs.rename(tmpUiPath, uiPath) } override def stop() { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 55b25f145ae0dc50466510f34277ceba9c32006c..e748c2275d589c6702576b72ac49038e86033c93 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,13 +27,17 @@ import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar} import org.apache.spark.{SerializableWritable, Logging} import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId} +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.storage._ /** - * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. + * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. */ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging { - private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + + private val bufferSize = { + System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + } def newKryoOutput() = new KryoOutput(bufferSize) @@ -42,21 +46,11 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging val kryo = instantiator.newKryo() val classLoader = Thread.currentThread.getContextClassLoader - val blockId = TestBlockId("1") - // Register some commonly used classes - val toRegister: Seq[AnyRef] = Seq( - ByteBuffer.allocate(1), - StorageLevel.MEMORY_ONLY, - PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), - GotBlock(blockId, ByteBuffer.allocate(1)), - GetBlock(blockId), - 1 to 10, - 1 until 10, - 1L to 10L, - 1L until 10L - ) - - for (obj <- toRegister) kryo.register(obj.getClass) + // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. + // Do this before we invoke the user registrator so the user registrator can override this. + kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) + + for (cls <- KryoSerializer.toRegister) kryo.register(cls) // Allow sending SerializableWritable kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) @@ -78,10 +72,6 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging new AllScalaRegistrar().apply(kryo) kryo.setClassLoader(classLoader) - - // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops - kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean) - kryo } @@ -165,3 +155,21 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ trait KryoRegistrator { def registerClasses(kryo: Kryo) } + +private[serializer] object KryoSerializer { + // Commonly used classes. + private val toRegister: Seq[Class[_]] = Seq( + ByteBuffer.allocate(1).getClass, + classOf[StorageLevel], + classOf[PutBlock], + classOf[GotBlock], + classOf[GetBlock], + classOf[MapStatus], + classOf[BlockManagerId], + classOf[Array[Byte]], + (1 to 10).getClass, + (1 until 10).getClass, + (1L to 10L).getClass, + (1L until 10L).getClass + ) +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala index dbe0bda61589ca6a2d8e138113e48e8acfdaa392..c8f397609a0b473fc6c8557b3edf7f155509b2bb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.storage import java.util.concurrent.ConcurrentHashMap -private[storage] trait BlockInfo { - def level: StorageLevel - def tellMaster: Boolean +private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { // To save space, 'pending' and 'failed' are encoded as special sizes: @volatile var size: Long = BlockInfo.BLOCK_PENDING private def pending: Boolean = size == BlockInfo.BLOCK_PENDING @@ -81,17 +79,3 @@ private object BlockInfo { private val BLOCK_PENDING: Long = -1L private val BLOCK_FAILED: Long = -2L } - -// All shuffle blocks have the same `level` and `tellMaster` properties, -// so we can save space by not storing them in each instance: -private[storage] class ShuffleBlockInfo extends BlockInfo { - // These need to be defined using 'def' instead of 'val' in order for - // the compiler to eliminate the fields: - def level: StorageLevel = StorageLevel.DISK_ONLY - def tellMaster: Boolean = false -} - -private[storage] class BlockInfoImpl(val level: StorageLevel, val tellMaster: Boolean) - extends BlockInfo { - // Intentionally left blank -} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 76d537f8e838a75ec9f6296e143e13011a46408d..a34c95b6f07b67e63fc8099ba79d695845a1b097 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{InputStream, OutputStream} +import java.io.{File, InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{HashMap, ArrayBuffer} @@ -47,7 +47,7 @@ private[spark] class BlockManager( extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) - val diskBlockManager = new DiskBlockManager( + val diskBlockManager = new DiskBlockManager(shuffleBlockManager, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -462,20 +462,10 @@ private[spark] class BlockManager( * This is currently used for writing shuffle files out. Callers should handle error * cases. */ - def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int) + def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) - val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true) - val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream) - writer.registerCloseEventHandler(() => { - if (shuffleBlockManager.consolidateShuffleFiles) { - diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment()) - } - val myInfo = new ShuffleBlockInfo() - blockInfo.put(blockId, myInfo) - myInfo.markReady(writer.fileSegment().length) - }) - writer + new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream) } /** @@ -505,7 +495,7 @@ private[spark] class BlockManager( // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. val myInfo = { - val tinfo = new BlockInfoImpl(level, tellMaster) + val tinfo = new BlockInfo(level, tellMaster) // Do atomically ! val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 32d2dd06943a0952f7a6763397cbb81000b17933..469e68fed74bb4effb3aa9efa11157859d047084 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -34,20 +34,12 @@ import org.apache.spark.serializer.{SerializationStream, Serializer} */ abstract class BlockObjectWriter(val blockId: BlockId) { - var closeEventHandler: () => Unit = _ - def open(): BlockObjectWriter - def close() { - closeEventHandler() - } + def close() def isOpen: Boolean - def registerCloseEventHandler(handler: () => Unit) { - closeEventHandler = handler - } - /** * Flush the partial writes and commit them as a single atomic block. Return the * number of bytes written for this commit. @@ -78,11 +70,11 @@ abstract class BlockObjectWriter(val blockId: BlockId) { /** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */ class DiskBlockObjectWriter( - blockId: BlockId, - file: File, - serializer: Serializer, - bufferSize: Int, - compressStream: OutputStream => OutputStream) + blockId: BlockId, + file: File, + serializer: Serializer, + bufferSize: Int, + compressStream: OutputStream => OutputStream) extends BlockObjectWriter(blockId) with Logging { @@ -111,8 +103,8 @@ class DiskBlockObjectWriter( private var fos: FileOutputStream = null private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null - private var initialPosition = 0L - private var lastValidPosition = 0L + private val initialPosition = file.length() + private var lastValidPosition = initialPosition private var initialized = false private var _timeWriting = 0L @@ -120,7 +112,6 @@ class DiskBlockObjectWriter( fos = new FileOutputStream(file, true) ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() - initialPosition = channel.position lastValidPosition = initialPosition bs = compressStream(new FastBufferedOutputStream(ts, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) @@ -147,8 +138,6 @@ class DiskBlockObjectWriter( ts = null objOut = null } - // Invoke the close callback handler. - super.close() } override def isOpen: Boolean = objOut != null diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index bcb58ad9467e6c8ff6fcf611ec570edaebb5c735..fcd2e9798295596b48966c4b4b2529f4730a7ea3 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -20,12 +20,11 @@ package org.apache.spark.storage import java.io.File import java.text.SimpleDateFormat import java.util.{Date, Random} -import java.util.concurrent.ConcurrentHashMap import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.network.netty.{PathResolver, ShuffleSender} -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.util.Utils /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -35,7 +34,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH * * @param rootDirs The directories to use for storing block files. Data will be hashed among these. */ -private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging { +private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String) + extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt @@ -47,54 +47,23 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) private var shuffleSender : ShuffleSender = null - // Stores only Blocks which have been specifically mapped to segments of files - // (rather than the default, which maps a Block to a whole file). - // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks. - private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment] - - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup) - addShutdownHook() - /** - * Creates a logical mapping from the given BlockId to a segment of a file. - * This will cause any accesses of the logical BlockId to be directed to the specified - * physical location. - */ - def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) { - blockToFileSegmentMap.put(blockId, fileSegment) - } - /** * Returns the phyiscal file segment in which the given BlockId is located. * If the BlockId has been mapped to a specific FileSegment, that will be returned. * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly. */ def getBlockLocation(blockId: BlockId): FileSegment = { - if (blockToFileSegmentMap.internalMap.containsKey(blockId)) { - blockToFileSegmentMap.get(blockId).get + if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) { + shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) } else { val file = getFile(blockId.name) new FileSegment(file, 0, file.length()) } } - /** - * Simply returns a File to place the given Block into. This does not physically create the file. - * If filename is given, that file will be used. Otherwise, we will use the BlockId to get - * a unique filename. - */ - def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = { - val actualFilename = if (filename == "") blockId.name else filename - val file = getFile(actualFilename) - if (!allowAppending && file.exists()) { - throw new IllegalStateException( - "Attempted to create file that already exists: " + actualFilename) - } - file - } - - private def getFile(filename: String): File = { + def getFile(filename: String): File = { // Figure out which local directory it hashes to, and which subdirectory in that val hash = Utils.nonNegativeHash(filename) val dirId = hash % localDirs.length @@ -119,6 +88,8 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit new File(subDir, filename) } + def getFile(blockId: BlockId): File = getFile(blockId.name) + private def createLocalDirs(): Array[File] = { logDebug("Creating local directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") @@ -151,10 +122,6 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit } } - private def cleanup(cleanupTime: Long) { - blockToFileSegmentMap.clearOldValues(cleanupTime) - } - private def addShutdownHook() { localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index a3c496f9e05c517f198510095471ab6623b40d22..5a1e7b44440fdac533ae6256ba61c33d70552b7d 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -44,7 +44,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage val bytes = _bytes.duplicate() logDebug("Attempting to put block " + blockId) val startTime = System.currentTimeMillis - val file = diskManager.createBlockFile(blockId, allowAppending = false) + val file = diskManager.getFile(blockId) val channel = new FileOutputStream(file).getChannel() while (bytes.remaining > 0) { channel.write(bytes) @@ -64,7 +64,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage logDebug("Attempting to write values for block " + blockId) val startTime = System.currentTimeMillis - val file = diskManager.createBlockFile(blockId, allowAppending = false) + val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) blockManager.dataSerializeStream(blockId, outputStream, values.iterator) val length = file.length diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 066e45a12b8c7a8e9784a42eba63c373e1b44378..2f1b049ce4839f631c737d2d6cc0f9947ac2c93c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -17,33 +17,45 @@ package org.apache.spark.storage +import java.io.File import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConversions._ + import org.apache.spark.serializer.Serializer +import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} +import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup -private[spark] -class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter]) +/** A group of writers for a ShuffleMapTask, one writer per reducer. */ +private[spark] trait ShuffleWriterGroup { + val writers: Array[BlockObjectWriter] -private[spark] -trait ShuffleBlocks { - def acquireWriters(mapId: Int): ShuffleWriterGroup - def releaseWriters(group: ShuffleWriterGroup) + /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ + def releaseWriters(success: Boolean) } /** - * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer - * per reducer. + * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file + * per reducer (this set of files is called a ShuffleFileGroup). * * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer - * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files, - * it releases them for another task. + * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle + * files, it releases them for another task. * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: * - shuffleId: The unique id given to the entire shuffle stage. * - bucketId: The id of the output partition (i.e., reducer id) * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a * time owns a particular fileId, and this id is returned to a pool when the task finishes. + * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) + * that specifies where in a given file the actual block data is located. + * + * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping + * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for + * each block stored in each file. In order to find the location of a shuffle block, we search the + * files within a ShuffleFileGroups associated with the block's reducer. */ private[spark] class ShuffleBlockManager(blockManager: BlockManager) { @@ -52,45 +64,152 @@ class ShuffleBlockManager(blockManager: BlockManager) { val consolidateShuffleFiles = System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean - var nextFileId = new AtomicInteger(0) - val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]() + private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + + /** + * Contains all the state related to a particular shuffle. This includes a pool of unused + * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + */ + private class ShuffleState() { + val nextFileId = new AtomicInteger(0) + val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + } + + type ShuffleId = Int + private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] + + private + val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup) - def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = { - new ShuffleBlocks { - // Get a group of writers for a map task. - override def acquireWriters(mapId: Int): ShuffleWriterGroup = { - val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 - val fileId = getUnusedFileId() - val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { + new ShuffleWriterGroup { + shuffleStates.putIfAbsent(shuffleId, new ShuffleState()) + private val shuffleState = shuffleStates(shuffleId) + private var fileGroup: ShuffleFileGroup = null + + val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { + fileGroup = getUnusedFileGroup() + Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - if (consolidateShuffleFiles) { - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.getDiskWriter(blockId, filename, serializer, bufferSize) - } else { - blockManager.getDiskWriter(blockId, blockId.name, serializer, bufferSize) + blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize) + } + } else { + Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) + val blockFile = blockManager.diskBlockManager.getFile(blockId) + blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize) + } + } + + override def releaseWriters(success: Boolean) { + if (consolidateShuffleFiles) { + if (success) { + val offsets = writers.map(_.fileSegment().offset) + fileGroup.recordMapOutput(mapId, offsets) } + recycleFileGroup(fileGroup) } - new ShuffleWriterGroup(mapId, fileId, writers) } - override def releaseWriters(group: ShuffleWriterGroup) { - recycleFileId(group.fileId) + private def getUnusedFileGroup(): ShuffleFileGroup = { + val fileGroup = shuffleState.unusedFileGroups.poll() + if (fileGroup != null) fileGroup else newFileGroup() + } + + private def newFileGroup(): ShuffleFileGroup = { + val fileId = shuffleState.nextFileId.getAndIncrement() + val files = Array.tabulate[File](numBuckets) { bucketId => + val filename = physicalFileName(shuffleId, bucketId, fileId) + blockManager.diskBlockManager.getFile(filename) + } + val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files) + shuffleState.allFileGroups.add(fileGroup) + fileGroup } - } - } - private def getUnusedFileId(): Int = { - val fileId = unusedFileIds.poll() - if (fileId == null) nextFileId.getAndIncrement() else fileId + private def recycleFileGroup(group: ShuffleFileGroup) { + shuffleState.unusedFileGroups.add(group) + } + } } - private def recycleFileId(fileId: Int) { - if (consolidateShuffleFiles) { - unusedFileIds.add(fileId) + /** + * Returns the physical file segment in which the given BlockId is located. + * This function should only be called if shuffle file consolidation is enabled, as it is + * an error condition if we don't find the expected block. + */ + def getBlockLocation(id: ShuffleBlockId): FileSegment = { + // Search all file groups associated with this shuffle. + val shuffleState = shuffleStates(id.shuffleId) + for (fileGroup <- shuffleState.allFileGroups) { + val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId) + if (segment.isDefined) { return segment.get } } + throw new IllegalStateException("Failed to find shuffle block: " + id) } private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) } + + private def cleanup(cleanupTime: Long) { + shuffleStates.clearOldValues(cleanupTime) + } +} + +private[spark] +object ShuffleBlockManager { + /** + * A group of shuffle files, one per reducer. + * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. + */ + private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { + /** + * Stores the absolute index of each mapId in the files of this group. For instance, + * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. + */ + private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() + + /** + * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file. + * This ordering allows us to compute block lengths by examining the following block offset. + * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every + * reducer. + */ + private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { + new PrimitiveVector[Long]() + } + + def numBlocks = mapIdToIndex.size + + def apply(bucketId: Int) = files(bucketId) + + def recordMapOutput(mapId: Int, offsets: Array[Long]) { + mapIdToIndex(mapId) = numBlocks + for (i <- 0 until offsets.length) { + blockOffsetsByReducer(i) += offsets(i) + } + } + + /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ + def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { + val file = files(reducerId) + val blockOffsets = blockOffsetsByReducer(reducerId) + val index = mapIdToIndex.getOrElse(mapId, -1) + if (index >= 0) { + val offset = blockOffsets(index) + val length = + if (index + 1 < numBlocks) { + blockOffsets(index + 1) - offset + } else { + file.length() - offset + } + assert(length >= 0) + Some(new FileSegment(file, offset, length)) + } else { + None + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala index 7dcadc380542cbe93413a339d8573f835b633546..1e4db4f66bd2c42cc3a6db41857469043a07e647 100644 --- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala +++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala @@ -38,19 +38,19 @@ object StoragePerfTester { val blockManager = sc.env.blockManager def writeOutputBytes(mapId: Int, total: AtomicLong) = { - val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits, + val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, new KryoSerializer()) - val buckets = shuffle.acquireWriters(mapId) + val writers = shuffle.writers for (i <- 1 to recordsPerMap) { - buckets.writers(i % numOutputSplits).write(writeData) + writers(i % numOutputSplits).write(writeData) } - buckets.writers.map {w => + writers.map {w => w.commit() total.addAndGet(w.fileSegment().length) w.close() } - shuffle.releaseWriters(buckets) + shuffle.releaseWriters(true) } val start = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 35b5d5fd59534938d632bd271d883d75fd5fc499..c1c7aa70e6c92385766eff5b0bb864de4a7a66ef 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -152,6 +152,22 @@ private[spark] class StagePage(parent: JobProgressUI) { else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) + var shuffleReadSortable: String = "" + var shuffleReadReadable: String = "" + if (shuffleRead) { + shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString() + shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => + Utils.bytesToString(s.remoteBytesRead)}.getOrElse("") + } + + var shuffleWriteSortable: String = "" + var shuffleWriteReadable: String = "" + if (shuffleWrite) { + shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString() + shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("") + } + <tr> <td>{info.index}</td> <td>{info.taskId}</td> @@ -166,14 +182,17 @@ private[spark] class StagePage(parent: JobProgressUI) { {if (gcTime > 0) parent.formatDuration(gcTime) else ""} </td> {if (shuffleRead) { - <td>{metrics.flatMap{m => m.shuffleReadMetrics}.map{s => - Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td> + <td sorttable_customkey={shuffleReadSortable}> + {shuffleReadReadable} + </td> }} {if (shuffleWrite) { - <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td> - <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td> + <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")} + </td> + <td sorttable_customkey={shuffleWriteSortable}> + {shuffleWriteReadable} + </td> }} <td>{exception.map(e => <span> diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index d7d0441c388fab65ee5e6978d481964c6bdf703b..9ad6de3c6d8de79c758f1d0764b1a171bd56012c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -79,11 +79,14 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr case None => "Unknown" } - val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match { + val shuffleReadSortable = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) + val shuffleRead = shuffleReadSortable match { case 0 => "" case b => Utils.bytesToString(b) } - val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match { + + val shuffleWriteSortable = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) + val shuffleWrite = shuffleWriteSortable match { case 0 => "" case b => Utils.bytesToString(b) } @@ -119,8 +122,8 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr <td class="progress-cell"> {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)} </td> - <td>{shuffleRead}</td> - <td>{shuffleWrite}</td> + <td sorttable_customekey={shuffleReadSortable.toString}>{shuffleRead}</td> + <td sorttable_customekey={shuffleWriteSortable.toString}>{shuffleWrite}</td> </tr> } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 3f963727d98ddd3aca90c8bb327e410dceb6f546..67a7f87a5ca6e40bdb254ebee8c61b6e459c856e 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -59,7 +59,7 @@ object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") { val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, - SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value + SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value type MetadataCleanerType = Value diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fd2811e44c9b118a057725b29043df88d1becc3f..fe932d8ede2f3a480eb5b0f2ce2eddce853d8a89 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,13 +18,12 @@ package org.apache.spark.util import java.io._ -import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} +import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address} import java.util.{Locale, Random, UUID} -import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import java.util.regex.Pattern +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} import scala.collection.Map -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.io.Source @@ -36,7 +35,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer -import org.apache.spark.{SparkEnv, SparkException, Logging} +import org.apache.spark.{SparkException, Logging} /** @@ -148,7 +147,7 @@ private[spark] object Utils extends Logging { return buf } - private val shutdownDeletePaths = new collection.mutable.HashSet[String]() + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { @@ -818,4 +817,10 @@ private[spark] object Utils extends Logging { // Nothing else to guard against ? hashAbs } + + /** Returns a copy of the system properties that is thread-safe to iterator over. */ + def getSystemProperties(): Map[String, String] = { + return System.getProperties().clone() + .asInstanceOf[java.util.Properties].toMap[String, String] + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala new file mode 100644 index 0000000000000000000000000000000000000000..a1a452315d1437d35ff674b496371224edd98dea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A simple, fixed-size bit set implementation. This implementation is fast because it avoids + * safety/bound checking. + */ +class BitSet(numBits: Int) { + + private[this] val words = new Array[Long](bit2words(numBits)) + private[this] val numWords = words.length + + /** + * Sets the bit at the specified index to true. + * @param index the bit index + */ + def set(index: Int) { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + words(index >> 6) |= bitmask // div by 64 and mask + } + + /** + * Return the value of the bit with the specified index. The value is true if the bit with + * the index is currently set in this BitSet; otherwise, the result is false. + * + * @param index the bit index + * @return the value of the bit with the specified index + */ + def get(index: Int): Boolean = { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + (words(index >> 6) & bitmask) != 0 // div by 64 and mask + } + + /** Return the number of bits set to true in this BitSet. */ + def cardinality(): Int = { + var sum = 0 + var i = 0 + while (i < numWords) { + sum += java.lang.Long.bitCount(words(i)) + i += 1 + } + sum + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then -1 is returned. + * + * To iterate over the true bits in a BitSet, use the following loop: + * + * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) { + * // operate on index i here + * } + * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + def nextSetBit(fromIndex: Int): Int = { + var wordIndex = fromIndex >> 6 + if (wordIndex >= numWords) { + return -1 + } + + // Try to find the next set bit in the current word + val subIndex = fromIndex & 0x3f + var word = words(wordIndex) >> subIndex + if (word != 0) { + return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word) + } + + // Find the next set bit in the rest of the words + wordIndex += 1 + while (wordIndex < numWords) { + word = words(wordIndex) + if (word != 0) { + return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word) + } + wordIndex += 1 + } + + -1 + } + + /** Return the number of longs it would take to hold numBits. */ + private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..80545c9688aa603cd3cd84f263ad0f4e54fa0b97 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, + * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less + * space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + protected var _keySet = new OpenHashSet[K](initialCapacity) + + // Init in constructor (instead of in declaration) to work around a Scala compiler specialization + // bug that would generate two arrays (one for Object and one for specialized T). + private var _values: Array[V] = _ + _values = new Array[V](_keySet.capacity) + + @transient private var _oldValues: Array[V] = null + + // Treat the null key differently so we can use nulls in "data" to represent empty items. + private var haveNullValue = false + private var nullValue: V = null.asInstanceOf[V] + + override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + if (k == null) { + nullValue + } else { + val pos = _keySet.getPos(k) + if (pos < 0) { + null.asInstanceOf[V] + } else { + _values(pos) + } + } + } + + /** Set the value for a key */ + def update(k: K, v: V) { + if (k == null) { + haveNullValue = true + nullValue = v + } else { + val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + if (k == null) { + if (haveNullValue) { + nullValue = mergeValue(nullValue) + } else { + haveNullValue = true + nullValue = defaultValue + } + nullValue + } else { + val pos = _keySet.addWithoutResize(k) + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = -1 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + pos += 1 + return (null.asInstanceOf[K], nullValue) + } + pos += 1 + } + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala new file mode 100644 index 0000000000000000000000000000000000000000..4592e4f939e5c570b0a19abccd26430e35e2143c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never + * removed. + * + * The underlying implementation uses Scala compiler's specialization to generate optimized + * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet + * while incurring much less memory overhead. This can serve as building blocks for higher level + * data structures such as an optimized HashMap. + * + * This OpenHashSet is designed to serve as building blocks for higher level data structures + * such as an optimized hash map. Compared with standard hash set implementations, this class + * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to + * retrieve the position of a key in the underlying array. + * + * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed + * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). + */ +private[spark] +class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( + initialCapacity: Int, + loadFactor: Double) + extends Serializable { + + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + require(loadFactor < 1.0, "Load factor must be less than 1.0") + require(loadFactor > 0.0, "Load factor must be greater than 0.0") + + import OpenHashSet._ + + def this(initialCapacity: Int) = this(initialCapacity, 0.7) + + def this() = this(64) + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + + protected val hasher: Hasher[T] = { + // It would've been more natural to write the following using pattern matching. But Scala 2.9.x + // compiler has a bug when specialization is used together with this pattern matching, and + // throws: + // scala.tools.nsc.symtab.Types$TypeError: type mismatch; + // found : scala.reflect.AnyValManifest[Long] + // required: scala.reflect.ClassManifest[Int] + // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298) + // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207) + // ... + val mt = classManifest[T] + if (mt == ClassManifest.Long) { + (new LongHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassManifest.Int) { + (new IntHasher).asInstanceOf[Hasher[T]] + } else { + new Hasher[T] + } + } + + protected var _capacity = nextPowerOf2(initialCapacity) + protected var _mask = _capacity - 1 + protected var _size = 0 + + protected var _bitset = new BitSet(_capacity) + + // Init of the array in constructor (instead of in declaration) to work around a Scala compiler + // specialization bug that would generate two arrays (one for Object and one for specialized T). + protected var _data: Array[T] = _ + _data = new Array[T](_capacity) + + /** Number of elements in the set. */ + def size: Int = _size + + /** The capacity of the set (i.e. size of the underlying array). */ + def capacity: Int = _capacity + + /** Return true if this set contains the specified element. */ + def contains(k: T): Boolean = getPos(k) != INVALID_POS + + /** + * Add an element to the set. If the set is over capacity after the insertion, grow the set + * and rehash all elements. + */ + def add(k: T) { + addWithoutResize(k) + rehashIfNeeded(k, grow, move) + } + + /** + * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. + * The caller is responsible for calling rehashIfNeeded. + * + * Use (retval & POSITION_MASK) to get the actual position, and + * (retval & EXISTENCE_MASK) != 0 for prior existence. + * + * @return The position where the key is placed, plus the highest order bit is set if the key + * exists previously. + */ + def addWithoutResize(k: T): Int = putInto(_bitset, _data, k) + + /** + * Rehash the set if it is overloaded. + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + if (_size > loadFactor * _capacity) { + rehash(k, allocateFunc, moveFunc) + } + } + + /** + * Return the position of the element in the underlying array, or INVALID_POS if it is not found. + */ + def getPos(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + return INVALID_POS + } else if (k == _data(pos)) { + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + INVALID_POS + } + + /** Return the value at the specified position. */ + def getValue(pos: Int): T = _data(pos) + + /** + * Return the next position with an element stored, starting from the given position inclusively. + */ + def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) + + /** + * Put an entry into the set. Return the position where the key is placed. In addition, the + * highest bit in the returned position is set if the key exists prior to this put. + * + * This function assumes the data array has at least one empty slot. + */ + private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { + val mask = data.length - 1 + var pos = hashcode(hasher.hash(k)) & mask + var i = 1 + while (true) { + if (!bitset.get(pos)) { + // This is a new key. + data(pos) = k + bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK + } else if (data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } + + /** + * Double the table's size and re-hash everything. We are not really using k, but it is declared + * so Scala compiler can specialize this method (which leads to calling the specialized version + * of putInto). + * + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + val newCapacity = _capacity * 2 + require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + + allocateFunc(newCapacity) + val newData = new Array[T](newCapacity) + val newBitset = new BitSet(newCapacity) + var pos = 0 + _size = 0 + while (pos < _capacity) { + if (_bitset.get(pos)) { + val newPos = putInto(newBitset, newData, _data(pos)) + moveFunc(pos, newPos & POSITION_MASK) + } + pos += 1 + } + _bitset = newBitset + _data = newData + _capacity = newCapacity + _mask = newCapacity - 1 + } + + /** + * Re-hash a value to deal better with hash functions that don't differ + * in the lower bits, similar to java.util.HashMap + */ + private def hashcode(h: Int): Int = { + val r = h ^ (h >>> 20) ^ (h >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) + } + + private def nextPowerOf2(n: Int): Int = { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } +} + + +private[spark] +object OpenHashSet { + + val INVALID_POS = -1 + val NONEXISTENCE_MASK = 0x80000000 + val POSITION_MASK = 0xEFFFFFF + + /** + * A set of specialized hash function implementation to avoid boxing hash code computation + * in the specialized implementation of OpenHashSet. + */ + sealed class Hasher[@specialized(Long, Int) T] { + def hash(o: T): Int = o.hashCode() + } + + class LongHasher extends Hasher[Long] { + override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt + } + + class IntHasher extends Hasher[Int] { + override def hash(o: Int): Int = o + } + + private def grow1(newSize: Int) {} + private def move1(oldPos: Int, newPos: Int) { } + + private val grow = grow1 _ + private val move = move1 _ +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..d76143e45aa58117f8d64fdef20b4c9de28c01ee --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A fast hash map implementation for primitive, non-null keys. This hash map supports + * insertions and updates, but not deletions. This map is about an order of magnitude + * faster than java.util.HashMap, while using much less space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, + @specialized(Long, Int, Double) V: ClassManifest]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int]) + + // Init in constructor (instead of in declaration) to work around a Scala compiler specialization + // bug that would generate two arrays (one for Object and one for specialized T). + protected var _keySet: OpenHashSet[K] = _ + private var _values: Array[V] = _ + _keySet = new OpenHashSet[K](initialCapacity) + _values = new Array[V](_keySet.capacity) + + private var _oldValues: Array[V] = null + + override def size = _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + val pos = _keySet.getPos(k) + _values(pos) + } + + /** Get the value for a given key, or returns elseValue if it doesn't exist. */ + def getOrElse(k: K, elseValue: V): V = { + val pos = _keySet.getPos(k) + if (pos >= 0) _values(pos) else elseValue + } + + /** Set the value for a key */ + def update(k: K, v: V) { + val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + val pos = _keySet.addWithoutResize(k) + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = 0 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the unspecialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala new file mode 100644 index 0000000000000000000000000000000000000000..369519c5595de8b12a646bea9b26d1704d42f32d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */ +private[spark] +class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) { + private var numElements = 0 + private var array: Array[V] = _ + + // NB: This must be separate from the declaration, otherwise the specialized parent class + // will get its own array with the same initial size. TODO: Figure out why... + array = new Array[V](initialSize) + + def apply(index: Int): V = { + require(index < numElements) + array(index) + } + + def +=(value: V) { + if (numElements == array.length) { resize(array.length * 2) } + array(numElements) = value + numElements += 1 + } + + def length = numElements + + def getUnderlyingArray = array + + /** Resizes the array, dropping elements if the total length decreases. */ + def resize(newLength: Int) { + val newArray = new Array[V](newLength) + array.copyToArray(newArray) + array = newArray + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..8f0954122b322dbbe3a4504e53eb8fda06e40bb5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -0,0 +1,19 @@ +package org.apache.spark.deploy.worker + +import java.io.File +import org.scalatest.FunSuite +import org.apache.spark.deploy.{ExecutorState, Command, ApplicationDescription} + +class ExecutorRunnerTest extends FunSuite { + test("command includes appId") { + def f(s:String) = new File(s) + val sparkHome = sys.env("SPARK_HOME") + val appDesc = new ApplicationDescription("app name", 8, 500, Command("foo", Seq(),Map()), + sparkHome, "appUiUrl") + val appId = "12345-worker321-9876" + val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome), + f("ooga"), ExecutorState.RUNNING) + + assert(er.buildCommandSeq().last === appId) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index 8406093246dac26994aecf2ee56f558a8faa2b51..984881861c9a985a3dc92950c0f7759a37fa949a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -65,7 +65,7 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers val rootStageInfo = new StageInfo(rootStage) joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null)) - joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName) + joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getSimpleName) parentRdd.setName("MyRDD") joblogger.getRddNameTest(parentRdd) should be ("MyRDD") joblogger.createLogWriterTest(jobID) @@ -91,8 +91,10 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers sc.addSparkListener(joblogger) val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } rdd.reduceByKey(_+_).collect() + + val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER) - joblogger.getLogDir should be ("/tmp/spark") + joblogger.getLogDir should be ("/tmp/spark-%s".format(user)) joblogger.getJobIDtoPrintWriter.size should be (1) joblogger.getStageIDToJobID.size should be (2) joblogger.getStageIDToJobID.get(0) should be (Some(0)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index d9a1d6d087ece89c66b165311a589b0abbe6b57c..f3e592bf5cabf4ef6861eee6fa9ffb0169ec0f23 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -84,7 +84,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc i } - val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)} + val d = sc.parallelize(0 to 1e4.toInt, 64).map{i => w(i)} d.count() assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be (1) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0b9056344c1dd6d4b1b77bf7b2afb2b22a64e84c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -0,0 +1,84 @@ +package org.apache.spark.storage + +import java.io.{FileWriter, File} + +import scala.collection.mutable + +import com.google.common.io.Files +import org.scalatest.{BeforeAndAfterEach, FunSuite} + +class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { + + val rootDir0 = Files.createTempDir() + rootDir0.deleteOnExit() + val rootDir1 = Files.createTempDir() + rootDir1.deleteOnExit() + val rootDirs = rootDir0.getName + "," + rootDir1.getName + println("Created root dirs: " + rootDirs) + + val shuffleBlockManager = new ShuffleBlockManager(null) { + var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() + override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) + } + + var diskBlockManager: DiskBlockManager = _ + + override def beforeEach() { + diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs) + shuffleBlockManager.idToSegmentMap.clear() + } + + test("basic block creation") { + val blockId = new TestBlockId("test") + assertSegmentEquals(blockId, blockId.name, 0, 0) + + val newFile = diskBlockManager.getFile(blockId) + writeToFile(newFile, 10) + assertSegmentEquals(blockId, blockId.name, 0, 10) + + newFile.delete() + } + + test("block appending") { + val blockId = new TestBlockId("test") + val newFile = diskBlockManager.getFile(blockId) + writeToFile(newFile, 15) + assertSegmentEquals(blockId, blockId.name, 0, 15) + val newFile2 = diskBlockManager.getFile(blockId) + assert(newFile === newFile2) + writeToFile(newFile2, 12) + assertSegmentEquals(blockId, blockId.name, 0, 27) + newFile.delete() + } + + test("block remapping") { + val filename = "test" + val blockId0 = new ShuffleBlockId(1, 2, 3) + val newFile = diskBlockManager.getFile(filename) + writeToFile(newFile, 15) + shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15) + assertSegmentEquals(blockId0, filename, 0, 15) + + val blockId1 = new ShuffleBlockId(1, 2, 4) + val newFile2 = diskBlockManager.getFile(filename) + writeToFile(newFile2, 12) + shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12) + assertSegmentEquals(blockId1, filename, 15, 12) + + assert(newFile === newFile2) + newFile.delete() + } + + def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) { + val segment = diskBlockManager.getBlockLocation(blockId) + assert(segment.file.getName === filename) + assert(segment.offset === offset) + assert(segment.length === length) + } + + def writeToFile(file: File, numBytes: Int) { + val writer = new FileWriter(file, true) + for (i <- 0 until numBytes) writer.write(i) + writer.close() + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0f1ab3d20eea4456385f26df2c724595450e6234 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.scalatest.FunSuite + + +class BitSetSuite extends FunSuite { + + test("basic set and get") { + val setBits = Seq(0, 9, 1, 10, 90, 96) + val bitset = new BitSet(100) + + for (i <- 0 until 100) { + assert(!bitset.get(i)) + } + + setBits.foreach(i => bitset.set(i)) + + for (i <- 0 until 100) { + if (setBits.contains(i)) { + assert(bitset.get(i)) + } else { + assert(!bitset.get(i)) + } + } + assert(bitset.cardinality() === setBits.size) + } + + test("100% full bit set") { + val bitset = new BitSet(10000) + for (i <- 0 until 10000) { + assert(!bitset.get(i)) + bitset.set(i) + } + for (i <- 0 until 10000) { + assert(bitset.get(i)) + } + assert(bitset.cardinality() === 10000) + } + + test("nextSetBit") { + val setBits = Seq(0, 9, 1, 10, 90, 96) + val bitset = new BitSet(100) + setBits.foreach(i => bitset.set(i)) + + assert(bitset.nextSetBit(0) === 0) + assert(bitset.nextSetBit(1) === 1) + assert(bitset.nextSetBit(2) === 9) + assert(bitset.nextSetBit(9) === 9) + assert(bitset.nextSetBit(10) === 10) + assert(bitset.nextSetBit(11) === 90) + assert(bitset.nextSetBit(80) === 90) + assert(bitset.nextSetBit(91) === 96) + assert(bitset.nextSetBit(96) === 96) + assert(bitset.nextSetBit(97) === -1) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ca3f684668d605e868d491fffb5b4f0bcc4a23a1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -0,0 +1,148 @@ +package org.apache.spark.util.collection + +import scala.collection.mutable.HashSet +import org.scalatest.FunSuite + +class OpenHashMapSuite extends FunSuite { + + test("initialization") { + val goodMap1 = new OpenHashMap[String, Int](1) + assert(goodMap1.size === 0) + val goodMap2 = new OpenHashMap[String, Int](255) + assert(goodMap2.size === 0) + val goodMap3 = new OpenHashMap[String, String](256) + assert(goodMap3.size === 0) + intercept[IllegalArgumentException] { + new OpenHashMap[String, Int](1 << 30) // Invalid map size: bigger than 2^29 + } + intercept[IllegalArgumentException] { + new OpenHashMap[String, Int](-1) + } + intercept[IllegalArgumentException] { + new OpenHashMap[String, String](0) + } + } + + test("primitive value") { + val map = new OpenHashMap[String, Int] + + for (i <- 1 to 1000) { + map(i.toString) = i + assert(map(i.toString) === i) + } + + assert(map.size === 1000) + assert(map(null) === 0) + + map(null) = -1 + assert(map.size === 1001) + assert(map(null) === -1) + + for (i <- 1 to 1000) { + assert(map(i.toString) === i) + } + + // Test iterator + val set = new HashSet[(String, Int)] + for ((k, v) <- map) { + set.add((k, v)) + } + val expected = (1 to 1000).map(x => (x.toString, x)) :+ (null.asInstanceOf[String], -1) + assert(set === expected.toSet) + } + + test("non-primitive value") { + val map = new OpenHashMap[String, String] + + for (i <- 1 to 1000) { + map(i.toString) = i.toString + assert(map(i.toString) === i.toString) + } + + assert(map.size === 1000) + assert(map(null) === null) + + map(null) = "-1" + assert(map.size === 1001) + assert(map(null) === "-1") + + for (i <- 1 to 1000) { + assert(map(i.toString) === i.toString) + } + + // Test iterator + val set = new HashSet[(String, String)] + for ((k, v) <- map) { + set.add((k, v)) + } + val expected = (1 to 1000).map(_.toString).map(x => (x, x)) :+ (null.asInstanceOf[String], "-1") + assert(set === expected.toSet) + } + + test("null keys") { + val map = new OpenHashMap[String, String]() + for (i <- 1 to 100) { + map(i.toString) = i.toString + } + assert(map.size === 100) + assert(map(null) === null) + map(null) = "hello" + assert(map.size === 101) + assert(map(null) === "hello") + } + + test("null values") { + val map = new OpenHashMap[String, String]() + for (i <- 1 to 100) { + map(i.toString) = null + } + assert(map.size === 100) + assert(map("1") === null) + assert(map(null) === null) + assert(map.size === 100) + map(null) = null + assert(map.size === 101) + assert(map(null) === null) + } + + test("changeValue") { + val map = new OpenHashMap[String, String]() + for (i <- 1 to 100) { + map(i.toString) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + val res = map.changeValue(i.toString, { assert(false); "" }, v => { + assert(v === i.toString) + v + "!" + }) + assert(res === i + "!") + } + // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a + // bug where changeValue would return the wrong result when the map grew on that insert + for (i <- 101 to 400) { + val res = map.changeValue(i.toString, { i + "!" }, v => { assert(false); v }) + assert(res === i + "!") + } + assert(map.size === 400) + assert(map(null) === null) + map.changeValue(null, { "null!" }, v => { assert(false); v }) + assert(map.size === 401) + map.changeValue(null, { assert(false); "" }, v => { + assert(v === "null!") + "null!!" + }) + assert(map.size === 401) + } + + test("inserting in capacity-1 map") { + val map = new OpenHashMap[String, String](1) + for (i <- 1 to 100) { + map(i.toString) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map(i.toString) === i.toString) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..4e11e8a628b44e3dffa1b076263cfc3696eea438 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -0,0 +1,145 @@ +package org.apache.spark.util.collection + +import org.scalatest.FunSuite + + +class OpenHashSetSuite extends FunSuite { + + test("primitive int") { + val set = new OpenHashSet[Int] + assert(set.size === 0) + assert(!set.contains(10)) + assert(!set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + + set.add(10) + assert(set.contains(10)) + assert(!set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + + set.add(50) + assert(set.size === 2) + assert(set.contains(10)) + assert(set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + + set.add(999) + assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(!set.contains(10000)) + + set.add(50) + assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(!set.contains(10000)) + } + + test("primitive long") { + val set = new OpenHashSet[Long] + assert(set.size === 0) + assert(!set.contains(10L)) + assert(!set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + + set.add(10L) + assert(set.size === 1) + assert(set.contains(10L)) + assert(!set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + + set.add(50L) + assert(set.size === 2) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + + set.add(999L) + assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(!set.contains(10000L)) + + set.add(50L) + assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(!set.contains(10000L)) + } + + test("non-primitive") { + val set = new OpenHashSet[String] + assert(set.size === 0) + assert(!set.contains(10.toString)) + assert(!set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.add(10.toString) + assert(set.size === 1) + assert(set.contains(10.toString)) + assert(!set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.add(50.toString) + assert(set.size === 2) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.add(999.toString) + assert(set.size === 3) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.add(50.toString) + assert(set.size === 3) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(set.contains(999.toString)) + assert(!set.contains(10000.toString)) + } + + test("non-primitive set growth") { + val set = new OpenHashSet[String] + for (i <- 1 to 1000) { + set.add(i.toString) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + for (i <- 1 to 100) { + set.add(i.toString) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + } + + test("primitive set growth") { + val set = new OpenHashSet[Long] + for (i <- 1 to 1000) { + set.add(i.toLong) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + for (i <- 1 to 100) { + set.add(i.toLong) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..dfd6aed2c4bccf7f1d9a25690ce0c6be41097678 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala @@ -0,0 +1,90 @@ +package org.apache.spark.util.collection + +import scala.collection.mutable.HashSet +import org.scalatest.FunSuite + +class PrimitiveKeyOpenHashSetSuite extends FunSuite { + + test("initialization") { + val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1) + assert(goodMap1.size === 0) + val goodMap2 = new PrimitiveKeyOpenHashMap[Int, Int](255) + assert(goodMap2.size === 0) + val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256) + assert(goodMap3.size === 0) + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 + } + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](-1) + } + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](0) + } + } + + test("basic operations") { + val longBase = 1000000L + val map = new PrimitiveKeyOpenHashMap[Long, Int] + + for (i <- 1 to 1000) { + map(i + longBase) = i + assert(map(i + longBase) === i) + } + + assert(map.size === 1000) + + for (i <- 1 to 1000) { + assert(map(i + longBase) === i) + } + + // Test iterator + val set = new HashSet[(Long, Int)] + for ((k, v) <- map) { + set.add((k, v)) + } + assert(set === (1 to 1000).map(x => (x + longBase, x)).toSet) + } + + test("null values") { + val map = new PrimitiveKeyOpenHashMap[Long, String]() + for (i <- 1 to 100) { + map(i.toLong) = null + } + assert(map.size === 100) + assert(map(1.toLong) === null) + } + + test("changeValue") { + val map = new PrimitiveKeyOpenHashMap[Long, String]() + for (i <- 1 to 100) { + map(i.toLong) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + val res = map.changeValue(i.toLong, { assert(false); "" }, v => { + assert(v === i.toString) + v + "!" + }) + assert(res === i + "!") + } + // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a + // bug where changeValue would return the wrong result when the map grew on that insert + for (i <- 101 to 400) { + val res = map.changeValue(i.toLong, { i + "!" }, v => { assert(false); v }) + assert(res === i + "!") + } + assert(map.size === 400) + } + + test("inserting in capacity-1 map") { + val map = new PrimitiveKeyOpenHashMap[Long, String](1) + for (i <- 1 to 100) { + map(i.toLong) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map(i.toLong) === i.toString) + } + } +} diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index f679cad713769d45abfc9f10936fa39c3d724d8b..5927f736f3579bd7a2a505999257d2ea9abbac01 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -13,7 +13,7 @@ object in your main program (called the _driver program_). Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_ (either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are -worker processes that run computations and store data for your application. +worker processes that run computations and store data for your application. Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to the executors. Finally, SparkContext sends *tasks* for the executors to run. @@ -57,6 +57,18 @@ which takes a list of JAR files (Java/Scala) or .egg and .zip libraries (Python) worker nodes. You can also dynamically add new files to be sent to executors with `SparkContext.addJar` and `addFile`. +## URIs for addJar / addFile + +- **file:** - Absolute paths and `file:/` URIs are served by the driver's HTTP file server, and every executor + pulls the file from the driver HTTP server +- **hdfs:**, **http:**, **https:**, **ftp:** - these pull down files and JARs from the URI as expected +- **local:** - a URI starting with local:/ is expected to exist as a local file on each worker node. This + means that no network IO will be incurred, and works well for large files/JARs that are pushed to each worker, + or shared via NFS, GlusterFS, etc. + +Note that JARs and files are copied to the working directory for each SparkContext on the executor nodes. +Over time this can use up a significant amount of space and will need to be cleaned up. + # Monitoring Each driver program has a web UI, typically on port 4040, that displays information about running diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 1e5575d6570eaae02471ae055eb87e4fa57f9365..156a727026790b9e45232fdf1c44c0906921efe3 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -98,7 +98,7 @@ permissions on your private key file, you can run `launch` with the `bin/hadoop` script in that directory. Note that the data in this HDFS goes away when you stop and restart a machine. - There is also a *persistent HDFS* instance in - `/root/presistent-hdfs` that will keep data across cluster restarts. + `/root/persistent-hdfs` that will keep data across cluster restarts. Typically each node has relatively little space of persistent data (about 3 GB), but you can use the `--ebs-vol-size` option to `spark-ec2` to attach a persistent EBS volume to each node for diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 2898af0bed8c0d6d6de9be6a2c10472994e2a222..6fd1d0d150306668f34de2376852884a442b0b97 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -21,6 +21,7 @@ The assembled JAR will be something like this: # Preparations - Building a YARN-enabled assembly (see above). +- The assembled jar can be installed into HDFS or used locally. - Your application code must be packaged into a separate JAR file. If you want to test out the YARN deployment mode, you can use the current Spark examples. A `spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}` file can be generated by running `sbt/sbt assembly`. NOTE: since the documentation you're reading is for Spark version {{site.SPARK_VERSION}}, we are assuming here that you have downloaded Spark {{site.SPARK_VERSION}} or checked it out of source control. If you are using a different version of Spark, the version numbers in the jar generated by the sbt package command will obviously be different. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 65868b76b91fcfdd895ba7a14bbf2895407471e4..11892324286a5dbc070dad4bad7be8322d3c17db 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -72,12 +72,12 @@ def parse_args(): parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") parser.add_option("-v", "--spark-version", default="0.8.0", help="Version of Spark to use: 'X.Y.Z' or a specific git hash") - parser.add_option("--spark-git-repo", - default="https://github.com/mesos/spark", + parser.add_option("--spark-git-repo", + default="https://github.com/apache/incubator-spark", help="Github repo from which to checkout supplied commit hash") parser.add_option("--hadoop-major-version", default="1", help="Major version of Hadoop (default: 1)") - parser.add_option("-D", metavar="[ADDRESS:]PORT", dest="proxy_port", + parser.add_option("-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + "the given local address (for use with login)") parser.add_option("--resume", action="store_true", default=False, @@ -101,6 +101,8 @@ def parse_args(): help="The SSH user you want to connect as (default: root)") parser.add_option("--delete-groups", action="store_true", default=False, help="When destroying a cluster, delete the security groups that were created") + parser.add_option("--use-existing-master", action="store_true", default=False, + help="Launch fresh slaves, but use an existing stopped master if possible") (opts, args) = parser.parse_args() if len(args) != 2: @@ -191,7 +193,7 @@ def get_spark_ami(opts): instance_type = "pvm" print >> stderr,\ "Don't recognize %s, assuming type is pvm" % opts.instance_type - + ami_path = "%s/%s/%s" % (AMI_PREFIX, opts.region, instance_type) try: ami = urllib2.urlopen(ami_path).read().strip() @@ -215,6 +217,7 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize(src_group=slave_group) master_group.authorize('tcp', 22, 22, '0.0.0.0/0') master_group.authorize('tcp', 8080, 8081, '0.0.0.0/0') + master_group.authorize('tcp', 19999, 19999, '0.0.0.0/0') master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0') master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0') master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0') @@ -232,9 +235,9 @@ def launch_cluster(conn, opts, cluster_name): slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0') # Check if instances are already running in our groups - active_nodes = get_existing_cluster(conn, opts, cluster_name, - die_on_error=False) - if any(active_nodes): + existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, + die_on_error=False) + if existing_slaves or (existing_masters and not opts.use_existing_master): print >> stderr, ("ERROR: There are already instances running in " + "group %s or %s" % (master_group.name, slave_group.name)) sys.exit(1) @@ -335,21 +338,28 @@ def launch_cluster(conn, opts, cluster_name): zone, slave_res.id) i += 1 - # Launch masters - master_type = opts.master_instance_type - if master_type == "": - master_type = opts.instance_type - if opts.zone == 'all': - opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run(key_name = opts.key_pair, - security_groups = [master_group], - instance_type = master_type, - placement = opts.zone, - min_count = 1, - max_count = 1, - block_device_map = block_map) - master_nodes = master_res.instances - print "Launched master in %s, regid = %s" % (zone, master_res.id) + # Launch or resume masters + if existing_masters: + print "Starting master..." + for inst in existing_masters: + if inst.state not in ["shutting-down", "terminated"]: + inst.start() + master_nodes = existing_masters + else: + master_type = opts.master_instance_type + if master_type == "": + master_type = opts.instance_type + if opts.zone == 'all': + opts.zone = random.choice(conn.get_all_zones()).name + master_res = image.run(key_name = opts.key_pair, + security_groups = [master_group], + instance_type = master_type, + placement = opts.zone, + min_count = 1, + max_count = 1, + block_device_map = block_map) + master_nodes = master_res.instances + print "Launched master in %s, regid = %s" % (zone, master_res.id) # Return all the instances return (master_nodes, slave_nodes) @@ -403,8 +413,8 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): print slave.public_dns_name ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar) - modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone'] + modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs', + 'mapreduce', 'spark-standalone', 'tachyon'] if opts.hadoop_major_version == "1": modules = filter(lambda x: x != "mapreduce", modules) @@ -668,12 +678,12 @@ def real_main(): print "Terminating slaves..." for inst in slave_nodes: inst.terminate() - + # Delete security groups as well if opts.delete_groups: print "Deleting security groups (this will take some time)..." group_names = [cluster_name + "-master", cluster_name + "-slaves"] - + attempt = 1; while attempt <= 3: print "Attempt %d" % attempt @@ -731,6 +741,7 @@ def real_main(): cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " + "BUT THE CLUSTER WILL KEEP USING SPACE ON\n" + "AMAZON EBS IF IT IS EBS-BACKED!!\n" + + "All data on spot-instance slaves will be lost.\n" + "Stop cluster " + cluster_name + " (y/N): ") if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( @@ -742,7 +753,10 @@ def real_main(): print "Stopping slaves..." for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: - inst.stop() + if inst.spot_instance_request_id: + inst.terminate() + else: + inst.stop() elif action == "start": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) diff --git a/pom.xml b/pom.xml index 53ac82efd0247ebf45accfdc465d3fd0377c9901..edcc3b35cda084a91b1eb13a59472c8f54c67e22 100644 --- a/pom.xml +++ b/pom.xml @@ -385,6 +385,12 @@ <version>3.1</version> <scope>test</scope> </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <version>1.8.5</version> + <scope>test</scope> + </dependency> <dependency> <groupId>org.scalacheck</groupId> <artifactId>scalacheck_2.9.3</artifactId> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 45fd30a7c836408813b473df32db3197508c3be0..96232718f8293ff13892cf502ecf5840bc748f55 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -170,7 +170,8 @@ object SparkBuild extends Build { "org.scalatest" %% "scalatest" % "1.9.1" % "test", "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", "com.novocode" % "junit-interface" % "0.9" % "test", - "org.easymock" % "easymock" % "3.1" % "test" + "org.easymock" % "easymock" % "3.1" % "test", + "org.mockito" % "mockito-all" % "1.8.5" % "test" ), /* Workaround for issue #206 (fixed after SBT 0.11.0) */ watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task, @@ -265,7 +266,7 @@ object SparkBuild extends Build { def toolsSettings = sharedSettings ++ Seq( name := "spark-tools" - ) + ) ++ assemblySettings ++ extraAssemblySettings def bagelSettings = sharedSettings ++ Seq( name := "spark-bagel" diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index e6e35c9b5df9e85c8d2778cedd23b23e2abe134b..870e12de341dd13159ffc3e9df9934d17bf12648 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -878,14 +878,21 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends (message, false) } } + + // Get a copy of the local properties from SparkContext, and set it later in the thread + // that triggers the execution. This is to make sure the caller of this function can pass + // the right thread local (inheritable) properties down into Spark. + val sc = org.apache.spark.repl.Main.interp.sparkContext + val props = if (sc != null) sc.getLocalProperties() else null try { val execution = lineManager.set(originalLine) { // MATEI: set the right SparkEnv for our SparkContext, because // this execution will happen in a separate thread - val sc = org.apache.spark.repl.Main.interp.sparkContext - if (sc != null && sc.env != null) + if (sc != null && sc.env != null) { SparkEnv.set(sc.env) + sc.setLocalProperties(props) + } // Execute the line lineRep call "$export" } diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 8f9b632c0eea67ace92595a389cbe6755902b161..6e4504d4d5f41a09e06547e14067295b54250d85 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -21,12 +21,14 @@ import java.io._ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ -import org.scalatest.FunSuite import com.google.common.io.Files +import org.scalatest.FunSuite +import org.apache.spark.SparkContext + class ReplSuite extends FunSuite { + def runInterpreter(master: String, input: String): String = { val in = new BufferedReader(new StringReader(input + "\n")) val out = new StringWriter() @@ -64,6 +66,35 @@ class ReplSuite extends FunSuite { "Interpreter output contained '" + message + "':\n" + output) } + test("propagation of local properties") { + // A mock ILoop that doesn't install the SIGINT handler. + class ILoop(out: PrintWriter) extends SparkILoop(None, out, None) { + settings = new scala.tools.nsc.Settings + settings.usejavacp.value = true + org.apache.spark.repl.Main.interp = this + override def createInterpreter() { + intp = new SparkILoopInterpreter + intp.setContextClassLoader() + } + } + + val out = new StringWriter() + val interp = new ILoop(new PrintWriter(out)) + interp.sparkContext = new SparkContext("local", "repl-test") + interp.createInterpreter() + interp.intp.initialize() + interp.sparkContext.setLocalProperty("someKey", "someValue") + + // Make sure the value we set in the caller to interpret is propagated in the thread that + // interprets the command. + interp.interpret("org.apache.spark.repl.Main.interp.sparkContext.getLocalProperty(\"someKey\")") + assert(out.toString.contains("someValue")) + + interp.sparkContext.stop() + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + } + test ("simple foreach with accumulator") { val output = runInterpreter("local", """ val accum = sc.accumulator(0) diff --git a/spark-class b/spark-class index fb9d1a4f8eaaf5e29632969a9682dd162d3f88d2..bbeca7f245692977adf97656d26ea05c17276812 100755 --- a/spark-class +++ b/spark-class @@ -110,8 +110,21 @@ if [ ! -f "$FWDIR/RELEASE" ]; then fi fi +TOOLS_DIR="$FWDIR"/tools +SPARK_TOOLS_JAR="" +if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar ]; then + # Use the JAR from the SBT build + export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar` +fi +if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then + # Use the JAR from the Maven build + # TODO: this also needs to become an assembly! + export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar` +fi + # Compute classpath using external script CLASSPATH=`$FWDIR/bin/compute-classpath.sh` +CLASSPATH="$SPARK_TOOLS_JAR:$CLASSPATH" export CLASSPATH if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then diff --git a/spark-class2.cmd b/spark-class2.cmd index d4d853e8ad930931e7ddf2295c6c1448846a29e0..3869d0761bfaa8e7ba0e3688b1ec23f8e8a56d87 100644 --- a/spark-class2.cmd +++ b/spark-class2.cmd @@ -65,10 +65,17 @@ if "%FOUND_JAR%"=="0" ( ) :skip_build_test +set TOOLS_DIR=%FWDIR%tools +set SPARK_TOOLS_JAR= +for %%d in ("%TOOLS_DIR%\target\scala-%SCALA_VERSION%\spark-tools*assembly*.jar") do ( + set SPARK_TOOLS_JAR=%%d +) + rem Compute classpath using external script set DONT_PRINT_CLASSPATH=1 call "%FWDIR%bin\compute-classpath.cmd" set DONT_PRINT_CLASSPATH=0 +set CLASSPATH=%SPARK_TOOLS_JAR%;%CLASSPATH% rem Figure out where java is. set RUNNER=java diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 8d3ac0fc65ad5999a41c1637d7f8ed919855e00f..a82862c8029b2879c907216666f78e74d298894e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -232,11 +232,11 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log logInfo("Data handler stopped") } - def += (obj: T) { + def += (obj: T): Unit = synchronized { currentBuffer += obj } - private def updateCurrentBuffer(time: Long) { + private def updateCurrentBuffer(time: Long): Unit = synchronized { try { val newBlockBuffer = currentBuffer currentBuffer = new ArrayBuffer[T] diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index c29b75ece69f05211be5dc977b0f2e860282cca0..a559db468a771826e818e0aaccc0bf5b5e9b5913 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -23,15 +23,15 @@ import akka.actor.IOManager import akka.actor.Props import akka.util.ByteString -import dstream.SparkFlumeEvent +import org.apache.spark.streaming.dstream.{NetworkReceiver, SparkFlumeEvent} import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} +import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import util.ManualClock import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.receivers.Receiver -import org.apache.spark.Logging +import org.apache.spark.{SparkContext, Logging} import scala.util.Random import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter @@ -44,6 +44,7 @@ import java.nio.ByteBuffer import collection.JavaConversions._ import java.nio.charset.Charset import com.google.common.io.Files +import java.util.concurrent.atomic.AtomicInteger class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -61,7 +62,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { System.clearProperty("spark.hostPort") } - test("socket input stream") { // Start the server val testServer = new TestServer() @@ -275,10 +275,49 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { kafka.serializer.StringDecoder, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK) } + + test("multi-thread receiver") { + // set up the test receiver + val numThreads = 10 + val numRecordsPerThread = 1000 + val numTotalRecords = numThreads * numRecordsPerThread + val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread) + MultiThreadTestReceiver.haveAllThreadsFinished = false + + // set up the network stream using the test receiver + val ssc = new StreamingContext(master, framework, batchDuration) + val networkStream = ssc.networkStream[Int](testReceiver) + val countStream = networkStream.count + val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]] + val outputStream = new TestOutputStream(countStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Let the data from the receiver be received + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val startTime = System.currentTimeMillis() + while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) && + System.currentTimeMillis() - startTime < 5000) { + Thread.sleep(100) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(1000) + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + assert(output.sum === numTotalRecords) + } } -/** This is server to test the network input stream */ +/** This is a server to test the network input stream */ class TestServer() extends Logging { val queue = new ArrayBlockingQueue[String](100) @@ -340,6 +379,7 @@ object TestServer { } } +/** This is an actor for testing actor input stream */ class TestActor(port: Int) extends Actor with Receiver { def bytesToString(byteString: ByteString) = byteString.utf8String @@ -351,3 +391,36 @@ class TestActor(port: Int) extends Actor with Receiver { pushBlock(bytesToString(bytes)) } } + +/** This is a receiver to test multiple threads inserting data using block generator */ +class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) + extends NetworkReceiver[Int] { + lazy val executorPool = Executors.newFixedThreadPool(numThreads) + lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY) + lazy val finishCount = new AtomicInteger(0) + + protected def onStart() { + blockGenerator.start() + (1 to numThreads).map(threadId => { + val runnable = new Runnable { + def run() { + (1 to numRecordsPerThread).foreach(i => + blockGenerator += (threadId * numRecordsPerThread + i) ) + if (finishCount.incrementAndGet == numThreads) { + MultiThreadTestReceiver.haveAllThreadsFinished = true + } + logInfo("Finished thread " + threadId) + } + } + executorPool.submit(runnable) + }) + } + + protected def onStop() { + executorPool.shutdown() + } +} + +object MultiThreadTestReceiver { + var haveAllThreadsFinished = false +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 3bc619df07bb9d8988c5dca6ace832adfda79dee..8a065c6d7d1d7f458789f3ab50f0054c8a947dfb 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -61,6 +61,16 @@ <groupId>org.apache.avro</groupId> <artifactId>avro-ipc</artifactId> </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_2.9.3</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <scope>test</scope> + </dependency> </dependencies> <build> @@ -106,6 +116,46 @@ </execution> </executions> </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-antrun-plugin</artifactId> + <executions> + <execution> + <phase>test</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <exportAntProperties>true</exportAntProperties> + <tasks> + <property name="spark.classpath" refid="maven.test.classpath" /> + <property environment="env" /> + <fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry."> + <condition> + <not> + <or> + <isset property="env.SCALA_HOME" /> + <isset property="env.SCALA_LIBRARY_PATH" /> + </or> + </not> + </condition> + </fail> + </tasks> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <configuration> + <environmentVariables> + <SPARK_HOME>${basedir}/..</SPARK_HOME> + <SPARK_TESTING>1</SPARK_TESTING> + <SPARK_CLASSPATH>${spark.classpath}</SPARK_CLASSPATH> + </environmentVariables> + </configuration> + </plugin> </plugins> </build> </project> diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c1a87d33738bede69e4d85cfac13075351898ece..4302ef4cda2619c446608c8cdf47b45d66ecf563 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -349,7 +349,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e try { val preserveFiles = System.getProperty("spark.yarn.preserve.staging.files", "false").toBoolean if (!preserveFiles) { - stagingDirPath = new Path(System.getenv("SPARK_YARN_JAR_PATH")).getParent() + stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) if (stagingDirPath == null) { logError("Staging directory is null") return diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1a380ae714534bc06146d62c0921a404bc64ce5e..4e0e060ddc29b982947f9dc0669e0c1f742a6e80 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,26 +17,31 @@ package org.apache.spark.deploy.yarn -import java.net.{InetSocketAddress, URI} +import java.net.{InetAddress, InetSocketAddress, UnknownHostException, URI} import java.nio.ByteBuffer + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil} +import org.apache.hadoop.fs.permission.FsPermission; import org.apache.hadoop.mapred.Master import org.apache.hadoop.net.NetUtils import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.client.YarnClientImpl import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, Records} + import scala.collection.mutable.HashMap +import scala.collection.mutable.Map import scala.collection.JavaConversions._ + import org.apache.spark.Logging import org.apache.spark.util.Utils -import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils} -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.spark.deploy.SparkHadoopUtil class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging { @@ -46,13 +51,14 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl var rpc: YarnRPC = YarnRPC.create(conf) val yarnConf: YarnConfiguration = new YarnConfiguration(conf) val credentials = UserGroupInformation.getCurrentUser().getCredentials() - private var distFiles = None: Option[String] - private var distFilesTimeStamps = None: Option[String] - private var distFilesFileSizes = None: Option[String] - private var distArchives = None: Option[String] - private var distArchivesTimeStamps = None: Option[String] - private var distArchivesFileSizes = None: Option[String] - + private val SPARK_STAGING: String = ".sparkStaging" + private val distCacheMgr = new ClientDistributedCacheManager() + + // staging directory is private! -> rwx-------- + val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short) + // app files are world-wide readable and owner writable -> rw-r--r-- + val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) + def run() { init(yarnConf) start() @@ -63,8 +69,9 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl verifyClusterResources(newApp) val appContext = createApplicationSubmissionContext(appId) - val localResources = prepareLocalResources(appId, ".sparkStaging") - val env = setupLaunchEnv(localResources) + val appStagingDir = getAppStagingDir(appId) + val localResources = prepareLocalResources(appStagingDir) + val env = setupLaunchEnv(localResources, appStagingDir) val amContainer = createContainerLaunchContext(newApp, localResources, env) appContext.setQueue(args.amQueue) @@ -76,7 +83,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl monitorApplication(appId) System.exit(0) } - + + def getAppStagingDir(appId: ApplicationId): String = { + SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR + } def logClusterResourceDetails() { val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics @@ -116,73 +126,73 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl return appContext } + /* + * see if two file systems are the same or not. + */ + private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + if (srcUri.getScheme() == null) { + return false + } + if (!srcUri.getScheme().equals(dstUri.getScheme())) { + return false + } + var srcHost = srcUri.getHost() + var dstHost = dstUri.getHost() + if ((srcHost != null) && (dstHost != null)) { + try { + srcHost = InetAddress.getByName(srcHost).getCanonicalHostName(); + dstHost = InetAddress.getByName(dstHost).getCanonicalHostName(); + } catch { + case e: UnknownHostException => + return false + } + if (!srcHost.equals(dstHost)) { + return false + } + } else if (srcHost == null && dstHost != null) { + return false + } else if (srcHost != null && dstHost == null) { + return false + } + //check for ports + if (srcUri.getPort() != dstUri.getPort()) { + return false + } + return true; + } + /** - * Copy the local file into HDFS and configure to be distributed with the - * job via the distributed cache. - * If a fragment is specified the file will be referenced as that fragment. + * Copy the file into HDFS if needed. */ - private def copyLocalFile( + private def copyRemoteFile( dstDir: Path, - resourceType: LocalResourceType, originalPath: Path, replication: Short, - localResources: HashMap[String,LocalResource], - fragment: String, - appMasterOnly: Boolean = false): Unit = { + setPerms: Boolean = false): Path = { val fs = FileSystem.get(conf) - val newPath = new Path(dstDir, originalPath.getName()) - logInfo("Uploading " + originalPath + " to " + newPath) - fs.copyFromLocalFile(false, true, originalPath, newPath) - fs.setReplication(newPath, replication); - val destStatus = fs.getFileStatus(newPath) - - val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - amJarRsrc.setType(resourceType) - amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) - amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath)) - amJarRsrc.setTimestamp(destStatus.getModificationTime()) - amJarRsrc.setSize(destStatus.getLen()) - var pathURI: URI = new URI(newPath.toString() + "#" + originalPath.getName()); - if ((fragment == null) || (fragment.isEmpty())){ - localResources(originalPath.getName()) = amJarRsrc - } else { - localResources(fragment) = amJarRsrc - pathURI = new URI(newPath.toString() + "#" + fragment); - } - val distPath = pathURI.toString() - if (appMasterOnly == true) return - if (resourceType == LocalResourceType.FILE) { - distFiles match { - case Some(path) => - distFilesFileSizes = Some(distFilesFileSizes.get + "," + - destStatus.getLen().toString()) - distFilesTimeStamps = Some(distFilesTimeStamps.get + "," + - destStatus.getModificationTime().toString()) - distFiles = Some(path + "," + distPath) - case _ => - distFilesFileSizes = Some(destStatus.getLen().toString()) - distFilesTimeStamps = Some(destStatus.getModificationTime().toString()) - distFiles = Some(distPath) - } - } else { - distArchives match { - case Some(path) => - distArchivesTimeStamps = Some(distArchivesTimeStamps.get + "," + - destStatus.getModificationTime().toString()) - distArchivesFileSizes = Some(distArchivesFileSizes.get + "," + - destStatus.getLen().toString()) - distArchives = Some(path + "," + distPath) - case _ => - distArchivesTimeStamps = Some(destStatus.getModificationTime().toString()) - distArchivesFileSizes = Some(destStatus.getLen().toString()) - distArchives = Some(distPath) - } - } + val remoteFs = originalPath.getFileSystem(conf); + var newPath = originalPath + if (! compareFs(remoteFs, fs)) { + newPath = new Path(dstDir, originalPath.getName()) + logInfo("Uploading " + originalPath + " to " + newPath) + FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf); + fs.setReplication(newPath, replication); + if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION)) + } + // resolve any symlinks in the URI path so using a "current" symlink + // to point to a specific version shows the specific version + // in the distributed cache configuration + val qualPath = fs.makeQualified(newPath) + val fc = FileContext.getFileContext(qualPath.toUri(), conf) + val destPath = fc.resolvePath(qualPath) + destPath } - def prepareLocalResources(appId: ApplicationId, sparkStagingDir: String): HashMap[String, LocalResource] = { + def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { logInfo("Preparing Local resources") - // Upload Spark and the application JAR to the remote file system + // Upload Spark and the application JAR to the remote file system if necessary // Add them as local resources to the AM val fs = FileSystem.get(conf) @@ -193,9 +203,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl System.exit(1) } } - - val pathSuffix = sparkStagingDir + "/" + appId.toString() + "/" - val dst = new Path(fs.getHomeDirectory(), pathSuffix) + val dst = new Path(fs.getHomeDirectory(), appStagingDir) val replication = System.getProperty("spark.yarn.submit.file.replication", "3").toShort if (UserGroupInformation.isSecurityEnabled()) { @@ -203,55 +211,65 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl dstFs.addDelegationTokens(delegTokenRenewer, credentials); } val localResources = HashMap[String, LocalResource]() + FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) + + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + + if (System.getenv("SPARK_JAR") == null || args.userJar == null) { + logError("Error: You must set SPARK_JAR environment variable and specify a user jar!") + System.exit(1) + } - Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF")) + Map(Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, + Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF")) .foreach { case(destName, _localPath) => val localPath: String = if (_localPath != null) _localPath.trim() else "" if (! localPath.isEmpty()) { - val src = new Path(localPath) - val newPath = new Path(dst, destName) - logInfo("Uploading " + src + " to " + newPath) - fs.copyFromLocalFile(false, true, src, newPath) - fs.setReplication(newPath, replication); - val destStatus = fs.getFileStatus(newPath) - - val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - amJarRsrc.setType(LocalResourceType.FILE) - amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) - amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath)) - amJarRsrc.setTimestamp(destStatus.getModificationTime()) - amJarRsrc.setSize(destStatus.getLen()) - localResources(destName) = amJarRsrc + var localURI = new URI(localPath) + // if not specified assume these are in the local filesystem to keep behavior like Hadoop + if (localURI.getScheme() == null) { + localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString()) + } + val setPermissions = if (destName.equals(Client.APP_JAR)) true else false + val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + destName, statCache) } } // handle any add jars if ((args.addJars != null) && (!args.addJars.isEmpty())){ args.addJars.split(',').foreach { case file: String => - val tmpURI = new URI(file) - val tmp = new Path(tmpURI) - copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources, - tmpURI.getFragment(), true) + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + linkname, statCache, true) } } // handle any distributed cache files if ((args.files != null) && (!args.files.isEmpty())){ args.files.split(',').foreach { case file: String => - val tmpURI = new URI(file) - val tmp = new Path(tmpURI) - copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources, - tmpURI.getFragment()) + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + linkname, statCache) } } // handle any distributed cache archives if ((args.archives != null) && (!args.archives.isEmpty())) { args.archives.split(',').foreach { case file:String => - val tmpURI = new URI(file) - val tmp = new Path(tmpURI) - copyLocalFile(dst, LocalResourceType.ARCHIVE, tmp, replication, - localResources, tmpURI.getFragment()) + val localURI = new URI(file.trim()) + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyRemoteFile(dst, localPath, replication) + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, + linkname, statCache) } } @@ -259,44 +277,21 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl return localResources } - def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = { + def setupLaunchEnv( + localResources: HashMap[String, LocalResource], + stagingDir: String): HashMap[String, String] = { logInfo("Setting up the launch environment") - val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null) + val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null) val env = new HashMap[String, String]() Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env) env("SPARK_YARN_MODE") = "true" - env("SPARK_YARN_JAR_PATH") = - localResources("spark.jar").getResource().getScheme.toString() + "://" + - localResources("spark.jar").getResource().getFile().toString() - env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString() - env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString() - - env("SPARK_YARN_USERJAR_PATH") = - localResources("app.jar").getResource().getScheme.toString() + "://" + - localResources("app.jar").getResource().getFile().toString() - env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString() - env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString() - - if (log4jConfLocalRes != null) { - env("SPARK_YARN_LOG4J_PATH") = - log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString() - env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString() - env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString() - } + env("SPARK_YARN_STAGING_DIR") = stagingDir // set the environment variables to be passed on to the Workers - if (distFiles != None) { - env("SPARK_YARN_CACHE_FILES") = distFiles.get - env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = distFilesTimeStamps.get - env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = distFilesFileSizes.get - } - if (distArchives != None) { - env("SPARK_YARN_CACHE_ARCHIVES") = distArchives.get - env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = distArchivesTimeStamps.get - env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = distArchivesFileSizes.get - } + distCacheMgr.setDistFilesEnv(env) + distCacheMgr.setDistArchivesEnv(env) // allow users to specify some environment variables Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV")) @@ -365,6 +360,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl javaCommand = Environment.JAVA_HOME.$() + "/bin/java" } + if (args.userClass == null) { + logError("Error: You must specify a user class!") + System.exit(1) + } + val commands = List[String](javaCommand + " -server " + JAVA_OPTS + @@ -432,6 +432,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } object Client { + val SPARK_JAR: String = "spark.jar" + val APP_JAR: String = "app.jar" + val LOG4J_PROP: String = "log4j.properties" + def main(argStrings: Array[String]) { // Set an env variable indicating we are running in YARN mode. // Note that anything with SPARK prefix gets propagated to all (remote) processes @@ -453,22 +457,22 @@ object Client { // If log4j present, ensure ours overrides all others if (addLog4j) { Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + - Path.SEPARATOR + "log4j.properties") + Path.SEPARATOR + LOG4J_PROP) } // normally the users app.jar is last in case conflicts with spark jars val userClasspathFirst = System.getProperty("spark.yarn.user.classpath.first", "false") .toBoolean if (userClasspathFirst) { Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + - Path.SEPARATOR + "app.jar") + Path.SEPARATOR + APP_JAR) } Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + - Path.SEPARATOR + "spark.jar") + Path.SEPARATOR + SPARK_JAR) Client.populateHadoopClasspath(conf, env) if (!userClasspathFirst) { Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + - Path.SEPARATOR + "app.jar") + Path.SEPARATOR + APP_JAR) } Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + "*") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..07686fefd7c067bb7753de6eebfce6937a214981 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI; + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.yarn.api.records.LocalResource +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility +import org.apache.hadoop.yarn.api.records.LocalResourceType +import org.apache.hadoop.yarn.util.{Records, ConverterUtils} + +import org.apache.spark.Logging + +import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap +import scala.collection.mutable.Map + + +/** Client side methods to setup the Hadoop distributed cache */ +class ClientDistributedCacheManager() extends Logging { + private val distCacheFiles: Map[String, Tuple3[String, String, String]] = + LinkedHashMap[String, Tuple3[String, String, String]]() + private val distCacheArchives: Map[String, Tuple3[String, String, String]] = + LinkedHashMap[String, Tuple3[String, String, String]]() + + + /** + * Add a resource to the list of distributed cache resources. This list can + * be sent to the ApplicationMaster and possibly the workers so that it can + * be downloaded into the Hadoop distributed cache for use by this application. + * Adds the LocalResource to the localResources HashMap passed in and saves + * the stats of the resources to they can be sent to the workers and verified. + * + * @param fs FileSystem + * @param conf Configuration + * @param destPath path to the resource + * @param localResources localResource hashMap to insert the resource into + * @param resourceType LocalResourceType + * @param link link presented in the distributed cache to the destination + * @param statCache cache to store the file/directory stats + * @param appMasterOnly Whether to only add the resource to the app master + */ + def addResource( + fs: FileSystem, + conf: Configuration, + destPath: Path, + localResources: HashMap[String, LocalResource], + resourceType: LocalResourceType, + link: String, + statCache: Map[URI, FileStatus], + appMasterOnly: Boolean = false) = { + val destStatus = fs.getFileStatus(destPath) + val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + amJarRsrc.setType(resourceType) + val visibility = getVisibility(conf, destPath.toUri(), statCache) + amJarRsrc.setVisibility(visibility) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(destPath)) + amJarRsrc.setTimestamp(destStatus.getModificationTime()) + amJarRsrc.setSize(destStatus.getLen()) + if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") + localResources(link) = amJarRsrc + + if (appMasterOnly == false) { + val uri = destPath.toUri() + val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) + if (resourceType == LocalResourceType.FILE) { + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + destStatus.getModificationTime().toString(), visibility.name()) + } else { + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + destStatus.getModificationTime().toString(), visibility.name()) + } + } + } + + /** + * Adds the necessary cache file env variables to the env passed in + * @param env + */ + def setDistFilesEnv(env: Map[String, String]) = { + val (keys, tupleValues) = distCacheFiles.unzip + val (sizes, timeStamps, visibilities) = tupleValues.unzip3 + + if (keys.size > 0) { + env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + } + } + + /** + * Adds the necessary cache archive env variables to the env passed in + * @param env + */ + def setDistArchivesEnv(env: Map[String, String]) = { + val (keys, tupleValues) = distCacheArchives.unzip + val (sizes, timeStamps, visibilities) = tupleValues.unzip3 + + if (keys.size > 0) { + env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + } + } + + /** + * Returns the local resource visibility depending on the cache file permissions + * @param conf + * @param uri + * @param statCache + * @return LocalResourceVisibility + */ + def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + LocalResourceVisibility = { + if (isPublic(conf, uri, statCache)) { + return LocalResourceVisibility.PUBLIC + } + return LocalResourceVisibility.PRIVATE + } + + /** + * Returns a boolean to denote whether a cache file is visible to all(public) + * or not + * @param conf + * @param uri + * @param statCache + * @return true if the path in the uri is visible to all, false otherwise + */ + def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = { + val fs = FileSystem.get(uri, conf) + val current = new Path(uri.getPath()) + //the leaf level file should be readable by others + if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) { + return false + } + return ancestorsHaveExecutePermissions(fs, current.getParent(), statCache) + } + + /** + * Returns true if all ancestors of the specified path have the 'execute' + * permission set for all users (i.e. that other users can traverse + * the directory heirarchy to the given path) + * @param fs + * @param path + * @param statCache + * @return true if all ancestors have the 'execute' permission set for all users + */ + def ancestorsHaveExecutePermissions(fs: FileSystem, path: Path, + statCache: Map[URI, FileStatus]): Boolean = { + var current = path + while (current != null) { + //the subdirs in the path should have execute permissions for others + if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) { + return false + } + current = current.getParent() + } + return true + } + + /** + * Checks for a given path whether the Other permissions on it + * imply the permission in the passed FsAction + * @param fs + * @param path + * @param action + * @param statCache + * @return true if the path in the uri is visible to all, false otherwise + */ + def checkPermissionOfOther(fs: FileSystem, path: Path, + action: FsAction, statCache: Map[URI, FileStatus]): Boolean = { + val status = getFileStatus(fs, path.toUri(), statCache); + val perms = status.getPermission() + val otherAction = perms.getOtherAction() + if (otherAction.implies(action)) { + return true; + } + return false + } + + /** + * Checks to see if the given uri exists in the cache, if it does it + * returns the existing FileStatus, otherwise it stats the uri, stores + * it in the cache, and returns the FileStatus. + * @param fs + * @param uri + * @param statCache + * @return FileStatus + */ + def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { + val stat = statCache.get(uri) match { + case Some(existstat) => existstat + case None => + val newStat = fs.getFileStatus(new Path(uri)) + statCache.put(uri, newStat) + newStat + } + return stat + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala index ba352daac485d0381a2df87f391b8435a315bc61..7a66532254c74f2393095da68f123727ce5e22f7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala @@ -142,11 +142,12 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S rtype: LocalResourceType, localResources: HashMap[String, LocalResource], timestamp: String, - size: String) = { + size: String, + vis: String) = { val uri = new URI(file) val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] amJarRsrc.setType(rtype) - amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) + amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis)) amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri)) amJarRsrc.setTimestamp(timestamp.toLong) amJarRsrc.setSize(size.toLong) @@ -158,44 +159,14 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S logInfo("Preparing Local resources") val localResources = HashMap[String, LocalResource]() - // Spark JAR - val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - sparkJarResource.setType(LocalResourceType.FILE) - sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION) - sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI( - new URI(System.getenv("SPARK_YARN_JAR_PATH")))) - sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong) - sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong) - localResources("spark.jar") = sparkJarResource - // User JAR - val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - userJarResource.setType(LocalResourceType.FILE) - userJarResource.setVisibility(LocalResourceVisibility.APPLICATION) - userJarResource.setResource(ConverterUtils.getYarnUrlFromURI( - new URI(System.getenv("SPARK_YARN_USERJAR_PATH")))) - userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong) - userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong) - localResources("app.jar") = userJarResource - - // Log4j conf - if available - if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { - val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] - log4jConfResource.setType(LocalResourceType.FILE) - log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION) - log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI( - new URI(System.getenv("SPARK_YARN_LOG4J_PATH")))) - log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong) - log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong) - localResources("log4j.properties") = log4jConfResource - } - if (System.getenv("SPARK_YARN_CACHE_FILES") != null) { val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',') for( i <- 0 to distFiles.length - 1) { setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i), - fileSizes(i)) + fileSizes(i), visibilities(i)) } } @@ -203,9 +174,10 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',') val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',') val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',') + val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',') for( i <- 0 to distArchives.length - 1) { setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources, - timeStamps(i), fileSizes(i)) + timeStamps(i), fileSizes(i), visibilities(i)) } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..c0a2af0c6faf35cc88f3f770de4e9fb3fe87749c --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI; + +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar +import org.mockito.Mockito.when + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.permission.FsAction +import org.apache.hadoop.yarn.api.records.LocalResource +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility +import org.apache.hadoop.yarn.api.records.LocalResourceType +import org.apache.hadoop.yarn.util.{Records, ConverterUtils} + +import scala.collection.mutable.HashMap +import scala.collection.mutable.Map + + +class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + + class MockClientDistributedCacheManager extends ClientDistributedCacheManager { + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + LocalResourceVisibility = { + return LocalResourceVisibility.PRIVATE + } + } + + test("test getFileStatus empty") { + val distMgr = new ClientDistributedCacheManager() + val fs = mock[FileSystem] + val uri = new URI("/tmp/testing") + when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + val stat = distMgr.getFileStatus(fs, uri, statCache) + assert(stat.getPath() === null) + } + + test("test getFileStatus cached") { + val distMgr = new ClientDistributedCacheManager() + val fs = mock[FileSystem] + val uri = new URI("/tmp/testing") + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + null, new Path("/tmp/testing")) + when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) + val stat = distMgr.getFileStatus(fs, uri, statCache) + assert(stat.getPath().toString() === "/tmp/testing") + } + + test("test addResource") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) + + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + statCache, false) + val resource = localResources("link") + assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath) + assert(resource.getTimestamp() === 0) + assert(resource.getSize() === 0) + assert(resource.getType() === LocalResourceType.FILE) + + val env = new HashMap[String, String]() + distMgr.setDistFilesEnv(env) + assert(env("SPARK_YARN_CACHE_FILES") === "file:/foo.invalid.com:8080/tmp/testing#link") + assert(env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === "0") + assert(env("SPARK_YARN_CACHE_FILES_FILE_SIZES") === "0") + assert(env("SPARK_YARN_CACHE_FILES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name()) + + distMgr.setDistArchivesEnv(env) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) + + //add another one and verify both there and order correct + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + null, new Path("/tmp/testing2")) + val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") + when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + statCache, false) + val resource2 = localResources("link2") + assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === destPath2) + assert(resource2.getTimestamp() === 10) + assert(resource2.getSize() === 20) + assert(resource2.getType() === LocalResourceType.FILE) + + val env2 = new HashMap[String, String]() + distMgr.setDistFilesEnv(env2) + val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') + val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') + assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") + assert(timestamps(0) === "0") + assert(sizes(0) === "0") + assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name()) + + assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2") + assert(timestamps(1) === "10") + assert(sizes(1) === "20") + assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name()) + } + + test("test addResource link null") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) + + intercept[Exception] { + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + statCache, false) + } + assert(localResources.get("link") === None) + assert(localResources.size === 0) + } + + test("test addResource appmaster only") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + null, new Path("/tmp/testing")) + when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) + + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + statCache, true) + val resource = localResources("link") + assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath) + assert(resource.getTimestamp() === 10) + assert(resource.getSize() === 20) + assert(resource.getType() === LocalResourceType.ARCHIVE) + + val env = new HashMap[String, String]() + distMgr.setDistFilesEnv(env) + assert(env.get("SPARK_YARN_CACHE_FILES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None) + + distMgr.setDistArchivesEnv(env) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) + } + + test("test addResource archive") { + val distMgr = new MockClientDistributedCacheManager() + val fs = mock[FileSystem] + val conf = new Configuration() + val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") + val localResources = HashMap[String, LocalResource]() + val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + null, new Path("/tmp/testing")) + when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) + + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + statCache, false) + val resource = localResources("link") + assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) + assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath) + assert(resource.getTimestamp() === 10) + assert(resource.getSize() === 20) + assert(resource.getType() === LocalResourceType.ARCHIVE) + + val env = new HashMap[String, String]() + + distMgr.setDistArchivesEnv(env) + assert(env("SPARK_YARN_CACHE_ARCHIVES") === "file:/foo.invalid.com:8080/tmp/testing#link") + assert(env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === "10") + assert(env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === "20") + assert(env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name()) + + distMgr.setDistFilesEnv(env) + assert(env.get("SPARK_YARN_CACHE_FILES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None) + assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None) + } + + +}