Skip to content
Snippets Groups Projects
Commit fb3081d3 authored by Shixiong Zhu's avatar Shixiong Zhu Committed by Yin Huai
Browse files

[SPARK-13747][CORE] Fix potential ThreadLocal leaks in RPC when using ForkJoinPool

## What changes were proposed in this pull request?

Some places in SQL may call `RpcEndpointRef.askWithRetry` (e.g., ParquetFileFormat.buildReader -> SparkContext.broadcast -> ... -> BlockManagerMaster.updateBlockInfo -> RpcEndpointRef.askWithRetry), which will finally call `Await.result`. It may cause `java.lang.IllegalArgumentException: spark.sql.execution.id is already set` when running in Scala ForkJoinPool.

This PR includes the following changes to fix this issue:

- Remove `ThreadUtils.awaitResult`
- Rename `ThreadUtils. awaitResultInForkJoinSafely` to `ThreadUtils.awaitResult`
- Replace `Await.result` in RpcTimeout with `ThreadUtils.awaitResult`.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16230 from zsxwing/fix-SPARK-13747.
parent d53f18ca
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ import scala.concurrent.duration._
import scala.util.control.NonFatal
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.util.Utils
import org.apache.spark.util.{ThreadUtils, Utils}
/**
* An exception thrown if RpcTimeout modifies a [[TimeoutException]].
......@@ -72,15 +72,9 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S
* is still not ready
*/
def awaitResult[T](future: Future[T]): T = {
val wrapAndRethrow: PartialFunction[Throwable, T] = {
case NonFatal(t) =>
throw new SparkException("Exception thrown in awaitResult", t)
}
try {
// scalastyle:off awaitresult
Await.result(future, duration)
// scalastyle:on awaitresult
} catch addMessageIfTimeout.orElse(wrapAndRethrow)
ThreadUtils.awaitResult(future, duration)
} catch addMessageIfTimeout
}
}
......
......@@ -19,7 +19,7 @@ package org.apache.spark.util
import java.util.concurrent._
import scala.concurrent.{Await, Awaitable, ExecutionContext, ExecutionContextExecutor}
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor}
import scala.concurrent.duration.Duration
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
import scala.util.control.NonFatal
......@@ -180,39 +180,30 @@ private[spark] object ThreadUtils {
// scalastyle:off awaitresult
/**
* Preferred alternative to `Await.result()`. This method wraps and re-throws any exceptions
* thrown by the underlying `Await` call, ensuring that this thread's stack trace appears in
* logs.
*/
@throws(classOf[SparkException])
def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
try {
Await.result(awaitable, atMost)
// scalastyle:on awaitresult
} catch {
case NonFatal(t) =>
throw new SparkException("Exception thrown in awaitResult: ", t)
}
}
/**
* Calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s `BlockingContext`, wraps
* and re-throws any exceptions with nice stack track.
* Preferred alternative to `Await.result()`.
*
* This method wraps and re-throws any exceptions thrown by the underlying `Await` call, ensuring
* that this thread's stack trace appears in logs.
*
* Codes running in the user's thread may be in a thread of Scala ForkJoinPool. As concurrent
* executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this method
* basically prevents ForkJoinPool from running other tasks in the current waiting thread.
* In addition, it calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s
* `BlockingContext`. Codes running in the user's thread may be in a thread of Scala ForkJoinPool.
* As concurrent executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this
* method basically prevents ForkJoinPool from running other tasks in the current waiting thread.
* In general, we should use this method because many places in Spark use [[ThreadLocal]] and it's
* hard to debug when [[ThreadLocal]]s leak to other tasks.
*/
@throws(classOf[SparkException])
def awaitResultInForkJoinSafely[T](awaitable: Awaitable[T], atMost: Duration): T = {
def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
try {
// `awaitPermission` is not actually used anywhere so it's safe to pass in null here.
// See SPARK-13747.
val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
awaitable.result(Duration.Inf)(awaitPermission)
awaitable.result(atMost)(awaitPermission)
} catch {
case NonFatal(t) =>
// TimeoutException is thrown in the current thread, so not need to warp the exception.
case NonFatal(t) if !t.isInstanceOf[TimeoutException] =>
throw new SparkException("Exception thrown in awaitResult: ", t)
}
}
// scalastyle:on awaitresult
}
......@@ -199,10 +199,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim
val f = sc.parallelize(1 to 100, 4)
.mapPartitions(itr => { Thread.sleep(20); itr })
.countAsync()
val e = intercept[SparkException] {
intercept[TimeoutException] {
ThreadUtils.awaitResult(f, Duration(20, "milliseconds"))
}
assert(e.getCause.isInstanceOf[TimeoutException])
}
private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = {
......
......@@ -158,10 +158,9 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
0 until rdd.partitions.size, resultHandler, () => Unit)
// It's an error if the job completes successfully even though no committer was authorized,
// so throw an exception if the job was allowed to complete.
val e = intercept[SparkException] {
intercept[TimeoutException] {
ThreadUtils.awaitResult(futureAction, 5 seconds)
}
assert(e.getCause.isInstanceOf[TimeoutException])
assert(tempDir.list().size === 0)
}
......
......@@ -200,7 +200,6 @@ This file is divided into 3 sections:
// scalastyle:off awaitresult
Await.result(...)
// scalastyle:on awaitresult
If your codes use ThreadLocal and may run in threads created by the user, use ThreadUtils.awaitResultInForkJoinSafely instead.
]]></customMessage>
</check>
......
......@@ -578,7 +578,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
}
override def executeCollect(): Array[InternalRow] = {
ThreadUtils.awaitResultInForkJoinSafely(relationFuture, Duration.Inf)
ThreadUtils.awaitResult(relationFuture, Duration.Inf)
}
}
......
......@@ -128,8 +128,7 @@ case class BroadcastExchangeExec(
}
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
ThreadUtils.awaitResultInForkJoinSafely(relationFuture, timeout)
.asInstanceOf[broadcast.Broadcast[T]]
ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment