diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 5f14102c3c36670af045b5349be7cca06519c82f..29163e7f3054627cd5df711a7f0225ca46ae3fe3 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -34,6 +34,8 @@ import org.apache.spark.internal.Logging * * @param enabled enables or disables SSL; if it is set to false, the rest of the * settings are disregarded + * @param port the port where to bind the SSL server; if not defined, it will be + * based on the non-SSL port for the same service. * @param keyStore a path to the key-store file * @param keyStorePassword a password to access the key-store file * @param keyPassword a password to access the private key in the key-store @@ -47,6 +49,7 @@ import org.apache.spark.internal.Logging */ private[spark] case class SSLOptions( enabled: Boolean = false, + port: Option[Int] = None, keyStore: Option[File] = None, keyStorePassword: Option[String] = None, keyPassword: Option[String] = None, @@ -164,6 +167,11 @@ private[spark] object SSLOptions extends Logging { def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) + val port = conf.getOption(s"$ns.port").map(_.toInt) + port.foreach { p => + require(p >= 0, "Port number must be a non-negative value.") + } + val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_)) .orElse(defaults.flatMap(_.keyStore)) @@ -198,6 +206,7 @@ private[spark] object SSLOptions extends Logging { new SSLOptions( enabled, + port, keyStore, keyStorePassword, keyPassword, diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index f713619cd7ec400ea1af622082ca05ac13c716b6..7909821db954b4f0808509c1f9f6d3e2cedb9111 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -27,7 +27,7 @@ import scala.xml.Node import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet -import org.eclipse.jetty.server.{HttpConnectionFactory, Request, Server, ServerConnector} +import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.servlet._ import org.eclipse.jetty.servlets.gzip.GzipHandler @@ -279,109 +279,125 @@ private[spark] object JettyUtils extends Logging { addFilters(handlers, conf) - val gzipHandlers = handlers.map { h => - h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME)) - - val gzipHandler = new GzipHandler - gzipHandler.setHandler(h) - gzipHandler + // Start the server first, with no connectors. + val pool = new QueuedThreadPool + if (serverName.nonEmpty) { + pool.setName(serverName) } + pool.setDaemon(true) - // Bind to the given port, or throw a java.net.BindException if the port is occupied - def connect(currentPort: Int): ((Server, Option[Int]), Int) = { - val pool = new QueuedThreadPool - if (serverName.nonEmpty) { - pool.setName(serverName) - } - pool.setDaemon(true) - - val server = new Server(pool) - val connectors = new ArrayBuffer[ServerConnector]() - val collection = new ContextHandlerCollection - - // Create a connector on port currentPort to listen for HTTP requests - val httpConnector = new ServerConnector( - server, - null, - // Call this full constructor to set this, which forces daemon threads: - new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true), - null, - -1, - -1, - new HttpConnectionFactory()) - httpConnector.setPort(currentPort) - connectors += httpConnector - - val httpsConnector = sslOptions.createJettySslContextFactory() match { - case Some(factory) => - // If the new port wraps around, do not try a privileged port. - val securePort = - if (currentPort != 0) { - (currentPort + 400 - 1024) % (65536 - 1024) + 1024 - } else { - 0 - } - val scheme = "https" - // Create a connector on port securePort to listen for HTTPS requests - val connector = new ServerConnector(server, factory) - connector.setPort(securePort) - connector.setName(SPARK_CONNECTOR_NAME) - connectors += connector - - // redirect the HTTP requests to HTTPS port - httpConnector.setName(REDIRECT_CONNECTOR_NAME) - collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) - Some(connector) + val server = new Server(pool) - case None => - // No SSL, so the HTTP connector becomes the official one where all contexts bind. - httpConnector.setName(SPARK_CONNECTOR_NAME) - None - } + val errorHandler = new ErrorHandler() + errorHandler.setShowStacks(true) + errorHandler.setServer(server) + server.addBean(errorHandler) + + val collection = new ContextHandlerCollection + server.setHandler(collection) + + // Executor used to create daemon threads for the Jetty connectors. + val serverExecutor = new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true) + + try { + server.start() // As each acceptor and each selector will use one thread, the number of threads should at // least be the number of acceptors and selectors plus 1. (See SPARK-13776) var minThreads = 1 - connectors.foreach { connector => + + def newConnector( + connectionFactories: Array[ConnectionFactory], + port: Int): (ServerConnector, Int) = { + val connector = new ServerConnector( + server, + null, + serverExecutor, + null, + -1, + -1, + connectionFactories: _*) + connector.setPort(port) + connector.start() + // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8)) connector.setHost(hostName) // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 + + (connector, connector.getLocalPort()) } - pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) - val errorHandler = new ErrorHandler() - errorHandler.setShowStacks(true) - errorHandler.setServer(server) - server.addBean(errorHandler) - - gzipHandlers.foreach(collection.addHandler) - server.setHandler(collection) - - server.setConnectors(connectors.toArray) - try { - server.start() - ((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort) - } catch { - case e: Exception => - server.stop() - pool.stop() - throw e + // If SSL is configured, create the secure connector first. + val securePort = sslOptions.createJettySslContextFactory().map { factory => + val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0) + val secureServerName = if (serverName.nonEmpty) s"$serverName (HTTPS)" else serverName + val connectionFactories = AbstractConnectionFactory.getFactories(factory, + new HttpConnectionFactory()) + + def sslConnect(currentPort: Int): (ServerConnector, Int) = { + newConnector(connectionFactories, currentPort) + } + + val (connector, boundPort) = Utils.startServiceOnPort[ServerConnector](securePort, + sslConnect, conf, secureServerName) + connector.setName(SPARK_CONNECTOR_NAME) + server.addConnector(connector) + boundPort } - } - val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf, - serverName) - ServerInfo(server, boundPort, securePort, - server.getHandler().asInstanceOf[ContextHandlerCollection]) + // Bind the HTTP port. + def httpConnect(currentPort: Int): (ServerConnector, Int) = { + newConnector(Array(new HttpConnectionFactory()), currentPort) + } + + val (httpConnector, httpPort) = Utils.startServiceOnPort[ServerConnector](port, httpConnect, + conf, serverName) + + // If SSL is configured, then configure redirection in the HTTP connector. + securePort match { + case Some(p) => + httpConnector.setName(REDIRECT_CONNECTOR_NAME) + val redirector = createRedirectHttpsHandler(p, "https") + collection.addHandler(redirector) + redirector.start() + + case None => + httpConnector.setName(SPARK_CONNECTOR_NAME) + } + + server.addConnector(httpConnector) + + // Add all the known handlers now that connectors are configured. + handlers.foreach { h => + h.setVirtualHosts(toVirtualHosts(SPARK_CONNECTOR_NAME)) + val gzipHandler = new GzipHandler() + gzipHandler.setHandler(h) + collection.addHandler(gzipHandler) + gzipHandler.start() + } + + pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) + ServerInfo(server, httpPort, securePort, collection) + } catch { + case e: Exception => + server.stop() + if (serverExecutor.isStarted()) { + serverExecutor.stop() + } + if (pool.isStarted()) { + pool.stop() + } + throw e + } } private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { val redirectHandler: ContextHandler = new ContextHandler redirectHandler.setContextPath("/") - redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME)) + redirectHandler.setVirtualHosts(toVirtualHosts(REDIRECT_CONNECTOR_NAME)) redirectHandler.setHandler(new AbstractHandler { override def handle( target: String, @@ -394,8 +410,7 @@ private[spark] object JettyUtils extends Logging { val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort, baseRequest.getRequestURI, baseRequest.getQueryString) response.setContentLength(0) - response.encodeRedirectURL(httpsURI) - response.sendRedirect(httpsURI) + response.sendRedirect(response.encodeRedirectURL(httpsURI)) baseRequest.setHandled(true) } }) @@ -456,6 +471,8 @@ private[spark] object JettyUtils extends Logging { new URI(scheme, authority, path, query, null).toString } + def toVirtualHosts(connectors: String*): Array[String] = connectors.map("@" + _).toArray + } private[spark] case class ServerInfo( @@ -465,7 +482,7 @@ private[spark] case class ServerInfo( private val rootHandler: ContextHandlerCollection) { def addHandler(handler: ContextHandler): Unit = { - handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME)) + handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME)) rootHandler.addHandler(handler) if (!handler.isStarted()) { handler.start() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2c1d331b9ab18e44c82c1eda4c3e3a9de8a90ffb..c225e1a0cc1bf982e91300e9498d079ba101bedd 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2202,6 +2202,14 @@ private[spark] object Utils extends Logging { } } + /** + * Returns the user port to try when trying to bind a service. Handles wrapping and skipping + * privileged ports. + */ + def userPort(base: Int, offset: Int): Int = { + (base + offset - 1024) % (65536 - 1024) + 1024 + } + /** * Attempt to start a service on the given port, or fail after a number of attempts. * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0). @@ -2229,8 +2237,7 @@ private[spark] object Utils extends Logging { val tryPort = if (startPort == 0) { startPort } else { - // If the new port wraps around, do not try a privilege port - ((startPort + offset - 1024) % (65536 - 1024)) + 1024 + userPort(startPort, offset) } try { val (service, port) = startService(tryPort) diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 2b8b1805bc83f334dae3f3862353f6815c2782e8..6fc7cea6ee94a78d506e133edf4aef1d4ef75417 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -103,6 +103,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConf conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.ui.enabled", "false") + conf.set("spark.ssl.ui.port", "4242") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") conf.set("spark.ssl.ui.keyStorePassword", "12345") @@ -118,6 +119,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) + assert(opts.port === Some(4242)) assert(opts.trustStore.isDefined === true) assert(opts.trustStore.get.getName === "truststore") assert(opts.trustStore.get.getAbsolutePath === trustStorePath) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index aa67f49185e7fef87373f1eac298f55e35621001..f1be0f6de3ce26d5c08b69b6f8c2e1d5f1ba2dae 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.LocalSparkContext._ +import org.apache.spark.util.Utils class UISuite extends SparkFunSuite { @@ -52,13 +53,16 @@ class UISuite extends SparkFunSuite { (conf, new SecurityManager(conf).getSSLOptions("ui")) } - private def sslEnabledConf(): (SparkConf, SSLOptions) = { + private def sslEnabledConf(sslPort: Option[Int] = None): (SparkConf, SSLOptions) = { val keyStoreFilePath = getTestResourcePath("spark.keystore") val conf = new SparkConf() .set("spark.ssl.ui.enabled", "true") .set("spark.ssl.ui.keyStore", keyStoreFilePath) .set("spark.ssl.ui.keyStorePassword", "123456") .set("spark.ssl.ui.keyPassword", "123456") + sslPort.foreach { p => + conf.set("spark.ssl.ui.port", p.toString) + } (conf, new SecurityManager(conf).getSSLOptions("ui")) } @@ -275,6 +279,28 @@ class UISuite extends SparkFunSuite { } } + test("specify both http and https ports separately") { + var socket: ServerSocket = null + var serverInfo: ServerInfo = null + try { + socket = new ServerSocket(0) + + // Make sure the SSL port lies way outside the "http + 400" range used as the default. + val baseSslPort = Utils.userPort(socket.getLocalPort(), 10000) + val (conf, sslOptions) = sslEnabledConf(sslPort = Some(baseSslPort)) + + serverInfo = JettyUtils.startJettyServer("0.0.0.0", socket.getLocalPort() + 1, + sslOptions, Seq[ServletContextHandler](), conf, "server1") + + val notAllowed = Utils.userPort(serverInfo.boundPort, 400) + assert(serverInfo.securePort.isDefined) + assert(serverInfo.securePort.get != Utils.userPort(serverInfo.boundPort, 400)) + } finally { + stopServer(serverInfo) + closeSocket(socket) + } + } + def stopServer(info: ServerInfo): Unit = { if (info != null) info.stop() } diff --git a/docs/configuration.md b/docs/configuration.md index 7c040330db637115c97c103759eedb5d21009243..2eaaa21fe4b16381be74643d6ae205d994c3c4af 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1796,6 +1796,20 @@ Apart from these, the following properties are also available, and may be useful Configuration</a> for details on hierarchical SSL configuration for services. </td> </tr> + <tr> + <td><code>spark.ssl.[namespace].port</code></td> + <td>None</td> + <td> + The port where the SSL service will listen on. + + <br />The port must be defined within a namespace configuration; see + <a href="security.html#ssl-configuration">SSL Configuration</a> for the available + namespaces. + + <br />When not set, the SSL port will be derived from the non-SSL port for the + same service. A value of "0" will make the service bind to an ephemeral port. + </td> + </tr> <tr> <td><code>spark.ssl.enabledAlgorithms</code></td> <td>Empty</td> diff --git a/docs/security.md b/docs/security.md index 67956930fe5d9f582e8bc4bed02006da13835591..42a09a9148d402b4109a9a4e244e13843cfd6ece 100644 --- a/docs/security.md +++ b/docs/security.md @@ -50,7 +50,7 @@ component-specific configuration namespaces used to override the default setting </tr> <tr> <td><code>spark.ssl.fs</code></td> - <td>HTTP file server and broadcast server</td> + <td>File download client (used to download jars and files from HTTPS-enabled servers).</td> </tr> <tr> <td><code>spark.ssl.ui</code></td>