diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index 70bea1ed80fda8e733b4130cbf1dc5db9d5c46b6..3475c2c9db080a31c94d9e8d49410880e4462be1 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -18,41 +18,112 @@
 
 package org.apache.spark.sql.hive.thriftserver
 
-import java.io.{BufferedReader, InputStreamReader, PrintWriter}
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Future, Promise}
+import scala.sys.process.{Process, ProcessLogger}
+
+import java.io._
+import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.hadoop.hive.conf.HiveConf.ConfVars
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 
-class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils {
-  val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli")
-  val METASTORE_PATH = TestUtils.getMetastorePath("cli")
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.util.getTempFilePath
+
+class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
+  def runCliWithin(
+      timeout: FiniteDuration,
+      extraArgs: Seq[String] = Seq.empty)(
+      queriesAndExpectedAnswers: (String, String)*) {
+
+    val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip
+    val warehousePath = getTempFilePath("warehouse")
+    val metastorePath = getTempFilePath("metastore")
+    val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator)
 
-  override def beforeAll() {
-    val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true"
-    val commands =
-      s"""../../bin/spark-sql
+    val command = {
+      val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true"
+      s"""$cliScript
          |  --master local
          |  --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl
-         |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$WAREHOUSE_PATH
-       """.stripMargin.split("\\s+")
-
-    val pb = new ProcessBuilder(commands: _*)
-    process = pb.start()
-    outputWriter = new PrintWriter(process.getOutputStream, true)
-    inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
-    errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
-    waitForOutput(inputReader, "spark-sql>")
+         |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
+       """.stripMargin.split("\\s+").toSeq ++ extraArgs
+    }
+
+    // AtomicInteger is needed because stderr and stdout of the forked process are handled in
+    // different threads.
+    val next = new AtomicInteger(0)
+    val foundAllExpectedAnswers = Promise.apply[Unit]()
+    val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes)
+    val buffer = new ArrayBuffer[String]()
+
+    def captureOutput(source: String)(line: String) {
+      buffer += s"$source> $line"
+      if (line.contains(expectedAnswers(next.get()))) {
+        if (next.incrementAndGet() == expectedAnswers.size) {
+          foundAllExpectedAnswers.trySuccess(())
+        }
+      }
+    }
+
+    // Searching expected output line from both stdout and stderr of the CLI process
+    val process = (Process(command) #< queryStream).run(
+      ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
+
+    Future {
+      val exitValue = process.exitValue()
+      logInfo(s"Spark SQL CLI process exit value: $exitValue")
+    }
+
+    try {
+      Await.result(foundAllExpectedAnswers.future, timeout)
+    } catch { case cause: Throwable =>
+      logError(
+        s"""
+           |=======================
+           |CliSuite failure output
+           |=======================
+           |Spark SQL CLI command line: ${command.mkString(" ")}
+           |
+           |Executed query ${next.get()} "${queries(next.get())}",
+           |But failed to capture expected output "${expectedAnswers(next.get())}" within $timeout.
+           |
+           |${buffer.mkString("\n")}
+           |===========================
+           |End CliSuite failure output
+           |===========================
+         """.stripMargin, cause)
+    } finally {
+      warehousePath.delete()
+      metastorePath.delete()
+      process.destroy()
+    }
   }
 
-  override def afterAll() {
-    process.destroy()
-    process.waitFor()
+  test("Simple commands") {
+    val dataFilePath =
+      Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
+
+    runCliWithin(1.minute)(
+      "CREATE TABLE hive_test(key INT, val STRING);"
+        -> "OK",
+      "SHOW TABLES;"
+        -> "hive_test",
+      s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;"
+        -> "OK",
+      "CACHE TABLE hive_test;"
+        -> "Time taken: ",
+      "SELECT COUNT(*) FROM hive_test;"
+        -> "5",
+      "DROP TABLE hive_test"
+        -> "Time taken: "
+    )
   }
 
-  test("simple commands") {
-    val dataFilePath = getDataFile("data/files/small_kv.txt")
-    executeQuery("create table hive_test1(key int, val string);")
-    executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;")
-    executeQuery("cache table hive_test1", "Time taken")
+  test("Single command with -e") {
+    runCliWithin(1.minute, Seq("-e", "SHOW TABLES;"))("" -> "OK")
   }
 }
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
index 326b0a7275b34a7d4e4eddd7f30a946267229195..38977ff1620976a895fc78c35d863e01a013b3a9 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -17,32 +17,32 @@
 
 package org.apache.spark.sql.hive.thriftserver
 
-import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent._
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Future, Promise}
+import scala.sys.process.{Process, ProcessLogger}
 
-import java.io.{BufferedReader, InputStreamReader}
+import java.io.File
 import java.net.ServerSocket
-import java.sql.{Connection, DriverManager, Statement}
+import java.sql.{DriverManager, Statement}
+import java.util.concurrent.TimeoutException
 
 import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.apache.hive.jdbc.HiveDriver
