diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index b3bb5f911dbd7e850e40be71b31d4fda67543be5..334a5b10142aa23a259601f608fc45f6c5146906 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -38,7 +38,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - private[ui] val timeout = RpcUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askRpcTimeout(worker.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 69181edb9ad44e1fdcf4ba0dc36366328af3b8b3..6ae47894598be862c14cc552370a743d4759d355 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -17,8 +17,7 @@ package org.apache.spark.rpc -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.Future import scala.reflect.ClassTag import org.apache.spark.util.RpcUtils @@ -32,7 +31,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) private[this] val maxRetries = RpcUtils.numRetries(conf) private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) /** * return the address for the [[RpcEndpointRef]] @@ -52,7 +51,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to @@ -91,7 +90,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts var attempts = 0 var lastException: Exception = null @@ -99,7 +98,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) attempts += 1 try { val future = ask[T](message, timeout) - val result = Await.result(future, timeout) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -110,10 +109,14 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) lastException = e logWarning(s"Error sending message [message = $message] in $attempts attempts", e) } - Thread.sleep(retryWaitMs) + + if (attempts < maxRetries) { + Thread.sleep(retryWaitMs) + } } throw new SparkException( s"Error sending message [message = $message]", lastException) } + } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 3b6938ec639c3f91eaed405d7fc3000e6867d590..1709bdf560b6f737b0adbce9d87b5d8868dec14d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -18,8 +18,10 @@ package org.apache.spark.rpc import java.net.URI +import java.util.concurrent.TimeoutException -import scala.concurrent.{Await, Future} +import scala.concurrent.{Awaitable, Await, Future} +import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.spark.{SecurityManager, SparkConf} @@ -66,7 +68,7 @@ private[spark] object RpcEnv { */ private[spark] abstract class RpcEnv(conf: SparkConf) { - private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf) + private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf) /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement @@ -94,7 +96,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. */ def setupEndpointRefByURI(uri: String): RpcEndpointRef = { - Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) + defaultLookupTimeout.awaitResult(asyncSetupEndpointRefByURI(uri)) } /** @@ -184,3 +186,107 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } + + +/** + * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + */ +private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) + extends TimeoutException(message) { initCause(cause) } + + +/** + * Associates a timeout with a description so that a when a TimeoutException occurs, additional + * context about the timeout can be amended to the exception message. + * @param duration timeout duration in seconds + * @param timeoutProp the configuration property that controls this timeout + */ +private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: String) + extends Serializable { + + /** Amends the standard message of TimeoutException to include the description */ + private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = { + new RpcTimeoutException(te.getMessage() + ". This timeout is controlled by " + timeoutProp, te) + } + + /** + * PartialFunction to match a TimeoutException and add the timeout description to the message + * + * @note This can be used in the recover callback of a Future to add to a TimeoutException + * Example: + * val timeout = new RpcTimeout(5 millis, "short timeout") + * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) + */ + def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { + // The exception has already been converted to a RpcTimeoutException so just raise it + case rte: RpcTimeoutException => throw rte + // Any other TimeoutException get converted to a RpcTimeoutException with modified message + case te: TimeoutException => throw createRpcTimeoutException(te) + } + + /** + * Wait for the completed result and return it. If the result is not available within this + * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. + * @param awaitable the `Awaitable` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * is still not ready + */ + def awaitResult[T](awaitable: Awaitable[T]): T = { + try { + Await.result(awaitable, duration) + } catch addMessageIfTimeout + } +} + + +private[spark] object RpcTimeout { + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @throws NoSuchElementException if property is not set + */ + def apply(conf: SparkConf, timeoutProp: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup the timeout property in the configuration and create + * a RpcTimeout with the property key in the description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutProp property key for the timeout in seconds + * @param defaultValue default timeout value in seconds if property not found + */ + def apply(conf: SparkConf, timeoutProp: String, defaultValue: String): RpcTimeout = { + val timeout = { conf.getTimeAsSeconds(timeoutProp, defaultValue) seconds } + new RpcTimeout(timeout, timeoutProp) + } + + /** + * Lookup prioritized list of timeout properties in the configuration + * and create a RpcTimeout with the first set property key in the + * description. + * Uses the given default value if property is not set + * @param conf configuration properties containing the timeout + * @param timeoutPropList prioritized list of property keys for the timeout in seconds + * @param defaultValue default timeout value in seconds if no properties found + */ + def apply(conf: SparkConf, timeoutPropList: Seq[String], defaultValue: String): RpcTimeout = { + require(timeoutPropList.nonEmpty) + + // Find the first set property or use the default value with the first property + val itr = timeoutPropList.iterator + var foundProp: Option[(String, String)] = None + while (itr.hasNext && foundProp.isEmpty){ + val propKey = itr.next() + conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) } + } + val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue) + val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds } + new RpcTimeout(timeout, finalProp._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 31ebe5ac5bca3ef663dd86c2990390cab0e55720..f2d87f68341afe90ab2dba06849838caac381d8e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -20,7 +20,6 @@ package org.apache.spark.rpc.akka import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future -import scala.concurrent.duration._ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -214,8 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] ( override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). + // this is just in case there is a timeout from creating the future in resolveOne, we want the + // exception to indicate the conf that determines the timeout + recover(defaultLookupTimeout.addMessageIfTimeout) } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { @@ -295,8 +297,8 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { - actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { // The function will run in the calling thread, so it should be short and never block. case msg @ AkkaMessage(message, reply) => if (reply) { @@ -307,7 +309,8 @@ private[akka] class AkkaRpcEndpointRef( } case AkkaFailure(e) => Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T] + }(ThreadUtils.sameThread).mapTo[T]. + recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } override def toString: String = s"${getClass.getSimpleName}($actorRef)" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a7cf0c23d961313b7fb2eb1934b74be5eb7ea184..6841fa835747fec6867c3373372e88c9ffbbf902 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -35,6 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ @@ -188,7 +189,7 @@ class DAGScheduler( blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( - BlockManagerHeartbeat(blockManagerId), 600 seconds) + BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } // Called by TaskScheduler when an executor fails. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 190ff61d689d1d77ed67740bd07d6e47d89b109b..bc67abb5df4468341dbc803061b0c4a2e650f90a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -46,7 +46,7 @@ private[spark] abstract class YarnSchedulerBackend( private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - private implicit val askTimeout = RpcUtils.askTimeout(sc.conf) + private implicit val askTimeout = RpcUtils.askRpcTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7cdae22b0e253505619a2c0b38534e800ff7e920..f70f701494dbfb0bb3002c9207f9a72134e6fbfe 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -33,7 +33,7 @@ class BlockManagerMaster( isDriver: Boolean) extends Logging { - val timeout = RpcUtils.askTimeout(conf) + val timeout = RpcUtils.askRpcTimeout(conf) /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { @@ -106,7 +106,7 @@ class BlockManagerMaster( logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -118,7 +118,7 @@ class BlockManagerMaster( logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -132,7 +132,7 @@ class BlockManagerMaster( s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) }(ThreadUtils.sameThread) if (blocking) { - Await.result(future, timeout) + timeout.awaitResult(future) } } @@ -176,8 +176,8 @@ class BlockManagerMaster( CanBuildFrom[Iterable[Future[Option[BlockStatus]]], Option[BlockStatus], Iterable[Option[BlockStatus]]]] - val blockStatus = Await.result( - Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread), timeout) + val blockStatus = timeout.awaitResult( + Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread)) if (blockStatus == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } @@ -199,7 +199,7 @@ class BlockManagerMaster( askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) - Await.result(future, timeout) + timeout.awaitResult(future) } /** diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 96aa2fe16470399c5a9657cfc437bacda32e9818..c179833e5b06a7824787956f73d8e7e37f2c63a2 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -18,8 +18,6 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap -import scala.concurrent.Await -import scala.concurrent.duration.FiniteDuration import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -28,6 +26,7 @@ import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} +import org.apache.spark.rpc.RpcTimeout /** * Various utility classes for working with Akka. @@ -147,7 +146,7 @@ private[spark] object AkkaUtils extends Logging { def askWithReply[T]( message: Any, actor: ActorRef, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) } @@ -160,7 +159,7 @@ private[spark] object AkkaUtils extends Logging { actor: ActorRef, maxAttempts: Int, retryInterval: Long, - timeout: FiniteDuration): T = { + timeout: RpcTimeout): T = { // TODO: Consider removing multiple attempts if (actor == null) { throw new SparkException(s"Error sending message [message = $message]" + @@ -171,8 +170,8 @@ private[spark] object AkkaUtils extends Logging { while (attempts < maxAttempts) { attempts += 1 try { - val future = actor.ask(message)(timeout) - val result = Await.result(future, timeout) + val future = actor.ask(message)(timeout.duration) + val result = timeout.awaitResult(future) if (result == null) { throw new SparkException("Actor returned null") } @@ -198,9 +197,9 @@ private[spark] object AkkaUtils extends Logging { val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = RpcUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupRpcTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def makeExecutorRef( @@ -212,9 +211,9 @@ private[spark] object AkkaUtils extends Logging { val executorActorSystemName = SparkEnv.executorActorSystemName Utils.checkHost(host, "Expected hostname") val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = RpcUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupRpcTimeout(conf) logInfo(s"Connecting to $name: $url") - Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) } def protocol(actorSystem: ActorSystem): String = { diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index f16cc8e7e42c62a7d1f0929d98066829ac2e658f..7578a3b1d85f2773a9e61cb669af3769dad8805e 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,11 +17,11 @@ package org.apache.spark.util -import scala.concurrent.duration._ +import scala.concurrent.duration.FiniteDuration import scala.language.postfixOps import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} object RpcUtils { @@ -47,14 +47,22 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC ask operations. */ + private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = { + RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") + } + + @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0") def askTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.askTimeout", - conf.get("spark.network.timeout", "120s")) seconds + askRpcTimeout(conf).duration } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ + private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { + RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") + } + + @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0") def lookupTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.lookupTimeout", - conf.get("spark.network.timeout", "120s")) seconds + lookupRpcTimeout(conf).duration } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 9fbaeb33f97cd83cf07f10750f8954eb4ad56628..90cb7da94e88ac529373ea59b483d10be0fd5401 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -260,10 +260,10 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst assert(RpcUtils.retryWaitMs(conf) === 2L) conf.set("spark.akka.askTimeout", "3") - assert(RpcUtils.askTimeout(conf) === (3 seconds)) + assert(RpcUtils.askRpcTimeout(conf).duration === (3 seconds)) conf.set("spark.akka.lookupTimeout", "4") - assert(RpcUtils.lookupTimeout(conf) === (4 seconds)) + assert(RpcUtils.lookupRpcTimeout(conf).duration === (4 seconds)) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 1f0aa759b08da8c5ac77d1571309b78e90683e6e..6ceafe4337747170dc8163020fdbb57a086f27c4 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -155,16 +155,21 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) val conf = new SparkConf() + val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") val anotherEnv = createRpcEnv(conf, "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { - val e = intercept[Exception] { - rpcEndpointRef.askWithRetry[String]("hello", 1 millis) + // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause + val e = intercept[SparkException] { + rpcEndpointRef.askWithRetry[String]("hello", new RpcTimeout(1 millis, shortProp)) } - assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) + // The SparkException cause should be a RpcTimeoutException with message indicating the + // controlling timeout property + assert(e.getCause.isInstanceOf[RpcTimeoutException]) + assert(e.getCause.getMessage.contains(shortProp)) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -539,6 +544,92 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("construct RpcTimeout with conf property") { + val conf = new SparkConf + + val testProp = "spark.ask.test.timeout" + val testDurationSeconds = 30 + val secondaryProp = "spark.ask.secondary.timeout" + + conf.set(testProp, s"${testDurationSeconds}s") + conf.set(secondaryProp, "100s") + + // Construct RpcTimeout with a single property + val rt1 = RpcTimeout(conf, testProp) + assert( testDurationSeconds === rt1.duration.toSeconds ) + + // Construct RpcTimeout with prioritized list of properties + val rt2 = RpcTimeout(conf, Seq("spark.ask.invalid.timeout", testProp, secondaryProp), "1s") + assert( testDurationSeconds === rt2.duration.toSeconds ) + + // Construct RpcTimeout with default value, + val defaultProp = "spark.ask.default.timeout" + val defaultDurationSeconds = 1 + val rt3 = RpcTimeout(conf, Seq(defaultProp), defaultDurationSeconds.toString + "s") + assert( defaultDurationSeconds === rt3.duration.toSeconds ) + assert( rt3.timeoutProp.contains(defaultProp) ) + + // Try to construct RpcTimeout with an unconfigured property + intercept[NoSuchElementException] { + RpcTimeout(conf, "spark.ask.invalid.timeout") + } + } + + test("ask a message timeout on Future using RpcTimeout") { + case class NeverReply(msg: String) + + val rpcEndpointRef = env.setupEndpoint("ask-future", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => context.reply(msg) + case _: NeverReply => + } + }) + + val longTimeout = new RpcTimeout(1 second, "spark.rpc.long.timeout") + val shortTimeout = new RpcTimeout(10 millis, "spark.rpc.short.timeout") + + // Ask with immediate response, should complete successfully + val fut1 = rpcEndpointRef.ask[String]("hello", longTimeout) + val reply1 = longTimeout.awaitResult(fut1) + assert("hello" === reply1) + + // Ask with a delayed response and wait for response immediately that should timeout + val fut2 = rpcEndpointRef.ask[String](NeverReply("doh"), shortTimeout) + val reply2 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut2) + }.getMessage + + // RpcTimeout.awaitResult should have added the property to the TimeoutException message + assert(reply2.contains(shortTimeout.timeoutProp)) + + // Ask with delayed response and allow the Future to timeout before Await.result + val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout) + + // Allow future to complete with failure using plain Await.result, this will return + // once the future is complete to verify addMessageIfTimeout was invoked + val reply3 = + intercept[RpcTimeoutException] { + Await.result(fut3, 200 millis) + }.getMessage + + // When the future timed out, the recover callback should have used + // RpcTimeout.addMessageIfTimeout to add the property to the TimeoutException message + assert(reply3.contains(shortTimeout.timeoutProp)) + + // Use RpcTimeout.awaitResult to process Future, since it has already failed with + // RpcTimeoutException, the same RpcTimeoutException should be thrown + val reply4 = + intercept[RpcTimeoutException] { + shortTimeout.awaitResult(fut3) + }.getMessage + + // Ensure description is not in message twice after addMessageIfTimeout and awaitResult + assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) + } + } class UnserializableClass