Skip to content
Snippets Groups Projects
Commit 79636054 authored by Eric Liang's avatar Eric Liang Committed by Reynold Xin
Browse files

[SPARK-20148][SQL] Extend the file commit API to allow subscribing to task commit messages

## What changes were proposed in this pull request?

The internal FileCommitProtocol interface returns all task commit messages in bulk to the implementation when a job finishes. However, it is sometimes useful to access those messages before the job completes, so that the driver gets incremental progress updates before the job finishes.

This adds an `onTaskCommit` listener to the internal api.

## How was this patch tested?

Unit tests.

cc rxin

Author: Eric Liang <ekl@databricks.com>

Closes #17475 from ericl/file-commit-api-ext.
parent 60977889
No related branches found
No related tags found
No related merge requests found
...@@ -121,6 +121,13 @@ abstract class FileCommitProtocol { ...@@ -121,6 +121,13 @@ abstract class FileCommitProtocol {
def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = { def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = {
fs.delete(path, recursive) fs.delete(path, recursive)
} }
/**
* Called on the driver after a task commits. This can be used to access task commit messages
* before the job has finished. These same task commit messages will be passed to commitJob()
* if the entire job succeeds.
*/
def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {}
} }
......
...@@ -80,6 +80,9 @@ object FileFormatWriter extends Logging { ...@@ -80,6 +80,9 @@ object FileFormatWriter extends Logging {
""".stripMargin) """.stripMargin)
} }
/** The result of a successful write task. */
private case class WriteTaskResult(commitMsg: TaskCommitMessage, updatedPartitions: Set[String])
/** /**
* Basic work flow of this command is: * Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific * 1. Driver side setup, including output committer initialization and data source specific
...@@ -172,8 +175,9 @@ object FileFormatWriter extends Logging { ...@@ -172,8 +175,9 @@ object FileFormatWriter extends Logging {
global = false, global = false,
child = queryExecution.executedPlan).execute() child = queryExecution.executedPlan).execute()
} }
val ret = new Array[WriteTaskResult](rdd.partitions.length)
val ret = sparkSession.sparkContext.runJob(rdd, sparkSession.sparkContext.runJob(
rdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => { (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask( executeTask(
description = description, description = description,
...@@ -182,10 +186,16 @@ object FileFormatWriter extends Logging { ...@@ -182,10 +186,16 @@ object FileFormatWriter extends Logging {
sparkAttemptNumber = taskContext.attemptNumber(), sparkAttemptNumber = taskContext.attemptNumber(),
committer, committer,
iterator = iter) iterator = iter)
},
0 until rdd.partitions.length,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
}) })
val commitMsgs = ret.map(_._1) val commitMsgs = ret.map(_.commitMsg)
val updatedPartitions = ret.flatMap(_._2).distinct.map(PartitioningUtils.parsePathFragment) val updatedPartitions = ret.flatMap(_.updatedPartitions)
.distinct.map(PartitioningUtils.parsePathFragment)
committer.commitJob(job, commitMsgs) committer.commitJob(job, commitMsgs)
logInfo(s"Job ${job.getJobID} committed.") logInfo(s"Job ${job.getJobID} committed.")
...@@ -205,7 +215,7 @@ object FileFormatWriter extends Logging { ...@@ -205,7 +215,7 @@ object FileFormatWriter extends Logging {
sparkPartitionId: Int, sparkPartitionId: Int,
sparkAttemptNumber: Int, sparkAttemptNumber: Int,
committer: FileCommitProtocol, committer: FileCommitProtocol,
iterator: Iterator[InternalRow]): (TaskCommitMessage, Set[String]) = { iterator: Iterator[InternalRow]): WriteTaskResult = {
val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId) val jobId = SparkHadoopWriterUtils.createJobID(new Date, sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
...@@ -238,7 +248,7 @@ object FileFormatWriter extends Logging { ...@@ -238,7 +248,7 @@ object FileFormatWriter extends Logging {
// Execute the task to write rows out and commit the task. // Execute the task to write rows out and commit the task.
val outputPartitions = writeTask.execute(iterator) val outputPartitions = writeTask.execute(iterator)
writeTask.releaseResources() writeTask.releaseResources()
(committer.commitTask(taskAttemptContext), outputPartitions) WriteTaskResult(committer.commitTask(taskAttemptContext), outputPartitions)
})(catchBlock = { })(catchBlock = {
// If there is an error, release resource and then abort the task // If there is an error, release resource and then abort the task
try { try {
......
...@@ -18,9 +18,12 @@ ...@@ -18,9 +18,12 @@
package org.apache.spark.sql.test package org.apache.spark.sql.test
import java.io.File import java.io.File
import java.util.concurrent.ConcurrentLinkedQueue
import org.scalatest.BeforeAndAfter import org.scalatest.BeforeAndAfter
import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources._
...@@ -41,7 +44,6 @@ object LastOptions { ...@@ -41,7 +44,6 @@ object LastOptions {
} }
} }
/** Dummy provider. */ /** Dummy provider. */
class DefaultSource class DefaultSource
extends RelationProvider extends RelationProvider
...@@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema ...@@ -107,6 +109,20 @@ class DefaultSourceWithoutUserSpecifiedSchema
} }
} }
object MessageCapturingCommitProtocol {
val commitMessages = new ConcurrentLinkedQueue[TaskCommitMessage]()
}
class MessageCapturingCommitProtocol(jobId: String, path: String)
extends HadoopMapReduceCommitProtocol(jobId, path) {
// captures commit messages for testing
override def onTaskCommit(msg: TaskCommitMessage): Unit = {
MessageCapturingCommitProtocol.commitMessages.offer(msg)
}
}
class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._ import testImplicits._
...@@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be ...@@ -291,6 +307,19 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) Option(dir).map(spark.read.format("org.apache.spark.sql.test").load)
} }
test("write path implements onTaskCommit API correctly") {
withSQLConf(
"spark.sql.sources.commitProtocolClass" ->
classOf[MessageCapturingCommitProtocol].getCanonicalName) {
withTempDir { dir =>
val path = dir.getCanonicalPath
MessageCapturingCommitProtocol.commitMessages.clear()
spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
assert(MessageCapturingCommitProtocol.commitMessages.size() == 10)
}
}
}
test("read a data source that does not extend SchemaRelationProvider") { test("read a data source that does not extend SchemaRelationProvider") {
val dfReader = spark.read val dfReader = spark.read
.option("from", "1") .option("from", "1")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment