diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 7e12bbb2128bfb9f97f5e3d1c90161440430843e..3b064a5bc489f222a1be3fb35bf58cf07e27ea3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -239,50 +239,48 @@ private[sql] class DefaultWriterContainer( extends BaseWriterContainer(relation, job, isAppend) { def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - if (iterator.hasNext) { - executorSideSetup(taskContext) - val configuration = taskAttemptContext.getConfiguration - configuration.set("spark.sql.sources.output.path", outputPath) - var writer = newOutputWriter(getWorkPath) - writer.initConverter(dataSchema) + executorSideSetup(taskContext) + val configuration = taskAttemptContext.getConfiguration + configuration.set("spark.sql.sources.output.path", outputPath) + var writer = newOutputWriter(getWorkPath) + writer.initConverter(dataSchema) - // If anything below fails, we should abort the task. - try { - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iterator.hasNext) { - val internalRow = iterator.next() - writer.writeInternal(internalRow) - } - commitTask() - }(catchBlock = abortTask()) - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) - } + // If anything below fails, we should abort the task. + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val internalRow = iterator.next() + writer.writeInternal(internalRow) + } + commitTask() + }(catchBlock = abortTask()) + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } - def commitTask(): Unit = { - try { - if (writer != null) { - writer.close() - writer = null - } - super.commitTask() - } catch { - case cause: Throwable => - // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and - // will cause `abortTask()` to be invoked. - throw new RuntimeException("Failed to commit task", cause) + def commitTask(): Unit = { + try { + if (writer != null) { + writer.close() + writer = null } + super.commitTask() + } catch { + case cause: Throwable => + // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and + // will cause `abortTask()` to be invoked. + throw new RuntimeException("Failed to commit task", cause) } + } - def abortTask(): Unit = { - try { - if (writer != null) { - writer.close() - } - } finally { - super.abortTask() + def abortTask(): Unit = { + try { + if (writer != null) { + writer.close() } + } finally { + super.abortTask() } } } @@ -365,87 +363,84 @@ private[sql] class DynamicPartitionWriterContainer( } def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - if (iterator.hasNext) { - executorSideSetup(taskContext) - - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val sortingExpressions: Seq[Expression] = - partitionColumns ++ bucketIdExpression ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) - - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) - }) + executorSideSetup(taskContext) + + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns + val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema) + + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) + }) + + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + + // Returns the partition path given a partition key. + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) + + // Sorts the data before write, so that we only need one writer at the same time. + // TODO: inject a local sort operator in planning. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logInfo(s"Sorting complete. Writing out partition files one at a time.") - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) - - // Returns the partition path given a partition key. - val getPartitionString = - UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - - // Sorts the data before write, so that we only need one writer at the same time. - // TODO: inject a local sort operator in planning. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } - val sortedIterator = sorter.sortedIterator() + val sortedIterator = sorter.sortedIterator() - // If anything below fails, we should abort the task. - var currentWriter: OutputWriter = null - try { - Utils.tryWithSafeFinallyAndFailureCallbacks { - var currentKey: UnsafeRow = null - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) + // If anything below fails, we should abort the task. + var currentWriter: OutputWriter = null + try { + Utils.tryWithSafeFinallyAndFailureCallbacks { + var currentKey: UnsafeRow = null + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } - currentWriter.writeInternal(sortedIterator.getValue) - } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") - commitTask() - }(catchBlock = { - if (currentWriter != null) { - currentWriter.close() + currentWriter = newOutputWriter(currentKey, getPartitionString) } - abortTask() - }) - } catch { - case t: Throwable => - throw new SparkException("Task failed while writing rows", t) - } + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + + commitTask() + }(catchBlock = { + if (currentWriter != null) { + currentWriter.close() + } + abortTask() + }) + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 706fdbc2604fecf8e58b68ee61c58533b950cb3c..794fe264ead5dbdb9d0b51fc79d7287071fc06ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -178,21 +178,19 @@ private[hive] class SparkHiveWriterContainer( // this function is executed on executor side def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = { - if (iterator.hasNext) { - val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() - executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) - - iterator.foreach { row => - var i = 0 - while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) - i += 1 - } - writer.write(serializer.serialize(outputData, standardOI)) - } + val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite() + executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) - close() + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + i += 1 + } + writer.write(serializer.serialize(outputData, standardOI)) } + + close() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index b25684562075cda8d1705789d45360a271cce128..4a55bcc3b19d0d136b65fa455efdbf85c7ef2116 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException -import org.apache.spark.sql._ +import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -118,10 +118,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql( s""" - |CREATE TABLE table_with_partition(c1 string) - |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) - |location '${tmpDir.toURI.toString}' - """.stripMargin) + |CREATE TABLE table_with_partition(c1 string) + |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) sql( """ |INSERT OVERWRITE TABLE table_with_partition @@ -216,35 +216,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE hiveTableWithStructValue") } - test("SPARK-10216: Avoid empty files during overwrite into Hive table with group by query") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - val testDataset = hiveContext.sparkContext.parallelize( - (1 to 2).map(i => TestData(i, i.toString))).toDF() - testDataset.createOrReplaceTempView("testDataset") - - val tmpDir = Utils.createTempDir() - sql( - s""" - |CREATE TABLE table1(key int,value string) - |location '${tmpDir.toURI.toString}' - """.stripMargin) - sql( - """ - |INSERT OVERWRITE TABLE table1 - |SELECT count(key), value FROM testDataset GROUP BY value - """.stripMargin) - - val overwrittenFiles = tmpDir.listFiles() - .filter(f => f.isFile && !f.getName.endsWith(".crc")) - .sortBy(_.getName) - val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0) - - assert(overwrittenFiles === overwrittenFilesWithoutEmpty) - - sql("DROP TABLE table1") - } - } - test("Reject partitioning that does not match table") { withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index a3183f2977223132063f69bed9ef44c64c9c6f61..62998572eaf4b31971816a693665ae315bbea5d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.DataSourceScanExec -import org.apache.spark.sql.execution.datasources.{FileScanRDD, LocalityTestFileSystem} +import org.apache.spark.sql.execution.datasources.{FileScanRDD, HadoopFsRelation, LocalityTestFileSystem, LogicalRelation} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -879,26 +879,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } } - - test("SPARK-10216: Avoid empty files during overwriting with group by query") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - withTempPath { path => - val df = spark.range(0, 5) - val groupedDF = df.groupBy("id").count() - groupedDF.write - .format(dataSourceName) - .mode(SaveMode.Overwrite) - .save(path.getCanonicalPath) - - val overwrittenFiles = path.listFiles() - .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_")) - .sortBy(_.getName) - val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0) - - assert(overwrittenFiles === overwrittenFilesWithoutEmpty) - } - } - } } // This class is used to test SPARK-8578. We should not use any custom output committer when