Skip to content
Snippets Groups Projects
Commit e1ac2a95 authored by Rene Treffer's avatar Rene Treffer Committed by Michael Armbrust
Browse files

[SPARK-6888] [SQL] Make the jdbc driver handling user-definable

Replace the DriverQuirks with JdbcDialect(s) (and MySQLDialect/PostgresDialect)
and allow developers to change the dialects on the fly (for new JDBCRRDs only).

Some types (like an unsigned 64bit number) can be trivially mapped to java.
The status quo is that the RRD will fail to load.
This patch makes it possible to overwrite the type mapping to read e.g.
64Bit numbers as strings and handle them afterwards in software.

JDBCSuite has an example that maps all types to String, which should always
work (at the cost of extra code afterwards).

As a side effect it should now be possible to develop simple dialects
out-of-tree and even with spark-shell.

Author: Rene Treffer <treffer@measite.de>

Closes #5555 from rtreffer/jdbc-dialects and squashes the following commits:

3cbafd7 [Rene Treffer] [SPARK-6888] ignore classes belonging to changed API in MIMA report
fe7e2e8 [Rene Treffer] [SPARK-6888] Make the jdbc driver handling user-definable
parent 563bfcc1
No related branches found
No related tags found
No related merge requests found
......@@ -137,6 +137,14 @@ object MimaExcludes {
// implementing this interface in Java. Note that ShuffleWriter is private[spark].
ProblemFilters.exclude[IncompatibleTemplateDefProblem](
"org.apache.spark.shuffle.ShuffleWriter")
) ++ Seq(
// SPARK-6888 make jdbc driver handling user definable
// This patch renames some classes to API friendly names.
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks")
)
case v if v.startsWith("1.3") =>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.jdbc
import org.apache.spark.sql.types._
import java.sql.Types
/**
* Encapsulates workarounds for the extensions, quirks, and bugs in various
* databases. Lots of databases define types that aren't explicitly supported
* by the JDBC spec. Some JDBC drivers also report inaccurate
* information---for instance, BIT(n>1) being reported as a BIT type is quite
* common, even though BIT in JDBC is meant for single-bit values. Also, there
* does not appear to be a standard name for an unbounded string or binary
* type; we use BLOB and CLOB by default but override with database-specific
* alternatives when these are absent or do not behave correctly.
*
* Currently, the only thing DriverQuirks does is handle type mapping.
* `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
* is used when writing to a JDBC table. If `getCatalystType` returns `null`,
* the default type handling is used for the given JDBC type. Similarly,
* if `getJDBCType` returns `(null, None)`, the default type handling is used
* for the given Catalyst type.
*/
private[sql] abstract class DriverQuirks {
def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType
def getJDBCType(dt: DataType): (String, Option[Int])
}
private[sql] object DriverQuirks {
/**
* Fetch the DriverQuirks class corresponding to a given database url.
*/
def get(url: String): DriverQuirks = {
if (url.startsWith("jdbc:mysql")) {
new MySQLQuirks()
} else if (url.startsWith("jdbc:postgresql")) {
new PostgresQuirks()
} else {
new NoQuirks()
}
}
}
private[sql] class NoQuirks extends DriverQuirks {
def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType =
null
def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
}
private[sql] class PostgresQuirks extends DriverQuirks {
def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
BinaryType
} else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
StringType
} else if (sqlType == Types.OTHER && typeName.equals("inet")) {
StringType
} else null
}
def getJDBCType(dt: DataType): (String, Option[Int]) = dt match {
case StringType => ("TEXT", Some(java.sql.Types.CHAR))
case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY))
case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN))
case _ => (null, None)
}
}
private[sql] class MySQLQuirks extends DriverQuirks {
def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
// This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
// byte arrays instead of longs.
md.putLong("binarylong", 1)
LongType
} else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
BooleanType
} else null
}
def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
}
......@@ -41,7 +41,7 @@ private[sql] object JDBCRDD extends Logging {
/**
* Maps a JDBC type to a Catalyst type. This function is called only when
* the DriverQuirks class corresponding to your database driver returns null.
* the JdbcDialect class corresponding to your database driver returns null.
*
* @param sqlType - A field of java.sql.Types
* @return The Catalyst type corresponding to sqlType.
......@@ -51,7 +51,7 @@ private[sql] object JDBCRDD extends Logging {
case java.sql.Types.ARRAY => null
case java.sql.Types.BIGINT => LongType
case java.sql.Types.BINARY => BinaryType
case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers.
case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
case java.sql.Types.BLOB => BinaryType
case java.sql.Types.BOOLEAN => BooleanType
case java.sql.Types.CHAR => StringType
......@@ -108,7 +108,7 @@ private[sql] object JDBCRDD extends Logging {
* @throws SQLException if the table contains an unsupported type.
*/
def resolveTable(url: String, table: String, properties: Properties): StructType = {
val quirks = DriverQuirks.get(url)
val dialect = JdbcDialects.get(url)
val conn: Connection = DriverManager.getConnection(url, properties)
try {
val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
......@@ -125,8 +125,9 @@ private[sql] object JDBCRDD extends Logging {
val fieldScale = rsmd.getScale(i + 1)
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder().putString("name", columnName)
var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata)
if (columnType == null) columnType = getCatalystType(dataType, fieldSize, fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale))
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.jdbc
import org.apache.spark.sql.types._
import org.apache.spark.annotation.DeveloperApi
import java.sql.Types
/**
* :: DeveloperApi ::
* A database type definition coupled with the jdbc type needed to send null
* values to the database.
* @param databaseTypeDefinition The database type definition
* @param jdbcNullType The jdbc type (as defined in java.sql.Types) used to
* send a null value to the database.
*/
@DeveloperApi
case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
/**
* :: DeveloperApi ::
* Encapsulates everything (extensions, workarounds, quirks) to handle the
* SQL dialect of a certain database or jdbc driver.
* Lots of databases define types that aren't explicitly supported
* by the JDBC spec. Some JDBC drivers also report inaccurate
* information---for instance, BIT(n>1) being reported as a BIT type is quite
* common, even though BIT in JDBC is meant for single-bit values. Also, there
* does not appear to be a standard name for an unbounded string or binary
* type; we use BLOB and CLOB by default but override with database-specific
* alternatives when these are absent or do not behave correctly.
*
* Currently, the only thing done by the dialect is type mapping.
* `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
* is used when writing to a JDBC table. If `getCatalystType` returns `null`,
* the default type handling is used for the given JDBC type. Similarly,
* if `getJDBCType` returns `(null, None)`, the default type handling is used
* for the given Catalyst type.
*/
@DeveloperApi
abstract class JdbcDialect {
/**
* Check if this dialect instance can handle a certain jdbc url.
* @param url the jdbc url.
* @return True if the dialect can be applied on the given jdbc url.
* @throws NullPointerException if the url is null.
*/
def canHandle(url : String): Boolean
/**
* Get the custom datatype mapping for the given jdbc meta information.
* @param sqlType The sql type (see java.sql.Types)
* @param typeName The sql type name (e.g. "BIGINT UNSIGNED")
* @param size The size of the type.
* @param md Result metadata associated with this type.
* @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]])
* or null if the default type mapping should be used.
*/
def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None
/**
* Retrieve the jdbc / sql type for a given datatype.
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The new JdbcType if there is an override for this DataType
*/
def getJDBCType(dt: DataType): Option[JdbcType] = None
}
/**
* :: DeveloperApi ::
* Registry of dialects that apply to every new jdbc [[org.apache.spark.sql.DataFrame]].
*
* If multiple matching dialects are registered then all matching ones will be
* tried in reverse order. A user-added dialect will thus be applied first,
* overwriting the defaults.
*
* Note that all new dialects are applied to new jdbc DataFrames only. Make
* sure to register your dialects first.
*/
@DeveloperApi
object JdbcDialects {
private var dialects = List[JdbcDialect]()
/**
* Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]].
* Readding an existing dialect will cause a move-to-front.
* @param dialect The new dialect.
*/
def registerDialect(dialect: JdbcDialect) : Unit = {
dialects = dialect :: dialects.filterNot(_ == dialect)
}
/**
* Unregister a dialect. Does nothing if the dialect is not registered.
* @param dialect The jdbc dialect.
*/
def unregisterDialect(dialect : JdbcDialect) : Unit = {
dialects = dialects.filterNot(_ == dialect)
}
registerDialect(MySQLDialect)
registerDialect(PostgresDialect)
/**
* Fetch the JdbcDialect class corresponding to a given database url.
*/
private[sql] def get(url: String): JdbcDialect = {
val matchingDialects = dialects.filter(_.canHandle(url))
matchingDialects.length match {
case 0 => NoopDialect
case 1 => matchingDialects.head
case _ => new AggregatedDialect(matchingDialects)
}
}
}
/**
* :: DeveloperApi ::
* AggregatedDialect can unify multiple dialects into one virtual Dialect.
* Dialects are tried in order, and the first dialect that does not return a
* neutral element will will.
* @param dialects List of dialects.
*/
@DeveloperApi
class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect {
require(!dialects.isEmpty)
def canHandle(url : String): Boolean =
dialects.map(_.canHandle(url)).reduce(_ && _)
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
dialects.map(_.getCatalystType(sqlType, typeName, size, md)).flatten.headOption
override def getJDBCType(dt: DataType): Option[JdbcType] =
dialects.map(_.getJDBCType(dt)).flatten.headOption
}
/**
* :: DeveloperApi ::
* NOOP dialect object, always returning the neutral element.
*/
@DeveloperApi
case object NoopDialect extends JdbcDialect {
def canHandle(url : String): Boolean = true
}
/**
* :: DeveloperApi ::
* Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write.
*/
@DeveloperApi
case object PostgresDialect extends JdbcDialect {
def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
Some(BinaryType)
} else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
Some(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("inet")) {
Some(StringType)
} else None
}
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
case _ => None
}
}
/**
* :: DeveloperApi ::
* Default mysql dialect to read bit/bitsets correctly.
*/
@DeveloperApi
case object MySQLDialect extends JdbcDialect {
def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
// This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
// byte arrays instead of longs.
md.putLong("binarylong", 1)
Some(LongType)
} else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
Some(BooleanType)
} else None
}
}
......@@ -129,25 +129,26 @@ package object jdbc {
*/
def schemaString(df: DataFrame, url: String): String = {
val sb = new StringBuilder()
val quirks = DriverQuirks.get(url)
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
var typ: String = quirks.getJDBCType(field.dataType)._1
if (typ == null) typ = field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
case DoubleType => "DOUBLE PRECISION"
case FloatType => "REAL"
case ShortType => "INTEGER"
case ByteType => "BYTE"
case BooleanType => "BIT(1)"
case StringType => "TEXT"
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case DecimalType.Unlimited => "DECIMAL(40,20)"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
}
val typ: String =
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
case DoubleType => "DOUBLE PRECISION"
case FloatType => "REAL"
case ShortType => "INTEGER"
case ByteType => "BYTE"
case BooleanType => "BIT(1)"
case StringType => "TEXT"
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case DecimalType.Unlimited => "DECIMAL(40,20)"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
})
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
......@@ -162,10 +163,9 @@ package object jdbc {
url: String,
table: String,
properties: Properties = new Properties()) {
val quirks = DriverQuirks.get(url)
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
val nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2
if (nullType.isEmpty) {
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
field.dataType match {
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
......@@ -181,8 +181,7 @@ package object jdbc {
case DecimalType.Unlimited => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
s"Can't translate null value for field $field")
}
} else nullType.get
})
}
val rddSchema = df.schema
......
......@@ -35,6 +35,13 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
val testH2Dialect = new JdbcDialect {
def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
Some(StringType)
}
before {
Class.forName("org.h2.Driver")
// Extra properties that will be specified for our database. We need these to test
......@@ -353,4 +360,46 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
""".stripMargin.replaceAll("\n", " "))
}
}
test("Remap types via JdbcDialects") {
JdbcDialects.registerDialect(testH2Dialect)
val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
assert(df.schema.filter(
_.dataType != org.apache.spark.sql.types.StringType
).isEmpty)
val rows = df.collect()
assert(rows(0).get(0).isInstanceOf[String])
assert(rows(0).get(1).isInstanceOf[String])
JdbcDialects.unregisterDialect(testH2Dialect)
}
test("Default jdbc dialect registration") {
assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect)
assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect)
assert(JdbcDialects.get("test.invalid") == NoopDialect)
}
test("Dialect unregister") {
JdbcDialects.registerDialect(testH2Dialect)
JdbcDialects.unregisterDialect(testH2Dialect)
assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect)
}
test("Aggregated dialects") {
val agg = new AggregatedDialect(List(new JdbcDialect {
def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
if (sqlType % 2 == 0) {
Some(LongType)
} else {
None
}
}, testH2Dialect))
assert(agg.canHandle("jdbc:h2:xxx"))
assert(!agg.canHandle("jdbc:h2"))
assert(agg.getCatalystType(0,"",1,null) == Some(LongType))
assert(agg.getCatalystType(1,"",1,null) == Some(StringType))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment