diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ba3e55fc061a74442960ff8a7aa13a909f7ce92a..656e7ecdab0bbf60cc404b995dc927e27289b7b8 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1086,6 +1086,13 @@ the following case-sensitive options: </td> </tr> + <tr> + <td><code>maxConnections</code></td> + <td> + The maximum number of concurrent JDBC connections that can be used, if set. Only applies when writing. It works by limiting the operation's parallelism, which depends on the input's partition count. If its partition count exceeds this limit, the operation will coalesce the input to fewer partitions before writing. + </td> + </tr> + <tr> <td><code>isolationLevel</code></td> <td> diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 7f419b5788c4fe2ebc70caf30616373128af467a..d416eec6ddaec1eec8b7d0dc7b305a00a78d9e90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -122,6 +122,11 @@ class JDBCOptions( case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE } + // the maximum number of connections + val maxConnections = parameters.get(JDBC_MAX_CONNECTIONS).map(_.toInt) + require(maxConnections.isEmpty || maxConnections.get > 0, + s"Invalid value `${maxConnections.get}` for parameter `$JDBC_MAX_CONNECTIONS`. " + + "The minimum value is 1.") } object JDBCOptions { @@ -144,4 +149,5 @@ object JDBCOptions { val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") + val JDBC_MAX_CONNECTIONS = newOption("maxConnections") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 41edb6511c2ce3dd8bf090c5ea044563df3f5ae5..cdc3c99daa1ab3774e5b8acb765edb5259c01062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -667,7 +667,14 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize val isolationLevel = options.isolationLevel - df.foreachPartition(iterator => savePartition( + val maxConnections = options.maxConnections + val repartitionedDF = + if (maxConnections.isDefined && maxConnections.get < df.rdd.getNumPartitions) { + df.coalesce(maxConnections.get) + } else { + df + } + repartitionedDF.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e3d3c6c3a887c3d765c8cd6899f9ab06f1073da7..5795b4d860cb179a39c40e234d1f05d440f70763 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -312,4 +312,16 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .options(properties.asScala) .save() } + + test("SPARK-18413: Add `maxConnections` JDBCOption") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val e = intercept[IllegalArgumentException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .option(s"${JDBCOptions.JDBC_MAX_CONNECTIONS}", "0") + .save() + }.getMessage + assert(e.contains("Invalid value `0` for parameter `maxConnections`. The minimum value is 1")) + } }