diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 7432a1538ce97eed2bed4112ca4686c4e5fbfe23..1419d69f983ab47f3d7cf98efd4c412733d1c0c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -43,6 +43,12 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect } override def isCascadingTruncateTable(): Option[Boolean] = { - dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) + // If any dialect claims cascading truncate, this dialect is also cascading truncate. + // Otherwise, if any dialect has unknown cascading truncate, this dialect is also unknown. + dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) match { + case Some(true) => Some(true) + case _ if dialects.exists(_.isCascadingTruncateTable().isEmpty) => None + case _ => Some(false) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index fd12bb9e530b833f9947cf08a029db72157384a5..34205e0b2bf08092f6c3f6ab06b2a2f4cf6299d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -749,6 +749,31 @@ class JDBCSuite extends SparkFunSuite assert(agg.isCascadingTruncateTable() === Some(true)) } + test("Aggregated dialects: isCascadingTruncateTable") { + def genDialect(cascadingTruncateTable: Option[Boolean]): JdbcDialect = new JdbcDialect { + override def canHandle(url: String): Boolean = true + override def getCatalystType( + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = None + override def isCascadingTruncateTable(): Option[Boolean] = cascadingTruncateTable + } + + def testDialects(cascadings: List[Option[Boolean]], expected: Option[Boolean]): Unit = { + val dialects = cascadings.map(genDialect(_)) + val agg = new AggregatedDialect(dialects) + assert(agg.isCascadingTruncateTable() === expected) + } + + testDialects(List(Some(true), Some(false), None), Some(true)) + testDialects(List(Some(true), Some(true), None), Some(true)) + testDialects(List(Some(false), Some(false), None), None) + testDialects(List(Some(true), Some(true)), Some(true)) + testDialects(List(Some(false), Some(false)), Some(false)) + testDialects(List(None, None), None) + } + test("DB2Dialect type mapping") { val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB")