Skip to content
Snippets Groups Projects
Commit 8de038eb authored by Michael Armbrust's avatar Michael Armbrust Committed by Reynold Xin
Browse files

[SQL] SPARK-1366 Consistent sql function across different types of SQLContexts

Now users who want to use HiveQL should explicitly say `hiveql` or `hql`.

Author: Michael Armbrust <michael@databricks.com>

Closes #319 from marmbrus/standardizeSqlHql and squashes the following commits:

de68d0e [Michael Armbrust] Fix sampling test.
fbe4a54 [Michael Armbrust] Make `sql` always use spark sql parser, users of hive context can now use hql or hiveql to run queries using HiveQL instead.
parent b50ddfde
No related branches found
No related tags found
No related merge requests found
......@@ -33,20 +33,20 @@ object HiveFromSpark {
val hiveContext = new LocalHiveContext(sc)
import hiveContext._
sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src")
hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
hql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src")
// Queries are expressed in HiveQL
println("Result of 'SELECT *': ")
sql("SELECT * FROM src").collect.foreach(println)
hql("SELECT * FROM src").collect.foreach(println)
// Aggregation queries are also supported.
val count = sql("SELECT COUNT(*) FROM src").collect().head.getInt(0)
val count = hql("SELECT COUNT(*) FROM src").collect().head.getInt(0)
println(s"COUNT(*): $count")
// The results of SQL queries are themselves RDDs and support all normal RDD functions. The
// items in the RDD are of type Row, which allows you to access each column by ordinal.
val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key")
val rddFromSql = hql("SELECT key, value FROM src WHERE key < 10 ORDER BY key")
println("Result of RDD.map:")
val rddAsStrings = rddFromSql.map {
......@@ -59,6 +59,6 @@ object HiveFromSpark {
// Queries can then join RDD data with data stored in Hive.
println("Result of SELECT *:")
sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println)
hql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println)
}
}
......@@ -67,14 +67,13 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) {
class HiveContext(sc: SparkContext) extends SQLContext(sc) {
self =>
override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql)
override def executePlan(plan: LogicalPlan): this.QueryExecution =
override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
/**
* Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD.
*/
def hql(hqlQuery: String): SchemaRDD = {
def hiveql(hqlQuery: String): SchemaRDD = {
val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery))
// We force query optimization to happen right away instead of letting it happen lazily like
// when using the query DSL. This is so DDL commands behave as expected. This is only
......@@ -83,6 +82,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
result
}
/** An alias for `hiveql`. */
def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery)
// Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
@transient
protected val outputBuffer = new java.io.OutputStream {
......@@ -120,7 +122,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
override def lookupRelation(
databaseName: Option[String],
tableName: String,
......@@ -132,7 +134,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/* An analyzer that uses the Hive metastore. */
@transient
override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
override protected[sql] lazy val analyzer =
new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
/**
* Runs the specified SQL query using Hive.
......@@ -214,14 +217,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}
@transient
override val planner = hivePlanner
override protected[sql] val planner = hivePlanner
@transient
protected lazy val emptyResult =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
/** Extends QueryExecution with hive specific features. */
abstract class QueryExecution extends super.QueryExecution {
protected[sql] abstract class QueryExecution extends super.QueryExecution {
// TODO: Create mixin for the analyzer instead of overriding things here.
override lazy val optimizedPlan =
optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
......
......@@ -110,10 +110,10 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
val describedTable = "DESCRIBE (\\w+)".r
class SqlQueryExecution(sql: String) extends this.QueryExecution {
lazy val logical = HiveQl.parseSql(sql)
def hiveExec() = runSqlHive(sql)
override def toString = sql + "\n" + super.toString
protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution {
lazy val logical = HiveQl.parseSql(hql)
def hiveExec() = runSqlHive(hql)
override def toString = hql + "\n" + super.toString
}
/**
......@@ -140,8 +140,8 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
case class TestTable(name: String, commands: (()=>Unit)*)
implicit class SqlCmd(sql: String) {
def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit
protected[hive] implicit class SqlCmd(sql: String) {
def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit
}
/**
......
......@@ -125,7 +125,7 @@ abstract class HiveComparisonTest
}
protected def prepareAnswer(
hiveQuery: TestHive.type#SqlQueryExecution,
hiveQuery: TestHive.type#HiveQLQueryExecution,
answer: Seq[String]): Seq[String] = {
val orderedAnswer = hiveQuery.logical match {
// Clean out non-deterministic time schema info.
......@@ -227,7 +227,7 @@ abstract class HiveComparisonTest
try {
// MINOR HACK: You must run a query before calling reset the first time.
TestHive.sql("SHOW TABLES")
TestHive.hql("SHOW TABLES")
if (reset) { TestHive.reset() }
val hiveCacheFiles = queryList.zipWithIndex.map {
......@@ -256,7 +256,7 @@ abstract class HiveComparisonTest
hiveCachedResults
} else {
val hiveQueries = queryList.map(new TestHive.SqlQueryExecution(_))
val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_))
// Make sure we can at least parse everything before attempting hive execution.
hiveQueries.foreach(_.logical)
val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map {
......@@ -302,7 +302,7 @@ abstract class HiveComparisonTest
// Run w/ catalyst
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
val query = new TestHive.SqlQueryExecution(queryString)
val query = new TestHive.HiveQLQueryExecution(queryString)
try { (query, prepareAnswer(query, query.stringResult())) } catch {
case e: Exception =>
val errorMessage =
......@@ -359,7 +359,7 @@ abstract class HiveComparisonTest
// When we encounter an error we check to see if the environment is still okay by running a simple query.
// If this fails then we halt testing since something must have gone seriously wrong.
try {
new TestHive.SqlQueryExecution("SELECT key FROM src").stringResult()
new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult()
TestHive.runSqlHive("SELECT key FROM src")
} catch {
case e: Exception =>
......
......@@ -23,6 +23,16 @@ import org.apache.spark.sql.hive.TestHive._
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
*/
class HiveQuerySuite extends HiveComparisonTest {
test("Query expressed in SQL") {
assert(sql("SELECT 1").collect() === Array(Seq(1)))
}
test("Query expressed in HiveQL") {
hql("FROM src SELECT key").collect()
hiveql("FROM src SELECT key").collect()
}
createQueryTest("Simple Average",
"SELECT AVG(key) FROM src")
......@@ -133,7 +143,7 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v")
test("sampling") {
sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
hql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
}
}
......@@ -56,7 +56,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2)) :: Nil)
.registerAsTable("caseSensitivityTest")
sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
}
/**
......
......@@ -136,7 +136,7 @@ class PruningSuite extends HiveComparisonTest {
expectedScannedColumns: Seq[String],
expectedPartValues: Seq[Seq[String]]) = {
test(s"$testCaseName - pruning test") {
val plan = new TestHive.SqlQueryExecution(sql).executedPlan
val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan
val actualOutputColumns = plan.output.map(_.name)
val (actualScannedColumns, actualPartValues) = plan.collect {
case p @ HiveTableScan(columns, relation, _) =>
......
......@@ -57,34 +57,34 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft
}
test("SELECT on Parquet table") {
val rdd = sql("SELECT * FROM testsource").collect()
val rdd = hql("SELECT * FROM testsource").collect()
assert(rdd != null)
assert(rdd.forall(_.size == 6))
}
test("Simple column projection + filter on Parquet table") {
val rdd = sql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect()
val rdd = hql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect()
assert(rdd.size === 5, "Filter returned incorrect number of rows")
assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value")
}
test("Converting Hive to Parquet Table via saveAsParquetFile") {
sql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath)
hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath)
parquetFile(dirname.getAbsolutePath).registerAsTable("ptable")
val rddOne = sql("SELECT * FROM src").collect().sortBy(_.getInt(0))
val rddTwo = sql("SELECT * from ptable").collect().sortBy(_.getInt(0))
val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0))
val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0))
compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String"))
}
test("INSERT OVERWRITE TABLE Parquet table") {
sql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath)
hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath)
parquetFile(dirname.getAbsolutePath).registerAsTable("ptable")
// let's do three overwrites for good measure
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
sql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
val rddCopy = sql("SELECT * FROM ptable").collect()
val rddOrig = sql("SELECT * FROM testsource").collect()
hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
val rddCopy = hql("SELECT * FROM ptable").collect()
val rddOrig = hql("SELECT * FROM testsource").collect()
assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??")
compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames)
}
......@@ -93,13 +93,13 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft
createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
.registerAsTable("tmp")
val rddCopy =
sql("INSERT INTO TABLE tmp SELECT * FROM src")
hql("INSERT INTO TABLE tmp SELECT * FROM src")
.collect()
.sortBy[Int](_.apply(0) match {
case x: Int => x
case _ => 0
})
val rddOrig = sql("SELECT * FROM src")
val rddOrig = hql("SELECT * FROM src")
.collect()
.sortBy(_.getInt(0))
compareRDDs(rddOrig, rddCopy, "src (Hive)", Seq("key:Int", "value:String"))
......@@ -108,22 +108,22 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft
test("Appending to Parquet table") {
createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
.registerAsTable("tmpnew")
sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
sql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
val rddCopies = sql("SELECT * FROM tmpnew").collect()
val rddOrig = sql("SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
val rddCopies = hql("SELECT * FROM tmpnew").collect()
val rddOrig = hql("SELECT * FROM src").collect()
assert(rddCopies.size === 3 * rddOrig.size, "number of copied rows via INSERT INTO did not match correct number")
}
test("Appending to and then overwriting Parquet table") {
createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
.registerAsTable("tmp")
sql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
sql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
sql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect()
val rddCopies = sql("SELECT * FROM tmp").collect()
val rddOrig = sql("SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
hql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
hql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect()
val rddCopies = hql("SELECT * FROM tmp").collect()
val rddOrig = hql("SELECT * FROM src").collect()
assert(rddCopies.size === rddOrig.size, "INSERT OVERWRITE did not actually overwrite")
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment