From 2881a2d1d1a650a91df2c6a01275eba14a43b42a Mon Sep 17 00:00:00 2001 From: Hossein <hossein@databricks.com> Date: Sun, 30 Oct 2016 16:17:23 -0700 Subject: [PATCH] [SPARK-17919] Make timeout to RBackend configurable in SparkR ## What changes were proposed in this pull request? This patch makes RBackend connection timeout configurable by user. ## How was this patch tested? N/A Author: Hossein <hossein@databricks.com> Closes #15471 from falaki/SPARK-17919. --- R/pkg/R/backend.R | 20 ++++++++-- R/pkg/R/client.R | 2 +- R/pkg/R/sparkR.R | 8 +++- R/pkg/inst/worker/daemon.R | 4 +- R/pkg/inst/worker/worker.R | 7 +++- .../org/apache/spark/api/r/RBackend.scala | 15 ++++++- .../apache/spark/api/r/RBackendHandler.scala | 39 +++++++++++++++++-- .../org/apache/spark/api/r/RRunner.scala | 3 ++ .../apache/spark/api/r/SparkRDefaults.scala | 30 ++++++++++++++ .../org/apache/spark/deploy/RRunner.scala | 7 +++- docs/configuration.md | 15 +++++++ 11 files changed, 134 insertions(+), 16 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 03e70bb2cb..0a789e6c37 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -108,13 +108,27 @@ invokeJava <- function(isStatic, objId, methodName, ...) { conn <- get(".sparkRCon", .sparkREnv) writeBin(requestMessage, conn) - # TODO: check the status code to output error information returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + + # Backend will send +1 as keep alive value to prevent various connection timeouts + # on very long running jobs. See spark.r.heartBeatInterval + while (returnStatus == 1) { + returnStatus <- readInt(conn) + handleErrors(returnStatus, conn) + } + + readObject(conn) +} + +# Helper function to check for returned errors and print appropriate error message to user +handleErrors <- function(returnStatus, conn) { if (length(returnStatus) == 0) { stop("No status is returned. Java SparkR backend might have failed.") } - if (returnStatus != 0) { + + # 0 is success and +1 is reserved for heartbeats. Other negative values indicate errors. + if (returnStatus < 0) { stop(readString(conn)) } - readObject(conn) } diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 2d341d836c..9d82814211 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout = 6000) { +connectBackend <- function(hostname, port, timeout) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index cc6d591bb2..6b4a2f2fdc 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -154,6 +154,7 @@ sparkR.sparkContext <- function( packages <- processSparkPackages(sparkPackages) existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") + connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) if (existingPort != "") { if (length(packages) != 0) { warning(paste("sparkPackages has no effect when using spark-submit or sparkR shell", @@ -187,6 +188,7 @@ sparkR.sparkContext <- function( backendPort <- readInt(f) monitorPort <- readInt(f) rLibPath <- readString(f) + connectionTimeout <- readInt(f) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || @@ -194,7 +196,9 @@ sparkR.sparkContext <- function( length(rLibPath) != 1) { stop("JVM failed to launch") } - assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) + assign(".monitorConn", + socketConnection(port = monitorPort, timeout = connectionTimeout), + envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -204,7 +208,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort) + connectBackend("localhost", backendPort, timeout = connectionTimeout) }, error = function(err) { stop("Failed to connect JVM\n") diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index b92e6be995..3a318b71ea 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -18,6 +18,7 @@ # Worker daemon rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") @@ -26,7 +27,8 @@ script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600) +inputCon <- socketConnection( + port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) while (TRUE) { ready <- socketSelect(list(inputCon)) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index cfe41ded20..03e7450147 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -90,6 +90,7 @@ bootTime <- currentTimeSecs() bootElap <- elapsedSecs() rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000")) dirs <- strsplit(rLibDir, ",")[[1]] # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require @@ -98,8 +99,10 @@ dirs <- strsplit(rLibDir, ",")[[1]] suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) -inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb") -outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb") +inputCon <- socketConnection( + port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) +outputCon <- socketConnection( + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 41d0a85ee3..550746c552 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} +import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -43,7 +44,10 @@ private[spark] class RBackend { def init(): Int = { val conf = new SparkConf() - bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + bossGroup = new NioEventLoopGroup( + conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) @@ -63,6 +67,7 @@ private[spark] class RBackend { // initialBytesToStrip = 4, i.e. strip out the length field itself new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) + .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) .addLast("handler", handler) } }) @@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging { val boundPort = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() + // Connection timeout is set by socket client. To make it configurable we will pass the + // timeout value to client inside the temp file + val conf = new SparkConf() + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) // tell the R process via temporary file val path = args(0) @@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(boundPort) dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) + dos.writeInt(backendConnectionTimeout) dos.close() f.renameTo(new File(path)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 1422ef888f..9f5afa29d6 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -18,16 +18,19 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util.concurrent.TimeUnit import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.ChannelHandler.Sharable +import io.netty.handler.timeout.ReadTimeoutException import org.apache.spark.api.r.SerDe._ import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.SparkConf +import org.apache.spark.util.{ThreadUtils, Utils} /** * Handler for RBackend @@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend) writeString(dos, s"Error: unknown method $methodName") } } else { + // To avoid timeouts when reading results in SparkR driver, we will be regularly sending + // heartbeat responses. We use special code +1 to signal the client that backend is + // alive and it should continue blocking for result. + val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread") + val pingRunner = new Runnable { + override def run(): Unit = { + val pingBaos = new ByteArrayOutputStream() + val pingDaos = new DataOutputStream(pingBaos) + writeInt(pingDaos, +1) + ctx.write(pingBaos.toByteArray) + } + } + val conf = new SparkConf() + val heartBeatInterval = conf.getInt( + "spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL) + val backendConnectionTimeout = conf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) + val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1) + + execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS) handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) + execService.shutdown() + execService.awaitTermination(1, TimeUnit.SECONDS) } val reply = bos.toByteArray @@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend) } override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - // Close the connection when an exception is raised. - cause.printStackTrace() - ctx.close() + cause match { + case timeout: ReadTimeoutException => + // Do nothing. We don't want to timeout on read + logWarning("Ignoring read timeout in RBackendHandler") + case _ => + // Close the connection when an exception is raised. + cause.printStackTrace() + ctx.close() + } } def handleMethodCall( diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 496fdf851f..7ef64723d9 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -333,6 +333,8 @@ private[r] object RRunner { var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") rCommand = sparkConf.get("spark.r.command", rCommand) + val rConnectionTimeout = sparkConf.getInt( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir(0) + "/SparkR/worker/" + script @@ -344,6 +346,7 @@ private[r] object RRunner { pb.environment().put("R_TESTS", "") pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) diff --git a/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala new file mode 100644 index 0000000000..af67cbbce4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +private[spark] object SparkRDefaults { + + // Default value for spark.r.backendConnectionTimeout config + val DEFAULT_CONNECTION_TIMEOUT: Int = 6000 + + // Default value for spark.r.heartBeatInterval config + val DEFAULT_HEARTBEAT_INTERVAL: Int = 100 + + // Default value for spark.r.numRBackendThreads config + val DEFAULT_NUM_RBACKEND_THREADS = 2 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index d0466830b2..6eb53a8252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkUserAppException} -import org.apache.spark.api.r.{RBackend, RUtils} +import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults} import org.apache.spark.util.RedirectThread /** @@ -51,6 +51,10 @@ object RRunner { cmd } + // Connection timeout set by R process on its connection to RBackend in seconds. + val backendConnectionTimeout = sys.props.getOrElse( + "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString) + // Check if the file path exists. // If not, change directory to current working directory for YARN cluster mode val rF = new File(rFile) @@ -81,6 +85,7 @@ object RRunner { val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) + env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) // Put the R package directories into an env variable of comma-separated paths env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) diff --git a/docs/configuration.md b/docs/configuration.md index 6600cb6c0a..780fc94908 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1890,6 +1890,21 @@ showDF(properties, numRows = 200, truncate = FALSE) <code>spark.r.shell.command</code> is used for sparkR shell while <code>spark.r.driver.command</code> is used for running R script. </td> </tr> +<tr> + <td><code>spark.r.backendConnectionTimeout</code></td> + <td>6000</td> + <td> + Connection timeout set by R process on its connection to RBackend in seconds. + </td> +</tr> +<tr> + <td><code>spark.r.heartBeatInterval</code></td> + <td>100</td> + <td> + Interval for heartbeats sents from SparkR backend to R process to prevent connection timeout. + </td> +</tr> + </table> #### Deploy -- GitLab