+import org.scalatest.FunSuite
 
 import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.util.getTempFilePath
 
 /**
- * Test for the HiveThriftServer2 using JDBC.
+ * Tests for the HiveThriftServer2 using JDBC.
  */
-class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging {
+class HiveThriftServer2Suite extends FunSuite with Logging {
+  Class.forName(classOf[HiveDriver].getCanonicalName)
 
-  val WAREHOUSE_PATH = getTempFilePath("warehouse")
-  val METASTORE_PATH = getTempFilePath("metastore")
-
-  val DRIVER_NAME  = "org.apache.hive.jdbc.HiveDriver"
-  val TABLE = "test"
-  val HOST = "localhost"
-  val PORT =  {
+  private val listeningHost = "localhost"
+  private val listeningPort =  {
     // Let the system to choose a random available port to avoid collision with other parallel
     // builds.
     val socket = new ServerSocket(0)
@@ -51,96 +51,126 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt
     port
   }
 
-  Class.forName(DRIVER_NAME)
-
-  override def beforeAll() { launchServer() }
+  private val warehousePath = getTempFilePath("warehouse")
+  private val metastorePath = getTempFilePath("metastore")
+  private val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
 
-  override def afterAll() { stopServer() }
+  def startThriftServerWithin(timeout: FiniteDuration = 30.seconds)(f: Statement => Unit) {
+    val serverScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
 
-  private def launchServer(args: Seq[String] = Seq.empty) {
-    // Forking a new process to start the Hive Thrift server. The reason to do this is it is
-    // hard to clean up Hive resources entirely, so we just start a new process and kill
-    // that process for cleanup.
-    val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true"
     val command =
-      s"""../../sbin/start-thriftserver.sh
+      s"""$serverScript
          |  --master local
-         |  --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl
-         |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$METASTORE_PATH
-         |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$HOST
-         |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$PORT
-       """.stripMargin.split("\\s+")
-
-    val pb = new ProcessBuilder(command ++ args: _*)
-    val environment = pb.environment()
-    process = pb.start()
-    inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
-    errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream))
-    waitForOutput(inputReader, "ThriftBinaryCLIService listening on", 300000)
-
-    // Spawn a thread to read the output from the forked process.
-    // Note that this is necessary since in some configurations, log4j could be blocked
-    // if its output to stderr are not read, and eventually blocking the entire test suite.
-    future {
-      while (true) {
-        val stdout = readFrom(inputReader)
-        val stderr = readFrom(errorReader)
-        print(stdout)
-        print(stderr)
-        Thread.sleep(50)
+         |  --hiveconf hive.root.logger=INFO,console
+         |  --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
+         |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
+         |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost
+         |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort
+       """.stripMargin.split("\\s+").toSeq
+
+    val serverStarted = Promise[Unit]()
+    val buffer = new ArrayBuffer[String]()
+
+    def captureOutput(source: String)(line: String) {
+      buffer += s"$source> $line"
+      if (line.contains("ThriftBinaryCLIService listening on")) {
+        serverStarted.success(())
       }
     }
-  }
 
-  private def stopServer() {
-    process.destroy()
-    process.waitFor()
+    val process = Process(command).run(
+      ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
+
+    Future {
+      val exitValue = process.exitValue()
+      logInfo(s"Spark SQL Thrift server process exit value: $exitValue")
+    }
+
+    val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/"
+    val user = System.getProperty("user.name")
+
+    try {
+      Await.result(serverStarted.future, timeout)
+
+      val connection = DriverManager.getConnection(jdbcUri, user, "")
+      val statement = connection.createStatement()
+
+      try {
+        f(statement)
+      } finally {
+        statement.close()
+        connection.close()
+      }
+    } catch {
+      case cause: Exception =>
+        cause match {
+          case _: TimeoutException =>
+            logError(s"Failed to start Hive Thrift server within $timeout", cause)
+          case _ =>
+        }
+        logError(
+          s"""
+             |=====================================
+             |HiveThriftServer2Suite failure output
+             |=====================================
+             |HiveThriftServer2 command line: ${command.mkString(" ")}
+             |JDBC URI: $jdbcUri
+             |User: $user
+             |
+             |${buffer.mkString("\n")}
+             |=========================================
+             |End HiveThriftServer2Suite failure output
+             |=========================================
+           """.stripMargin, cause)
+    } finally {
+      warehousePath.delete()
+      metastorePath.delete()
+      process.destroy()
+    }
   }
 
-  test("test query execution against a Hive Thrift server") {
-    Thread.sleep(5 * 1000)
-    val dataFilePath = getDataFile("data/files/small_kv.txt")
-    val stmt = createStatement()
-    stmt.execute("DROP TABLE IF EXISTS test")
-    stmt.execute("DROP TABLE IF EXISTS test_cached")
-    stmt.execute("CREATE TABLE test(key INT, val STRING)")
-    stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test")
-    stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4")
-    stmt.execute("CACHE TABLE test_cached")
-
-    var rs = stmt.executeQuery("SELECT COUNT(*) FROM test")
-    rs.next()
-    assert(rs.getInt(1) === 5)
-
-    rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached")
-    rs.next()
-    assert(rs.getInt(1) === 4)
-
-    stmt.close()
+  test("Test JDBC query execution") {
+    startThriftServerWithin() { statement =>
+      val dataFilePath =
+        Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
+
+      val queries = Seq(
+        "CREATE TABLE test(key INT, val STRING)",
+        s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test",
+        "CACHE TABLE test")
+
+      queries.foreach(statement.execute)
+
+      assertResult(5, "Row count mismatch") {
+        val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
+        resultSet.next()
+        resultSet.getInt(1)
+      }
+    }
   }
 
   test("SPARK-3004 regression: result set containing NULL") {
-    Thread.sleep(5 * 1000)
-    val dataFilePath = getDataFile("data/files/small_kv_with_null.txt")
-    val stmt = createStatement()
-    stmt.execute("DROP TABLE IF EXISTS test_null")
-    stmt.execute("CREATE TABLE test_null(key INT, val STRING)")
-    stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null")
-
-    val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
-    var count = 0
-    while (rs.next()) {
-      count += 1
-    }
-    assert(count === 5)
+    startThriftServerWithin() { statement =>
+      val dataFilePath =
+        Thread.currentThread().getContextClassLoader.getResource(
+          "data/files/small_kv_with_null.txt")
 
-    stmt.close()
-  }
+      val queries = Seq(
+        "DROP TABLE IF EXISTS test_null",
+        "CREATE TABLE test_null(key INT, val STRING)",
+        s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null")
 
-  def getConnection: Connection = {
-    val connectURI = s"jdbc:hive2://localhost:$PORT/"
-    DriverManager.getConnection(connectURI, System.getProperty("user.name"), "")
-  }
+      queries.foreach(statement.execute)
+
+      val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
+
+      (0 until 5).foreach { _ =>
+        resultSet.next()
+        assert(resultSet.getInt(1) === 0)
+        assert(resultSet.wasNull())
+      }
 
-  def createStatement(): Statement = getConnection.createStatement()
+      assert(!resultSet.next())
+    }
+  }
 }
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
deleted file mode 100644
index bb2242618fbefc7053ec34427882da58595e1960..0000000000000000000000000000000000000000
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/TestUtils.scala
+++ /dev/null
@@ -1,108 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import java.io.{BufferedReader, PrintWriter}
-import java.text.SimpleDateFormat
-import java.util.Date
-
-import org.apache.hadoop.hive.common.LogUtils
-import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
-
-object TestUtils {
-  val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss")
-
-  def getWarehousePath(prefix: String): String = {
-    System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" +
-      timestamp.format(new Date)
-  }
-
-  def getMetastorePath(prefix: String): String = {
-    System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" +
-      timestamp.format(new Date)
-  }
-
-  // Dummy function for initialize the log4j properties.
-  def init() { }
-
-  // initialize log4j
-  try {
-    LogUtils.initHiveLog4j()
-  } catch {
-    case e: LogInitializationException => // Ignore the error.
-  }
-}
-
-trait TestUtils {
-  var process : Process = null
-  var outputWriter : PrintWriter = null
-  var inputReader : BufferedReader = null
-  var errorReader : BufferedReader = null
-
-  def executeQuery(
-    cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = {
-    println("Executing: " + cmd + ", expecting output: " + outputMessage)
-    outputWriter.write(cmd + "\n")
-    outputWriter.flush()
-    waitForQuery(timeout, outputMessage)
-  }
-
-  protected def waitForQuery(timeout: Long, message: String): String = {
-    if (waitForOutput(errorReader, message, timeout)) {
-      Thread.sleep(500)
-      readOutput()
-    } else {
-      assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput())
-      null
-    }
-  }
-
-  // Wait for the specified str to appear in the output.
-  protected def waitForOutput(
-    reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = {
-    val startTime = System.currentTimeMillis
-    var out = ""
-    while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) {
-      out += readFrom(reader)
-    }
-    out.contains(str)
-  }
-
-  // Read stdout output and filter out garbage collection messages.
-  protected def readOutput(): String = {
-    val output = readFrom(inputReader)
-    // Remove GC Messages
-    val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC"))
-      .mkString("\n")
-    filteredOutput
-  }
-
-  protected def readFrom(reader: BufferedReader): String = {
-    var out = ""
-    var c = 0
-    while (reader.ready) {
-      c = reader.read()
-      out += c.asInstanceOf[Char]
-    }
-    out
-  }
-
-  protected def getDataFile(name: String) = {
-    Thread.currentThread().getContextClassLoader.getResource(name)
-  }
-}