diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index aeb3f0062df3beab53179aecea90ada162d29ecf..4b5e0efdde92d927571a5f2c62a592d45628f818 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, ListBuffer, Map} +import scala.util.{Try, Success, Failure} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ @@ -378,7 +379,7 @@ trait ClientBase extends Logging { } } -object ClientBase { +object ClientBase extends Logging { val SPARK_JAR: String = "__spark__.jar" val APP_JAR: String = "__app__.jar" val LOG4J_PROP: String = "log4j.properties" @@ -388,37 +389,47 @@ object ClientBase { def getSparkJar = sys.env.get("SPARK_JAR").getOrElse(SparkContext.jarOfClass(this.getClass).head) - // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps - def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) { - val classpathEntries = Option(conf.getStrings( - YarnConfiguration.YARN_APPLICATION_CLASSPATH)).getOrElse( - getDefaultYarnApplicationClasspath()) - if (classpathEntries != null) { - for (c <- classpathEntries) { - YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, c.trim, - File.pathSeparator) - } + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) = { + val classPathElementsToAdd = getYarnAppClasspath(conf) ++ getMRAppClasspath(conf) + for (c <- classPathElementsToAdd.flatten) { + YarnSparkHadoopUtil.addToEnvironment( + env, + Environment.CLASSPATH.name, + c.trim, + File.pathSeparator) } + classPathElementsToAdd + } - val mrClasspathEntries = Option(conf.getStrings( - "mapreduce.application.classpath")).getOrElse( - getDefaultMRApplicationClasspath()) - if (mrClasspathEntries != null) { - for (c <- mrClasspathEntries) { - YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, c.trim, - File.pathSeparator) - } - } + private def getYarnAppClasspath(conf: Configuration): Option[Seq[String]] = + Option(conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) match { + case Some(s) => Some(s.toSeq) + case None => getDefaultYarnApplicationClasspath } - def getDefaultYarnApplicationClasspath(): Array[String] = { - try { - val field = classOf[MRJobConfig].getField("DEFAULT_YARN_APPLICATION_CLASSPATH") - field.get(null).asInstanceOf[Array[String]] - } catch { - case err: NoSuchFieldError => null - case err: NoSuchFieldException => null + private def getMRAppClasspath(conf: Configuration): Option[Seq[String]] = + Option(conf.getStrings("mapreduce.application.classpath")) match { + case Some(s) => Some(s.toSeq) + case None => getDefaultMRApplicationClasspath + } + + def getDefaultYarnApplicationClasspath: Option[Seq[String]] = { + val triedDefault = Try[Seq[String]] { + val field = classOf[YarnConfiguration].getField("DEFAULT_YARN_APPLICATION_CLASSPATH") + val value = field.get(null).asInstanceOf[Array[String]] + value.toSeq + } recoverWith { + case e: NoSuchFieldException => Success(Seq.empty[String]) } + + triedDefault match { + case f: Failure[_] => + logError("Unable to obtain the default YARN Application classpath.", f.exception) + case s: Success[_] => + logDebug(s"Using the default YARN application classpath: ${s.get.mkString(",")}") + } + + triedDefault.toOption } /** @@ -426,20 +437,30 @@ object ClientBase { * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. * So we need to use reflection to retrieve it. */ - def getDefaultMRApplicationClasspath(): Array[String] = { - try { + def getDefaultMRApplicationClasspath: Option[Seq[String]] = { + val triedDefault = Try[Seq[String]] { val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH") - if (field.getType == classOf[String]) { - StringUtils.getStrings(field.get(null).asInstanceOf[String]) + val value = if (field.getType == classOf[String]) { + StringUtils.getStrings(field.get(null).asInstanceOf[String]).toArray } else { field.get(null).asInstanceOf[Array[String]] } - } catch { - case err: NoSuchFieldError => null - case err: NoSuchFieldException => null + value.toSeq + } recoverWith { + case e: NoSuchFieldException => Success(Seq.empty[String]) } + + triedDefault match { + case f: Failure[_] => + logError("Unable to obtain the default MR Application classpath.", f.exception) + case s: Success[_] => + logDebug(s"Using the default MR application classpath: ${s.get.mkString(",")}") + } + + triedDefault.toOption } + /** * Returns the java command line argument for setting up log4j. If there is a log4j.properties * in the given local resources, it is used, otherwise the SPARK_LOG4J_CONF environment variable diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..608c6e92624c6d99ae92d479475691b8ff5e95df --- /dev/null +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.MRJobConfig +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers._ + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.util.Try + + +class ClientBaseSuite extends FunSuite { + + test("default Yarn application classpath") { + ClientBase.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP)) + } + + test("default MR application classpath") { + ClientBase.getDefaultMRApplicationClasspath should be(Some(Fixtures.knownDefMRAppCP)) + } + + test("resultant classpath for an application that defines a classpath for YARN") { + withAppConf(Fixtures.mapYARNAppConf) { conf => + val env = newEnv + ClientBase.populateHadoopClasspath(conf, env) + classpath(env) should be( + flatten(Fixtures.knownYARNAppCP, ClientBase.getDefaultMRApplicationClasspath)) + } + } + + test("resultant classpath for an application that defines a classpath for MR") { + withAppConf(Fixtures.mapMRAppConf) { conf => + val env = newEnv + ClientBase.populateHadoopClasspath(conf, env) + classpath(env) should be( + flatten(ClientBase.getDefaultYarnApplicationClasspath, Fixtures.knownMRAppCP)) + } + } + + test("resultant classpath for an application that defines both classpaths, YARN and MR") { + withAppConf(Fixtures.mapAppConf) { conf => + val env = newEnv + ClientBase.populateHadoopClasspath(conf, env) + classpath(env) should be(flatten(Fixtures.knownYARNAppCP, Fixtures.knownMRAppCP)) + } + } + + object Fixtures { + + val knownDefYarnAppCP: Seq[String] = + getFieldValue[Array[String], Seq[String]](classOf[YarnConfiguration], + "DEFAULT_YARN_APPLICATION_CLASSPATH", + Seq[String]())(a => a.toSeq) + + + val knownDefMRAppCP: Seq[String] = + getFieldValue[String, Seq[String]](classOf[MRJobConfig], + "DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH", + Seq[String]())(a => a.split(",")) + + val knownYARNAppCP = Some(Seq("/known/yarn/path")) + + val knownMRAppCP = Some(Seq("/known/mr/path")) + + val mapMRAppConf = + Map("mapreduce.application.classpath" -> knownMRAppCP.map(_.mkString(":")).get) + + val mapYARNAppConf = + Map(YarnConfiguration.YARN_APPLICATION_CLASSPATH -> knownYARNAppCP.map(_.mkString(":")).get) + + val mapAppConf = mapYARNAppConf ++ mapMRAppConf + } + + def withAppConf(m: Map[String, String] = Map())(testCode: (Configuration) => Any) { + val conf = new Configuration + m.foreach { case (k, v) => conf.set(k, v, "ClientBaseSpec") } + testCode(conf) + } + + def newEnv = MutableHashMap[String, String]() + + def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;") + + def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray + + def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = + Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + +}