diff --git a/bagel/pom.xml b/bagel/pom.xml index be2e3580917d873e802dbc48051d2689591c0120..b83a0ef6c0f375e9faab544b44e4731786ff9194 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -102,5 +102,42 @@ </plugins> </build> </profile> + <profile> + <id>hadoop2-yarn</id> + <dependencies> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-core</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-api</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-common</artifactId> + <scope>provided</scope> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <configuration> + <classifier>hadoop2-yarn</classifier> + </configuration> + </plugin> + </plugins> + </build> + </profile> </profiles> </project> diff --git a/core/pom.xml b/core/pom.xml index 08717860a77356a42d4660928efeea911d3beee9..da26d674ec02e782316bd302d8f0718242605c31 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -279,5 +279,72 @@ </plugins> </build> </profile> + <profile> + <id>hadoop2-yarn</id> + <dependencies> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-api</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-common</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-client</artifactId> + <scope>provided</scope> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <executions> + <execution> + <id>add-source</id> + <phase>generate-sources</phase> + <goals> + <goal>add-source</goal> + </goals> + <configuration> + <sources> + <source>src/main/scala</source> + <source>src/hadoop2-yarn/scala</source> + </sources> + </configuration> + </execution> + <execution> + <id>add-scala-test-sources</id> + <phase>generate-test-sources</phase> + <goals> + <goal>add-test-source</goal> + </goals> + <configuration> + <sources> + <source>src/test/scala</source> + </sources> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <configuration> + <classifier>hadoop2-yarn</classifier> + </configuration> + </plugin> + </plugins> + </build> + </profile> </profiles> </project> diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala index ca9f7219de7399218634d1833377e3e8719b7886..f286f2cf9c5122935fe39e77861f44304566ac79 100644 --- a/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -4,4 +4,7 @@ trait HadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId) def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala index de7b0f81e38016cc03eafad530afa9d38c35135e..264d421d14a18964117ec8970b05fb8ffc501f03 100644 --- a/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala +++ b/core/src/hadoop1/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -6,4 +6,7 @@ trait HadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId) def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..a0fb4fe25d188b02f1ac62245eaaa0de86b5b46c --- /dev/null +++ b/core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,23 @@ +package spark.deploy +import org.apache.hadoop.conf.Configuration + + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + def getUserNameFromEnvironment(): String = { + // defaulting to -D ... + System.getProperty("user.name") + } + + def runAsUser(func: (Product) => Unit, args: Product) { + + // Add support, if exists - for now, simply run func ! + func(args) + } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + def newConfiguration(): Configuration = new Configuration() +} diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..875c0a220bd36a911ffb5e12a3c99658276df41f --- /dev/null +++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -0,0 +1,13 @@ + +package org.apache.hadoop.mapred + +import org.apache.hadoop.mapreduce.TaskType + +trait HadoopMapRedUtil { + def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = + new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId) +} diff --git a/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..8bc6fb6dea18687a2df0dfbb032c9ba0f600d349 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -0,0 +1,13 @@ +package org.apache.hadoop.mapreduce + +import org.apache.hadoop.conf.Configuration +import task.{TaskAttemptContextImpl, JobContextImpl} + +trait HadoopMapReduceUtil { + def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) + + def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = + new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId) +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..ab1ab9d8a7a8e8530938f6a1ddc4c860cffb1af0 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,63 @@ +package spark.deploy + +import collection.mutable.HashMap +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import java.security.PrivilegedExceptionAction + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + val yarnConf = newConfiguration() + + def getUserNameFromEnvironment(): String = { + // defaulting to env if -D is not present ... + val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name)) + + // If nothing found, default to user we are running as + if (retval == null) System.getProperty("user.name") else retval + } + + def runAsUser(func: (Product) => Unit, args: Product) { + runAsUser(func, args, getUserNameFromEnvironment()) + } + + def runAsUser(func: (Product) => Unit, args: Product, user: String) { + + // println("running as user " + jobUserName) + + UserGroupInformation.setConfiguration(yarnConf) + val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user) + appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + func(args) + // no return value ... + null + } + }) + } + + // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true. + def isYarnMode(): Boolean = { + val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")) + java.lang.Boolean.valueOf(yarnMode) + } + + // Set an env variable indicating we are running in YARN mode. + // Note that anything with SPARK prefix gets propagated to all (remote) processes + def setYarnMode() { + System.setProperty("SPARK_YARN_MODE", "true") + } + + def setYarnMode(env: HashMap[String, String]) { + env("SPARK_YARN_MODE") = "true" + } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + // Always create a new config, dont reuse yarnConf. + def newConfiguration(): Configuration = new YarnConfiguration(new Configuration()) +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae719267e8062d90009fd0658c9fcc161c4fb645 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMaster.scala @@ -0,0 +1,342 @@ +package spark.deploy.yarn + +import java.net.Socket +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import scala.collection.JavaConversions._ +import spark.{SparkContext, Logging, Utils} +import org.apache.hadoop.security.UserGroupInformation +import java.security.PrivilegedExceptionAction + +class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging { + + def this(args: ApplicationMasterArguments) = this(args, new Configuration()) + + private var rpc: YarnRPC = YarnRPC.create(conf) + private var resourceManager: AMRMProtocol = null + private var appAttemptId: ApplicationAttemptId = null + private var userThread: Thread = null + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + private var yarnAllocator: YarnAllocationHandler = null + + def run() { + + // Initialization + val jobUserName = Utils.getUserNameFromEnvironment() + logInfo("running as user " + jobUserName) + + // run as user ... + UserGroupInformation.setConfiguration(yarnConf) + val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName) + appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] { + def run: AnyRef = { + runImpl() + return null + } + }) + } + + private def runImpl() { + + appAttemptId = getApplicationAttemptId() + resourceManager = registerWithResourceManager() + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + + // Compute number of threads for akka + val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() + + if (minimumMemory > 0) { + val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) + + if (numCore > 0) { + // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 + // TODO: Uncomment when hadoop is on a version which has this fixed. + // args.workerCores = numCore + } + } + + // Workaround until hadoop moves to something which has + // https://issues.apache.org/jira/browse/HADOOP-8406 + // ignore result + // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times + // Hence args.workerCores = numCore disabled above. Any better option ? + // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + + ApplicationMaster.register(this) + // Start the user's JAR + userThread = startUserClass() + + // This a bit hacky, but we need to wait until the spark.driver.port property has + // been set by the Thread executing the user class. + waitForSparkMaster() + + // Allocate all containers + allocateWorkers() + + // Wait for the user class to Finish + userThread.join() + + // Finish the ApplicationMaster + finishApplicationMaster() + // TODO: Exit based on success/failure + System.exit(0) + } + + private def getApplicationAttemptId(): ApplicationAttemptId = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) + return appAttemptId + } + + private def registerWithResourceManager(): AMRMProtocol = { + val rmAddress = NetUtils.createSocketAddr(yarnConf.get( + YarnConfiguration.RM_SCHEDULER_ADDRESS, + YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) + logInfo("Connecting to ResourceManager at " + rmAddress) + return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] + } + + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { + logInfo("Registering the ApplicationMaster") + val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) + .asInstanceOf[RegisterApplicationMasterRequest] + appMasterRequest.setApplicationAttemptId(appAttemptId) + // Setting this to master host,port - so that the ApplicationReport at client has some sensible info. + // Users can then monitor stderr/stdout on that node if required. + appMasterRequest.setHost(Utils.localHostName()) + appMasterRequest.setRpcPort(0) + // What do we provide here ? Might make sense to expose something sensible later ? + appMasterRequest.setTrackingUrl("") + return resourceManager.registerApplicationMaster(appMasterRequest) + } + + private def waitForSparkMaster() { + logInfo("Waiting for spark driver to be reachable.") + var driverUp = false + while(!driverUp) { + val driverHost = System.getProperty("spark.driver.host") + val driverPort = System.getProperty("spark.driver.port") + try { + val socket = new Socket(driverHost, driverPort.toInt) + socket.close() + logInfo("Master now available: " + driverHost + ":" + driverPort) + driverUp = true + } catch { + case e: Exception => + logError("Failed to connect to driver at " + driverHost + ":" + driverPort) + Thread.sleep(100) + } + } + } + + private def startUserClass(): Thread = { + logInfo("Starting the user JAR in a separate Thread") + val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader) + .getMethod("main", classOf[Array[String]]) + val t = new Thread { + override def run() { + var mainArgs: Array[String] = null + var startIndex = 0 + + // I am sure there is a better 'scala' way to do this .... but I am just trying to get things to work right now ! + if (args.userArgs.isEmpty || args.userArgs.get(0) != "yarn-standalone") { + // ensure that first param is ALWAYS "yarn-standalone" + mainArgs = new Array[String](args.userArgs.size() + 1) + mainArgs.update(0, "yarn-standalone") + startIndex = 1 + } + else { + mainArgs = new Array[String](args.userArgs.size()) + } + + args.userArgs.copyToArray(mainArgs, startIndex, args.userArgs.size()) + + mainMethod.invoke(null, mainArgs) + } + } + t.start() + return t + } + + private def allocateWorkers() { + logInfo("Waiting for spark context initialization") + + try { + var sparkContext: SparkContext = null + ApplicationMaster.sparkContextRef.synchronized { + var count = 0 + while (ApplicationMaster.sparkContextRef.get() == null) { + logInfo("Waiting for spark context initialization ... " + count) + count = count + 1 + ApplicationMaster.sparkContextRef.wait(10000L) + } + sparkContext = ApplicationMaster.sparkContextRef.get() + assert(sparkContext != null) + this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData) + } + + + logInfo("Allocating " + args.numWorkers + " workers.") + // Wait until all containers have finished + // TODO: This is a bit ugly. Can we make it nicer? + // TODO: Handle container failure + while(yarnAllocator.getNumWorkersRunning < args.numWorkers && + // If user thread exists, then quit ! + userThread.isAlive) { + + this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + ApplicationMaster.incrementAllocatorLoop(1) + Thread.sleep(100) + } + } finally { + // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT : + // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks + ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) + } + logInfo("All workers have launched.") + + // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout + if (userThread.isAlive){ + // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. + + val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + // must be <= timeoutInterval/ 2. + // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. + // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. + val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + launchReporterThread(interval) + } + } + + // TODO: We might want to extend this to allocate more containers in case they die ! + private def launchReporterThread(_sleepTime: Long): Thread = { + val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime + + val t = new Thread { + override def run() { + while (userThread.isAlive){ + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning + if (missingWorkerCount > 0) { + logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers") + yarnAllocator.allocateContainers(missingWorkerCount) + } + else sendProgress() + Thread.sleep(sleepTime) + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.start() + logInfo("Started progress reporter thread - sleep time : " + sleepTime) + return t + } + + private def sendProgress() { + logDebug("Sending progress") + // simulated with an allocate request with no nodes requested ... + yarnAllocator.allocateContainers(0) + } + + /* + def printContainers(containers: List[Container]) = { + for (container <- containers) { + logInfo("Launching shell command on a new container." + + ", containerId=" + container.getId() + + ", containerNode=" + container.getNodeId().getHost() + + ":" + container.getNodeId().getPort() + + ", containerNodeURI=" + container.getNodeHttpAddress() + + ", containerState" + container.getState() + + ", containerResourceMemory" + + container.getResource().getMemory()) + } + } + */ + + def finishApplicationMaster() { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + // TODO: Check if the application has failed or succeeded + finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED) + resourceManager.finishApplicationMaster(finishReq) + } + +} + +object ApplicationMaster { + // number of times to wait for the allocator loop to complete. + // each loop iteration waits for 100ms, so maximum of 3 seconds. + // This is to ensure that we have reasonable number of containers before we start + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more + // containers are available. Might need to handle this better. + private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + def incrementAllocatorLoop(by: Int) { + val count = yarnAllocatorLoop.getAndAdd(by) + if (count >= ALLOCATOR_LOOP_WAIT_COUNT){ + yarnAllocatorLoop.synchronized { + // to wake threads off wait ... + yarnAllocatorLoop.notifyAll() + } + } + } + + private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() + + def register(master: ApplicationMaster) { + applicationMasters.add(master) + } + + val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) + val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) + + def sparkContextInitialized(sc: SparkContext): Boolean = { + var modified = false + sparkContextRef.synchronized { + modified = sparkContextRef.compareAndSet(null, sc) + sparkContextRef.notifyAll() + } + + // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit + // Should not really have to do this, but it helps yarn to evict resources earlier. + // not to mention, prevent Client declaring failure even though we exit'ed properly. + if (modified) { + Runtime.getRuntime().addShutdownHook(new Thread with Logging { + // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run' + logInfo("Adding shutdown hook for context " + sc) + override def run() { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + // best case ... + for (master <- applicationMasters) master.finishApplicationMaster + } + } ) + } + + // Wait for initialization to complete and atleast 'some' nodes can get allocated + yarnAllocatorLoop.synchronized { + while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){ + yarnAllocatorLoop.wait(1000L) + } + } + modified + } + + def main(argStrings: Array[String]) { + val args = new ApplicationMasterArguments(argStrings) + new ApplicationMaster(args).run() + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala new file mode 100644 index 0000000000000000000000000000000000000000..dc89125d8184a2ee5cc367b0bdf8fc853217c1e2 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -0,0 +1,78 @@ +package spark.deploy.yarn + +import spark.util.IntParam +import collection.mutable.ArrayBuffer + +class ApplicationMasterArguments(val args: Array[String]) { + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 + var workerCores = 1 + var numWorkers = 2 + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer = new ArrayBuffer[String]() + + var args = inputArgs + + while (! args.isEmpty) { + + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: IntParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + } + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: spark.deploy.yarn.ApplicationMaster [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n") + System.exit(exitCode) + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala new file mode 100644 index 0000000000000000000000000000000000000000..7a881e26dfffc423f7ab60d719926a9581db4dd9 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala @@ -0,0 +1,272 @@ +package spark.deploy.yarn + +import java.net.{InetSocketAddress, URI} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.client.YarnClientImpl +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ +import spark.{Logging, Utils} +import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils} +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import spark.deploy.SparkHadoopUtil + +class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging { + + def this(args: ClientArguments) = this(new Configuration(), args) + + var rpc: YarnRPC = YarnRPC.create(conf) + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run() { + init(yarnConf) + start() + logClusterResourceDetails() + + val newApp = super.getNewApplication() + val appId = newApp.getApplicationId() + + verifyClusterResources(newApp) + val appContext = createApplicationSubmissionContext(appId) + val localResources = prepareLocalResources(appId, "spark") + val env = setupLaunchEnv(localResources) + val amContainer = createContainerLaunchContext(newApp, localResources, env) + + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(amContainer) + appContext.setUser(args.amUser) + + submitApp(appContext) + + monitorApplication(appId) + System.exit(0) + } + + + def logClusterResourceDetails() { + val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics + logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers) + + val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue) + logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity + + ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size + + ", queueChildQueueCount=" + queueInfo.getChildQueues.size) + } + + + def verifyClusterResources(app: GetNewApplicationResponse) = { + val maxMem = app.getMaximumResourceCapability().getMemory() + logInfo("Max mem capabililty of resources in this cluster " + maxMem) + + // If the cluster does not have enough memory resources, exit. + val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory + if (requestedMem > maxMem) { + logError("Cluster cannot satisfy memory resource request of " + requestedMem) + System.exit(1) + } + } + + def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = { + logInfo("Setting up application submission context for ASM") + val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) + appContext.setApplicationId(appId) + appContext.setApplicationName("Spark") + return appContext + } + + def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val locaResources = HashMap[String, LocalResource]() + // Upload Spark and the application JAR to the remote file system + // Add them as local resources to the AM + val fs = FileSystem.get(conf) + Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF")) + .foreach { case(destName, _localPath) => + val localPath: String = if (_localPath != null) _localPath.trim() else "" + if (! localPath.isEmpty()) { + val src = new Path(localPath) + val pathSuffix = appName + "/" + appId.getId() + destName + val dst = new Path(fs.getHomeDirectory(), pathSuffix) + logInfo("Uploading " + src + " to " + dst) + fs.copyFromLocalFile(false, true, src, dst) + val destStatus = fs.getFileStatus(dst) + + val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + amJarRsrc.setType(LocalResourceType.FILE) + amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION) + amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst)) + amJarRsrc.setTimestamp(destStatus.getModificationTime()) + amJarRsrc.setSize(destStatus.getLen()) + locaResources(destName) = amJarRsrc + } + } + return locaResources + } + + def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = { + logInfo("Setting up the launch environment") + val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null) + + val env = new HashMap[String, String]() + Apps.addToEnvironment(env, Environment.USER.name, args.amUser) + + // If log4j present, ensure ours overrides all others + if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") + Client.populateHadoopClasspath(yarnConf, env) + SparkHadoopUtil.setYarnMode(env) + env("SPARK_YARN_JAR_PATH") = + localResources("spark.jar").getResource().getScheme.toString() + "://" + + localResources("spark.jar").getResource().getFile().toString() + env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString() + env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString() + + env("SPARK_YARN_USERJAR_PATH") = + localResources("app.jar").getResource().getScheme.toString() + "://" + + localResources("app.jar").getResource().getFile().toString() + env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString() + env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString() + + if (log4jConfLocalRes != null) { + env("SPARK_YARN_LOG4J_PATH") = + log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString() + env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString() + env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString() + } + + // Add each SPARK-* key to the environment + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + return env + } + + def userArgsToString(clientArgs: ClientArguments): String = { + val prefix = " --args " + val args = clientArgs.userArgs + val retval = new StringBuilder() + for (arg <- args){ + retval.append(prefix).append(" '").append(arg).append("' ") + } + + retval.toString + } + + def createContainerLaunchContext(newApp: GetNewApplicationResponse, + localResources: HashMap[String, LocalResource], + env: HashMap[String, String]): ContainerLaunchContext = { + logInfo("Setting up container launch context") + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) + amContainer.setLocalResources(localResources) + amContainer.setEnvironment(env) + + val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() + + var amMemory = ((args.amMemory / minResMemory) * minResMemory) + + (if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD + + // Extra options for the JVM + var JAVA_OPTS = "" + + // Add Xmx for am memory + JAVA_OPTS += "-Xmx" + amMemory + "m " + + // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same + // node, spark gc effects all other containers performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is + // limited to subset of cores on a node. + if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) { + // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + + // Command for the ApplicationMaster + val commands = List[String]("java " + + " -server " + + JAVA_OPTS + + " spark.deploy.yarn.ApplicationMaster" + + " --class " + args.userClass + + " --jar " + args.userJar + + userArgsToString(args) + + " --worker-memory " + args.workerMemory + + " --worker-cores " + args.workerCores + + " --num-workers " + args.numWorkers + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Command for the ApplicationMaster: " + commands(0)) + amContainer.setCommands(commands) + + val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] + // Memory for the ApplicationMaster + capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + amContainer.setResource(capability) + + return amContainer + } + + def submitApp(appContext: ApplicationSubmissionContext) = { + // Submit the application to the applications manager + logInfo("Submitting application to ASM") + super.submitApplication(appContext) + } + + def monitorApplication(appId: ApplicationId): Boolean = { + while(true) { + Thread.sleep(1000) + val report = super.getApplicationReport(appId) + + logInfo("Application report from ASM: \n" + + "\t application identifier: " + appId.toString() + "\n" + + "\t appId: " + appId.getId() + "\n" + + "\t clientToken: " + report.getClientToken() + "\n" + + "\t appDiagnostics: " + report.getDiagnostics() + "\n" + + "\t appMasterHost: " + report.getHost() + "\n" + + "\t appQueue: " + report.getQueue() + "\n" + + "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + + "\t appStartTime: " + report.getStartTime() + "\n" + + "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + + "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + + "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + + "\t appUser: " + report.getUser() + ) + + val state = report.getYarnApplicationState() + val dsStatus = report.getFinalApplicationStatus() + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + return true + } + } + return true + } +} + +object Client { + def main(argStrings: Array[String]) { + val args = new ClientArguments(argStrings) + SparkHadoopUtil.setYarnMode() + new Client(args).run + } + + // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) { + for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim) + } + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala new file mode 100644 index 0000000000000000000000000000000000000000..2e69fe3fb0566130238b7d76b423d448d29724c2 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/ClientArguments.scala @@ -0,0 +1,106 @@ +package spark.deploy.yarn + +import spark.util.MemoryParam +import spark.util.IntParam +import collection.mutable.{ArrayBuffer, HashMap} +import spark.scheduler.{InputFormatInfo, SplitInfo} + +// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware ! +class ClientArguments(val args: Array[String]) { + var userJar: String = null + var userClass: String = null + var userArgs: Seq[String] = Seq[String]() + var workerMemory = 1024 + var workerCores = 1 + var numWorkers = 2 + var amUser = System.getProperty("user.name") + var amQueue = System.getProperty("QUEUE", "default") + var amMemory: Int = 512 + // TODO + var inputFormatInfo: List[InputFormatInfo] = null + + parseArgs(args.toList) + + private def parseArgs(inputArgs: List[String]): Unit = { + val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() + val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]() + + var args = inputArgs + + while (! args.isEmpty) { + + args match { + case ("--jar") :: value :: tail => + userJar = value + args = tail + + case ("--class") :: value :: tail => + userClass = value + args = tail + + case ("--args") :: value :: tail => + userArgsBuffer += value + args = tail + + case ("--master-memory") :: MemoryParam(value) :: tail => + amMemory = value + args = tail + + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + + case ("--worker-memory") :: MemoryParam(value) :: tail => + workerMemory = value + args = tail + + case ("--worker-cores") :: IntParam(value) :: tail => + workerCores = value + args = tail + + case ("--user") :: value :: tail => + amUser = value + args = tail + + case ("--queue") :: value :: tail => + amQueue = value + args = tail + + case Nil => + if (userJar == null || userClass == null) { + printUsageAndExit(1) + } + + case _ => + printUsageAndExit(1, args) + } + } + + userArgs = userArgsBuffer.readOnly + inputFormatInfo = inputFormatMap.values.toList + } + + + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + if (unknownParam != null) { + System.err.println("Unknown/unsupported param " + unknownParam) + } + System.err.println( + "Usage: spark.deploy.yarn.Client [options] \n" + + "Options:\n" + + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + + " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + + " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" + + " --user USERNAME Run the ApplicationMaster (and slaves) as a different user\n" + ) + System.exit(exitCode) + } + +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala new file mode 100644 index 0000000000000000000000000000000000000000..a2bf0af762417e2d9d3085abef1983a6caa39448 --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/WorkerRunnable.scala @@ -0,0 +1,171 @@ +package spark.deploy.yarn + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + +import scala.collection.JavaConversions._ +import scala.collection.mutable.HashMap + +import spark.{Logging, Utils} + +class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String, + slaveId: String, hostname: String, workerMemory: Int, workerCores: Int) + extends Runnable with Logging { + + var rpc: YarnRPC = YarnRPC.create(conf) + var cm: ContainerManager = null + val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + def run = { + logInfo("Starting Worker Container") + cm = connectToCM + startContainer + } + + def startContainer = { + logInfo("Setting up ContainerLaunchContext") + + val ctx = Records.newRecord(classOf[ContainerLaunchContext]) + .asInstanceOf[ContainerLaunchContext] + + ctx.setContainerId(container.getId()) + ctx.setResource(container.getResource()) + val localResources = prepareLocalResources + ctx.setLocalResources(localResources) + + val env = prepareEnvironment + ctx.setEnvironment(env) + + // Extra options for the JVM + var JAVA_OPTS = "" + // Set the JVM memory + val workerMemoryString = workerMemory + "m" + JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " " + if (env.isDefinedAt("SPARK_JAVA_OPTS")) { + JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + } + // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out. + // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same + // node, spark gc effects all other containers performance (which can also be other spark containers) + // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is + // limited to subset of cores on a node. +/* + else { + // If no java_opts specified, default to using -XX:+CMSIncrementalMode + // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it. + // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines + // The options are based on + // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline + JAVA_OPTS += " -XX:+UseConcMarkSweepGC " + JAVA_OPTS += " -XX:+CMSIncrementalMode " + JAVA_OPTS += " -XX:+CMSIncrementalPacing " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 " + JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 " + } +*/ + + ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) + val commands = List[String]("java " + + " -server " + + // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. + // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state. + // TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ? + " -XX:OnOutOfMemoryError='kill %p' " + + JAVA_OPTS + + " spark.executor.StandaloneExecutorBackend " + + masterAddress + " " + + slaveId + " " + + hostname + " " + + workerCores + + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") + logInfo("Setting up worker with commands: " + commands) + ctx.setCommands(commands) + + // Send the start request to the ContainerManager + val startReq = Records.newRecord(classOf[StartContainerRequest]) + .asInstanceOf[StartContainerRequest] + startReq.setContainerLaunchContext(ctx) + cm.startContainer(startReq) + } + + + def prepareLocalResources: HashMap[String, LocalResource] = { + logInfo("Preparing Local resources") + val locaResources = HashMap[String, LocalResource]() + + // Spark JAR + val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + sparkJarResource.setType(LocalResourceType.FILE) + sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION) + sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_JAR_PATH")))) + sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong) + sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong) + locaResources("spark.jar") = sparkJarResource + // User JAR + val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + userJarResource.setType(LocalResourceType.FILE) + userJarResource.setVisibility(LocalResourceVisibility.APPLICATION) + userJarResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_USERJAR_PATH")))) + userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong) + userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong) + locaResources("app.jar") = userJarResource + + // Log4j conf - if available + if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { + val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + log4jConfResource.setType(LocalResourceType.FILE) + log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION) + log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI( + new URI(System.getenv("SPARK_YARN_LOG4J_PATH")))) + log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong) + log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong) + locaResources("log4j.properties") = log4jConfResource + } + + + logInfo("Prepared Local resources " + locaResources) + return locaResources + } + + def prepareEnvironment: HashMap[String, String] = { + val env = new HashMap[String, String]() + // should we add this ? + Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment()) + + // If log4j present, ensure ours overrides all others + if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) { + // Which is correct ? + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./") + } + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*") + Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH") + Client.populateHadoopClasspath(yarnConf, env) + + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + return env + } + + def connectToCM: ContainerManager = { + val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort() + val cmAddress = NetUtils.createSocketAddr(cmHostPortStr) + logInfo("Connecting to ContainerManager at " + cmHostPortStr) + return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager] + } + +} diff --git a/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala new file mode 100644 index 0000000000000000000000000000000000000000..61dd72a6518ca9bebbb5b30c4c6d88ed540173be --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/deploy/yarn/YarnAllocationHandler.scala @@ -0,0 +1,547 @@ +package spark.deploy.yarn + +import spark.{Logging, Utils} +import spark.scheduler.SplitInfo +import scala.collection +import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container} +import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend} +import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.util.{RackResolver, Records} +import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} +import java.util.concurrent.atomic.AtomicInteger +import org.apache.hadoop.yarn.api.AMRMProtocol +import collection.JavaConversions._ +import collection.mutable.{ArrayBuffer, HashMap, HashSet} +import org.apache.hadoop.conf.Configuration +import java.util.{Collections, Set => JSet} +import java.lang.{Boolean => JBoolean} + +object AllocationType extends Enumeration ("HOST", "RACK", "ANY") { + type AllocationType = Value + val HOST, RACK, ANY = Value +} + +// too many params ? refactor it 'somehow' ? +// needs to be mt-safe +// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it +// more proactive and decoupled. +// Note that right now, we assume all node asks as uniform in terms of capabilities and priority +// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info +// on how we are requesting for containers. +private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol, + val appAttemptId: ApplicationAttemptId, + val maxWorkers: Int, val workerMemory: Int, val workerCores: Int, + val preferredHostToCount: Map[String, Int], + val preferredRackToCount: Map[String, Int]) + extends Logging { + + + // These three are locked on allocatedHostToContainersMap. Complementary data structures + // allocatedHostToContainersMap : containers which are running : host, Set<containerid> + // allocatedContainerToHostMap: container to host mapping + private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]() + private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node) + // As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap + private val allocatedRackCount = new HashMap[String, Int]() + + // containers which have been released. + private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() + // containers to be released in next request to RM + private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + + private val numWorkersRunning = new AtomicInteger() + // Used to generate a unique id per worker + private val workerIdCounter = new AtomicInteger() + private val lastResponseId = new AtomicInteger() + + def getNumWorkersRunning: Int = numWorkersRunning.intValue + + + def isResourceConstraintSatisfied(container: Container): Boolean = { + container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + } + + def allocateContainers(workersToRequest: Int) { + // We need to send the request only once from what I understand ... but for now, not modifying this much. + + // Keep polling the Resource Manager for containers + val amResp = allocateWorkerResources(workersToRequest).getAMResponse + + val _allocatedContainers = amResp.getAllocatedContainers() + if (_allocatedContainers.size > 0) { + + + logDebug("Allocated " + _allocatedContainers.size + " containers, current count " + + numWorkersRunning.get() + ", to-be-released " + releasedContainerList + + ", pendingReleaseContainers : " + pendingReleaseContainers) + logDebug("Cluster Resources: " + amResp.getAvailableResources) + + val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() + + // ignore if not satisfying constraints { + for (container <- _allocatedContainers) { + if (isResourceConstraintSatisfied(container)) { + // allocatedContainers += container + + val host = container.getNodeId.getHost + val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]()) + + containers += container + } + // Add all ignored containers to released list + else releasedContainerList.add(container.getId()) + } + + // Find the appropriate containers to use + // Slightly non trivial groupBy I guess ... + val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (candidateHost <- hostToContainers.keySet) + { + val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) + val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) + + var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null) + assert(remainingContainers != null) + + if (requiredHostCount >= remainingContainers.size){ + // Since we got <= required containers, add all to dataLocalContainers + dataLocalContainers.put(candidateHost, remainingContainers) + // all consumed + remainingContainers = null + } + else if (requiredHostCount > 0) { + // container list has more containers than we need for data locality. + // Split into two : data local container count of (remainingContainers.size - requiredHostCount) + // and rest as remainingContainer + val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount) + dataLocalContainers.put(candidateHost, dataLocal) + // remainingContainers = remaining + + // yarn has nasty habit of allocating a tonne of containers on a host - discourage this : + // add remaining to release list. If we have insufficient containers, next allocation cycle + // will reallocate (but wont treat it as data local) + for (container <- remaining) releasedContainerList.add(container.getId()) + remainingContainers = null + } + + // now rack local + if (remainingContainers != null){ + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + + if (rack != null){ + val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) + val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - + rackLocalContainers.get(rack).getOrElse(List()).size + + + if (requiredRackCount >= remainingContainers.size){ + // Add all to dataLocalContainers + dataLocalContainers.put(rack, remainingContainers) + // all consumed + remainingContainers = null + } + else if (requiredRackCount > 0) { + // container list has more containers than we need for data locality. + // Split into two : data local container count of (remainingContainers.size - requiredRackCount) + // and rest as remainingContainer + val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount) + val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]()) + + existingRackLocal ++= rackLocal + remainingContainers = remaining + } + } + } + + // If still not consumed, then it is off rack host - add to that list. + if (remainingContainers != null){ + offRackContainers.put(candidateHost, remainingContainers) + } + } + + // Now that we have split the containers into various groups, go through them in order : + // first host local, then rack local and then off rack (everything else). + // Note that the list we create below tries to ensure that not all containers end up within a host + // if there are sufficiently large number of hosts/containers. + + val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) + allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers) + + // Run each of the allocated containers + for (container <- allocatedContainers) { + val numWorkersRunningNow = numWorkersRunning.incrementAndGet() + val workerHostname = container.getNodeId.getHost + val containerId = container.getId + + assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + + if (numWorkersRunningNow > maxWorkers) { + logInfo("Ignoring container " + containerId + " at host " + workerHostname + + " .. we already have required number of containers") + releasedContainerList.add(containerId) + // reset counter back to old value. + numWorkersRunning.decrementAndGet() + } + else { + // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter) + val workerId = workerIdCounter.incrementAndGet().toString + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), + StandaloneSchedulerBackend.ACTOR_NAME) + + logInfo("launching container on " + containerId + " host " + workerHostname) + // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but .. + pendingReleaseContainers.remove(containerId) + + val rack = YarnAllocationHandler.lookupRack(conf, workerHostname) + allocatedHostToContainersMap.synchronized { + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]()) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, workerHostname) + if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) + } + + new Thread( + new WorkerRunnable(container, conf, driverUrl, workerId, + workerHostname, workerMemory, workerCores) + ).start() + } + } + logDebug("After allocated " + allocatedContainers.size + " containers (orig : " + + _allocatedContainers.size + "), current count " + numWorkersRunning.get() + + ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers) + } + + + val completedContainers = amResp.getCompletedContainersStatuses() + if (completedContainers.size > 0){ + logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() + + ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers) + + for (completedContainer <- completedContainers){ + val containerId = completedContainer.getContainerId + + // Was this released by us ? If yes, then simply remove from containerSet and move on. + if (pendingReleaseContainers.containsKey(containerId)) { + pendingReleaseContainers.remove(containerId) + } + else { + // simply decrement count - next iteration of ReporterThread will take care of allocating ! + numWorkersRunning.decrementAndGet() + logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState + + " httpaddress: " + completedContainer.getDiagnostics) + } + + allocatedHostToContainersMap.synchronized { + if (allocatedContainerToHostMap.containsKey(containerId)) { + val host = allocatedContainerToHostMap.get(containerId).getOrElse(null) + assert (host != null) + + val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null) + assert (containerSet != null) + + containerSet -= containerId + if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host) + else allocatedHostToContainersMap.update(host, containerSet) + + allocatedContainerToHostMap -= containerId + + // doing this within locked context, sigh ... move to outside ? + val rack = YarnAllocationHandler.lookupRack(conf, host) + if (rack != null) { + val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 + if (rackCount > 0) allocatedRackCount.put(rack, rackCount) + else allocatedRackCount.remove(rack) + } + } + } + } + logDebug("After completed " + completedContainers.size + " containers, current count " + + numWorkersRunning.get() + ", to-be-released " + releasedContainerList + + ", pendingReleaseContainers : " + pendingReleaseContainers) + } + } + + def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = { + // First generate modified racks and new set of hosts under it : then issue requests + val rackToCounts = new HashMap[String, Int]() + + // Within this lock - used to read/write to the rack related maps too. + for (container <- hostContainers) { + val candidateHost = container.getHostName + val candidateNumContainers = container.getNumContainers + assert(YarnAllocationHandler.ANY_HOST != candidateHost) + + val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) + if (rack != null) { + var count = rackToCounts.getOrElse(rack, 0) + count += candidateNumContainers + rackToCounts.put(rack, count) + } + } + + val requestedContainers: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](rackToCounts.size) + for ((rack, count) <- rackToCounts){ + requestedContainers += + createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY) + } + + requestedContainers.toList + } + + def allocatedContainersOnHost(host: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedHostToContainersMap.getOrElse(host, Set()).size + } + retval + } + + def allocatedContainersOnRack(rack: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedRackCount.getOrElse(rack, 0) + } + retval + } + + private def allocateWorkerResources(numWorkers: Int): AllocateResponse = { + + var resourceRequests: List[ResourceRequest] = null + + // default. + if (numWorkers <= 0 || preferredHostToCount.isEmpty) { + logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty) + resourceRequests = List( + createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)) + } + else { + // request for all hosts in preferred nodes and for numWorkers - + // candidates.size, request by default allocation policy. + val hostContainerRequests: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](preferredHostToCount.size) + for ((candidateHost, candidateCount) <- preferredHostToCount) { + val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) + + if (requiredCount > 0) { + hostContainerRequests += + createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY) + } + } + val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList) + + val anyContainerRequests: ResourceRequest = + createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY) + + val containerRequests: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1) + + containerRequests ++= hostContainerRequests + containerRequests ++= rackContainerRequests + containerRequests += anyContainerRequests + + resourceRequests = containerRequests.toList + } + + val req = Records.newRecord(classOf[AllocateRequest]) + req.setResponseId(lastResponseId.incrementAndGet) + req.setApplicationAttemptId(appAttemptId) + + req.addAllAsks(resourceRequests) + + val releasedContainerList = createReleasedContainerList() + req.addAllReleases(releasedContainerList) + + + + if (numWorkers > 0) { + logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.") + } + else { + logDebug("Empty allocation req .. release : " + releasedContainerList) + } + + for (req <- resourceRequests) { + logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers + + ", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability) + } + resourceManager.allocate(req) + } + + + private def createResourceRequest(requestType: AllocationType.AllocationType, + resource:String, numWorkers: Int, priority: Int): ResourceRequest = { + + // If hostname specified, we need atleast two requests - node local and rack local. + // There must be a third request - which is ANY : that will be specially handled. + requestType match { + case AllocationType.HOST => { + assert (YarnAllocationHandler.ANY_HOST != resource) + + val hostname = resource + val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority) + + // add to host->rack mapping + YarnAllocationHandler.populateRackInfo(conf, hostname) + + nodeLocal + } + + case AllocationType.RACK => { + val rack = resource + createResourceRequestImpl(rack, numWorkers, priority) + } + + case AllocationType.ANY => { + createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority) + } + + case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType) + } + } + + private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = { + + val rsrcRequest = Records.newRecord(classOf[ResourceRequest]) + val memCapability = Records.newRecord(classOf[Resource]) + // There probably is some overhead here, let's reserve a bit more memory. + memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + rsrcRequest.setCapability(memCapability) + + val pri = Records.newRecord(classOf[Priority]) + pri.setPriority(priority) + rsrcRequest.setPriority(pri) + + rsrcRequest.setHostName(hostname) + + rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0)) + rsrcRequest + } + + def createReleasedContainerList(): ArrayBuffer[ContainerId] = { + + val retval = new ArrayBuffer[ContainerId](1) + // iterator on COW list ... + for (container <- releasedContainerList.iterator()){ + retval += container + } + // remove from the original list. + if (! retval.isEmpty) { + releasedContainerList.removeAll(retval) + for (v <- retval) pendingReleaseContainers.put(v, true) + logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " + + pendingReleaseContainers) + } + + retval + } +} + +object YarnAllocationHandler { + + val ANY_HOST = "*" + // all requests are issued with same priority : we do not (yet) have any distinction between + // request types (like map/reduce in hadoop for example) + val PRIORITY = 1 + + // Additional memory overhead - in mb + val MEMORY_OVERHEAD = 384 + + // host to rack map - saved from allocation requests + // We are expecting this not to change. + // Note that it is possible for this to change : and RM will indicate that to us via update + // response to allocate. But we are punting on handling that for now. + private val hostToRack = new ConcurrentHashMap[String, String]() + private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() + + def newAllocator(conf: Configuration, + resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = { + + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + + + new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers, + args.workerMemory, args.workerCores, hostToCount, rackToCount) + } + + def newAllocator(conf: Configuration, + resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, + maxWorkers: Int, workerMemory: Int, workerCores: Int, + map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = { + + val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + + new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers, + workerMemory, workerCores, hostToCount, rackToCount) + } + + // A simple method to copy the split info map. + private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) : + // host to count, rack to count + (Map[String, Int], Map[String, Int]) = { + + if (input == null) return (Map[String, Int](), Map[String, Int]()) + + val hostToCount = new HashMap[String, Int] + val rackToCount = new HashMap[String, Int] + + for ((host, splits) <- input) { + val hostCount = hostToCount.getOrElse(host, 0) + hostToCount.put(host, hostCount + splits.size) + + val rack = lookupRack(conf, host) + if (rack != null){ + val rackCount = rackToCount.getOrElse(host, 0) + rackToCount.put(host, rackCount + splits.size) + } + } + + (hostToCount.toMap, rackToCount.toMap) + } + + def lookupRack(conf: Configuration, host: String): String = { + if (! hostToRack.contains(host)) populateRackInfo(conf, host) + hostToRack.get(host) + } + + def fetchCachedHostsForRack(rack: String): Option[Set[String]] = { + val set = rackToHostSet.get(rack) + if (set == null) return None + + // No better way to get a Set[String] from JSet ? + val convertedSet: collection.mutable.Set[String] = set + Some(convertedSet.toSet) + } + + def populateRackInfo(conf: Configuration, hostname: String) { + Utils.checkHost(hostname) + + if (!hostToRack.containsKey(hostname)) { + // If there are repeated failures to resolve, all to an ignore list ? + val rackInfo = RackResolver.resolve(conf, hostname) + if (rackInfo != null && rackInfo.getNetworkLocation != null) { + val rack = rackInfo.getNetworkLocation + hostToRack.put(hostname, rack) + if (! rackToHostSet.containsKey(rack)) { + rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]())) + } + rackToHostSet.get(rack).add(hostname) + + // Since RackResolver caches, we are disabling this for now ... + } /* else { + // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... + hostToRack.put(hostname, null) + } */ + } + } +} diff --git a/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala new file mode 100644 index 0000000000000000000000000000000000000000..ed732d36bfc7c6e1d110d6544676f0278835319c --- /dev/null +++ b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -0,0 +1,42 @@ +package spark.scheduler.cluster + +import spark._ +import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler} +import org.apache.hadoop.conf.Configuration + +/** + * + * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done + */ +private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { + + def this(sc: SparkContext) = this(sc, new Configuration()) + + // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate + // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?) + // Subsequent creations are ignored - since nodes are already allocated by then. + + + // By default, rack is unknown + override def getRackForHost(hostPort: String): Option[String] = { + val host = Utils.parseHostPort(hostPort)._1 + val retval = YarnAllocationHandler.lookupRack(conf, host) + if (retval != null) Some(retval) else None + } + + // By default, if rack is unknown, return nothing + override def getCachedHostsForRack(rack: String): Option[Set[String]] = { + if (rack == None || rack == null) return None + + YarnAllocationHandler.fetchCachedHostsForRack(rack) + } + + override def postStartHook() { + val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc) + if (sparkContextInitialized){ + // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + Thread.sleep(3000L) + } + logInfo("YarnClusterScheduler.postStartHook done") + } +} diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala index 35300cea5883521b98b44baa792b5c1638c671d0..a0652d7fc78a9ead97b2a2fec8fbd986f29e25aa 100644 --- a/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapred/HadoopMapRedUtil.scala @@ -4,4 +4,7 @@ trait HadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala index 7afdbff3205c36b7c920ffa20851dfde06db3eec..7fdbe322fdf0ba50d1f62de502247f4991fbf16c 100644 --- a/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala +++ b/core/src/hadoop2/scala/org/apache/hadoop/mapreduce/HadoopMapReduceUtil.scala @@ -7,4 +7,7 @@ trait HadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId) def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + + def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier, + jobId, isMap, taskId, attemptId) } diff --git a/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala new file mode 100644 index 0000000000000000000000000000000000000000..a0fb4fe25d188b02f1ac62245eaaa0de86b5b46c --- /dev/null +++ b/core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala @@ -0,0 +1,23 @@ +package spark.deploy +import org.apache.hadoop.conf.Configuration + + +/** + * Contains util methods to interact with Hadoop from spark. + */ +object SparkHadoopUtil { + + def getUserNameFromEnvironment(): String = { + // defaulting to -D ... + System.getProperty("user.name") + } + + def runAsUser(func: (Product) => Unit, args: Product) { + + // Add support, if exists - for now, simply run func ! + func(args) + } + + // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems + def newConfiguration(): Configuration = new Configuration() +} diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala index 98525b99c84201ba2b9728b3b07a18ffd2c4041e..50d6a1c5c9fb93559bbe470498c72952f2993085 100644 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ b/core/src/main/scala/spark/ClosureCleaner.scala @@ -8,12 +8,20 @@ import scala.collection.mutable.Set import org.objectweb.asm.{ClassReader, MethodVisitor, Type} import org.objectweb.asm.commons.EmptyVisitor import org.objectweb.asm.Opcodes._ +import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream} private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it private def getClassReader(cls: Class[_]): ClassReader = { - new ClassReader(cls.getResourceAsStream( - cls.getName.replaceFirst("^.*\\.", "") + ".class")) + // Copy data over, before delegating to ClassReader - else we can run out of open file handles. + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + // todo: Fixme - continuing with earlier behavior ... + if (resourceStream == null) return new ClassReader(resourceStream) + + val baos = new ByteArrayOutputStream(128) + Utils.copyStream(resourceStream, baos, true) + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) } // Check whether a class represents a Scala closure diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala index a953081d245fab139e944e7217b1d00c7ef37c2b..40b0193f1994fb9099a13d3ac4db1e701b085920 100644 --- a/core/src/main/scala/spark/FetchFailedException.scala +++ b/core/src/main/scala/spark/FetchFailedException.scala @@ -3,18 +3,25 @@ package spark import spark.storage.BlockManagerId private[spark] class FetchFailedException( - val bmAddress: BlockManagerId, - val shuffleId: Int, - val mapId: Int, - val reduceId: Int, + taskEndReason: TaskEndReason, + message: String, cause: Throwable) extends Exception { - - override def getMessage(): String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + + def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(bmAddress, shuffleId, mapId, reduceId), + "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId), + cause) + + def this (shuffleId: Int, reduceId: Int, cause: Throwable) = + this(FetchFailed(null, shuffleId, -1, reduceId), + "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause) + + override def getMessage(): String = message + override def getCause(): Throwable = cause - def toTaskEndReason: TaskEndReason = - FetchFailed(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = taskEndReason + } diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala index afcf9f6db4a2be26ebe334b098c09a861b98a37d..5e8396edb92235de51e2ff6c33d29027ca670c9e 100644 --- a/core/src/main/scala/spark/HadoopWriter.scala +++ b/core/src/main/scala/spark/HadoopWriter.scala @@ -2,14 +2,10 @@ package org.apache.hadoop.mapred import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.io.NullWritable -import org.apache.hadoop.io.Text import java.text.SimpleDateFormat import java.text.NumberFormat import java.io.IOException -import java.net.URI import java.util.Date import spark.Logging @@ -24,7 +20,7 @@ import spark.SerializableWritable * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable { - + private val now = new Date() private val conf = new SerializableWritable(jobConf) @@ -106,6 +102,12 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe } } + def commitJob() { + // always ? Or if cmtr.needsTaskCommit ? + val cmtr = getOutputCommitter() + cmtr.commitJob(getJobContext()) + } + def cleanup() { getOutputCommitter().cleanupJob(getJobContext()) } diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 7c1c1bb1440bcbd263947481ee8a0d762d1587be..0fc8c314630bc9bc5f9223ae7948665b0c7b1e32 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -68,6 +68,10 @@ trait Logging { if (log.isErrorEnabled) log.error(msg, throwable) } + protected def isTraceEnabled(): Boolean = { + log.isTraceEnabled + } + // Method for ensuring that logging is initialized, to avoid having multiple // threads do it concurrently (as SLF4J initialization is not thread safe). protected def initLogging() { log } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 866d630a6d27b8b41b2965c97146d8087d9f450b..fde597ffd1a30fd1891e016b9decc87abf474118 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -1,7 +1,6 @@ package spark import java.io._ -import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap @@ -12,8 +11,7 @@ import akka.dispatch._ import akka.pattern.ask import akka.remote._ import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ + import spark.scheduler.MapStatus import spark.storage.BlockManagerId @@ -40,10 +38,12 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac private[spark] class MapOutputTracker extends Logging { + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ - var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. @@ -52,7 +52,7 @@ private[spark] class MapOutputTracker extends Logging { // Cache a serialized version of the output statuses for each shuffle to send them out faster var cacheGeneration = generation - val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) @@ -60,7 +60,6 @@ private[spark] class MapOutputTracker extends Logging { // throw a SparkException if this fails. def askTracker(message: Any): Any = { try { - val timeout = 10.seconds val future = trackerActor.ask(message)(timeout) return Await.result(future, timeout) } catch { @@ -77,10 +76,9 @@ private[spark] class MapOutputTracker extends Logging { } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.get(shuffleId) != None) { + if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { @@ -101,8 +99,9 @@ private[spark] class MapOutputTracker extends Logging { } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = mapStatuses(shuffleId) - if (array != null) { + var arrayOpt = mapStatuses.get(shuffleId) + if (arrayOpt.isDefined && arrayOpt.get != null) { + var array = arrayOpt.get array.synchronized { if (array(mapId) != null && array(mapId).location == bmAddress) { array(mapId) = null @@ -115,13 +114,14 @@ private[spark] class MapOutputTracker extends Logging { } // Remembers which map output locations are currently being fetched on a worker - val fetching = new HashSet[Int] + private val fetching = new HashSet[Int] // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -132,31 +132,48 @@ private[spark] class MapOutputTracker extends Logging { case e: InterruptedException => } } - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId)) - } else { + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. fetching += shuffleId } } - // We won the race to fetch the output locs; do so - logInfo("Doing the fetch; tracker actor = " + trackerActor) - val host = System.getProperty("spark.hostname", Utils.localHostName) - // This try-finally prevents hangs due to timeouts: - var fetchedStatuses: Array[MapStatus] = null - try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]] - fetchedStatuses = deserializeStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() + + if (fetchedStatuses == null) { + // We won the race to fetch the output locs; do so + logInfo("Doing the fetch; tracker actor = " + trackerActor) + val hostPort = Utils.localHostPort() + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = + askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] + fetchedStatuses = deserializeStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + if (fetchedStatuses != null) { + fetchedStatuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) + else{ + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } } else { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } } } @@ -194,7 +211,8 @@ private[spark] class MapOutputTracker extends Logging { generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + mapStatuses.clear() generation = newGen } } @@ -232,10 +250,13 @@ private[spark] class MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { + private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - objOut.writeObject(statuses) + // Since statuses can be modified in parallel, sync on it + statuses.synchronized { + objOut.writeObject(statuses) + } objOut.close() out.toByteArray } @@ -243,7 +264,10 @@ private[spark] class MapOutputTracker extends Logging { // Opposite of serializeStatuses. def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject().asInstanceOf[Array[MapStatus]] + objIn.readObject(). + // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present + // comment this out - nulls could be due to missing location ? + asInstanceOf[Array[MapStatus]] // .filter( _ != null ) } } @@ -253,14 +277,11 @@ private[spark] object MapOutputTracker { // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If // any of the statuses is null (indicating a missing location due to a failed mapper), // throw a FetchFailedException. - def convertMapStatuses( + private def convertMapStatuses( shuffleId: Int, reduceId: Int, statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { - if (statuses == null) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing all output locations for shuffle " + shuffleId)) - } + assert (statuses != null) statuses.map { status => if (status == null) { diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 07efba9e8d26efa2e9a737f28d90e70a665f51d6..67fd1c1a8f68d57809ec8ca88fe7279de2d02a74 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -545,8 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" <split #> <attempt # = spark task #> */ - val attemptId = new TaskAttemptID(jobtrackerID, - stageId, false, context.splitId, attemptNumber) + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance val committer = format.getOutputCommitter(hadoopContext) @@ -565,11 +564,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * however we're only going to use this local OutputCommitter for * setupJob/commitJob, so we just use a dummy "map" task. */ - val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0) + val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) val count = self.context.runJob(self, writeShard _).sum + jobCommitter.commitJob(jobTaskContext) jobCommitter.cleanupJob(jobTaskContext) } @@ -637,6 +637,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } self.context.runJob(self, writeToFile _) + writer.commitJob() writer.cleanup() } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4957a54c1b8af5c199a3f7ffd235c1ff492e0033..5f5ec0b0f4d2992571fa06247dc3cdc13ac17c42 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -31,13 +31,13 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} import org.apache.mesos.MesosNativeLibrary -import spark.deploy.LocalSparkCluster +import spark.deploy.{SparkHadoopUtil, LocalSparkCluster} import spark.partial.ApproximateEvaluator import spark.partial.PartialResult import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD} import spark.scheduler._ import spark.scheduler.local.LocalScheduler -import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} +import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.BlockManagerUI import spark.util.{MetadataCleaner, TimeStampedHashMap} @@ -59,7 +59,10 @@ class SparkContext( val appName: String, val sparkHome: String = null, val jars: Seq[String] = Nil, - val environment: Map[String, String] = Map()) + val environment: Map[String, String] = Map(), + // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too. + // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host + val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()) extends Logging { // Ensure logging is initialized before we spawn any threads @@ -67,7 +70,7 @@ class SparkContext( // Set Spark driver host and port system properties if (System.getProperty("spark.driver.host") == null) { - System.setProperty("spark.driver.host", Utils.localIpAddress) + System.setProperty("spark.driver.host", Utils.localHostName()) } if (System.getProperty("spark.driver.port") == null) { System.setProperty("spark.driver.port", "0") @@ -99,7 +102,9 @@ class SparkContext( // Add each JAR given through the constructor - jars.foreach { addJar(_) } + if (jars != null) { + jars.foreach { addJar(_) } + } // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() @@ -111,7 +116,9 @@ class SparkContext( executorEnvs(key) = value } } - executorEnvs ++= environment + if (environment != null) { + executorEnvs ++= environment + } // Create and start the scheduler private var taskScheduler: TaskScheduler = { @@ -164,6 +171,22 @@ class SparkContext( } scheduler + case "yarn-standalone" => + val scheduler = try { + val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(this).asInstanceOf[ClusterScheduler] + } catch { + // TODO: Enumerate the exact reasons why it can fail + // But irrespective of it, it means we cannot proceed ! + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) + scheduler.initialize(backend) + scheduler + case _ => if (MESOS_REGEX.findFirstIn(master).isEmpty) { logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) @@ -183,12 +206,12 @@ class SparkContext( } taskScheduler.start() - private var dagScheduler = new DAGScheduler(taskScheduler) + @volatile private var dagScheduler = new DAGScheduler(taskScheduler) dagScheduler.start() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { - val conf = new Configuration() + val conf = SparkHadoopUtil.newConfiguration() // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) @@ -207,6 +230,9 @@ class SparkContext( private[spark] var checkpointDir: Option[String] = None + // Post init + taskScheduler.postStartHook() + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ @@ -471,7 +497,7 @@ class SparkContext( */ def getExecutorMemoryStatus: Map[String, (Long, Long)] = { env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.ip + ":" + blockManagerId.port, mem) + (blockManagerId.host + ":" + blockManagerId.port, mem) } } @@ -527,10 +553,13 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { - if (dagScheduler != null) { + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { metadataCleaner.cancel() - dagScheduler.stop() - dagScheduler = null + dagSchedulerCopy.stop() taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -546,6 +575,7 @@ class SparkContext( } } + /** * Get Spark's home location from either a value set through the constructor, * or the spark.home Java property, or the SPARK_HOME environment variable @@ -685,7 +715,7 @@ class SparkContext( */ def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration()) if (!useExisting) { if (fs.exists(path)) { throw new Exception("Checkpoint directory '" + path + "' already exists.") diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 7157fd26883d3a3f7b29fb71fc272886a92ecfd5..ffb40bab3a3c464c66ca8c71b53d5affceb1d481 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -72,6 +72,16 @@ object SparkEnv extends Logging { System.setProperty("spark.driver.port", boundPort.toString) } + // set only if unset until now. + if (System.getProperty("spark.hostPort", null) == null) { + if (!isDriver){ + // unexpected + Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") + } + Utils.checkHost(hostname) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + } + val classLoader = Thread.currentThread.getContextClassLoader // Create an instance of the class named by the given Java system property, or by @@ -88,9 +98,10 @@ object SparkEnv extends Logging { logInfo("Registering " + name) actorSystem.actorOf(Props(newActor), name = name) } else { - val driverIp: String = System.getProperty("spark.driver.host", "localhost") + val driverHost: String = System.getProperty("spark.driver.host", "localhost") val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt - val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name) + Utils.checkHost(driverHost, "Expected hostname") + val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name) logInfo("Connecting to " + name + ": " + url) actorSystem.actorFor(url) } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 81daacf958b5a03d3135f4587f2d6e411b0c6a3c..9f48cbe490a2703fe279c602d831f6b26794c556 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,18 +1,18 @@ package spark import java.io._ -import java.net._ +import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} import java.util.{Locale, Random, UUID} -import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} -import org.apache.hadoop.conf.Configuration +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConversions._ import scala.io.Source import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder -import scala.Some import spark.serializer.SerializerInstance +import spark.deploy.SparkHadoopUtil +import java.util.regex.Pattern /** * Various utility methods used by Spark. @@ -68,6 +68,41 @@ private object Utils extends Logging { return buf } + + private val shutdownDeletePaths = new collection.mutable.HashSet[String]() + + // Register the path to be deleted via shutdown hook + def registerShutdownDeleteDir(file: File) { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths += absolutePath + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteDir(file: File): Boolean = { + val absolutePath = file.getAbsolutePath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; else false + // This is to ensure that two shutdown hooks do not try to delete each others paths - resulting in IOException + // and incomplete cleanup + def hasRootAsShutdownDeleteDir(file: File): Boolean = { + + val absolutePath = file.getAbsolutePath() + + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.find(path => ! absolutePath.equals(path) && absolutePath.startsWith(path) ).isDefined + } + + if (retval) logInfo("path = " + file + ", already present as root for deletion.") + + retval + } + /** Create a temporary directory inside the given parent directory */ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { var attempts = 0 @@ -86,10 +121,14 @@ private object Utils extends Logging { } } catch { case e: IOException => ; } } + + registerShutdownDeleteDir(dir) + // Add a shutdown hook to delete the temp dir when the JVM exits Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) { override def run() { - Utils.deleteRecursively(dir) + // Attempt to delete if some patch which is parent of this is not already registered. + if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir) } }) return dir @@ -168,7 +207,7 @@ private object Utils extends Logging { case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others val uri = new URI(url) - val conf = new Configuration() + val conf = SparkHadoopUtil.newConfiguration() val fs = FileSystem.get(uri, conf) val in = fs.open(new Path(uri)) val out = new FileOutputStream(tempFile) @@ -227,8 +266,10 @@ private object Utils extends Logging { /** * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). + * Note, this is typically not used from within core spark. */ lazy val localIpAddress: String = findLocalIpAddress() + lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress) private def findLocalIpAddress(): String = { val defaultIpOverride = System.getenv("SPARK_LOCAL_IP") @@ -266,6 +307,8 @@ private object Utils extends Logging { * hostname it reports to the master. */ def setCustomHostname(hostname: String) { + // DEBUG code + Utils.checkHost(hostname) customHostname = Some(hostname) } @@ -273,7 +316,97 @@ private object Utils extends Logging { * Get the local machine's hostname. */ def localHostName(): String = { - customHostname.getOrElse(InetAddress.getLocalHost.getHostName) + customHostname.getOrElse(localIpAddressHostname) + } + + def getAddressHostName(address: String): String = { + InetAddress.getByName(address).getHostName + } + + + + def localHostPort(): String = { + val retval = System.getProperty("spark.hostPort", null) + if (retval == null) { + logErrorWithStack("spark.hostPort not set but invoking localHostPort") + return localHostName() + } + + retval + } + + /* + // Used by DEBUG code : remove when all testing done + private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$") + def checkHost(host: String, message: String = "") { + // Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous ! + // if (host.matches("^[0-9]+(\\.[0-9]+)*$")) { + if (ipPattern.matcher(host).matches()) { + Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message) + } + if (Utils.parseHostPort(host)._2 != 0){ + Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message) + } + } + + // Used by DEBUG code : remove when all testing done + def checkHostPort(hostPort: String, message: String = "") { + val (host, port) = Utils.parseHostPort(hostPort) + checkHost(host) + if (port <= 0){ + Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message) + } + } + */ + + // Once testing is complete in various modes, replace with this ? + def checkHost(host: String, message: String = "") {} + def checkHostPort(hostPort: String, message: String = "") {} + + def getUserNameFromEnvironment(): String = { + SparkHadoopUtil.getUserNameFromEnvironment + } + + // Used by DEBUG code : remove when all testing done + def logErrorWithStack(msg: String) { + try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } } + // temp code for debug + System.exit(-1) + } + + // Typically, this will be of order of number of nodes in cluster + // If not, we should change it to LRUCache or something. + private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() + def parseHostPort(hostPort: String): (String, Int) = { + { + // Check cache first. + var cached = hostPortParseResults.get(hostPort) + if (cached != null) return cached + } + + val indx: Int = hostPort.lastIndexOf(':') + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now. + // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 + if (-1 == indx) { + val retval = (hostPort, 0) + hostPortParseResults.put(hostPort, retval) + return retval + } + + val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt) + hostPortParseResults.putIfAbsent(hostPort, retval) + hostPortParseResults.get(hostPort) + } + + def addIfNoPort(hostPort: String, port: Int): String = { + if (port <= 0) throw new IllegalArgumentException("Invalid port specified " + port) + + // This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now. + // For now, we assume that if port exists, then it is valid - not check if it is an int > 0 + val indx: Int = hostPort.lastIndexOf(':') + if (-1 != indx) return hostPort + + hostPort + ":" + port } private[spark] val daemonThreadFactory: ThreadFactory = diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 9b4d54ab4e0461364d643bb857b267ed4e05bed6..807119ca8c08a1430d3db582a16f02b148812983 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -277,6 +277,8 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte] */ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { + + Utils.checkHost(serverHost, "Expected hostname") override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 8a3e64e4c22fa60ef4bf05a02a14ff399aea2632..51274acb1ed3b8e167e30d874514ca3dfb9ca22b 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState import spark.deploy.master.{WorkerInfo, ApplicationInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List +import spark.Utils private[spark] sealed trait DeployMessage extends Serializable @@ -19,7 +20,10 @@ case class RegisterWorker( memory: Int, webUiPort: Int, publicAddress: String) - extends DeployMessage + extends DeployMessage { + Utils.checkHost(host, "Required hostname") + assert (port > 0) +} private[spark] case class ExecutorStateChanged( @@ -58,7 +62,9 @@ private[spark] case class RegisteredApplication(appId: String) extends DeployMessage private[spark] -case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) +case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { + Utils.checkHostPort(hostPort, "Required hostport") +} private[spark] case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -81,6 +87,9 @@ private[spark] case class MasterState(host: String, port: Int, workers: Array[WorkerInfo], activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) { + Utils.checkHost(host, "Required hostname") + assert (port > 0) + def uri = "spark://" + host + ":" + port } @@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState private[spark] case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int, - coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) + coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { + + Utils.checkHost(host, "Required hostname") + assert (port > 0) +} diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index 38a6ebfc242c163a1f5be26c2c46dbacfe45da12..71a641a9efea2d10de69cbf864a6fc3bfacf31f6 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { def write(obj: WorkerInfo) = JsObject( "id" -> JsString(obj.id), "host" -> JsString(obj.host), + "port" -> JsNumber(obj.port), "webuiaddress" -> JsString(obj.webUiAddress), "cores" -> JsNumber(obj.cores), "coresused" -> JsNumber(obj.coresUsed), diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index 22319a96caef7ff80f97259bccd8381b0f6514bd..55bb61b0ccf88918a691684871c7627d064d99f6 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer private[spark] class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - private val localIpAddress = Utils.localIpAddress + private val localHostname = Utils.localHostName() private val masterActorSystems = ArrayBuffer[ActorSystem]() private val workerActorSystems = ArrayBuffer[ActorSystem]() @@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ - val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0) + val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0) masterActorSystems += masterSystem - val masterUrl = "spark://" + localIpAddress + ":" + masterPort + val masterUrl = "spark://" + localHostname + ":" + masterPort /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker, + val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masterUrl, null, Some(workerNum)) workerActorSystems += workerSystem } diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index 2fc5e657f96926a1363e4cef96c4837e855cac3a..4af44f9c164c5ff7052d58ec6aefa40dc7dca9e2 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -3,6 +3,7 @@ package spark.deploy.client import spark.deploy._ import akka.actor._ import akka.pattern.ask +import akka.util.Duration import akka.util.duration._ import akka.pattern.AskTimeoutException import spark.{SparkException, Logging} @@ -59,10 +60,10 @@ private[spark] class Client( markDisconnected() context.stop(self) - case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) => + case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores)) - listener.executorAdded(fullId, workerId, host, cores, memory) + logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) + listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => val fullId = appId + "/" + id @@ -112,7 +113,7 @@ private[spark] class Client( def stop() { if (actor != null) { try { - val timeout = 5.seconds + val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") val future = actor.ask(StopClient)(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index b7008321df564976d37ca9f428d7d47920a30bf3..e8c4083f9dfef8ec8acf0a3d3b5931cbd768f66c 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -12,7 +12,7 @@ private[spark] trait ClientListener { def disconnected(): Unit - def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit + def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index dc004b59ca5ac247d4e7c8125b775a1f1698e7ea..ad92532b5849a0f983ab5780711a7fe5d008bcb5 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -16,7 +16,7 @@ private[spark] object TestClient { System.exit(0) } - def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {} + def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {} def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {} } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 71b9d0801d59407559fdb44e82f282842fe9b733..160afe5239d984d1e641ffd257e695991f06c114 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils} import spark.util.AkkaUtils -private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { +private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 @@ -35,9 +35,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor var firstApp: Option[ApplicationInfo] = None + Utils.checkHost(host, "Expected hostname") + val masterPublicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else ip + if (envVar != null) envVar else host } // As a temporary workaround before better ways of configuring memory, we allow users to set @@ -46,7 +48,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean override def preStart() { - logInfo("Starting Spark master at spark://" + ip + ":" + port) + logInfo("Starting Spark master at spark://" + host + ":" + port) // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) startWebUi() @@ -145,7 +147,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } case RequestMasterState => { - sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray) + sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray) } } @@ -211,13 +213,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome) - exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) + exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) } def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, publicAddress: String): WorkerInfo = { // There may be one or more refs to dead workers on this same node (w/ different ID's), remove them. - workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _) + workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _) val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) workers += worker idToWorker(worker.id) = worker @@ -307,7 +309,7 @@ private[spark] object Master { def main(argStrings: Array[String]) { val args = new MasterArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort) + val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort) actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala index 4ceab3fc036da556e359e3e6db24f6c7b79a8a3d..3d28ecabb4d0e95ce843ab434b45c297dc17de45 100644 --- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala @@ -7,13 +7,13 @@ import spark.Utils * Command-line parser for the master. */ private[spark] class MasterArguments(args: Array[String]) { - var ip = Utils.localHostName() + var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 // Check for settings in environment variables - if (System.getenv("SPARK_MASTER_IP") != null) { - ip = System.getenv("SPARK_MASTER_IP") + if (System.getenv("SPARK_MASTER_HOST") != null) { + host = System.getenv("SPARK_MASTER_HOST") } if (System.getenv("SPARK_MASTER_PORT") != null) { port = System.getenv("SPARK_MASTER_PORT").toInt @@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) { def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - ip = value + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value parse(tail) case ("--port" | "-p") :: IntParam(value) :: tail => @@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) { "Usage: Master [options]\n" + "\n" + "Options:\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + " --webui-port PORT Port for web UI (default: 8080)") System.exit(exitCode) diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 54faa375fbd468e4ea05ce3fda9a9e41dfcea166..a4e21c81308e5b34c57eaa6dc1f604f47524d487 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -3,7 +3,7 @@ package spark.deploy.master import akka.actor.{ActorRef, ActorSystem} import akka.dispatch.Await import akka.pattern.ask -import akka.util.Timeout +import akka.util.{Duration, Timeout} import akka.util.duration._ import cc.spray.Directives import cc.spray.directives._ @@ -22,7 +22,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val RESOURCE_DIR = "spark/deploy/master/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(10 seconds) + implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")) val handler = { get { diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala index 23df1bb463288721e38379023b3fd01c4c8632d8..0c08c5f417238c88c9a07c56dba778cc9e1a287f 100644 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala @@ -2,6 +2,7 @@ package spark.deploy.master import akka.actor.ActorRef import scala.collection.mutable +import spark.Utils private[spark] class WorkerInfo( val id: String, @@ -13,6 +14,9 @@ private[spark] class WorkerInfo( val webUiPort: Int, val publicAddress: String) { + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info var state: WorkerState.Value = WorkerState.ALIVE var coresUsed = 0 @@ -23,6 +27,11 @@ private[spark] class WorkerInfo( def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed + def hostPort: String = { + assert (port > 0) + host + ":" + port + } + def addExecutor(exec: ExecutorInfo) { executors(exec.fullId) = exec coresUsed += exec.cores diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index de11771c8e62d7cdb0b629a105bb8d1ca21e544c..04a774658e4260ab212bdad06e68b8d3dcdf320f 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -21,11 +21,13 @@ private[spark] class ExecutorRunner( val memory: Int, val worker: ActorRef, val workerId: String, - val hostname: String, + val hostPort: String, val sparkHome: File, val workDir: File) extends Logging { + Utils.checkHostPort(hostPort, "Expected hostport") + val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null @@ -68,7 +70,7 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{EXECUTOR_ID}}" => execId.toString - case "{{HOSTNAME}}" => hostname + case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1 case "{{CORES}}" => cores.toString case other => other } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 8919d1261cc8631c620e54dd4b196c85412269e5..1a7da0f7bf91143d931e2d07585a2335adaf3bc7 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -16,7 +16,7 @@ import spark.deploy.master.Master import java.io.File private[spark] class Worker( - ip: String, + host: String, port: Int, webUiPort: Int, cores: Int, @@ -25,6 +25,9 @@ private[spark] class Worker( workDirPath: String = null) extends Actor with Logging { + Utils.checkHost(host, "Expected hostname") + assert (port > 0) + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs // Send a heartbeat every (heartbeat timeout) / 4 milliseconds @@ -39,7 +42,7 @@ private[spark] class Worker( val finishedExecutors = new HashMap[String, ExecutorRunner] val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else ip + if (envVar != null) envVar else host } var coresUsed = 0 @@ -51,10 +54,11 @@ private[spark] class Worker( def createWorkDir() { workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work")) try { - if (!workDir.exists() && !workDir.mkdirs()) { + if ( (workDir.exists() && !workDir.isDirectory) || (!workDir.exists() && !workDir.mkdirs()) ) { logError("Failed to create work directory " + workDir) System.exit(1) } + assert (workDir.isDirectory) } catch { case e: Exception => logError("Failed to create work directory " + workDir, e) @@ -64,7 +68,7 @@ private[spark] class Worker( override def preStart() { logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( - ip, port, cores, Utils.memoryMegabytesToString(memory))) + host, port, cores, Utils.memoryMegabytesToString(memory))) sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) logInfo("Spark home: " + sparkHome) createWorkDir() @@ -75,7 +79,7 @@ private[spark] class Worker( def connectToMaster() { logInfo("Connecting to master " + masterUrl) master = context.actorFor(Master.toAkkaUrl(masterUrl)) - master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress) + master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } @@ -106,7 +110,7 @@ private[spark] class Worker( case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner( - appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) + appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -141,7 +145,7 @@ private[spark] class Worker( masterDisconnected() case RequestWorkerState => { - sender ! WorkerState(ip, port, workerId, executors.values.toList, + sender ! WorkerState(host, port, workerId, executors.values.toList, finishedExecutors.values.toList, masterUrl, cores, memory, coresUsed, memoryUsed, masterWebUiUrl) } @@ -156,7 +160,7 @@ private[spark] class Worker( } def generateWorkerId(): String = { - "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port) + "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port) } override def postStop() { @@ -167,7 +171,7 @@ private[spark] class Worker( private[spark] object Worker { def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) - val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores, + val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.master, args.workDir) actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 08f02bad80d7f47b2f019745216538eda2640223..2b96611ee3372426edb58e1a0bb34a68be352bb1 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory * Command-line parser for the master. */ private[spark] class WorkerArguments(args: Array[String]) { - var ip = Utils.localHostName() + var host = Utils.localHostName() var port = 0 var webUiPort = 8081 var cores = inferDefaultCores() @@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) { def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - ip = value + Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + host = value + parse(tail) + + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value parse(tail) case ("--port" | "-p") :: IntParam(value) :: tail => @@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) { " -c CORES, --cores CORES Number of cores to use\n" + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + " --webui-port PORT Port for web UI (default: 8081)") System.exit(exitCode) diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index c834f87d5011183fae7164f96016fbba2755930c..3235c50d1bd31ae87a972a084a03f94bfec829f4 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -3,7 +3,7 @@ package spark.deploy.worker import akka.actor.{ActorRef, ActorSystem} import akka.dispatch.Await import akka.pattern.ask -import akka.util.Timeout +import akka.util.{Duration, Timeout} import akka.util.duration._ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ @@ -22,7 +22,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef, workDir: File) val RESOURCE_DIR = "spark/deploy/worker/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(10 seconds) + implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")) val handler = { get { diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 3e7407b58d8e6dacc52e405f1407c5b713087667..344face5e698689c301dade2b026ed8f0186aa5f 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -17,7 +17,7 @@ import java.nio.ByteBuffer * The Mesos executor for Spark. */ private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging { - + // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() @@ -27,6 +27,11 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert initLogging() + // No ip or host:port - just hostname + Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname") + // must not have port specified. + assert (0 == Utils.parseHostPort(slaveHostname)._2) + // Make sure the local hostname we report matches the cluster scheduler's name for this host Utils.setCustomHostname(slaveHostname) diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 1047f71c6ae0dc13eaf873049fbfeceaa4d1ff69..ebe2ac68d8b2527656003aa91229365507d3352a 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor import spark.scheduler.cluster.LaunchTask import spark.scheduler.cluster.RegisterExecutorFailed import spark.scheduler.cluster.RegisterExecutor +import spark.Utils +import spark.deploy.SparkHadoopUtil private[spark] class StandaloneExecutorBackend( driverUrl: String, executorId: String, - hostname: String, + hostPort: String, cores: Int) extends Actor with ExecutorBackend with Logging { + Utils.checkHostPort(hostPort, "Expected hostport") + var executor: Executor = null var driver: ActorRef = null override def preStart() { logInfo("Connecting to driver: " + driverUrl) driver = context.actorFor(driverUrl) - driver ! RegisterExecutor(executorId, hostname, cores) + driver ! RegisterExecutor(executorId, hostPort, cores) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(driver) // Doesn't work with remote actors, but useful for testing } @@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend( override def receive = { case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with driver") - executor = new Executor(executorId, hostname, sparkProperties) + // Make this host instead of hostPort ? + executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -63,11 +68,30 @@ private[spark] class StandaloneExecutorBackend( private[spark] object StandaloneExecutorBackend { def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { + SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores)) + } + + // This will be run 'as' the user + def run0(args: Product) { + assert(4 == args.productArity) + runImpl(args.productElement(0).asInstanceOf[String], + args.productElement(1).asInstanceOf[String], + args.productElement(2).asInstanceOf[String], + args.productElement(3).asInstanceOf[Int]) + } + + private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) { + // Debug code + Utils.checkHost(hostname) + // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) + // set it + val sparkHostPort = hostname + ":" + boundPort + System.setProperty("spark.hostPort", sparkHostPort) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(driverUrl, executorId, hostname, cores)), + Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), name = "Executor") actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index d1451bc2124c581eff01dfae5277612ea5c995c7..00a0433a441dc22ade387ca6c20bc54e2114740e 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -13,7 +13,7 @@ import java.net._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val remoteConnectionManagerId: ConnectionManagerId) extends Logging { + val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging { def this(channel_ : SocketChannel, selector_ : Selector) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( @@ -32,16 +32,43 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() + + // Read channels typically do not register for write and write does not for read + // Now, we do have write registering for read too (temporarily), but this is to detect + // channel close NOT to actually read/consume data on it ! + // How does this work if/when we move to SSL ? + + // What is the interest to register with selector for when we want this connection to be selected + def registerInterest() + // What is the interest to register with selector for when we want this connection to be de-selected + // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be + // SelectionKey.OP_READ (until we fix it properly) + def unregisterInterest() + + // On receiving a read event, should we change the interest for this channel or not ? + // Will be true for ReceivingConnection, false for SendingConnection. + def changeInterestForRead(): Boolean + + // On receiving a write event, should we change the interest for this channel or not ? + // Will be false for ReceivingConnection, true for SendingConnection. + // Actually, for now, should not get triggered for ReceivingConnection + def changeInterestForWrite(): Boolean + + def getRemoteConnectionManagerId(): ConnectionManagerId = { + socketRemoteConnectionManagerId + } def key() = channel.keyFor(selector) def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - def read() { + // Returns whether we have to register for further reads or not. + def read(): Boolean = { throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) } - - def write() { + + // Returns whether we have to register for further writes or not. + def write(): Boolean = { throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) } @@ -64,7 +91,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, if (onExceptionCallback != null) { onExceptionCallback(this, e) } else { - logError("Error in connection to " + remoteConnectionManagerId + + logError("Error in connection to " + getRemoteConnectionManagerId() + " and OnExceptionCallback not registered", e) } } @@ -73,7 +100,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, if (onCloseCallback != null) { onCloseCallback(this) } else { - logWarning("Connection to " + remoteConnectionManagerId + + logWarning("Connection to " + getRemoteConnectionManagerId() + " closed and OnExceptionCallback not registered") } @@ -81,7 +108,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, def changeConnectionKeyInterest(ops: Int) { if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) + onKeyInterestChangeCallback(this, ops) } else { throw new Exception("OnKeyInterestChangeCallback not registered") } @@ -122,7 +149,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]") } } @@ -149,9 +176,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } return chunk } else { - /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ + /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/ message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "] in " + message.timeTaken ) } } @@ -170,15 +197,15 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { messages.enqueue(message) nextMessageToBeUsed = nextMessageToBeUsed + 1 if (!message.started) { - logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]") + logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]") message.started = true message.startTime = System.currentTimeMillis } - logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]") + logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]") return chunk } else { message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "] in " + message.timeTaken ) } } @@ -187,26 +214,39 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } - val outbox = new Outbox(1) + private val outbox = new Outbox(1) val currentBuffers = new ArrayBuffer[ByteBuffer]() /*channel.socket.setSendBufferSize(256 * 1024)*/ - override def getRemoteAddress() = address + override def getRemoteAddress() = address + val DEFAULT_INTEREST = SelectionKey.OP_READ + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(DEFAULT_INTEREST) + } + def send(message: Message) { outbox.synchronized { outbox.addMessage(message) if (channel.isConnected) { - changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) + registerInterest() } } } + // MUST be called within the selector loop def connect() { try{ - channel.connect(address) channel.register(selector, SelectionKey.OP_CONNECT) + channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { case e: Exception => { @@ -216,20 +256,33 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } - def finishConnect() { + def finishConnect(force: Boolean): Boolean = { try { - channel.finishConnect - changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) + // Typically, this should finish immediately since it was triggered by a connect + // selection - though need not necessarily always complete successfully. + val connected = channel.finishConnect + if (!force && !connected) { + logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending") + return false + } + + // Fallback to previous behavior - assume finishConnect completed + // This will happen only when finishConnect failed for some repeated number of times (10 or so) + // Is highly unlikely unless there was an unclean close of socket, etc + registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") + return true } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) callOnExceptionCallback(e) + // ignore + return true } } } - override def write() { + override def write(): Boolean = { try{ while(true) { if (currentBuffers.size == 0) { @@ -239,8 +292,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { currentBuffers ++= chunk.buffers } case None => { - changeConnectionKeyInterest(SelectionKey.OP_READ) - return + // changeConnectionKeyInterest(0) + /*key.interestOps(0)*/ + return false } } } @@ -254,38 +308,53 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { currentBuffers -= buffer } if (writtenBytes < remainingBytes) { - return + // re-register for write. + return true } } } } catch { case e: Exception => { - logWarning("Error writing in connection to " + remoteConnectionManagerId, e) + logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() + return false } } + // should not happen - to keep scala compiler happy + return true } - override def read() { + // This is a hack to determine if remote socket was closed or not. + // SendingConnection DOES NOT expect to receive any data - if it does, it is an error + // For a bunch of cases, read will return -1 in case remote socket is closed : hence we + // register for reads to determine that. + override def read(): Boolean = { // We don't expect the other side to send anything; so, we just read to detect an error or EOF. try { val length = channel.read(ByteBuffer.allocate(1)) if (length == -1) { // EOF close() } else if (length > 0) { - logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId) + logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId()) } } catch { case e: Exception => - logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e) + logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() } + + false } + + override def changeInterestForRead(): Boolean = false + + override def changeInterestForWrite(): Boolean = true } +// Must be created within selector loop - else deadlock private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) extends Connection(channel_, selector_) { @@ -298,13 +367,13 @@ extends Connection(channel_, selector_) { val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis - logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]") + logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) newMessage } val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]") + logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]") message.getChunkForReceiving(header.chunkSize) } @@ -316,7 +385,27 @@ extends Connection(channel_, selector_) { messages -= message.id } } - + + @volatile private var inferredRemoteManagerId: ConnectionManagerId = null + override def getRemoteConnectionManagerId(): ConnectionManagerId = { + val currId = inferredRemoteManagerId + if (currId != null) currId else super.getRemoteConnectionManagerId() + } + + // The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver. + // We infer that from the messages we receive on the receiver socket. + private def processConnectionManagerId(header: MessageChunkHeader) { + val currId = inferredRemoteManagerId + if (header.address == null || currId != null) return + + val managerId = ConnectionManagerId.fromSocketAddress(header.address) + + if (managerId != null) { + inferredRemoteManagerId = managerId + } + } + + val inbox = new Inbox() val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) var onReceiveCallback: (Connection , Message) => Unit = null @@ -324,17 +413,18 @@ extends Connection(channel_, selector_) { channel.register(selector, SelectionKey.OP_READ) - override def read() { + override def read(): Boolean = { try { while (true) { if (currentChunk == null) { val headerBytesRead = channel.read(headerBuffer) if (headerBytesRead == -1) { close() - return + return false } if (headerBuffer.remaining > 0) { - return + // re-register for read event ... + return true } headerBuffer.flip if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { @@ -342,6 +432,9 @@ extends Connection(channel_, selector_) { } val header = MessageChunkHeader.create(headerBuffer) headerBuffer.clear() + + processConnectionManagerId(header) + header.typ match { case Message.BUFFER_MESSAGE => { if (header.totalSize == 0) { @@ -349,7 +442,8 @@ extends Connection(channel_, selector_) { onReceiveCallback(this, Message.create(header)) } currentChunk = null - return + // re-register for read event ... + return true } else { currentChunk = inbox.getChunk(header).orNull } @@ -362,10 +456,11 @@ extends Connection(channel_, selector_) { val bytesRead = channel.read(currentChunk.buffer) if (bytesRead == 0) { - return + // re-register for read event ... + return true } else if (bytesRead == -1) { close() - return + return false } /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ @@ -376,7 +471,7 @@ extends Connection(channel_, selector_) { if (bufferMessage.isCompletelyReceived) { bufferMessage.flip bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken) + logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken) if (onReceiveCallback != null) { onReceiveCallback(this, bufferMessage) } @@ -387,12 +482,31 @@ extends Connection(channel_, selector_) { } } catch { case e: Exception => { - logWarning("Error reading from connection to " + remoteConnectionManagerId, e) + logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() + return false } } + // should not happen - to keep scala compiler happy + return true } def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} + + override def changeInterestForRead(): Boolean = true + + override def changeInterestForWrite(): Boolean = { + throw new IllegalStateException("Unexpected invocation right now") + } + + override def registerInterest() { + // Registering read too - does not really help in most cases, but for some + // it does - so let us keep it for now. + changeConnectionKeyInterest(SelectionKey.OP_READ) + } + + override def unregisterInterest() { + changeConnectionKeyInterest(0) + } } diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index b6ec664d7e81bb581f1be553b3b1dfaafca61865..0c6bdb155972fcd16e5921c3a9f40c68411e5e53 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -6,12 +6,12 @@ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ -import java.util.concurrent.Executors +import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} +import scala.collection.mutable.HashSet import scala.collection.mutable.HashMap import scala.collection.mutable.SynchronizedMap import scala.collection.mutable.SynchronizedQueue -import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer import akka.dispatch.{Await, Promise, ExecutionContext, Future} @@ -19,6 +19,10 @@ import akka.util.Duration import akka.util.duration._ private[spark] case class ConnectionManagerId(host: String, port: Int) { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + def toSocketAddress() = new InetSocketAddress(host, port) } @@ -42,19 +46,37 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def markDone() { completionHandler(this) } } - val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt) - val serverChannel = ServerSocketChannel.open() - val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - val sendMessageRequests = new Queue[(Message, SendingConnection)] + private val selector = SelectorProvider.provider.openSelector() + + private val handleMessageExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.handler.threads.min","20").toInt, + System.getProperty("spark.core.connection.handler.threads.max","60").toInt, + System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val handleReadWriteExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.io.threads.min","4").toInt, + System.getProperty("spark.core.connection.io.threads.max","32").toInt, + System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap + private val handleConnectExecutor = new ThreadPoolExecutor( + System.getProperty("spark.core.connection.connect.threads.min","1").toInt, + System.getProperty("spark.core.connection.connect.threads.max","8").toInt, + System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS, + new LinkedBlockingDeque[Runnable]()) + + private val serverChannel = ServerSocketChannel.open() + private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] + private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] + private val messageStatuses = new HashMap[Int, MessageStatus] + private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + private val registerRequests = new SynchronizedQueue[SendingConnection] implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) - var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) @@ -65,46 +87,139 @@ private[spark] class ConnectionManager(port: Int) extends Logging { val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - val selectorThread = new Thread("connection-manager-thread") { + + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } selectorThread.setDaemon(true) selectorThread.start() - private def run() { - try { - while(!selectorThread.isInterrupted) { - for ((connectionManagerId, sendingConnection) <- connectionRequests) { - sendingConnection.connect() - addConnection(sendingConnection) - connectionRequests -= connectionManagerId + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerWrite(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + writeRunnableStarted.synchronized { + // So that we do not trigger more write events while processing this one. + // The write method will re-register when done. + if (conn.changeInterestForWrite()) conn.unregisterInterest() + if (writeRunnableStarted.contains(key)) { + // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE) + return + } + + writeRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + if (register && conn.changeInterestForWrite()) { + conn.registerInterest() + } + } } - sendMessageRequests.synchronized { - while (!sendMessageRequests.isEmpty) { - val (message, connection) = sendMessageRequests.dequeue - connection.send(message) + } + } ) + } + + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + + private def triggerRead(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null) + if (conn == null) return + + readRunnableStarted.synchronized { + // So that we do not trigger more read events while processing this one. + // The read method will re-register when done. + if (conn.changeInterestForRead())conn.unregisterInterest() + if (readRunnableStarted.contains(key)) { + return + } + + readRunnableStarted += key + } + handleReadWriteExecutor.execute(new Runnable { + override def run() { + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } + } ) + } + + private def triggerConnect(key: SelectionKey) { + val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection] + if (conn == null) return + + // prevent other events from being triggered + // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite + conn.changeConnectionKeyInterest(0) + + handleConnectExecutor.execute(new Runnable { + override def run() { + + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - while (!keyInterestChangeRequests.isEmpty) { + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need not + // succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } + } ) + } + + def run() { + try { + while(!selectorThread.isInterrupted) { + while (! registerRequests.isEmpty) { + val conn: SendingConnection = registerRequests.dequeue + addListeners(conn) + conn.connect() + addConnection(conn) + } + + while(!keyInterestChangeRequests.isEmpty) { val (key, ops) = keyInterestChangeRequests.dequeue - val connection = connectionsByKey(key) - val lastOps = key.interestOps() - key.interestOps(ops) - - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + val connection = connectionsByKey.getOrElse(key, null) + if (connection != null) { + val lastOps = key.interestOps() + key.interestOps(ops) + + // hot loop - prevent materialization of string if trace not enabled. + if (isTraceEnabled()) { + def intToOpStr(op: Int): String = { + val opStrs = ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() + + "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + } } - - logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId + - "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - } val selectedKeysCount = selector.select() @@ -123,12 +238,15 @@ private[spark] class ConnectionManager(port: Int) extends Logging { if (key.isValid) { if (key.isAcceptable) { acceptConnection(key) - } else if (key.isConnectable) { - connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect() - } else if (key.isReadable) { - connectionsByKey(key).read() - } else if (key.isWritable) { - connectionsByKey(key).write() + } else + if (key.isConnectable) { + triggerConnect(key) + } else + if (key.isReadable) { + triggerRead(key) + } else + if (key.isWritable) { + triggerWrite(key) } } } @@ -138,94 +256,116 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - private def acceptConnection(key: SelectionKey) { + def acceptConnection(key: SelectionKey) { val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - val newChannel = serverChannel.accept() - val newConnection = new ReceivingConnection(newChannel, selector) - newConnection.onReceive(receiveMessage) - newConnection.onClose(removeConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") - } - private def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection)) + var newChannel = serverChannel.accept() + + // accept them all in a tight loop. non blocking accept with no processing, should be fine + while (newChannel != null) { + try { + val newConnection = new ReceivingConnection(newChannel, selector) + newConnection.onReceive(receiveMessage) + addListeners(newConnection) + addConnection(newConnection) + logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") + } catch { + // might happen in case of issues with registering with selector + case e: Exception => logError("Error in accept loop", e) + } + + newChannel = serverChannel.accept() } + } + + private def addListeners(connection: Connection) { connection.onKeyInterestChange(changeConnectionKeyInterest) connection.onException(handleConnectionError) connection.onClose(removeConnection) } - private def removeConnection(connection: Connection) { + def addConnection(connection: Connection) { + connectionsByKey += ((connection.key, connection)) + } + + def removeConnection(connection: Connection) { connectionsByKey -= connection.key - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + + try { + if (connection.isInstanceOf[SendingConnection]) { + val sendingConnection = connection.asInstanceOf[SendingConnection] + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + logInfo("Removing SendingConnection to " + sendingConnectionManagerId) + + connectionsById -= sendingConnectionManagerId + + messageStatuses.synchronized { + messageStatuses + .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { + logInfo("Notifying " + status) + status.synchronized { + status.attempted = true + status.acked = false + status.markDone() + } + }) + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId }) + } + } else if (connection.isInstanceOf[ReceivingConnection]) { + val receivingConnection = connection.asInstanceOf[ReceivingConnection] + val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId() + logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) + + val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) + if (! sendingConnectionOpt.isDefined) { + logError("Corresponding SendingConnectionManagerId not found") + return + } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } else if (connection.isInstanceOf[ReceivingConnection]) { - val receivingConnection = connection.asInstanceOf[ReceivingConnection] - val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull - if (sendingConnectionManagerId == null) { - logError("Corresponding SendingConnectionManagerId not found") - return - } - logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId) - - val sendingConnection = connectionsById(sendingConnectionManagerId) - sendingConnection.close() - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { - logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() + val sendingConnection = sendingConnectionOpt.get + connectionsById -= remoteConnectionManagerId + sendingConnection.close() + + val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId() + + assert (sendingConnectionManagerId == remoteConnectionManagerId) + + messageStatuses.synchronized { + for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { + logInfo("Notifying " + s) + s.synchronized { + s.attempted = true + s.acked = false + s.markDone() + } } - } - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } } + } finally { + // So that the selection keys can be removed. + wakeupSelector() } } - private def handleConnectionError(connection: Connection, e: Exception) { - logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId) + def handleConnectionError(connection: Connection, e: Exception) { + logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) } - private def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) + def changeConnectionKeyInterest(connection: Connection, ops: Int) { + keyInterestChangeRequests += ((connection.key, ops)) + // so that registerations happen ! + wakeupSelector() } - private def receiveMessage(connection: Connection, message: Message) { + def receiveMessage(connection: Connection, message: Message) { val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) logDebug("Received [" + message + "] from [" + connectionManagerId + "]") val runnable = new Runnable() { @@ -293,18 +433,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, - new SendingConnection(inetSocketAddress, selector, connectionManagerId)) - newConnection + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + registerRequests.enqueue(newConnection) + + newConnection } - val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) - val connection = connectionsById.getOrElse(lookupKey, startNewConnection()) + // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) message.senderAddress = id.toSocketAddress() logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") - /*connection.send(message)*/ - sendMessageRequests.synchronized { - sendMessageRequests += ((message, connection)) - } + connection.send(message) + + wakeupSelector() + } + + private def wakeupSelector() { selector.wakeup() } @@ -337,6 +481,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logWarning("All connections not cleaned up") } handleMessageExecutor.shutdown() + handleReadWriteExecutor.shutdown() + handleConnectExecutor.shutdown() logInfo("ConnectionManager stopped") } } diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala index 525751b5bf915e814aaa3960a5404d542605b434..34fac9e77699b003430420efb7170e81ec511b25 100644 --- a/core/src/main/scala/spark/network/Message.scala +++ b/core/src/main/scala/spark/network/Message.scala @@ -17,7 +17,8 @@ private[spark] class MessageChunkHeader( val other: Int, val address: InetSocketAddress) { lazy val buffer = { - val ip = address.getAddress.getAddress() + // No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection + val ip = address.getAddress.getAddress() val port = address.getPort() ByteBuffer. allocate(MessageChunkHeader.HEADER_SIZE). diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 5db77eb142453b70129117fee9ba1ce671a9f19c..43ee39c993a3c0b254fe35bbbf9bcd8fb167e10c 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -8,6 +8,7 @@ import org.apache.hadoop.util.ReflectionUtils import org.apache.hadoop.fs.Path import java.io.{File, IOException, EOFException} import java.text.NumberFormat +import spark.deploy.SparkHadoopUtil private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -21,13 +22,20 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) override def getPartitions: Array[Partition] = { - val dirContents = fs.listStatus(new Path(checkpointPath)) - val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted - val numPartitions = partitionFiles.size - if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - ! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) { - throw new SparkException("Invalid checkpoint directory: " + checkpointPath) - } + val cpath = new Path(checkpointPath) + val numPartitions = + // listStatus can throw exception if path does not exist. + if (fs.exists(cpath)) { + val dirContents = fs.listStatus(cpath) + val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted + val numPart = partitionFiles.size + if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || + ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { + throw new SparkException("Invalid checkpoint directory: " + checkpointPath) + } + numPart + } else 0 + Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) } @@ -58,7 +66,7 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { val outputDir = new Path(path) - val fs = outputDir.getFileSystem(new Configuration()) + val fs = outputDir.getFileSystem(SparkHadoopUtil.newConfiguration()) val finalOutputName = splitIdToFile(ctx.splitId) val finalOutputPath = new Path(outputDir, finalOutputName) @@ -83,6 +91,7 @@ private[spark] object CheckpointRDD extends Logging { if (!fs.rename(tempOutputPath, finalOutputPath)) { if (!fs.exists(finalOutputPath)) { + logInfo("Deleting tempOutputPath " + tempOutputPath) fs.delete(tempOutputPath, false) throw new IOException("Checkpoint failed: failed to save output of task: " + ctx.attemptId + " and final output path does not exist") @@ -95,7 +104,7 @@ private[spark] object CheckpointRDD extends Logging { } def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration()) val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val fileInputStream = fs.open(path, bufferSize) val serializer = SparkEnv.get.serializer.newInstance() @@ -117,7 +126,7 @@ private[spark] object CheckpointRDD extends Logging { val sc = new SparkContext(cluster, "CheckpointRDD Test") val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") - val fs = path.getFileSystem(new Configuration()) + val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration()) sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index bdd974590af59d7dba4435ae2b94afab3a9a03eb..901d01ef3084f628839a4af8dcab2251b92cd953 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -57,7 +57,7 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] val conf = confBroadcast.value.value - val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) + val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance if (format.isInstanceOf[Configurable]) { diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index c54dce51d783969e09f6b924db09aefc1352e6a4..1440b93e65cf90cb8f4e2b28c93c467359ed7beb 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -50,6 +50,11 @@ class DAGScheduler( eventQueue.put(ExecutorLost(execId)) } + // Called by TaskScheduler when a host is added + override def executorGained(execId: String, hostPort: String) { + eventQueue.put(ExecutorGained(execId, hostPort)) + } + // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. override def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) @@ -113,7 +118,7 @@ class DAGScheduler( if (!cacheLocs.contains(rdd.id)) { val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { - locations => locations.map(_.ip).toList + locations => locations.map(_.hostPort).toList }.toArray } cacheLocs(rdd.id) @@ -293,6 +298,9 @@ class DAGScheduler( submitStage(finalStage) } + case ExecutorGained(execId, hostPort) => + handleExecutorGained(execId, hostPort) + case ExecutorLost(execId) => handleExecutorLost(execId) @@ -630,6 +638,14 @@ class DAGScheduler( "(generation " + currentGeneration + ")") } } + + private def handleExecutorGained(execId: String, hostPort: String) { + // remove from failedGeneration(execId) ? + if (failedGeneration.contains(execId)) { + logInfo("Host gained which was in lost list earlier: " + hostPort) + failedGeneration -= execId + } + } /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index ed0b9bf178a89c80f5709efd09d288c18fc7bbac..b46bb863f04d88b449995d15ab3678f148754489 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -32,6 +32,10 @@ private[spark] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent +private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent { + Utils.checkHostPort(hostPort, "Required hostport") +} + private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala new file mode 100644 index 0000000000000000000000000000000000000000..287f731787f1b184d4e2781dbd28e1f4463f5866 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/InputFormatInfo.scala @@ -0,0 +1,156 @@ +package spark.scheduler + +import spark.Logging +import scala.collection.immutable.Set +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.conf.Configuration +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ + + +/** + * Parses and holds information about inputFormat (and files) specified as a parameter. + */ +class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], + val path: String) extends Logging { + + var mapreduceInputFormat: Boolean = false + var mapredInputFormat: Boolean = false + + validate() + + override def toString(): String = { + "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + path.hashCode + hashCode + } + + // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path + // .. which is fine, this is best case effort to remove duplicates - right ? + override def equals(other: Any): Boolean = other match { + case that: InputFormatInfo => { + // not checking config - that should be fine, right ? + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path + } + case _ => false + } + + private def validate() { + logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path) + + try { + if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapreduce package") + mapreduceInputFormat = true + } + else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) { + logDebug("inputformat is from mapred package") + mapredInputFormat = true + } + else { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + + " is NOT a supported input format ? does not implement either of the supported hadoop api's") + } + } + catch { + case e: ClassNotFoundException => { + throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) + } + } + } + + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = { + val conf = new JobConf(configuration) + FileInputFormat.setInputPaths(conf, path) + + val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ + org.apache.hadoop.mapreduce.InputFormat[_, _]] + val job = new Job(conf) + + val retval = new ArrayBuffer[SplitInfo]() + val list = instance.getSplits(job) + for (split <- list) { + retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) + } + + return retval.toSet + } + + // This method does not expect failures, since validate has already passed ... + private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = { + val jobConf = new JobConf(configuration) + FileInputFormat.setInputPaths(jobConf, path) + + val instance: org.apache.hadoop.mapred.InputFormat[_, _] = + ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[ + org.apache.hadoop.mapred.InputFormat[_, _]] + + val retval = new ArrayBuffer[SplitInfo]() + instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach( + elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem) + ) + + return retval.toSet + } + + private def findPreferredLocations(): Set[SplitInfo] = { + logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + + ", inputFormatClazz : " + inputFormatClazz) + if (mapreduceInputFormat) { + return prefLocsFromMapreduceInputFormat() + } + else { + assert(mapredInputFormat) + return prefLocsFromMapredInputFormat() + } + } +} + + + + +object InputFormatInfo { + /** + Computes the preferred locations based on input(s) and returned a location to block map. + Typical use of this method for allocation would follow some algo like this + (which is what we currently do in YARN branch) : + a) For each host, count number of splits hosted on that host. + b) Decrement the currently allocated containers on that host. + c) Compute rack info for each host and update rack -> count map based on (b). + d) Allocate nodes based on (c) + e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node + (even if data locality on that is very high) : this is to prevent fragility of job if a single + (or small set of) hosts go down. + + go to (a) until required nodes are allocated. + + If a node 'dies', follow same procedure. + + PS: I know the wording here is weird, hopefully it makes some sense ! + */ + def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = { + + val nodeToSplit = new HashMap[String, HashSet[SplitInfo]] + for (inputSplit <- formats) { + val splits = inputSplit.findPreferredLocations() + + for (split <- splits){ + val location = split.hostLocation + val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo]) + set += split + } + } + + nodeToSplit + } +} diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index beb21a76fe5c8247b3373e05cb09355141d49ed8..89dc6640b2fe25a642381eee5090441fc3a29f60 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -70,6 +70,14 @@ private[spark] class ResultTask[T, U]( rdd.partitions(partition) } + // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. + val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq + + { + // DEBUG code + preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs)) + } + override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) metrics = Some(context.taskMetrics) @@ -80,7 +88,7 @@ private[spark] class ResultTask[T, U]( } } - override def preferredLocations: Seq[String] = locs + override def preferredLocations: Seq[String] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partition + ")" diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 36d087a4d009c8e6a4fdf85bfc5d6e6e6e3926bf..7dc6da45733916056ab2f2a2f4ede8a0862b266d 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -77,13 +77,21 @@ private[spark] class ShuffleMapTask( var rdd: RDD[_], var dep: ShuffleDependency[_,_], var partition: Int, - @transient var locs: Seq[String]) + @transient private var locs: Seq[String]) extends Task[MapStatus](stageId) with Externalizable with Logging { protected def this() = this(0, null, null, 0, null) + // data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts. + private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq + + { + // DEBUG code + preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs)) + } + var split = if (rdd == null) { null } else { @@ -154,7 +162,7 @@ private[spark] class ShuffleMapTask( } } - override def preferredLocations: Seq[String] = locs + override def preferredLocations: Seq[String] = preferredLocs override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) } diff --git a/core/src/main/scala/spark/scheduler/SplitInfo.scala b/core/src/main/scala/spark/scheduler/SplitInfo.scala new file mode 100644 index 0000000000000000000000000000000000000000..6abfb7a1f7150122523370abc03fa6546d5e96a1 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/SplitInfo.scala @@ -0,0 +1,61 @@ +package spark.scheduler + +import collection.mutable.ArrayBuffer + +// information about a specific split instance : handles both split instances. +// So that we do not need to worry about the differences. +class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String, + val length: Long, val underlyingSplit: Any) { + override def toString(): String = { + "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + + ", hostLocation : " + hostLocation + ", path : " + path + + ", length : " + length + ", underlyingSplit " + underlyingSplit + } + + override def hashCode(): Int = { + var hashCode = inputFormatClazz.hashCode + hashCode = hashCode * 31 + hostLocation.hashCode + hashCode = hashCode * 31 + path.hashCode + // ignore overflow ? It is hashcode anyway ! + hashCode = hashCode * 31 + (length & 0x7fffffff).toInt + hashCode + } + + // This is practically useless since most of the Split impl's dont seem to implement equals :-( + // So unless there is identity equality between underlyingSplits, it will always fail even if it + // is pointing to same block. + override def equals(other: Any): Boolean = other match { + case that: SplitInfo => { + this.hostLocation == that.hostLocation && + this.inputFormatClazz == that.inputFormatClazz && + this.path == that.path && + this.length == that.length && + // other split specific checks (like start for FileSplit) + this.underlyingSplit == that.underlyingSplit + } + case _ => false + } +} + +object SplitInfo { + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapredSplit.getLength + for (host <- mapredSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit) + } + retval + } + + def toSplitInfo(inputFormatClazz: Class[_], path: String, + mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = { + val retval = new ArrayBuffer[SplitInfo]() + val length = mapreduceSplit.getLength + for (host <- mapreduceSplit.getLocations) { + retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit) + } + retval + } +} diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala index d549b184b088d65472302b8523d5d609efe36a6b..7787b547620f6bea6cbb8f1d7a626638b7b6edb8 100644 --- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala @@ -10,6 +10,10 @@ package spark.scheduler private[spark] trait TaskScheduler { def start(): Unit + // Invoked after system has successfully initialized (typically in spark context). + // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc. + def postStartHook() { } + // Disconnect from the cluster. def stop(): Unit diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala index 771518dddfacaaf2916a1f6cd834983725cbc533..b75d3736cf5d0b064babcc100e3c63ac0f31c5c4 100644 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -14,6 +14,9 @@ private[spark] trait TaskSchedulerListener { def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit + // A node was added to the cluster. + def executorGained(execId: String, hostPort: String): Unit + // A node was lost from the cluster. def executorLost(execId: String): Unit diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 26fdef101bb278ecd363b1c281995d0b8f3256fc..a9d9c5e44c72b6526e8223707e3acf2f41d88cdc 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -1,6 +1,6 @@ package spark.scheduler.cluster -import java.io.{File, FileInputStream, FileOutputStream} +import java.lang.{Boolean => JBoolean} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -25,6 +25,35 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong + // How often to revive offers in case there are pending tasks - that is how often to try to get + // tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior + // Note that this is required due to delayed scheduling due to data locality waits, etc. + // TODO: rename property ? + val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong + + /* + This property controls how aggressive we should be to modulate waiting for host local task scheduling. + To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for host locality of tasks before + scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order : + host-local, rack-local and then others + But once all available host local (and no pref) tasks are scheduled, instead of waiting for 3 sec before + scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can + modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is + maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap. + + TODO: rename property ? The value is one of + - HOST_LOCAL (default, no change w.r.t current behavior), + - RACK_LOCAL and + - ANY + + Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective. + + Additional Note: For non trivial clusters, there is a 4x - 5x reduction in running time (in some of our experiments) based on whether + it is left at default HOST_LOCAL, RACK_LOCAL (if cluster is configured to be rack aware) or ANY. + If cluster is rack aware, then setting it to RACK_LOCAL gives best tradeoff and a 3x - 4x performance improvement while minimizing IO impact. + Also, it brings down the variance in running time drastically. + */ + val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "HOST_LOCAL")) val activeTaskSets = new HashMap[String, TaskSetManager] var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] @@ -33,9 +62,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] - var hasReceivedTask = false - var hasLaunchedTask = false - val starvationTimer = new Timer(true) + @volatile private var hasReceivedTask = false + @volatile private var hasLaunchedTask = false + private val starvationTimer = new Timer(true) // Incrementing Mesos task IDs val nextTaskId = new AtomicLong(0) @@ -43,11 +72,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // Which executor IDs we have executors on val activeExecutorIds = new HashSet[String] + // TODO: We might want to remove this and merge it with execId datastructures - but later. + // Which hosts in the cluster are alive (contains hostPort's) - used for hyper local and local task locality. + private val hostPortsAlive = new HashSet[String] + private val hostToAliveHostPorts = new HashMap[String, HashSet[String]] + // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - val executorsByHost = new HashMap[String, HashSet[String]] + val executorsByHostPort = new HashMap[String, HashSet[String]] - val executorIdToHost = new HashMap[String, String] + val executorIdToHostPort = new HashMap[String, String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null @@ -75,11 +109,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def start() { backend.start() - if (System.getProperty("spark.speculation", "false") == "true") { + if (JBoolean.getBoolean("spark.speculation")) { new Thread("ClusterScheduler speculation check") { setDaemon(true) override def run() { + logInfo("Starting speculative execution thread") while (true) { try { Thread.sleep(SPECULATION_INTERVAL) @@ -91,6 +126,27 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } }.start() } + + + // Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ? + if (TASK_REVIVAL_INTERVAL > 0) { + new Thread("ClusterScheduler task offer revival check") { + setDaemon(true) + + override def run() { + logInfo("Starting speculative task offer revival thread") + while (true) { + try { + Thread.sleep(TASK_REVIVAL_INTERVAL) + } catch { + case e: InterruptedException => {} + } + + if (hasPendingTasks()) backend.reviveOffers() + } + } + }.start() + } } override def submitTasks(taskSet: TaskSet) { @@ -139,22 +195,92 @@ private[spark] class ClusterScheduler(val sc: SparkContext) SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { - executorIdToHost(o.executorId) = o.hostname - if (!executorsByHost.contains(o.hostname)) { - executorsByHost(o.hostname) = new HashSet() + // DEBUG Code + Utils.checkHostPort(o.hostPort) + + executorIdToHostPort(o.executorId) = o.hostPort + if (! executorsByHostPort.contains(o.hostPort)) { + executorsByHostPort(o.hostPort) = new HashSet[String]() } + + hostPortsAlive += o.hostPort + hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort) + executorGained(o.executorId, o.hostPort) } // Build a list of tasks to assign to each slave val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = offers.map(o => o.cores).toArray var launchedTask = false + + for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { + + // Split offers based on host local, rack local and off-rack tasks. + val hostLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]() + val otherOffers = new HashMap[String, ArrayBuffer[Int]]() + + for (i <- 0 until offers.size) { + val hostPort = offers(i).hostPort + // DEBUG code + Utils.checkHostPort(hostPort) + val host = Utils.parseHostPort(hostPort)._1 + val numHostLocalTasks = math.max(0, math.min(manager.numPendingTasksForHost(hostPort), availableCpus(i))) + if (numHostLocalTasks > 0){ + val list = hostLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + for (j <- 0 until numHostLocalTasks) list += i + } + + val numRackLocalTasks = math.max(0, + // Remove host local tasks (which are also rack local btw !) from this + math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numHostLocalTasks, availableCpus(i))) + if (numRackLocalTasks > 0){ + val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + for (j <- 0 until numRackLocalTasks) list += i + } + if (numHostLocalTasks <= 0 && numRackLocalTasks <= 0){ + // add to others list - spread even this across cluster. + val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int]) + list += i + } + } + + val offersPriorityList = new ArrayBuffer[Int]( + hostLocalOffers.size + rackLocalOffers.size + otherOffers.size) + // First host local, then rack, then others + val numHostLocalOffers = { + val hostLocalPriorityList = ClusterScheduler.prioritizeContainers(hostLocalOffers) + offersPriorityList ++= hostLocalPriorityList + hostLocalPriorityList.size + } + val numRackLocalOffers = { + val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers) + offersPriorityList ++= rackLocalPriorityList + rackLocalPriorityList.size + } + offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers) + + var lastLoop = false + val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match { + case TaskLocality.HOST_LOCAL => numHostLocalOffers + case TaskLocality.RACK_LOCAL => numRackLocalOffers + numHostLocalOffers + case TaskLocality.ANY => offersPriorityList.size + } + do { launchedTask = false - for (i <- 0 until offers.size) { + var loopCount = 0 + for (i <- offersPriorityList) { val execId = offers(i).executorId - val host = offers(i).hostname - manager.slaveOffer(execId, host, availableCpus(i)) match { + val hostPort = offers(i).hostPort + + // If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing) + val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null + + // If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ... + loopCount += 1 + + manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match { case Some(task) => tasks(i) += task val tid = task.taskId @@ -162,15 +288,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext) taskSetTaskIds(manager.taskSet.id) += tid taskIdToExecutorId(tid) = execId activeExecutorIds += execId - executorsByHost(host) += execId + executorsByHostPort(hostPort) += execId availableCpus(i) -= 1 launchedTask = true - + case None => {} } } + // Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of + // data locality (we still go in order of priority : but that would not change anything since + // if data local tasks had been available, we would have scheduled them already) + if (lastLoop) { + // prevent more looping + launchedTask = false + } else if (!lastLoop && !launchedTask) { + // Do this only if TASK_SCHEDULING_AGGRESSION != HOST_LOCAL + if (TASK_SCHEDULING_AGGRESSION != TaskLocality.HOST_LOCAL) { + // fudge launchedTask to ensure we loop once more + launchedTask = true + // dont loop anymore + lastLoop = true + } + } } while (launchedTask) } + if (tasks.size > 0) { hasLaunchedTask = true } @@ -256,10 +398,15 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (jarServer != null) { jarServer.stop() } + + // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out. + // TODO: Do something better ! + Thread.sleep(5000L) } override def defaultParallelism() = backend.defaultParallelism() + // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false @@ -273,12 +420,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } + // Check for pending tasks in all our active jobs. + def hasPendingTasks(): Boolean = { + synchronized { + activeTaskSetsQueue.exists( _.hasPendingTasks() ) + } + } + def executorLost(executorId: String, reason: ExecutorLossReason) { var failedExecutor: Option[String] = None + synchronized { if (activeExecutorIds.contains(executorId)) { - val host = executorIdToHost(executorId) - logError("Lost executor %s on %s: %s".format(executorId, host, reason)) + val hostPort = executorIdToHostPort(executorId) + logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason)) removeExecutor(executorId) failedExecutor = Some(executorId) } else { @@ -296,19 +451,95 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } - /** Get a list of hosts that currently have executors */ - def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet - /** Remove an executor from all our data structures and mark it as lost */ private def removeExecutor(executorId: String) { activeExecutorIds -= executorId - val host = executorIdToHost(executorId) - val execs = executorsByHost.getOrElse(host, new HashSet) + val hostPort = executorIdToHostPort(executorId) + if (hostPortsAlive.contains(hostPort)) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + hostPortsAlive -= hostPort + hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort) + } + + val execs = executorsByHostPort.getOrElse(hostPort, new HashSet) execs -= executorId if (execs.isEmpty) { - executorsByHost -= host + executorsByHostPort -= hostPort } - executorIdToHost -= executorId - activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + executorIdToHostPort -= executorId + activeTaskSetsQueue.foreach(_.executorLost(executorId, hostPort)) + } + + def executorGained(execId: String, hostPort: String) { + listener.executorGained(execId, hostPort) + } + + def getExecutorsAliveOnHost(host: String): Option[Set[String]] = { + val retval = hostToAliveHostPorts.get(host) + if (retval.isDefined) { + return Some(retval.get.toSet) + } + + None + } + + // By default, rack is unknown + def getRackForHost(value: String): Option[String] = None + + // By default, (cached) hosts for rack is unknown + def getCachedHostsForRack(rack: String): Option[Set[String]] = None +} + + +object ClusterScheduler { + + // Used to 'spray' available containers across the available set to ensure too many containers on same host + // are not used up. Used in yarn mode and in task scheduling (when there are multiple containers available + // to execute a task) + // For example: yarn can returns more containers than we would have requested under ANY, this method + // prioritizes how to use the allocated containers. + // flatten the map such that the array buffer entries are spread out across the returned value. + // given <host, list[container]> == <h1, [c1 .. c5]>, <h2, [c1 .. c3]>, <h3, [c1, c2]>, <h4, c1>, <h5, c1>, i + // the return value would be something like : h1c1, h2c1, h3c1, h4c1, h5c1, h1c2, h2c2, h3c2, h1c3, h2c3, h1c4, h1c5 + // We then 'use' the containers in this order (consuming only the top K from this list where + // K = number to be user). This is to ensure that if we have multiple eligible allocations, + // they dont end up allocating all containers on a small number of hosts - increasing probability of + // multiple container failure when a host goes down. + // Note, there is bias for keys with higher number of entries in value to be picked first (by design) + // Also note that invocation of this method is expected to have containers of same 'type' + // (host-local, rack-local, off-rack) and not across types : so that reordering is simply better from + // the available list - everything else being same. + // That is, we we first consume data local, then rack local and finally off rack nodes. So the + // prioritization from this method applies to within each category + def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { + val _keyList = new ArrayBuffer[K](map.size) + _keyList ++= map.keys + + // order keyList based on population of value in map + val keyList = _keyList.sortWith( + (left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size + ) + + val retval = new ArrayBuffer[T](keyList.size * 2) + var index = 0 + var found = true + + while (found){ + found = false + for (key <- keyList) { + val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null) + assert(containerList != null) + // Get the index'th entry for this host - if present + if (index < containerList.size){ + retval += containerList.apply(index) + found = true + } + } + index += 1 + } + + retval.toList } } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index bb289c9cf391b4d1dcc6f94113117b9468c4a578..0b8922d139490a0fd73e66a48a94e12cc06972cf 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -57,9 +57,9 @@ private[spark] class SparkDeploySchedulerBackend( } } - override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format( - executorId, host, cores, Utils.memoryMegabytesToString(memory))) + override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { + logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( + executorId, hostPort, cores, Utils.memoryMegabytesToString(memory))) } override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) { diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index d7660678248b2d9f028d8364bd3caa8cbd47660c..333529484460e3dae9865d9f105edcca766e81a7 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -3,6 +3,7 @@ package spark.scheduler.cluster import spark.TaskState.TaskState import java.nio.ByteBuffer import spark.util.SerializableBuffer +import spark.Utils private[spark] sealed trait StandaloneClusterMessage extends Serializable @@ -19,8 +20,10 @@ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMess // Executors to driver private[spark] -case class RegisterExecutor(executorId: String, host: String, cores: Int) - extends StandaloneClusterMessage +case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) + extends StandaloneClusterMessage { + Utils.checkHostPort(hostPort, "Expected host port") +} private[spark] case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 7a428e3361976b3a8d78b8a8c1c1a7f771399c1a..004592a54043857fe1102bcb6b4161823fd0fc14 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -5,8 +5,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import akka.actor._ import akka.util.duration._ import akka.pattern.ask +import akka.util.Duration -import spark.{SparkException, Logging, TaskState} +import spark.{Utils, SparkException, Logging, TaskState} import akka.dispatch.Await import java.util.concurrent.atomic.AtomicInteger import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} @@ -24,12 +25,12 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor var totalCoreCount = new AtomicInteger(0) class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { - val executorActor = new HashMap[String, ActorRef] - val executorAddress = new HashMap[String, Address] - val executorHost = new HashMap[String, String] - val freeCores = new HashMap[String, Int] - val actorToExecutorId = new HashMap[ActorRef, String] - val addressToExecutorId = new HashMap[Address, String] + private val executorActor = new HashMap[String, ActorRef] + private val executorAddress = new HashMap[String, Address] + private val executorHostPort = new HashMap[String, String] + private val freeCores = new HashMap[String, Int] + private val actorToExecutorId = new HashMap[ActorRef, String] + private val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -37,7 +38,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } def receive = { - case RegisterExecutor(executorId, host, cores) => + case RegisterExecutor(executorId, hostPort, cores) => + Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorActor.contains(executorId)) { sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { @@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! RegisteredExecutor(sparkProperties) context.watch(sender) executorActor(executorId) = sender - executorHost(executorId) = host + executorHostPort(executorId) = hostPort freeCores(executorId) = cores executorAddress(executorId) = sender.path.address actorToExecutorId(sender) = executorId @@ -85,13 +87,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers( - executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))})) } // Make fake resource offers on just one executor def makeOffers(executorId: String) { launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) + Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId))))) } // Launch tasks returned by a set of resource offers @@ -110,9 +112,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor actorToExecutorId -= executorActor(executorId) addressToExecutorId -= executorAddress(executorId) executorActor -= executorId - executorHost -= executorId + executorHostPort -= executorId freeCores -= executorId - executorHost -= executorId + executorHostPort -= executorId totalCoreCount.addAndGet(-numCores) scheduler.executorLost(executorId, SlaveLost(reason)) } @@ -128,7 +130,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor while (iterator.hasNext) { val entry = iterator.next val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) { + if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { properties += ((key, value)) } } @@ -136,10 +138,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) } + private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + override def stop() { try { if (driverActor != null) { - val timeout = 5.seconds val future = driverActor.ask(StopDriver)(timeout) Await.result(future, timeout) } @@ -159,7 +162,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val timeout = 5.seconds val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index dfe3c5a85bc25f47a85a0da31a4649b27ef2c3be..718f26bfbd74e96f38d4fa5a7b3418022bb8ca34 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -1,5 +1,7 @@ package spark.scheduler.cluster +import spark.Utils + /** * Information about a running task attempt inside a TaskSet. */ @@ -9,8 +11,11 @@ class TaskInfo( val index: Int, val launchTime: Long, val executorId: String, - val host: String, - val preferred: Boolean) { + val hostPort: String, + val taskLocality: TaskLocality.TaskLocality) { + + Utils.checkHostPort(hostPort, "Expected hostport") + var finishTime: Long = 0 var failed = false diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index c9f2c488048ca2628387165ac498d7346da45627..27e713e2c43dd5d0629f37cf22cb46d47a012ed1 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -1,7 +1,6 @@ package spark.scheduler.cluster -import java.util.Arrays -import java.util.{HashMap => JHashMap} +import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -14,6 +13,36 @@ import spark.scheduler._ import spark.TaskState.TaskState import java.nio.ByteBuffer +private[spark] object TaskLocality extends Enumeration("HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging { + + val HOST_LOCAL, RACK_LOCAL, ANY = Value + + type TaskLocality = Value + + def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = { + + constraint match { + case TaskLocality.HOST_LOCAL => condition == TaskLocality.HOST_LOCAL + case TaskLocality.RACK_LOCAL => condition == TaskLocality.HOST_LOCAL || condition == TaskLocality.RACK_LOCAL + // For anything else, allow + case _ => true + } + } + + def parse(str: String): TaskLocality = { + // better way to do this ? + try { + TaskLocality.withName(str) + } catch { + case nEx: NoSuchElementException => { + logWarning("Invalid task locality specified '" + str + "', defaulting to HOST_LOCAL"); + // default to preserve earlier behavior + HOST_LOCAL + } + } + } +} + /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ @@ -47,14 +76,22 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Last time when we launched a preferred task (for delay scheduling) var lastPreferredLaunchTime = System.currentTimeMillis - // List of pending tasks for each node. These collections are actually + // List of pending tasks for each node (hyper local to container). These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect // tasks that repeatedly fail because whenever a task failed, it is put // back at the head of the stack. They are also only cleaned up lazily; // when a task is launched, it remains in all the pending lists except // the one that it was launched from, but gets removed from them later. - val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node. + // Essentially, similar to pendingTasksForHostPort, except at host level + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List of pending tasks for each node based on rack locality. + // Essentially, similar to pendingTasksForHost, except at rack level + private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]] // List containing pending tasks with no locality preferences val pendingTasksWithNoPrefs = new ArrayBuffer[Int] @@ -96,26 +133,117 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe addPendingTask(i) } + private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, rackLocal: Boolean = false): ArrayBuffer[String] = { + // DEBUG code + _taskPreferredLocations.foreach(h => Utils.checkHost(h, "taskPreferredLocation " + _taskPreferredLocations)) + + val taskPreferredLocations = if (! rackLocal) _taskPreferredLocations else { + // Expand set to include all 'seen' rack local hosts. + // This works since container allocation/management happens within master - so any rack locality information is updated in msater. + // Best case effort, and maybe sort of kludge for now ... rework it later ? + val hosts = new HashSet[String] + _taskPreferredLocations.foreach(h => { + val rackOpt = scheduler.getRackForHost(h) + if (rackOpt.isDefined) { + val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get) + if (hostsOpt.isDefined) { + hosts ++= hostsOpt.get + } + } + + // Ensure that irrespective of what scheduler says, host is always added ! + hosts += h + }) + + hosts + } + + val retval = new ArrayBuffer[String] + scheduler.synchronized { + for (prefLocation <- taskPreferredLocations) { + val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(prefLocation) + if (aliveLocationsOpt.isDefined) { + retval ++= aliveLocationsOpt.get + } + } + } + + retval + } + // Add a task to all the pending-task lists that it should be on. private def addPendingTask(index: Int) { - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive - if (locations.size == 0) { + // We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate + // hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it. + val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched) + val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, true) + + if (rackLocalLocations.size == 0) { + // Current impl ensures this. + assert (hostLocalLocations.size == 0) pendingTasksWithNoPrefs += index } else { - for (host <- locations) { - val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + + // host locality + for (hostPort <- hostLocalLocations) { + // DEBUG Code + Utils.checkHostPort(hostPort) + + val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer()) + hostPortList += index + + val host = Utils.parseHostPort(hostPort)._1 + val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + hostList += index + } + + // rack locality + for (rackLocalHostPort <- rackLocalLocations) { + // DEBUG Code + Utils.checkHostPort(rackLocalHostPort) + + val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1 + val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer()) list += index } } + allPendingTasks += index } + // Return the pending tasks list for a given host port (hyper local), or an empty list if + // there is no map entry for that host + private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = { + // DEBUG Code + Utils.checkHostPort(hostPort) + pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer()) + } + // Return the pending tasks list for a given host, or an empty list if // there is no map entry for that host - private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 pendingTasksForHost.getOrElse(host, ArrayBuffer()) } + // Return the pending tasks (rack level) list for a given host, or an empty list if + // there is no map entry for that host + private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = { + val host = Utils.parseHostPort(hostPort)._1 + pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Number of pending tasks for a given host (which would be data local) + def numPendingTasksForHost(hostPort: String): Int = { + getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Number of pending rack local tasks for a given host + def numRackLocalPendingTasksForHost(hostPort: String): Int = { + getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) ) + } + + // Dequeue a pending task from the given list and return its index. // Return None if the list is empty. // This method also cleans up any tasks in the list that have already @@ -132,26 +260,49 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } // Return a speculative task for a given host if any are available. The task should not have an - // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the - // task must have a preference for this host (or no preferred locations at all). - private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { - val hostsAlive = sched.hostsAlive + // attempt running on this host, in case the host is slow. In addition, if locality is set, the + // task must have a preference for this host/rack/no preferred locations at all. + private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + + assert (TaskLocality.isAllowed(locality, TaskLocality.HOST_LOCAL)) speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - val localTask = speculatableTasks.find { - index => - val locations = tasks(index).preferredLocations.toSet & hostsAlive - val attemptLocs = taskAttempts(index).map(_.host) - (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) + + if (speculatableTasks.size > 0) { + val localTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched) + val attemptLocs = taskAttempts(index).map(_.hostPort) + (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort) + } + + if (localTask != None) { + speculatableTasks -= localTask.get + return localTask } - if (localTask != None) { - speculatableTasks -= localTask.get - return localTask - } - if (!localOnly && speculatableTasks.size > 0) { - val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host)) - if (nonLocalTask != None) { - speculatableTasks -= nonLocalTask.get - return nonLocalTask + + // check for rack locality + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackTask = speculatableTasks.find { + index => + val locations = findPreferredLocations(tasks(index).preferredLocations, sched, true) + val attemptLocs = taskAttempts(index).map(_.hostPort) + locations.contains(hostPort) && !attemptLocs.contains(hostPort) + } + + if (rackTask != None) { + speculatableTasks -= rackTask.get + return rackTask + } + } + + // Any task ... + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + // Check for attemptLocs also ? + val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort)) + if (nonLocalTask != None) { + speculatableTasks -= nonLocalTask.get + return nonLocalTask + } } } return None @@ -159,59 +310,103 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. - private def findTask(host: String, localOnly: Boolean): Option[Int] = { - val localTask = findTaskFromList(getPendingTasksForHost(host)) + private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = { + val localTask = findTaskFromList(getPendingTasksForHost(hostPort)) if (localTask != None) { return localTask } + + if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort)) + if (rackLocalTask != None) { + return rackLocalTask + } + } + + // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner. + // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down). val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) if (noPrefTask != None) { return noPrefTask } - if (!localOnly) { + + if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { val nonLocalTask = findTaskFromList(allPendingTasks) if (nonLocalTask != None) { return nonLocalTask } } + // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(host, localOnly) + return findSpeculativeTask(hostPort, locality) } // Does a host count as a preferred location for a task? This is true if // either the task has preferred locations and this host is one, or it has // no preferred locations (in which we still count the launch as preferred). - private def isPreferredLocation(task: Task[_], host: String): Boolean = { + private def isPreferredLocation(task: Task[_], hostPort: String): Boolean = { val locs = task.preferredLocations - return (locs.contains(host) || locs.isEmpty) + // DEBUG code + locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs)) + + if (locs.contains(hostPort) || locs.isEmpty) return true + + val host = Utils.parseHostPort(hostPort)._1 + locs.contains(host) + } + + // Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location). + // This is true if either the task has preferred locations and this host is one, or it has + // no preferred locations (in which we still count the launch as preferred). + def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = { + + val locs = task.preferredLocations + + // DEBUG code + locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs)) + + val preferredRacks = new HashSet[String]() + for (preferredHost <- locs) { + val rack = sched.getRackForHost(preferredHost) + if (None != rack) preferredRacks += rack.get + } + + if (preferredRacks.isEmpty) return false + + val hostRack = sched.getRackForHost(hostPort) + + return None != hostRack && preferredRacks.contains(hostRack.get) } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = { + def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = { + if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - val time = System.currentTimeMillis - val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) + // If explicitly specified, use that + val locality = if (overrideLocality != null) overrideLocality else { + // expand only if we have waited for more than LOCALITY_WAIT for a host local task ... + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.HOST_LOCAL else TaskLocality.ANY + } - findTask(host, localOnly) match { + findTask(hostPort, locality) match { case Some(index) => { // Found a task; do some bookkeeping and return a Mesos task for it val task = tasks(index) val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch - val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) { - "preferred" - } else { - "non-preferred, not one of " + task.preferredLocations.mkString(", ") - } - logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, prefStr)) + val taskLocality = if (isPreferredLocation(task, hostPort)) TaskLocality.HOST_LOCAL else + if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else TaskLocality.ANY + val prefStr = taskLocality.toString + logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( + taskSet.id, index, taskId, execId, hostPort, prefStr)) // Do various bookkeeping copiesRunning(index) += 1 - val info = new TaskInfo(taskId, index, time, execId, host, preferred) + val time = System.currentTimeMillis + val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) - if (preferred) { + if (TaskLocality.HOST_LOCAL == taskLocality) { lastPreferredLaunchTime = time } // Serialize and return the task @@ -355,17 +550,15 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe sched.taskSetFinished(this) } - def executorLost(execId: String, hostname: String) { + def executorLost(execId: String, hostPort: String) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - val newHostsAlive = sched.hostsAlive // If some task has preferred locations only on hostname, and there are no more executors there, // put it in the no-prefs list to avoid the wait from delay scheduling - if (!newHostsAlive.contains(hostname)) { - for (index <- getPendingTasksForHost(hostname)) { - val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index - } + for (index <- getPendingTasksForHostPort(hostPort)) { + val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, true) + if (newLocs.isEmpty) { + assert (findPreferredLocations(tasks(index).preferredLocations, sched).isEmpty) + pendingTasksWithNoPrefs += index } } // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage @@ -419,7 +612,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe !speculatableTasks.contains(index)) { logInfo( "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) + taskSet.id, index, info.hostPort, threshold)) speculatableTasks += index foundTasks = true } @@ -427,4 +620,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe } return foundTasks } + + def hasPendingTasks(): Boolean = { + numTasks > 0 && tasksFinished < numTasks + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala index 3c3afcbb14d3f7f677be8aae91128e35eef01c7f..c47824315ccbd6c61d2fafa27c7288cbc2d548e1 100644 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala @@ -4,5 +4,5 @@ package spark.scheduler.cluster * Represents free resources available on an executor. */ private[spark] -class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) { +class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) { } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 9e1bde3fbe44e7196f5e3bb46b4edd71729e8f4b..f060a940a9269bc2e5dfd8a57111fcbe2e719c7a 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -7,7 +7,7 @@ import scala.collection.mutable.HashMap import spark._ import spark.executor.ExecutorURLClassLoader import spark.scheduler._ -import spark.scheduler.cluster.TaskInfo +import spark.scheduler.cluster.{TaskLocality, TaskInfo} /** * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally @@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon def runTask(task: Task[_], idInJob: Int, attemptId: Int) { logInfo("Running " + task) - val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true) + val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.HOST_LOCAL) // Set the Spark execution environment for the worker thread SparkEnv.set(env) try { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 210061e9726b4108e4ebe1c00c138f2ed3d4e85b..6e861ac734604c3455a39946027ea6ffaef11eab 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -37,17 +37,27 @@ class BlockManager( maxMemory: Long) extends Logging { - class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - var pending: Boolean = true - var size: Long = -1L - var failed: Boolean = false + private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { + @volatile var pending: Boolean = true + @volatile var size: Long = -1L + @volatile var initThread: Thread = null + @volatile var failed = false + + setInitThread() + + private def setInitThread() { + // Set current thread as init thread - waitForReady will not block this thread + // (in case there is non trivial initialization which ends up calling waitForReady as part of + // initialization itself) + this.initThread = Thread.currentThread() + } /** * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). * Return true if the block is available, false otherwise. */ def waitForReady(): Boolean = { - if (pending) { + if (initThread != Thread.currentThread() && pending) { synchronized { while (pending) this.wait() } @@ -57,19 +67,26 @@ class BlockManager( /** Mark this BlockInfo as ready (i.e. block is finished writing) */ def markReady(sizeInBytes: Long) { + assert (pending) + size = sizeInBytes + initThread = null + failed = false + initThread = null + pending = false synchronized { - pending = false - failed = false - size = sizeInBytes this.notifyAll() } } /** Mark this BlockInfo as ready but failed */ def markFailure() { + assert (pending) + size = 0 + initThread = null + failed = true + initThread = null + pending = false synchronized { - failed = true - pending = false this.notifyAll() } } @@ -101,7 +118,7 @@ class BlockManager( val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties - val host = System.getProperty("spark.hostname", Utils.localHostName()) + val hostPort = Utils.localHostPort() val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) @@ -212,9 +229,12 @@ class BlockManager( * Tell the master about the current storage status of a block. This will send a block update * message reflecting the current status, *not* the desired storage level in its block info. * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. + * + * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). + * This ensures that update in master will compensate for the increase in memory on slave. */ - def reportBlockStatus(blockId: String, info: BlockInfo) { - val needReregister = !tryToReportBlockStatus(blockId, info) + def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { + val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) if (needReregister) { logInfo("Got told to reregister updating block " + blockId) // Reregistering will report our new block for free. @@ -228,7 +248,7 @@ class BlockManager( * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ - private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = { + private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { info.level match { case null => @@ -237,7 +257,7 @@ class BlockManager( val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) - val memSize = if (inMem) memoryStore.getSize(blockId) else 0L + val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L (storageLevel, memSize, diskSize, info.tellMaster) } @@ -257,7 +277,7 @@ class BlockManager( def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis var managers = master.getLocations(blockId) - val locations = managers.map(_.ip) + val locations = managers.map(_.hostPort) logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -267,7 +287,7 @@ class BlockManager( */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray + val locations = master.getLocations(blockIds).map(_.map(_.hostPort).toSeq).toArray logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -339,6 +359,8 @@ class BlockManager( case Some(bytes) => // Put a copy of the block back in memory before returning it. Note that we can't // put the ByteBuffer returned by the disk store as that's a memory-mapped file. + // The use of rewind assumes this. + assert (0 == bytes.position()) val copyForMemory = ByteBuffer.allocate(bytes.limit) copyForMemory.put(bytes) memoryStore.putBytes(blockId, copyForMemory, level) @@ -411,6 +433,7 @@ class BlockManager( // Read it as a byte buffer into memory first, then return it diskStore.getBytes(blockId) match { case Some(bytes) => + assert (0 == bytes.position()) if (level.useMemory) { if (level.deserialized) { memoryStore.putBytes(blockId, bytes, level) @@ -450,7 +473,7 @@ class BlockManager( for (loc <- locations) { logDebug("Getting remote block " + blockId + " from " + loc) val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port)) + GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) if (data != null) { return Some(dataDeserialize(blockId, data)) } @@ -501,17 +524,26 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId).orNull - if (oldBlock != null && oldBlock.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return oldBlock.size - } - // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = new BlockInfo(level, tellMaster) - blockInfo.put(blockId, myInfo) + val myInfo = { + val tinfo = new BlockInfo(level, tellMaster) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + + if (oldBlockOpt.isDefined) { + if (oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return oldBlockOpt.get.size + } + + // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + oldBlockOpt.get + } else { + tinfo + } + } val startTimeMs = System.currentTimeMillis @@ -531,6 +563,7 @@ class BlockManager( logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") + var marked = false try { if (level.useMemory) { // Save it just to memory first, even if it also has useDisk set to true; we will later @@ -555,20 +588,20 @@ class BlockManager( // Now that the block is in either the memory or disk store, let other threads read it, // and tell the master about it. + marked = true myInfo.markReady(size) if (tellMaster) { reportBlockStatus(blockId, myInfo) } - } catch { + } finally { // If we failed at putting the block to memory/disk, notify other possible readers // that it has failed, and then remove it from the block info map. - case e: Exception => { + if (! marked) { // Note that the remove must happen before markFailure otherwise another thread // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) myInfo.markFailure() - logWarning("Putting block " + blockId + " failed", e) - throw e + logWarning("Putting block " + blockId + " failed") } } } @@ -611,16 +644,26 @@ class BlockManager( throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.contains(blockId)) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return - } - // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = new BlockInfo(level, tellMaster) - blockInfo.put(blockId, myInfo) + val myInfo = { + val tinfo = new BlockInfo(level, tellMaster) + // Do atomically ! + val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + + if (oldBlockOpt.isDefined) { + if (oldBlockOpt.get.waitForReady()) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return + } + + // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + oldBlockOpt.get + } else { + tinfo + } + } val startTimeMs = System.currentTimeMillis @@ -639,6 +682,7 @@ class BlockManager( logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") + var marked = false try { if (level.useMemory) { // Store it only in memory at first, even if useDisk is also set to true @@ -649,22 +693,24 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } + // assert (0 == bytes.position(), "" + bytes) + // Now that the block is in either the memory or disk store, let other threads read it, // and tell the master about it. + marked = true myInfo.markReady(bytes.limit) if (tellMaster) { reportBlockStatus(blockId, myInfo) } - } catch { + } finally { // If we failed at putting the block to memory/disk, notify other possible readers // that it has failed, and then remove it from the block info map. - case e: Exception => { + if (! marked) { // Note that the remove must happen before markFailure otherwise another thread // could've inserted a new BlockInfo before we remove it. blockInfo.remove(blockId) myInfo.markFailure() - logWarning("Putting block " + blockId + " failed", e) - throw e + logWarning("Putting block " + blockId + " failed") } } } @@ -698,7 +744,7 @@ class BlockManager( logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " + data.limit() + " Bytes. To node: " + peer) if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), - new ConnectionManagerId(peer.ip, peer.port))) { + new ConnectionManagerId(peer.host, peer.port))) { logError("Failed to call syncPutBlock to " + peer) } logDebug("Replicated BlockId " + blockId + " once used " + @@ -730,6 +776,14 @@ class BlockManager( val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { + // required ? As of now, this will be invoked only for blocks which are ready + // But in case this changes in future, adding for consistency sake. + if (! info.waitForReady() ) { + // If we get here, the block write failed. + logWarning("Block " + blockId + " was marked as failure. Nothing to drop") + return + } + val level = info.level if (level.useDisk && !diskStore.contains(blockId)) { logInfo("Writing block " + blockId + " to disk") @@ -740,12 +794,13 @@ class BlockManager( diskStore.putBytes(blockId, bytes, level) } } + val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val blockWasRemoved = memoryStore.remove(blockId) if (!blockWasRemoved) { logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") } if (info.tellMaster) { - reportBlockStatus(blockId, info) + reportBlockStatus(blockId, info, droppedMemorySize) } if (!level.useDisk) { // The block is completely gone from this node; forget it so we can put() it again later. @@ -938,8 +993,8 @@ class BlockFetcherIterator( def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip)) - val cmId = new ConnectionManagerId(req.address.ip, req.address.port) + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) val blockMessageArray = new BlockMessageArray(req.blocks.map { case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) }) @@ -1037,7 +1092,7 @@ class BlockFetcherIterator( logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") //an iterator that will read fetched blocks off the queue as they arrive. - var resultsGotten = 0 + @volatile private var resultsGotten = 0 def hasNext: Boolean = resultsGotten < totalBlocks @@ -1047,7 +1102,7 @@ class BlockFetcherIterator( val result = results.take() val stopFetchWait = System.currentTimeMillis() _fetchWaitTime += (stopFetchWait - startFetchWait) - bytesInFlight -= result.size + if (! result.failed) bytesInFlight -= result.size while (!fetchRequests.isEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala index f2f1e77d41a65cb810fd58904626dccb7b5caaff..f4a2181490f9fe9aceb67e2bc7c7aeb2198b54ea 100644 --- a/core/src/main/scala/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -2,6 +2,7 @@ package spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +import spark.Utils /** * This class represent an unique identifier for a BlockManager. @@ -13,7 +14,7 @@ import java.util.concurrent.ConcurrentHashMap */ private[spark] class BlockManagerId private ( private var executorId_ : String, - private var ip_ : String, + private var host_ : String, private var port_ : Int ) extends Externalizable { @@ -21,32 +22,45 @@ private[spark] class BlockManagerId private ( def executorId: String = executorId_ - def ip: String = ip_ + if (null != host_){ + Utils.checkHost(host_, "Expected hostname") + assert (port_ > 0) + } + + def hostPort: String = { + // DEBUG code + Utils.checkHost(host) + assert (port > 0) + + host + ":" + port + } + + def host: String = host_ def port: Int = port_ override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) - out.writeUTF(ip_) + out.writeUTF(host_) out.writeInt(port_) } override def readExternal(in: ObjectInput) { executorId_ = in.readUTF() - ip_ = in.readUTF() + host_ = in.readUTF() port_ = in.readInt() } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port) + override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, host, port) - override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port + override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port override def equals(that: Any) = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && ip == id.ip + executorId == id.executorId && port == id.port && host == id.host case _ => false } @@ -55,8 +69,8 @@ private[spark] class BlockManagerId private ( private[spark] object BlockManagerId { - def apply(execId: String, ip: String, port: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, ip, port)) + def apply(execId: String, host: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, host, port)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() @@ -67,11 +81,7 @@ private[spark] object BlockManagerId { val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { - if (blockManagerIdCache.containsKey(id)) { - blockManagerIdCache.get(id) - } else { - blockManagerIdCache.put(id, id) - id - } + blockManagerIdCache.putIfAbsent(id, id) + blockManagerIdCache.get(id) } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 036fdc3480119307f9c01094bfd47fd5e75c06e2..6fae62d373ba97c3bce412de31808451085df80a 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -22,7 +22,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" - val timeout = 10.seconds + val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") /** Remove a dead executor from the driver actor. This is only called on the driver side. */ def removeExecutor(execId: String) { diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala index 1a6a6cfd3fb55d145d790c0ccca0835eb9ce0d46..9b64f95df80731cd8188edb5a2a159410ccda153 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -333,8 +333,8 @@ object BlockManagerMasterActor { // Mapping from block id to its status. private val _blocks = new JHashMap[String, BlockStatus] - logInfo("Registering block manager %s:%d with %s RAM".format( - blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) + logInfo("Registering block manager %s with %s RAM".format( + blockManagerId.hostPort, Utils.memoryBytesToString(maxMem))) def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() @@ -359,13 +359,13 @@ object BlockManagerMasterActor { _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) if (storageLevel.useMemory) { _remainingMem -= memSize - logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + logInfo("Added %s in memory on %s (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize), Utils.memoryBytesToString(_remainingMem))) } if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s:%d (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + logInfo("Added %s on disk on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize))) } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. @@ -373,13 +373,13 @@ object BlockManagerMasterActor { _blocks.remove(blockId) if (blockStatus.storageLevel.useMemory) { _remainingMem += blockStatus.memSize - logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize), Utils.memoryBytesToString(_remainingMem))) } if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s:%d on disk (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + logInfo("Removed %s on %s on disk (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize))) } } } diff --git a/core/src/main/scala/spark/storage/BlockManagerUI.scala b/core/src/main/scala/spark/storage/BlockManagerUI.scala index 9e6721ec17169a8ca33753393329049467b105dd..07da5720440cd060c1a1ec67e83fe503e8d5e090 100644 --- a/core/src/main/scala/spark/storage/BlockManagerUI.scala +++ b/core/src/main/scala/spark/storage/BlockManagerUI.scala @@ -1,7 +1,7 @@ package spark.storage import akka.actor.{ActorRef, ActorSystem} -import akka.util.Timeout +import akka.util.Duration import akka.util.duration._ import cc.spray.typeconversion.TwirlSupport._ import cc.spray.Directives @@ -19,7 +19,7 @@ class BlockManagerUI(val actorSystem: ActorSystem, blockManagerMaster: ActorRef, val STATIC_RESOURCE_DIR = "spark/deploy/static" - implicit val timeout = Timeout(10 seconds) + implicit val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") /** Start a HTTP server to run the Web interface */ def start() { diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala index a25decb123d8651a7359d81c83d2b84d58334e9e..ee0c5ff9a2d1cdd97f8158f766e1949b73ca5a83 100644 --- a/core/src/main/scala/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala @@ -115,6 +115,7 @@ private[spark] object BlockMessageArray { val newBuffer = ByteBuffer.allocate(totalSize) newBuffer.clear() bufferMessage.buffers.foreach(buffer => { + assert (0 == buffer.position()) newBuffer.put(buffer) buffer.rewind() }) diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index ddbf8821ad15aa172cb248ef3aed45edb3e38f33..c9553d2e0fa9e2613a56696f87145acfef66771f 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -20,6 +20,9 @@ import spark.Utils private class DiskStore(blockManager: BlockManager, rootDirs: String) extends BlockStore(blockManager) { + private val mapMode = MapMode.READ_ONLY + private var mapOpenMode = "r" + val MAX_DIR_CREATION_ATTEMPTS: Int = 10 val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt @@ -35,7 +38,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getFile(blockId).length() } - override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // So that we do not modify the input offsets ! + // duplicate does not copy buffer, so inexpensive + val bytes = _bytes.duplicate() logDebug("Attempting to put block " + blockId) val startTime = System.currentTimeMillis val file = createFile(blockId) @@ -49,6 +55,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime))) } + private def getFileBytes(file: File): ByteBuffer = { + val length = file.length() + val channel = new RandomAccessFile(file, mapOpenMode).getChannel() + val buffer = try { + channel.map(mapMode, 0, length) + } finally { + channel.close() + } + + buffer + } + override def putValues( blockId: String, values: ArrayBuffer[Any], @@ -70,9 +88,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) if (returnValues) { // Return a byte buffer for the contents of the file - val channel = new RandomAccessFile(file, "r").getChannel() - val buffer = channel.map(MapMode.READ_ONLY, 0, length) - channel.close() + val buffer = getFileBytes(file) PutResult(length, Right(buffer)) } else { PutResult(length, null) @@ -81,10 +97,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def getBytes(blockId: String): Option[ByteBuffer] = { val file = getFile(blockId) - val length = file.length().toInt - val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = channel.map(MapMode.READ_ONLY, 0, length) - channel.close() + val bytes = getFileBytes(file) Some(bytes) } @@ -96,7 +109,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val file = getFile(blockId) if (file.exists()) { file.delete() - true } else { false } @@ -175,11 +187,12 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } private def addShutdownHook() { + localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir) ) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run() { logDebug("Shutdown hook called") try { - localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + localDirs.foreach(localDir => if (! Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)) } catch { case t: Throwable => logError("Exception while deleting local spark dirs", t) } diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 949588476c20150b1dd5c73f4303dbf85d2ad518..eba5ee507ff0c1d8e4fcaa7aef3eda716736a4ba 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -31,7 +31,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + // Work on a duplicate - since the original input might be used elsewhere. + val bytes = _bytes.duplicate() bytes.rewind() if (level.deserialized) { val values = blockManager.dataDeserialize(blockId, bytes) diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index 3b5a77ab228bb2df833cd04fa6e22e8ba93dd6dd..cc0c354e7e9ea204135c77daa91e1179f63f61ac 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -123,11 +123,7 @@ object StorageLevel { val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { - if (storageLevelCache.containsKey(level)) { - storageLevelCache.get(level) - } else { - storageLevelCache.put(level, level) - level - } + storageLevelCache.putIfAbsent(level, level) + storageLevelCache.get(level) } } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 3e805b78314c78267e86fa2c6f9a0e359637f624..9fb7e001badcbf5156a7755fa047b69764a8f586 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -11,7 +11,7 @@ import cc.spray.{SprayCanRootService, HttpService} import cc.spray.can.server.HttpServer import cc.spray.io.pipelines.MessageHandlerDispatch.SingletonHandler import akka.dispatch.Await -import spark.SparkException +import spark.{Utils, SparkException} import java.util.concurrent.TimeoutException /** @@ -31,7 +31,10 @@ private[spark] object AkkaUtils { val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt - val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean + val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. + val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt + val akkaConf = ConfigFactory.parseString(""" akka.daemonic = on akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] @@ -45,8 +48,9 @@ private[spark] object AkkaUtils { akka.remote.netty.execution-pool-size = %d akka.actor.default-dispatcher.throughput = %d akka.remote.log-remote-lifecycle-events = %s + akka.remote.netty.write-timeout = %ds """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, - if (lifecycleEvents) "on" else "off")) + lifecycleEvents, akkaWriteTimeout)) val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader) @@ -60,8 +64,9 @@ private[spark] object AkkaUtils { /** * Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to * handle requests. Returns the bound port or throws a SparkException on failure. + * TODO: Not changing ip to host here - is it required ? */ - def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, + def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route, name: String = "HttpServer"): ActorRef = { val ioWorker = new IoWorker(actorSystem).start() val httpService = actorSystem.actorOf(Props(new HttpService(route))) diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala index 188f8910da8d54f1b43a5a8d24fa08348a6d95cb..92dfaa6e6f3d1c8491df5ee933d634aa4199f2bc 100644 --- a/core/src/main/scala/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -3,6 +3,7 @@ package spark.util import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConversions import scala.collection.mutable.Map +import spark.scheduler.MapStatus /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion @@ -42,6 +43,13 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { this } + // Should we return previous value directly or as Option ? + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalMap.putIfAbsent(key, (value, currentTime)) + if (prev != null) Some(prev._1) else None + } + + override def -= (key: A): this.type = { internalMap.remove(key) this diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html index ac51a39a5199d4cd96636f602ba0a9e7ce8b23c0..b9b9f08810930a0e452e523e680bb7c44897be65 100644 --- a/core/src/main/twirl/spark/deploy/master/index.scala.html +++ b/core/src/main/twirl/spark/deploy/master/index.scala.html @@ -2,7 +2,7 @@ @import spark.deploy.master._ @import spark.Utils -@spark.common.html.layout(title = "Spark Master on " + state.host) { +@spark.common.html.layout(title = "Spark Master on " + state.host + ":" + state.port) { <!-- Cluster Details --> <div class="row"> diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html index c39f769a7387f96113c3f088a88fd6d1ac19351e..0e66af9284762abebd4acc8dbb0b8edd3aa58e9a 100644 --- a/core/src/main/twirl/spark/deploy/worker/index.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html @@ -1,7 +1,7 @@ @(worker: spark.deploy.WorkerState) @import spark.Utils -@spark.common.html.layout(title = "Spark Worker on " + worker.host) { +@spark.common.html.layout(title = "Spark Worker on " + worker.host + ":" + worker.port) { <!-- Worker Details --> <div class="row"> diff --git a/core/src/main/twirl/spark/storage/worker_table.scala.html b/core/src/main/twirl/spark/storage/worker_table.scala.html index d54b8de4cc81394149cceac0d52b965483e95a7c..cd72a688c1ffbd71e93f66e18b143d248b814ca3 100644 --- a/core/src/main/twirl/spark/storage/worker_table.scala.html +++ b/core/src/main/twirl/spark/storage/worker_table.scala.html @@ -12,7 +12,7 @@ <tbody> @for(status <- workersStatusList) { <tr> - <td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td> + <td>@(status.blockManagerId.host + ":" + status.blockManagerId.port)</td> <td> @(Utils.memoryBytesToString(status.memUsed(prefix))) (@(Utils.memoryBytesToString(status.memRemaining)) Total Available) diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 4104b33c8b6815ddebbe50ea595e633fb0cba46e..c9b4707def82be102423d0f18c3c2394c7d7e4ad 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -153,7 +153,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter val blockManager = SparkEnv.get.blockManager blockManager.master.getLocations(blockId).foreach(id => { val bytes = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(id.ip, id.port)) + GetBlock(blockId), ConnectionManagerId(id.host, id.port)) val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) }) diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 6da58a0f6e5fe3beeb6117382960e3c57ab679e9..c0f8986de8dd56a00e903276859c24391b3de2a8 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB")) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c2957e6cb42eb7863280f9b324e244578ba145f8..26424bbe52a9b8e5402080e83b3deca229617eed 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -5,18 +5,25 @@ title: Launching Spark on YARN Experimental support for running over a [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html) -cluster was added to Spark in version 0.6.0. Because YARN depends on version -2.0 of the Hadoop libraries, this currently requires checking out a separate -branch of Spark, called `yarn`, which you can do as follows: +cluster was added to Spark in version 0.6.0. This was merged into master as part of 0.7 effort. +To build spark core with YARN support, please use the hadoop2-yarn profile. +Ex: mvn -Phadoop2-yarn clean install - git clone git://github.com/mesos/spark - cd spark - git checkout -b yarn --track origin/yarn +# Building spark core consolidated jar. + +Currently, only sbt can buid a consolidated jar which contains the entire spark code - which is required for launching jars on yarn. +To do this via sbt - though (right now) is a manual process of enabling it in project/SparkBuild.scala. +Please comment out the + HADOOP_VERSION, HADOOP_MAJOR_VERSION and HADOOP_YARN +variables before the line 'For Hadoop 2 YARN support' +Next, uncomment the subsequent 3 variable declaration lines (for these three variables) which enable hadoop yarn support. + +Currnetly, it is a TODO to add support for maven assembly. # Preparations -- In order to distribute Spark within the cluster, it must be packaged into a single JAR file. This can be done by running `sbt/sbt assembly` +- Building spark core assembled jar (see above). - Your application code must be packaged into a separate JAR file. If you want to test out the YARN deployment mode, you can use the current Spark examples. A `spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}` file can be generated by running `sbt/sbt package`. NOTE: since the documentation you're reading is for Spark version {{site.SPARK_VERSION}}, we are assuming here that you have downloaded Spark {{site.SPARK_VERSION}} or checked it out of source control. If you are using a different version of Spark, the version numbers in the jar generated by the sbt package command will obviously be different. @@ -30,8 +37,11 @@ The command to launch the YARN Client is as follows: --class <APP_MAIN_CLASS> \ --args <APP_MAIN_ARGUMENTS> \ --num-workers <NUMBER_OF_WORKER_MACHINES> \ + --master-memory <MEMORY_FOR_MASTER> \ --worker-memory <MEMORY_PER_WORKER> \ - --worker-cores <CORES_PER_WORKER> + --worker-cores <CORES_PER_WORKER> \ + --user <hadoop_user> \ + --queue <queue_name> For example: @@ -40,8 +50,9 @@ For example: --class spark.examples.SparkPi \ --args standalone \ --num-workers 3 \ + --master-memory 4g \ --worker-memory 2g \ - --worker-cores 2 + --worker-cores 1 The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. @@ -49,3 +60,5 @@ The above starts a YARN Client programs which periodically polls the Application - When your application instantiates a Spark context it must use a special "standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "standalone" as an argument to your program, as shown in the example above. - YARN does not support requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed. +- Currently, we have not yet integrated with hadoop security. If --user is present, the hadoop_user specified will be used to run the tasks on the cluster. If unspecified, current user will be used (which should be valid in cluster). + Once hadoop security support is added, and if hadoop cluster is enabled with security, additional restrictions would apply via delegation tokens passed. diff --git a/examples/pom.xml b/examples/pom.xml index 270777e29c02774476e439d7d80c3b2fcda981e6..c42d2bcdb9ed55bcae303fa8844a8827e8a07ba1 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -118,5 +118,48 @@ </plugins> </build> </profile> + <profile> + <id>hadoop2-yarn</id> + <dependencies> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-core</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + </dependency> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-streaming</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-api</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-common</artifactId> + <scope>provided</scope> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <configuration> + <classifier>hadoop2-yarn</classifier> + </configuration> + </plugin> + </plugins> + </build> + </profile> </profiles> </project> diff --git a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala index 0f42f405a058cf9f2a218004b47c66d2330810d6..3d080a02577a38cdd2eeb1cd30a5c5e7b6fed629 100644 --- a/examples/src/main/scala/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/spark/examples/SparkHdfsLR.scala @@ -4,6 +4,8 @@ import java.util.Random import scala.math.exp import spark.util.Vector import spark._ +import spark.deploy.SparkHadoopUtil +import spark.scheduler.InputFormatInfo /** * Logistic regression based classification. @@ -32,9 +34,13 @@ object SparkHdfsLR { System.err.println("Usage: SparkHdfsLR <master> <file> <iters>") System.exit(1) } + val inputPath = args(1) + val conf = SparkHadoopUtil.newConfiguration() val sc = new SparkContext(args(0), "SparkHdfsLR", - System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) - val lines = sc.textFile(args(1)) + System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(), + InputFormatInfo.computePreferredLocations( + Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)))) + val lines = sc.textFile(inputPath) val points = lines.map(parsePoint _).cache() val ITERATIONS = args(2).toInt diff --git a/pom.xml b/pom.xml index 1174b475d325297fc4030acf1d0b627758ee020d..c3323ffad0f71dbe9b84efac787b6179bdd34c7b 100644 --- a/pom.xml +++ b/pom.xml @@ -417,8 +417,9 @@ <configuration> <reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory> <junitxml>.</junitxml> - <filereports>WDF TestSuite.txt</filereports> + <filereports>${project.build.directory}/SparkTestSuite.txt</filereports> <argLine>-Xms64m -Xmx1024m</argLine> + <stderr/> </configuration> <executions> <execution> @@ -558,5 +559,66 @@ </dependencies> </dependencyManagement> </profile> + + <profile> + <id>hadoop2-yarn</id> + <properties> + <hadoop.major.version>2</hadoop.major.version> + <!-- 0.23.* is same as 2.0.* - except hardened to run production jobs --> + <!-- <yarn.version>0.23.7</yarn.version> --> + <yarn.version>2.0.2-alpha</yarn.version> + </properties> + + <repositories> + <repository> + <id>maven-root</id> + <name>Maven root repository</name> + <url>http://repo1.maven.org/maven2/</url> + <releases> + <enabled>true</enabled> + </releases> + <snapshots> + <enabled>false</enabled> + </snapshots> + </repository> + </repositories> + + <dependencyManagement> + <dependencies> + <!-- TODO: check versions, bringover from yarn branch ! --> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + <version>${yarn.version}</version> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-api</artifactId> + <version>${yarn.version}</version> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-common</artifactId> + <version>${yarn.version}</version> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-client</artifactId> + <version>${yarn.version}</version> + </dependency> + <!-- Specify Avro version because Kafka also has it as a dependency --> + <dependency> + <groupId>org.apache.avro</groupId> + <artifactId>avro</artifactId> + <version>1.7.1.cloudera.2</version> + </dependency> + <dependency> + <groupId>org.apache.avro</groupId> + <artifactId>avro-ipc</artifactId> + <version>1.7.1.cloudera.2</version> + </dependency> + </dependencies> + </dependencyManagement> + </profile> </profiles> </project> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f0b371b2cf682c6e4b15f97053ed0646f396c38d..b3f410bfa63fcb52f742357d65e74ffffa211ad4 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1,3 +1,4 @@ + import sbt._ import sbt.Classpaths.publishTask import Keys._ @@ -10,12 +11,19 @@ import twirl.sbt.TwirlPlugin._ object SparkBuild extends Build { // Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or // "1.0.4" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop. - val HADOOP_VERSION = "1.0.4" - val HADOOP_MAJOR_VERSION = "1" + //val HADOOP_VERSION = "1.0.4" + //val HADOOP_MAJOR_VERSION = "1" // For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2" //val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1" //val HADOOP_MAJOR_VERSION = "2" + //val HADOOP_YARN = false + + // For Hadoop 2 YARN support + // val HADOOP_VERSION = "0.23.7" + val HADOOP_VERSION = "2.0.2-alpha" + val HADOOP_MAJOR_VERSION = "2" + val HADOOP_YARN = true lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming) @@ -40,7 +48,6 @@ object SparkBuild extends Build { scalacOptions := Seq("-unchecked", "-optimize", "-deprecation"), unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", transitiveClassifiers in Scope.GlobalScope := Seq("sources"), testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), @@ -99,6 +106,7 @@ object SparkBuild extends Build { */ libraryDependencies ++= Seq( + "io.netty" % "netty" % "3.5.3.Final", "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", "org.scalatest" %% "scalatest" % "1.9.1" % "test", "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", @@ -130,12 +138,13 @@ object SparkBuild extends Build { ), libraryDependencies ++= Seq( + "io.netty" % "netty" % "3.5.3.Final", "com.google.guava" % "guava" % "11.0.1", "log4j" % "log4j" % "1.2.16", "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion, + "commons-daemon" % "commons-daemon" % "1.0.10", "com.ning" % "compress-lzf" % "0.8.4", - "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.22", @@ -148,8 +157,32 @@ object SparkBuild extends Build { "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" % "spray-json_2.9.2" % "1.1.1", "org.apache.mesos" % "mesos" % "0.9.0-incubating" - ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, - unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } + ) ++ ( + if (HADOOP_MAJOR_VERSION == "2") { + if (HADOOP_YARN) { + Seq( + // Exclude rule required for all ? + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), + "org.apache.hadoop" % "hadoop-yarn-api" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), + "org.apache.hadoop" % "hadoop-yarn-common" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), + "org.apache.hadoop" % "hadoop-yarn-client" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ) + ) + } else { + Seq( + "org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ), + "org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ) + ) + } + } else { + Seq("org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION excludeAll( ExclusionRule(organization = "org.codehaus.jackson") ) ) + }), + unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / + ( if (HADOOP_YARN && HADOOP_MAJOR_VERSION == "2") { + "src/hadoop2-yarn/scala" + } else { + "src/hadoop" + HADOOP_MAJOR_VERSION + "/scala" + } ) + } ) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings def rootSettings = sharedSettings ++ Seq( @@ -181,6 +214,7 @@ object SparkBuild extends Build { def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard + case m if m.toLowerCase.matches("meta-inf/.*\\.sf$") => MergeStrategy.discard case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first } diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml index fe526a7616c7c8a8c22f45d41e7ba1cb1e6ca477..7a7280313edb7ab9f1bc5bb218198cf78305e799 100644 --- a/repl-bin/pom.xml +++ b/repl-bin/pom.xml @@ -153,6 +153,61 @@ </dependency> </dependencies> </profile> + <profile> + <id>hadoop2-yarn</id> + <properties> + <classifier>hadoop2-yarn</classifier> + </properties> + <dependencies> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-core</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + </dependency> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-bagel</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-examples</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-repl</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-api</artifactId> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-common</artifactId> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-client</artifactId> + <scope>runtime</scope> + </dependency> + </dependencies> + </profile> <profile> <id>deb</id> <build> diff --git a/repl/pom.xml b/repl/pom.xml index 0b5e400c3d14656f9ae1934d81d33594d31fc43b..038da5d9881c6fef98fc7cbe76137b66b0031123 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -187,5 +187,76 @@ </plugins> </build> </profile> + <profile> + <id>hadoop2-yarn</id> + <properties> + <classifier>hadoop2-yarn</classifier> + </properties> + <dependencies> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-core</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + </dependency> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-bagel</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-examples</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-streaming</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-api</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-common</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.avro</groupId> + <artifactId>avro</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.avro</groupId> + <artifactId>avro-ipc</artifactId> + <scope>provided</scope> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <configuration> + <classifier>hadoop2-yarn</classifier> + </configuration> + </plugin> + </plugins> + </build> + </profile> </profiles> </project> diff --git a/streaming/pom.xml b/streaming/pom.xml index b0d0cd0ff35ee2e0c5bd63b03f0083b31ceff089..08ff3e2ae12f48480d8d0526e6131254d127031e 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -149,5 +149,42 @@ </plugins> </build> </profile> + <profile> + <id>hadoop2-yarn</id> + <dependencies> + <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-core</artifactId> + <version>${project.version}</version> + <classifier>hadoop2-yarn</classifier> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-client</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-api</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-common</artifactId> + <scope>provided</scope> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <configuration> + <classifier>hadoop2-yarn</classifier> + </configuration> + </plugin> + </plugins> + </build> + </profile> </profiles> </project>