Skip to content
Snippets Groups Projects
Commit cae9414d authored by Cheng Lian's avatar Cheng Lian Committed by Michael Armbrust
Browse files

[SPARK-2929][SQL] Refactored Thrift server and CLI suites

Removed most hard coded timeout, timing assumptions and all `Thread.sleep`. Simplified IPC and synchronization with `scala.sys.process` and future/promise so that the test suites can run more robustly and faster.

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #1856 from liancheng/thriftserver-tests and squashes the following commits:

2d914ca [Cheng Lian] Minor refactoring
0e12e71 [Cheng Lian] Cleaned up test output
0ee921d [Cheng Lian] Refactored Thrift server and CLI suites
parent d299e2bf
No related branches found
No related tags found
No related merge requests found
...@@ -18,41 +18,112 @@ ...@@ -18,41 +18,112 @@
package org.apache.spark.sql.hive.thriftserver package org.apache.spark.sql.hive.thriftserver
import java.io.{BufferedReader, InputStreamReader, PrintWriter} import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent.{Await, Future, Promise}
import scala.sys.process.{Process, ProcessLogger}
import java.io._
import java.util.concurrent.atomic.AtomicInteger
import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.{BeforeAndAfterAll, FunSuite}
class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { import org.apache.spark.Logging
val WAREHOUSE_PATH = TestUtils.getWarehousePath("cli") import org.apache.spark.sql.catalyst.util.getTempFilePath
val METASTORE_PATH = TestUtils.getMetastorePath("cli")
class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
def runCliWithin(
timeout: FiniteDuration,
extraArgs: Seq[String] = Seq.empty)(
queriesAndExpectedAnswers: (String, String)*) {
val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip
val warehousePath = getTempFilePath("warehouse")
val metastorePath = getTempFilePath("metastore")
val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator)
override def beforeAll() { val command = {
val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true" val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true"
val commands = s"""$cliScript
s"""../../bin/spark-sql
| --master local | --master local
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$WAREHOUSE_PATH | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
""".stripMargin.split("\\s+") """.stripMargin.split("\\s+").toSeq ++ extraArgs
}
val pb = new ProcessBuilder(commands: _*)
process = pb.start() // AtomicInteger is needed because stderr and stdout of the forked process are handled in
outputWriter = new PrintWriter(process.getOutputStream, true) // different threads.
inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) val next = new AtomicInteger(0)
errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) val foundAllExpectedAnswers = Promise.apply[Unit]()
waitForOutput(inputReader, "spark-sql>") val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes)
val buffer = new ArrayBuffer[String]()
def captureOutput(source: String)(line: String) {
buffer += s"$source> $line"
if (line.contains(expectedAnswers(next.get()))) {
if (next.incrementAndGet() == expectedAnswers.size) {
foundAllExpectedAnswers.trySuccess(())
}
}
}
// Searching expected output line from both stdout and stderr of the CLI process
val process = (Process(command) #< queryStream).run(
ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
Future {
val exitValue = process.exitValue()
logInfo(s"Spark SQL CLI process exit value: $exitValue")
}
try {
Await.result(foundAllExpectedAnswers.future, timeout)
} catch { case cause: Throwable =>
logError(
s"""
|=======================
|CliSuite failure output
|=======================
|Spark SQL CLI command line: ${command.mkString(" ")}
|
|Executed query ${next.get()} "${queries(next.get())}",
|But failed to capture expected output "${expectedAnswers(next.get())}" within $timeout.
|
|${buffer.mkString("\n")}
|===========================
|End CliSuite failure output
|===========================
""".stripMargin, cause)
} finally {
warehousePath.delete()
metastorePath.delete()
process.destroy()
}
} }
override def afterAll() { test("Simple commands") {
process.destroy() val dataFilePath =
process.waitFor() Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
runCliWithin(1.minute)(
"CREATE TABLE hive_test(key INT, val STRING);"
-> "OK",
"SHOW TABLES;"
-> "hive_test",
s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;"
-> "OK",
"CACHE TABLE hive_test;"
-> "Time taken: ",
"SELECT COUNT(*) FROM hive_test;"
-> "5",
"DROP TABLE hive_test"
-> "Time taken: "
)
} }
test("simple commands") { test("Single command with -e") {
val dataFilePath = getDataFile("data/files/small_kv.txt") runCliWithin(1.minute, Seq("-e", "SHOW TABLES;"))("" -> "OK")
executeQuery("create table hive_test1(key int, val string);")
executeQuery("load data local inpath '" + dataFilePath+ "' overwrite into table hive_test1;")
executeQuery("cache table hive_test1", "Time taken")
} }
} }
...@@ -17,32 +17,32 @@ ...@@ -17,32 +17,32 @@
package org.apache.spark.sql.hive.thriftserver package org.apache.spark.sql.hive.thriftserver
import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent._ import scala.concurrent.duration._
import scala.concurrent.{Await, Future, Promise}
import scala.sys.process.{Process, ProcessLogger}
import java.io.{BufferedReader, InputStreamReader} import java.io.File
import java.net.ServerSocket import java.net.ServerSocket
import java.sql.{Connection, DriverManager, Statement} import java.sql.{DriverManager, Statement}
import java.util.concurrent.TimeoutException
import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.hive.jdbc.HiveDriver
import org.scalatest.FunSuite
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.catalyst.util.getTempFilePath
/** /**
* Test for the HiveThriftServer2 using JDBC. * Tests for the HiveThriftServer2 using JDBC.
*/ */
class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUtils with Logging { class HiveThriftServer2Suite extends FunSuite with Logging {
Class.forName(classOf[HiveDriver].getCanonicalName)
val WAREHOUSE_PATH = getTempFilePath("warehouse") private val listeningHost = "localhost"
val METASTORE_PATH = getTempFilePath("metastore") private val listeningPort = {
val DRIVER_NAME = "org.apache.hive.jdbc.HiveDriver"
val TABLE = "test"
val HOST = "localhost"
val PORT = {
// Let the system to choose a random available port to avoid collision with other parallel // Let the system to choose a random available port to avoid collision with other parallel
// builds. // builds.
val socket = new ServerSocket(0) val socket = new ServerSocket(0)
...@@ -51,96 +51,126 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt ...@@ -51,96 +51,126 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt
port port
} }
Class.forName(DRIVER_NAME) private val warehousePath = getTempFilePath("warehouse")
private val metastorePath = getTempFilePath("metastore")
override def beforeAll() { launchServer() } private val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
override def afterAll() { stopServer() } def startThriftServerWithin(timeout: FiniteDuration = 30.seconds)(f: Statement => Unit) {
val serverScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
private def launchServer(args: Seq[String] = Seq.empty) {
// Forking a new process to start the Hive Thrift server. The reason to do this is it is
// hard to clean up Hive resources entirely, so we just start a new process and kill
// that process for cleanup.
val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true"
val command = val command =
s"""../../sbin/start-thriftserver.sh s"""$serverScript
| --master local | --master local
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$METASTORE_PATH | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$HOST | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$PORT | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$listeningHost
""".stripMargin.split("\\s+") | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$listeningPort
""".stripMargin.split("\\s+").toSeq
val pb = new ProcessBuilder(command ++ args: _*)
val environment = pb.environment() val serverStarted = Promise[Unit]()
process = pb.start() val buffer = new ArrayBuffer[String]()
inputReader = new BufferedReader(new InputStreamReader(process.getInputStream))
errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream)) def captureOutput(source: String)(line: String) {
waitForOutput(inputReader, "ThriftBinaryCLIService listening on", 300000) buffer += s"$source> $line"
if (line.contains("ThriftBinaryCLIService listening on")) {
// Spawn a thread to read the output from the forked process. serverStarted.success(())
// Note that this is necessary since in some configurations, log4j could be blocked
// if its output to stderr are not read, and eventually blocking the entire test suite.
future {
while (true) {
val stdout = readFrom(inputReader)
val stderr = readFrom(errorReader)
print(stdout)
print(stderr)
Thread.sleep(50)
} }
} }
}
private def stopServer() { val process = Process(command).run(
process.destroy() ProcessLogger(captureOutput("stdout"), captureOutput("stderr")))
process.waitFor()
Future {
val exitValue = process.exitValue()
logInfo(s"Spark SQL Thrift server process exit value: $exitValue")
}
val jdbcUri = s"jdbc:hive2://$listeningHost:$listeningPort/"
val user = System.getProperty("user.name")
try {
Await.result(serverStarted.future, timeout)
val connection = DriverManager.getConnection(jdbcUri, user, "")
val statement = connection.createStatement()
try {
f(statement)
} finally {
statement.close()
connection.close()
}
} catch {
case cause: Exception =>
cause match {
case _: TimeoutException =>
logError(s"Failed to start Hive Thrift server within $timeout", cause)
case _ =>
}
logError(
s"""
|=====================================
|HiveThriftServer2Suite failure output
|=====================================
|HiveThriftServer2 command line: ${command.mkString(" ")}
|JDBC URI: $jdbcUri
|User: $user
|
|${buffer.mkString("\n")}
|=========================================
|End HiveThriftServer2Suite failure output
|=========================================
""".stripMargin, cause)
} finally {
warehousePath.delete()
metastorePath.delete()
process.destroy()
}
} }
test("test query execution against a Hive Thrift server") { test("Test JDBC query execution") {
Thread.sleep(5 * 1000) startThriftServerWithin() { statement =>
val dataFilePath = getDataFile("data/files/small_kv.txt") val dataFilePath =
val stmt = createStatement() Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt")
stmt.execute("DROP TABLE IF EXISTS test")
stmt.execute("DROP TABLE IF EXISTS test_cached") val queries = Seq(
stmt.execute("CREATE TABLE test(key INT, val STRING)") "CREATE TABLE test(key INT, val STRING)",
stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test",
stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4") "CACHE TABLE test")
stmt.execute("CACHE TABLE test_cached")
queries.foreach(statement.execute)
var rs = stmt.executeQuery("SELECT COUNT(*) FROM test")
rs.next() assertResult(5, "Row count mismatch") {
assert(rs.getInt(1) === 5) val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
resultSet.next()
rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached") resultSet.getInt(1)
rs.next() }
assert(rs.getInt(1) === 4) }
stmt.close()
} }
test("SPARK-3004 regression: result set containing NULL") { test("SPARK-3004 regression: result set containing NULL") {
Thread.sleep(5 * 1000) startThriftServerWithin() { statement =>
val dataFilePath = getDataFile("data/files/small_kv_with_null.txt") val dataFilePath =
val stmt = createStatement() Thread.currentThread().getContextClassLoader.getResource(
stmt.execute("DROP TABLE IF EXISTS test_null") "data/files/small_kv_with_null.txt")
stmt.execute("CREATE TABLE test_null(key INT, val STRING)")
stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null")
val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
var count = 0
while (rs.next()) {
count += 1
}
assert(count === 5)
stmt.close() val queries = Seq(
} "DROP TABLE IF EXISTS test_null",
"CREATE TABLE test_null(key INT, val STRING)",
s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null")
def getConnection: Connection = { queries.foreach(statement.execute)
val connectURI = s"jdbc:hive2://localhost:$PORT/"
DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
}
(0 until 5).foreach { _ =>
resultSet.next()
assert(resultSet.getInt(1) === 0)
assert(resultSet.wasNull())
}
def createStatement(): Statement = getConnection.createStatement() assert(!resultSet.next())
}
}
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive.thriftserver
import java.io.{BufferedReader, PrintWriter}
import java.text.SimpleDateFormat
import java.util.Date
import org.apache.hadoop.hive.common.LogUtils
import org.apache.hadoop.hive.common.LogUtils.LogInitializationException
object TestUtils {
val timestamp = new SimpleDateFormat("yyyyMMdd-HHmmss")
def getWarehousePath(prefix: String): String = {
System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-warehouse-" +
timestamp.format(new Date)
}
def getMetastorePath(prefix: String): String = {
System.getProperty("user.dir") + "/test_warehouses/" + prefix + "-metastore-" +
timestamp.format(new Date)
}
// Dummy function for initialize the log4j properties.
def init() { }
// initialize log4j
try {
LogUtils.initHiveLog4j()
} catch {
case e: LogInitializationException => // Ignore the error.
}
}
trait TestUtils {
var process : Process = null
var outputWriter : PrintWriter = null
var inputReader : BufferedReader = null
var errorReader : BufferedReader = null
def executeQuery(
cmd: String, outputMessage: String = "OK", timeout: Long = 15000): String = {
println("Executing: " + cmd + ", expecting output: " + outputMessage)
outputWriter.write(cmd + "\n")
outputWriter.flush()
waitForQuery(timeout, outputMessage)
}
protected def waitForQuery(timeout: Long, message: String): String = {
if (waitForOutput(errorReader, message, timeout)) {
Thread.sleep(500)
readOutput()
} else {
assert(false, "Didn't find \"" + message + "\" in the output:\n" + readOutput())
null
}
}
// Wait for the specified str to appear in the output.
protected def waitForOutput(
reader: BufferedReader, str: String, timeout: Long = 10000): Boolean = {
val startTime = System.currentTimeMillis
var out = ""
while (!out.contains(str) && System.currentTimeMillis < (startTime + timeout)) {
out += readFrom(reader)
}
out.contains(str)
}
// Read stdout output and filter out garbage collection messages.
protected def readOutput(): String = {
val output = readFrom(inputReader)
// Remove GC Messages
val filteredOutput = output.lines.filterNot(x => x.contains("[GC") || x.contains("[Full GC"))
.mkString("\n")
filteredOutput
}
protected def readFrom(reader: BufferedReader): String = {
var out = ""
var c = 0
while (reader.ready) {
c = reader.read()
out += c.asInstanceOf[Char]
}
out
}
protected def getDataFile(name: String) = {
Thread.currentThread().getContextClassLoader.getResource(name)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment