diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 487062a31f77f310363146128343b7b20f082c77..513bbaf98d804693c034305c8219f62967bf644d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -137,6 +137,14 @@ object MimaExcludes { // implementing this interface in Java. Note that ShuffleWriter is private[spark]. ProblemFilters.exclude[IncompatibleTemplateDefProblem]( "org.apache.spark.shuffle.ShuffleWriter") + ) ++ Seq( + // SPARK-6888 make jdbc driver handling user definable + // This patch renames some classes to API friendly names. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") ) case v if v.startsWith("1.3") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala deleted file mode 100644 index 0feabc4282f4a24d8e406ac178b09a8ac7219f85..0000000000000000000000000000000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import org.apache.spark.sql.types._ - -import java.sql.Types - - -/** - * Encapsulates workarounds for the extensions, quirks, and bugs in various - * databases. Lots of databases define types that aren't explicitly supported - * by the JDBC spec. Some JDBC drivers also report inaccurate - * information---for instance, BIT(n>1) being reported as a BIT type is quite - * common, even though BIT in JDBC is meant for single-bit values. Also, there - * does not appear to be a standard name for an unbounded string or binary - * type; we use BLOB and CLOB by default but override with database-specific - * alternatives when these are absent or do not behave correctly. - * - * Currently, the only thing DriverQuirks does is handle type mapping. - * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` - * is used when writing to a JDBC table. If `getCatalystType` returns `null`, - * the default type handling is used for the given JDBC type. Similarly, - * if `getJDBCType` returns `(null, None)`, the default type handling is used - * for the given Catalyst type. - */ -private[sql] abstract class DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType - def getJDBCType(dt: DataType): (String, Option[Int]) -} - -private[sql] object DriverQuirks { - /** - * Fetch the DriverQuirks class corresponding to a given database url. - */ - def get(url: String): DriverQuirks = { - if (url.startsWith("jdbc:mysql")) { - new MySQLQuirks() - } else if (url.startsWith("jdbc:postgresql")) { - new PostgresQuirks() - } else { - new NoQuirks() - } - } -} - -private[sql] class NoQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = - null - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} - -private[sql] class PostgresQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - BinaryType - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - StringType - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - StringType - } else null - } - - def getJDBCType(dt: DataType): (String, Option[Int]) = dt match { - case StringType => ("TEXT", Some(java.sql.Types.CHAR)) - case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY)) - case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN)) - case _ => (null, None) - } -} - -private[sql] class MySQLQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { - // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as - // byte arrays instead of longs. - md.putLong("binarylong", 1) - LongType - } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - BooleanType - } else null - } - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 4189dfcf956c04698cd2c153f4e7a397f85b72c3..f7b19096eaacb84a8fa43373e1a5a1f599e5807e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -41,7 +41,7 @@ private[sql] object JDBCRDD extends Logging { /** * Maps a JDBC type to a Catalyst type. This function is called only when - * the DriverQuirks class corresponding to your database driver returns null. + * the JdbcDialect class corresponding to your database driver returns null. * * @param sqlType - A field of java.sql.Types * @return The Catalyst type corresponding to sqlType. @@ -51,7 +51,7 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.ARRAY => null case java.sql.Types.BIGINT => LongType case java.sql.Types.BINARY => BinaryType - case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers. + case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks case java.sql.Types.BLOB => BinaryType case java.sql.Types.BOOLEAN => BooleanType case java.sql.Types.CHAR => StringType @@ -108,7 +108,7 @@ private[sql] object JDBCRDD extends Logging { * @throws SQLException if the table contains an unsupported type. */ def resolveTable(url: String, table: String, properties: Properties): StructType = { - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) val conn: Connection = DriverManager.getConnection(url, properties) try { val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() @@ -125,8 +125,9 @@ private[sql] object JDBCRDD extends Logging { val fieldScale = rsmd.getScale(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) - var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata) - if (columnType == null) columnType = getCatalystType(dataType, fieldSize, fieldScale) + val columnType = + dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( + getCatalystType(dataType, fieldSize, fieldScale)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala new file mode 100644 index 0000000000000000000000000000000000000000..6a169e106b96821a55fdf85b816d9bcb93a51dfe --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.DeveloperApi + +import java.sql.Types + +/** + * :: DeveloperApi :: + * A database type definition coupled with the jdbc type needed to send null + * values to the database. + * @param databaseTypeDefinition The database type definition + * @param jdbcNullType The jdbc type (as defined in java.sql.Types) used to + * send a null value to the database. + */ +@DeveloperApi +case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) + +/** + * :: DeveloperApi :: + * Encapsulates everything (extensions, workarounds, quirks) to handle the + * SQL dialect of a certain database or jdbc driver. + * Lots of databases define types that aren't explicitly supported + * by the JDBC spec. Some JDBC drivers also report inaccurate + * information---for instance, BIT(n>1) being reported as a BIT type is quite + * common, even though BIT in JDBC is meant for single-bit values. Also, there + * does not appear to be a standard name for an unbounded string or binary + * type; we use BLOB and CLOB by default but override with database-specific + * alternatives when these are absent or do not behave correctly. + * + * Currently, the only thing done by the dialect is type mapping. + * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` + * is used when writing to a JDBC table. If `getCatalystType` returns `null`, + * the default type handling is used for the given JDBC type. Similarly, + * if `getJDBCType` returns `(null, None)`, the default type handling is used + * for the given Catalyst type. + */ +@DeveloperApi +abstract class JdbcDialect { + /** + * Check if this dialect instance can handle a certain jdbc url. + * @param url the jdbc url. + * @return True if the dialect can be applied on the given jdbc url. + * @throws NullPointerException if the url is null. + */ + def canHandle(url : String): Boolean + + /** + * Get the custom datatype mapping for the given jdbc meta information. + * @param sqlType The sql type (see java.sql.Types) + * @param typeName The sql type name (e.g. "BIGINT UNSIGNED") + * @param size The size of the type. + * @param md Result metadata associated with this type. + * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]]) + * or null if the default type mapping should be used. + */ + def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None + + /** + * Retrieve the jdbc / sql type for a given datatype. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The new JdbcType if there is an override for this DataType + */ + def getJDBCType(dt: DataType): Option[JdbcType] = None +} + +/** + * :: DeveloperApi :: + * Registry of dialects that apply to every new jdbc [[org.apache.spark.sql.DataFrame]]. + * + * If multiple matching dialects are registered then all matching ones will be + * tried in reverse order. A user-added dialect will thus be applied first, + * overwriting the defaults. + * + * Note that all new dialects are applied to new jdbc DataFrames only. Make + * sure to register your dialects first. + */ +@DeveloperApi +object JdbcDialects { + + private var dialects = List[JdbcDialect]() + + /** + * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. + * Readding an existing dialect will cause a move-to-front. + * @param dialect The new dialect. + */ + def registerDialect(dialect: JdbcDialect) : Unit = { + dialects = dialect :: dialects.filterNot(_ == dialect) + } + + /** + * Unregister a dialect. Does nothing if the dialect is not registered. + * @param dialect The jdbc dialect. + */ + def unregisterDialect(dialect : JdbcDialect) : Unit = { + dialects = dialects.filterNot(_ == dialect) + } + + registerDialect(MySQLDialect) + registerDialect(PostgresDialect) + + /** + * Fetch the JdbcDialect class corresponding to a given database url. + */ + private[sql] def get(url: String): JdbcDialect = { + val matchingDialects = dialects.filter(_.canHandle(url)) + matchingDialects.length match { + case 0 => NoopDialect + case 1 => matchingDialects.head + case _ => new AggregatedDialect(matchingDialects) + } + } +} + +/** + * :: DeveloperApi :: + * AggregatedDialect can unify multiple dialects into one virtual Dialect. + * Dialects are tried in order, and the first dialect that does not return a + * neutral element will will. + * @param dialects List of dialects. + */ +@DeveloperApi +class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { + + require(!dialects.isEmpty) + + def canHandle(url : String): Boolean = + dialects.map(_.canHandle(url)).reduce(_ && _) + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + dialects.map(_.getCatalystType(sqlType, typeName, size, md)).flatten.headOption + + override def getJDBCType(dt: DataType): Option[JdbcType] = + dialects.map(_.getJDBCType(dt)).flatten.headOption + +} + +/** + * :: DeveloperApi :: + * NOOP dialect object, always returning the neutral element. + */ +@DeveloperApi +case object NoopDialect extends JdbcDialect { + def canHandle(url : String): Boolean = true +} + +/** + * :: DeveloperApi :: + * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write. + */ +@DeveloperApi +case object PostgresDialect extends JdbcDialect { + def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + Some(BinaryType) + } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { + Some(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("inet")) { + Some(StringType) + } else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case _ => None + } +} + +/** + * :: DeveloperApi :: + * Default mysql dialect to read bit/bitsets correctly. + */ +@DeveloperApi +case object MySQLDialect extends JdbcDialect { + def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + Some(LongType) + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + Some(BooleanType) + } else None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index a61790b8472c8a5da40890c194269e9e80a762b0..f21dd29aca37fc5acfc32f8726628e202b6b3f22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -129,25 +129,26 @@ package object jdbc { */ def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name - var typ: String = quirks.getJDBCType(field.dataType)._1 - if (typ == null) typ = field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case DecimalType.Unlimited => "DECIMAL(40,20)" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - } + val typ: String = + dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "BYTE" + case BooleanType => "BIT(1)" + case StringType => "TEXT" + case BinaryType => "BLOB" + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case DecimalType.Unlimited => "DECIMAL(40,20)" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + }) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") }} @@ -162,10 +163,9 @@ package object jdbc { url: String, table: String, properties: Properties = new Properties()) { - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => - val nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 - if (nullType.isEmpty) { + dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( field.dataType match { case IntegerType => java.sql.Types.INTEGER case LongType => java.sql.Types.BIGINT @@ -181,8 +181,7 @@ package object jdbc { case DecimalType.Unlimited => java.sql.Types.DECIMAL case _ => throw new IllegalArgumentException( s"Can't translate null value for field $field") - } - } else nullType.get + }) } val rddSchema = df.schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5a7b6f0aac6f71518efd89eb24ac65b6700eb49c..a8dddfb9b68584f16c07851e5372014855499ddf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -35,6 +35,13 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) + val testH2Dialect = new JdbcDialect { + def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + Some(StringType) + } + before { Class.forName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -353,4 +360,46 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { """.stripMargin.replaceAll("\n", " ")) } } + + test("Remap types via JdbcDialects") { + JdbcDialects.registerDialect(testH2Dialect) + val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + assert(df.schema.filter( + _.dataType != org.apache.spark.sql.types.StringType + ).isEmpty) + val rows = df.collect() + assert(rows(0).get(0).isInstanceOf[String]) + assert(rows(0).get(1).isInstanceOf[String]) + JdbcDialects.unregisterDialect(testH2Dialect) + } + + test("Default jdbc dialect registration") { + assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) + assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) + assert(JdbcDialects.get("test.invalid") == NoopDialect) + } + + test("Dialect unregister") { + JdbcDialects.registerDialect(testH2Dialect) + JdbcDialects.unregisterDialect(testH2Dialect) + assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect) + } + + test("Aggregated dialects") { + val agg = new AggregatedDialect(List(new JdbcDialect { + def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + if (sqlType % 2 == 0) { + Some(LongType) + } else { + None + } + }, testH2Dialect)) + assert(agg.canHandle("jdbc:h2:xxx")) + assert(!agg.canHandle("jdbc:h2")) + assert(agg.getCatalystType(0,"",1,null) == Some(LongType)) + assert(agg.getCatalystType(1,"",1,null) == Some(StringType)) + } + }