Skip to content
Snippets Groups Projects
Commit 2ba3ff04 authored by Michael Armbrust's avatar Michael Armbrust
Browse files

[SPARK-10216][SQL] Revert "[] Avoid creating empty files during overwrit…

This reverts commit 8d05a7a9 from #12855, which seems to have caused regressions when working with empty DataFrames.

Author: Michael Armbrust <michael@databricks.com>

Closes #13181 from marmbrus/revert12855.
parent dfa61f7b
No related branches found
No related tags found
No related merge requests found
......@@ -239,50 +239,48 @@ private[sql] class DefaultWriterContainer(
extends BaseWriterContainer(relation, job, isAppend) {
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
if (iterator.hasNext) {
executorSideSetup(taskContext)
val configuration = taskAttemptContext.getConfiguration
configuration.set("spark.sql.sources.output.path", outputPath)
var writer = newOutputWriter(getWorkPath)
writer.initConverter(dataSchema)
executorSideSetup(taskContext)
val configuration = taskAttemptContext.getConfiguration
configuration.set("spark.sql.sources.output.path", outputPath)
var writer = newOutputWriter(getWorkPath)
writer.initConverter(dataSchema)
// If anything below fails, we should abort the task.
try {
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iterator.hasNext) {
val internalRow = iterator.next()
writer.writeInternal(internalRow)
}
commitTask()
}(catchBlock = abortTask())
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
}
// If anything below fails, we should abort the task.
try {
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iterator.hasNext) {
val internalRow = iterator.next()
writer.writeInternal(internalRow)
}
commitTask()
}(catchBlock = abortTask())
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
}
def commitTask(): Unit = {
try {
if (writer != null) {
writer.close()
writer = null
}
super.commitTask()
} catch {
case cause: Throwable =>
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
// will cause `abortTask()` to be invoked.
throw new RuntimeException("Failed to commit task", cause)
def commitTask(): Unit = {
try {
if (writer != null) {
writer.close()
writer = null
}
super.commitTask()
} catch {
case cause: Throwable =>
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
// will cause `abortTask()` to be invoked.
throw new RuntimeException("Failed to commit task", cause)
}
}
def abortTask(): Unit = {
try {
if (writer != null) {
writer.close()
}
} finally {
super.abortTask()
def abortTask(): Unit = {
try {
if (writer != null) {
writer.close()
}
} finally {
super.abortTask()
}
}
}
......@@ -365,87 +363,84 @@ private[sql] class DynamicPartitionWriterContainer(
}
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
if (iterator.hasNext) {
executorSideSetup(taskContext)
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] =
partitionColumns ++ bucketIdExpression ++ sortColumns
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
val sortingKeySchema = StructType(sortingExpressions.map {
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id.
case _ => StructField("bucketId", IntegerType, nullable = false)
})
executorSideSetup(taskContext)
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
val sortingKeySchema = StructType(sortingExpressions.map {
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id.
case _ => StructField("bucketId", IntegerType, nullable = false)
})
// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
// Returns the partition path given a partition key.
val getPartitionString =
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)
while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
logInfo(s"Sorting complete. Writing out partition files one at a time.")
// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
// Returns the partition path given a partition key.
val getPartitionString =
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)
while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
logInfo(s"Sorting complete. Writing out partition files one at a time.")
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}
val sortedIterator = sorter.sortedIterator()
val sortedIterator = sorter.sortedIterator()
// If anything below fails, we should abort the task.
var currentWriter: OutputWriter = null
try {
Utils.tryWithSafeFinallyAndFailureCallbacks {
var currentKey: UnsafeRow = null
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")
currentWriter = newOutputWriter(currentKey, getPartitionString)
// If anything below fails, we should abort the task.
var currentWriter: OutputWriter = null
try {
Utils.tryWithSafeFinallyAndFailureCallbacks {
var currentKey: UnsafeRow = null
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentWriter.writeInternal(sortedIterator.getValue)
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")
commitTask()
}(catchBlock = {
if (currentWriter != null) {
currentWriter.close()
currentWriter = newOutputWriter(currentKey, getPartitionString)
}
abortTask()
})
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
}
currentWriter.writeInternal(sortedIterator.getValue)
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
commitTask()
}(catchBlock = {
if (currentWriter != null) {
currentWriter.close()
}
abortTask()
})
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
}
}
}
......@@ -178,21 +178,19 @@ private[hive] class SparkHiveWriterContainer(
// this function is executed on executor side
def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
if (iterator.hasNext) {
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
iterator.foreach { row =>
var i = 0
while (i < fieldOIs.length) {
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
i += 1
}
writer.write(serializer.serialize(outputData, standardOI))
}
val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
close()
iterator.foreach { row =>
var i = 0
while (i < fieldOIs.length) {
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
i += 1
}
writer.write(serializer.serialize(outputData, standardOI))
}
close()
}
}
......
......@@ -19,13 +19,13 @@ package org.apache.spark.sql.hive
import java.io.File
import org.apache.hadoop.hive.conf.HiveConf
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.{QueryTest, _}
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
......@@ -118,10 +118,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql(
s"""
|CREATE TABLE table_with_partition(c1 string)
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
|location '${tmpDir.toURI.toString}'
""".stripMargin)
|CREATE TABLE table_with_partition(c1 string)
|PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
|location '${tmpDir.toURI.toString}'
""".stripMargin)
sql(
"""
|INSERT OVERWRITE TABLE table_with_partition
......@@ -216,35 +216,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql("DROP TABLE hiveTableWithStructValue")
}
test("SPARK-10216: Avoid empty files during overwrite into Hive table with group by query") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
val testDataset = hiveContext.sparkContext.parallelize(
(1 to 2).map(i => TestData(i, i.toString))).toDF()
testDataset.createOrReplaceTempView("testDataset")
val tmpDir = Utils.createTempDir()
sql(
s"""
|CREATE TABLE table1(key int,value string)
|location '${tmpDir.toURI.toString}'
""".stripMargin)
sql(
"""
|INSERT OVERWRITE TABLE table1
|SELECT count(key), value FROM testDataset GROUP BY value
""".stripMargin)
val overwrittenFiles = tmpDir.listFiles()
.filter(f => f.isFile && !f.getName.endsWith(".crc"))
.sortBy(_.getName)
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
sql("DROP TABLE table1")
}
}
test("Reject partitioning that does not match table") {
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
......
......@@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources.{FileScanRDD, LocalityTestFileSystem}
import org.apache.spark.sql.execution.datasources.{FileScanRDD, HadoopFsRelation, LocalityTestFileSystem, LogicalRelation}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
......@@ -879,26 +879,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
}
}
}
test("SPARK-10216: Avoid empty files during overwriting with group by query") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
withTempPath { path =>
val df = spark.range(0, 5)
val groupedDF = df.groupBy("id").count()
groupedDF.write
.format(dataSourceName)
.mode(SaveMode.Overwrite)
.save(path.getCanonicalPath)
val overwrittenFiles = path.listFiles()
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
.sortBy(_.getName)
val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
}
}
}
}
// This class is used to test SPARK-8578. We should not use any custom output committer when
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment