diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala
index 10b35c74f473e37440840b8dea924acedccf6139..efec51d09745fce333d8afecd9bc9888ab41c631 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala
@@ -55,7 +55,7 @@ class KafkaSourceOffsetSuite extends OffsetSuite with SharedSQLContext {
   }
 
 
-  testWithUninterruptibleThread("OffsetSeqLog serialization - deserialization") {
+  test("OffsetSeqLog serialization - deserialization") {
     withTempDir { temp =>
       // use non-existent directory to test whether log make the dir
       val dir = new File(temp, "dir")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
index 3155ce04a11095a849e5f1103259929773dd05a4..f9e1f7de9ec08a2f7491f080b037ed5d6aba35ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala
@@ -32,7 +32,6 @@ import org.json4s.jackson.Serialization
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.util.UninterruptibleThread
 
 
 /**
@@ -109,39 +108,12 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
   override def add(batchId: Long, metadata: T): Boolean = {
     get(batchId).map(_ => false).getOrElse {
       // Only write metadata when the batch has not yet been written
-      if (fileManager.isLocalFileSystem) {
-        Thread.currentThread match {
-          case ut: UninterruptibleThread =>
-            // When using a local file system, "writeBatch" must be called on a
-            // [[org.apache.spark.util.UninterruptibleThread]] so that interrupts can be disabled
-            // while writing the batch file.
-            //
-            // This is because Hadoop "Shell.runCommand" swallows InterruptException (HADOOP-14084).
-            // If the user tries to stop a query, and the thread running "Shell.runCommand" is
-            // interrupted, then InterruptException will be dropped and the query will be still
-            // running. (Note: `writeBatch` creates a file using HDFS APIs and will call
-            // "Shell.runCommand" to set the file permission if using the local file system)
-            //
-            // Hence, we make sure that "writeBatch" is called on [[UninterruptibleThread]] which
-            // allows us to disable interrupts here, in order to propagate the interrupt state
-            // correctly. Also see SPARK-19599.
-            ut.runUninterruptibly { writeBatch(batchId, metadata) }
-          case _ =>
-            throw new IllegalStateException(
-              "HDFSMetadataLog.add() on a local file system must be executed on " +
-                "a o.a.spark.util.UninterruptibleThread")
-        }
-      } else {
-        // For a distributed file system, such as HDFS or S3, if the network is broken, write
-        // operations may just hang until timeout. We should enable interrupts to allow stopping
-        // the query fast.
-        writeBatch(batchId, metadata)
-      }
+      writeBatch(batchId, metadata)
       true
     }
   }
 
-  def writeTempBatch(metadata: T): Option[Path] = {
+  private def writeTempBatch(metadata: T): Option[Path] = {
     while (true) {
       val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp")
       try {
@@ -327,9 +299,6 @@ object HDFSMetadataLog {
 
     /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */
     def delete(path: Path): Unit
-
-    /** Whether the file systme is a local FS. */
-    def isLocalFileSystem: Boolean
   }
 
   /**
@@ -374,13 +343,6 @@ object HDFSMetadataLog {
         // ignore if file has already been deleted
       }
     }
-
-    override def isLocalFileSystem: Boolean = fc.getDefaultFileSystem match {
-      case _: local.LocalFs | _: local.RawLocalFs =>
-        // LocalFs = RawLocalFs + ChecksumFs
-        true
-      case _ => false
-    }
   }
 
   /**
@@ -437,12 +399,5 @@ object HDFSMetadataLog {
           // ignore if file has already been deleted
       }
     }
-
-    override def isLocalFileSystem: Boolean = fs match {
-      case _: LocalFileSystem | _: RawLocalFileSystem =>
-        // LocalFileSystem = RawLocalFileSystem + ChecksumFileSystem
-        true
-      case _ => false
-    }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index e1af420a697b123266c70e28839c7b91745c37ea..4bd6431cbe11090189b0dc204641a455e80b14f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
 
 import java.util.UUID
 import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicReference
 import java.util.concurrent.locks.ReentrantLock
 
 import scala.collection.mutable.ArrayBuffer
@@ -157,8 +158,7 @@ class StreamExecution(
   }
 
   /** Defines the internal state of execution */
-  @volatile
-  private var state: State = INITIALIZING
+  private val state = new AtomicReference[State](INITIALIZING)
 
   @volatile
   var lastExecution: IncrementalExecution = _
@@ -178,8 +178,8 @@ class StreamExecution(
 
   /**
    * The thread that runs the micro-batches of this stream. Note that this thread must be
-   * [[org.apache.spark.util.UninterruptibleThread]] to avoid swallowing `InterruptException` when
-   * using [[HDFSMetadataLog]]. See SPARK-19599 for more details.
+   * [[org.apache.spark.util.UninterruptibleThread]] to workaround KAFKA-1894: interrupting a
+   * running `KafkaConsumer` may cause endless loop.
    */
   val microBatchThread =
     new StreamExecutionThread(s"stream execution thread for $prettyIdString") {
@@ -200,10 +200,10 @@ class StreamExecution(
   val offsetLog = new OffsetSeqLog(sparkSession, checkpointFile("offsets"))
 
   /** Whether all fields of the query have been initialized */
-  private def isInitialized: Boolean = state != INITIALIZING
+  private def isInitialized: Boolean = state.get != INITIALIZING
 
   /** Whether the query is currently active or not */
-  override def isActive: Boolean = state != TERMINATED
+  override def isActive: Boolean = state.get != TERMINATED
 
   /** Returns the [[StreamingQueryException]] if the query was terminated by an exception. */
   override def exception: Option[StreamingQueryException] = Option(streamDeathCause)
@@ -249,53 +249,56 @@ class StreamExecution(
       updateStatusMessage("Initializing sources")
       // force initialization of the logical plan so that the sources can be created
       logicalPlan
-      state = ACTIVE
-      // Unblock `awaitInitialization`
-      initializationLatch.countDown()
-
-      triggerExecutor.execute(() => {
-        startTrigger()
-
-        val isTerminated =
-          if (isActive) {
-            reportTimeTaken("triggerExecution") {
-              if (currentBatchId < 0) {
-                // We'll do this initialization only once
-                populateStartOffsets()
-                logDebug(s"Stream running from $committedOffsets to $availableOffsets")
-              } else {
-                constructNextBatch()
+      if (state.compareAndSet(INITIALIZING, ACTIVE)) {
+        // Unblock `awaitInitialization`
+        initializationLatch.countDown()
+
+        triggerExecutor.execute(() => {
+          startTrigger()
+
+          val continueToRun =
+            if (isActive) {
+              reportTimeTaken("triggerExecution") {
+                if (currentBatchId < 0) {
+                  // We'll do this initialization only once
+                  populateStartOffsets()
+                  logDebug(s"Stream running from $committedOffsets to $availableOffsets")
+                } else {
+                  constructNextBatch()
+                }
+                if (dataAvailable) {
+                  currentStatus = currentStatus.copy(isDataAvailable = true)
+                  updateStatusMessage("Processing new data")
+                  runBatch()
+                }
               }
+
+              // Report trigger as finished and construct progress object.
+              finishTrigger(dataAvailable)
               if (dataAvailable) {
-                currentStatus = currentStatus.copy(isDataAvailable = true)
-                updateStatusMessage("Processing new data")
-                runBatch()
+                // We'll increase currentBatchId after we complete processing current batch's data
+                currentBatchId += 1
+              } else {
+                currentStatus = currentStatus.copy(isDataAvailable = false)
+                updateStatusMessage("Waiting for data to arrive")
+                Thread.sleep(pollingDelayMs)
               }
-            }
-
-            // Report trigger as finished and construct progress object.
-            finishTrigger(dataAvailable)
-            if (dataAvailable) {
-              // We'll increase currentBatchId after we complete processing current batch's data
-              currentBatchId += 1
+              true
             } else {
-              currentStatus = currentStatus.copy(isDataAvailable = false)
-              updateStatusMessage("Waiting for data to arrive")
-              Thread.sleep(pollingDelayMs)
+              false
             }
-            true
-          } else {
-            false
-          }
 
-        // Update committed offsets.
-        committedOffsets ++= availableOffsets
-        updateStatusMessage("Waiting for next trigger")
-        isTerminated
-      })
-      updateStatusMessage("Stopped")
+          // Update committed offsets.
+          committedOffsets ++= availableOffsets
+          updateStatusMessage("Waiting for next trigger")
+          continueToRun
+        })
+        updateStatusMessage("Stopped")
+      } else {
+        // `stop()` is already called. Let `finally` finish the cleanup.
+      }
     } catch {
-      case _: InterruptedException if state == TERMINATED => // interrupted by stop()
+      case _: InterruptedException if state.get == TERMINATED => // interrupted by stop()
         updateStatusMessage("Stopped")
       case e: Throwable =>
         streamDeathCause = new StreamingQueryException(
@@ -318,7 +321,7 @@ class StreamExecution(
       initializationLatch.countDown()
 
       try {
-        state = TERMINATED
+        state.set(TERMINATED)
         currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false)
 
         // Update metrics and status
@@ -562,7 +565,7 @@ class StreamExecution(
   override def stop(): Unit = {
     // Set the state to TERMINATED so that the batching thread knows that it was interrupted
     // intentionally
-    state = TERMINATED
+    state.set(TERMINATED)
     if (microBatchThread.isAlive) {
       microBatchThread.interrupt()
       microBatchThread.join()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala
index 435d874d75b92d340e868d24d56c66cd0c78801d..24d92a96237e3d8d4451c20dd27c72b3f2ab1f5f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala
@@ -156,7 +156,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext
       })
   }
 
-  testWithUninterruptibleThread("compact") {
+  test("compact") {
     withFakeCompactibleFileStreamLog(
       fileCleanupDelayMs = Long.MaxValue,
       defaultCompactInterval = 3,
@@ -174,7 +174,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext
       })
   }
 
-  testWithUninterruptibleThread("delete expired file") {
+  test("delete expired file") {
     // Set `fileCleanupDelayMs` to 0 so that we can detect the deleting behaviour deterministically
     withFakeCompactibleFileStreamLog(
       fileCleanupDelayMs = 0,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
index 7e0de5e2657be4f23ccae68cab7501d2a883c831..340d2945acd4a164925159b98cfbe70d9266e5e6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala
@@ -129,7 +129,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  testWithUninterruptibleThread("compact") {
+  test("compact") {
     withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") {
       withFileStreamSinkLog { sinkLog =>
         for (batchId <- 0 to 10) {
@@ -149,7 +149,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  testWithUninterruptibleThread("delete expired file") {
+  test("delete expired file") {
     // Set FILE_SINK_LOG_CLEANUP_DELAY to 0 so that we can detect the deleting behaviour
     // deterministically and one min batches to retain
     withSQLConf(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
index d556861a487fd1217e07dd3bcb493bad6543d332..55750b9202982f3d9950a029f1f92df60ff74ba5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
@@ -57,7 +57,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  testWithUninterruptibleThread("HDFSMetadataLog: basic") {
+  test("HDFSMetadataLog: basic") {
     withTempDir { temp =>
       val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
       val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath)
@@ -82,8 +82,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  testWithUninterruptibleThread(
-    "HDFSMetadataLog: fallback from FileContext to FileSystem", quietly = true) {
+  testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") {
     spark.conf.set(
       s"fs.$scheme.impl",
       classOf[FakeFileSystem].getName)
@@ -103,7 +102,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  testWithUninterruptibleThread("HDFSMetadataLog: purge") {
+  test("HDFSMetadataLog: purge") {
     withTempDir { temp =>
       val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
       assert(metadataLog.add(0, "batch0"))
@@ -128,7 +127,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  testWithUninterruptibleThread("HDFSMetadataLog: restart") {
+  test("HDFSMetadataLog: restart") {
     withTempDir { temp =>
       val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath)
       assert(metadataLog.add(0, "batch0"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
index bb4274a162e8e2e98797dc2e654e8db077d614ff..5ae8b2484d2ef7ce7135213fb07149484abc78f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala
@@ -36,7 +36,7 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext {
         OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
   }
 
-  testWithUninterruptibleThread("OffsetSeqLog - serialization - deserialization") {
+  test("OffsetSeqLog - serialization - deserialization") {
     withTempDir { temp =>
       val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
       val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 8a9fa94bea60174aa1885c0ab7fbc930f43ff6bb..5110d89c85b15c1e725a38500c526e85809f83da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -1174,7 +1174,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
     assert(map.isNewFile("b", 10))
   }
 
-  testWithUninterruptibleThread("do not recheck that files exist during getBatch") {
+  test("do not recheck that files exist during getBatch") {
     withTempDir { temp =>
       spark.conf.set(
         s"fs.$scheme.impl",