diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index cbc3569795d9704516df17e587f7b80ee5e74ce3..394d1a04e09c3136607e73e6b512ec8c2ead71fe 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1370,9 +1370,8 @@ test_that("column functions", { # passing option df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) schema2 <- structType(structField("date", "date")) - expect_error(tryCatch(collect(select(df, from_json(df$col, schema2))), - error = function(e) { stop(e) }), - paste0(".*(java.lang.NumberFormatException: For input string:).*")) + s <- collect(select(df, from_json(df$col, schema2))) + expect_equal(s[[1]][[1]], NA) s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy"))) expect_is(s[[1]][[1]]$date, "Date") expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e4e08a8665a5a558588f75db16c9e2647c224fda..08af5522d822d0cbbcfcb58c1a20d10d0f84f43c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -583,7 +583,7 @@ case class JsonToStructs( CreateJacksonParser.utf8String, identity[UTF8String])) } catch { - case _: SparkSQLJsonProcessingException => null + case _: BadRecordException => null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5f222ec602c99af531d9e224bbe77011a99e24d4..355c26afa6f0d0c3209e388316f8a29678cca7de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -65,7 +65,7 @@ private[sql] class JSONOptions( val allowBackslashEscapingAnyCharacter = parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 9b80c0fc87c9384506f868fee3c3add22e80c886..fdb7d88d5bd7f52d31b9f0ffe042f2ec62928cc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -32,17 +32,14 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg) - /** * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( schema: StructType, - options: JSONOptions) extends Logging { + val options: JSONOptions) extends Logging { import JacksonUtils._ - import ParseModes._ import com.fasterxml.jackson.core.JsonToken._ // A `ValueConverter` is responsible for converting a value from `JsonParser` @@ -55,108 +52,6 @@ class JacksonParser( private val factory = new JsonFactory() options.setJacksonOptions(factory) - private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) - - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - @transient - private[this] var isWarningPrinted: Boolean = false - - @transient - private def printWarningForMalformedRecord(record: () => UTF8String): Unit = { - def sampleRecord: String = { - if (options.wholeFile) { - "" - } else { - s"Sample record: ${record()}\n" - } - } - - def footer: String = { - s"""Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}. - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin - } - - if (options.permissive) { - logWarning( - s"""Found at least one malformed record. The JSON reader will replace - |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. - |To find out which corrupted records have been replaced with null, please use the - |default inferred schema instead of providing a custom schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } else if (options.dropMalformed) { - logWarning( - s"""Found at least one malformed record. The JSON reader will drop - |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which - |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE - |mode and use the default inferred schema. - | - |${sampleRecord ++ footer} - | - """.stripMargin) - } - } - - @transient - private def printWarningIfWholeFile(): Unit = { - if (options.wholeFile && corruptFieldIndex.isDefined) { - logWarning( - s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result - |in very large allocations or OutOfMemoryExceptions being raised. - | - """.stripMargin) - } - } - - /** - * This function deals with the cases it fails to parse. This function will be called - * when exceptions are caught during converting. This functions also deals with `mode` option. - */ - private def failedRecord(record: () => UTF8String): Seq[InternalRow] = { - corruptFieldIndex match { - case _ if options.failFast => - if (options.wholeFile) { - throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode") - } else { - throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}") - } - - case _ if options.dropMalformed => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - Nil - - case None => - if (!isWarningPrinted) { - printWarningForMalformedRecord(record) - isWarningPrinted = true - } - emptyRow - - case Some(corruptIndex) => - if (!isWarningPrinted) { - printWarningIfWholeFile() - isWarningPrinted = true - } - val row = new GenericInternalRow(schema.length) - row.update(corruptIndex, record()) - Seq(row) - } - } - /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. This is a wrapper for the method @@ -239,7 +134,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toFloat } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.") + throw new RuntimeException(s"Cannot parse $value as FloatType.") } } @@ -259,7 +154,7 @@ class JacksonParser( lowerCaseValue.equals("-inf")) { value.toDouble } else { - throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.") + throw new RuntimeException(s"Cannot parse $value as DoubleType.") } } @@ -391,9 +286,8 @@ class JacksonParser( case token => // We cannot parse this token based on the given data type. So, we throw a - // SparkSQLJsonProcessingException and this exception will be caught by - // `parse` method. - throw new SparkSQLJsonProcessingException( + // RuntimeException and this exception will be caught by `parse` method. + throw new RuntimeException( s"Failed to parse a value for data type $dataType (current token: $token).") } @@ -466,14 +360,14 @@ class JacksonParser( parser.nextToken() match { case null => Nil case _ => rootConverter.apply(parser) match { - case null => throw new SparkSQLJsonProcessingException("Root converter returned null") + case null => throw new RuntimeException("Root converter returned null") case rows => rows } } } } catch { - case _: JsonProcessingException | _: SparkSQLJsonProcessingException => - failedRecord(() => recordLiteral(record)) + case e @ (_: RuntimeException | _: JsonProcessingException) => + throw BadRecordException(() => recordLiteral(record), () => None, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala new file mode 100644 index 0000000000000000000000000000000000000000..e8da10d65ecb9c884bc8a7815714ca5fdf18657c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +class FailureSafeParser[IN]( + rawParser: IN => Seq[InternalRow], + mode: String, + schema: StructType, + columnNameOfCorruptRecord: String) { + + private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) + private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) + private val resultRow = new GenericInternalRow(schema.length) + private val nullResult = new GenericInternalRow(schema.length) + + // This function takes 2 parameters: an optional partial result, and the bad record. If the given + // schema doesn't contain a field for corrupted record, we just return the partial result or a + // row with all fields null. If the given schema contains a field for corrupted record, we will + // set the bad record to this field, and set other fields according to the partial result or null. + private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = { + if (corruptFieldIndex.isDefined) { + (row, badRecord) => { + var i = 0 + while (i < actualSchema.length) { + val from = actualSchema(i) + resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull + i += 1 + } + resultRow(corruptFieldIndex.get) = badRecord() + resultRow + } + } else { + (row, _) => row.getOrElse(nullResult) + } + } + + def parse(input: IN): Iterator[InternalRow] = { + try { + rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null)) + } catch { + case e: BadRecordException if ParseModes.isPermissiveMode(mode) => + Iterator(toResultRow(e.partialResult(), e.record)) + case _: BadRecordException if ParseModes.isDropMalformedMode(mode) => + Iterator.empty + case e: BadRecordException => throw e.cause + } + } +} + +/** + * Exception thrown when the underlying parser meet a bad record and can't parse it. + * @param record a function to return the record that cause the parser to fail + * @param partialResult a function that returns an optional row, which is the partial result of + * parsing this bad record. + * @param cause the actual exception about why the record is bad and can't be parsed. + */ +case class BadRecordException( + record: () => UTF8String, + partialResult: () => Option[InternalRow], + cause: Throwable) extends Exception(cause) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 88fbfb4c92a000ed268b8ba62bfc08edaaf1ef52..767a636d707319019c3e6c0acbe0300b55735fe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -27,6 +27,7 @@ import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.csv._ @@ -382,11 +383,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => - val parser = new JacksonParser(schema, parsedOptions) - iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) + val rawParser = new JacksonParser(actualSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => rawParser.parse(input, createParser, UTF8String.fromString), + parsedOptions.parseMode, + schema, + parsedOptions.columnNameOfCorruptRecord) + iter.flatMap(parser.parse) } Dataset.ofRows( @@ -435,14 +443,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val actualSchema = + StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine => filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions)) }.getOrElse(filteredLines.rdd) val parsed = linesWithoutHeader.mapPartitions { iter => - val parser = new UnivocityParser(schema, parsedOptions) - iter.flatMap(line => parser.parse(line)) + val rawParser = new UnivocityParser(actualSchema, parsedOptions) + val parser = new FailureSafeParser[String]( + input => Seq(rawParser.parse(input)), + parsedOptions.parseMode, + schema, + parsedOptions.columnNameOfCorruptRecord) + iter.flatMap(parser.parse) } Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 35ff924f27ce5745711356a3665e363e7c2a0ba6..63af18ec5b8eba3757926fc884ce7ad11b680e65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -49,7 +49,7 @@ abstract class CSVDataSource extends Serializable { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] + schema: StructType): Iterator[InternalRow] /** * Infers the schema from `inputPaths` files. @@ -115,17 +115,17 @@ object TextInputCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { val lines = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) linesReader.map { line => - new String(line.getBytes, 0, line.getLength, parsedOptions.charset) + new String(line.getBytes, 0, line.getLength, parser.options.charset) } } - val shouldDropHeader = parsedOptions.headerFlag && file.start == 0 - UnivocityParser.parseIterator(lines, shouldDropHeader, parser) + val shouldDropHeader = parser.options.headerFlag && file.start == 0 + UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema) } override def infer( @@ -192,11 +192,12 @@ object WholeFileCSVDataSource extends CSVDataSource { conf: Configuration, file: PartitionedFile, parser: UnivocityParser, - parsedOptions: CSVOptions): Iterator[InternalRow] = { + schema: StructType): Iterator[InternalRow] = { UnivocityParser.parseStream( CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - parsedOptions.headerFlag, - parser) + parser.options.headerFlag, + parser, + schema) } override def infer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 29c41455279e6237c2640c3bac5fb3994d0de93a..eef43c7629c12f3c8d7891ef5b98e5fa4cb9576c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -113,8 +113,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { (file: PartitionedFile) => { val conf = broadcastedHadoopConf.value.value - val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions) - CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions) + val parser = new UnivocityParser( + StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), + parsedOptions) + CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 2632e87971d6865fef5566532898a7f4be535320..f6c6b6f56cd9d301d0c362054e4e72da88024d99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -82,7 +82,7 @@ class CSVOptions( val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) - private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val charset = parameters.getOrElse("encoding", parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index e42ea3fa391f5c39606a18d70460a953575e14b3..263f77e11c4da9f9022e5719e5690069c492ced7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -30,14 +30,14 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( schema: StructType, requiredSchema: StructType, - private val options: CSVOptions) extends Logging { + val options: CSVOptions) extends Logging { require(requiredSchema.toSet.subsetOf(schema.toSet), "requiredSchema should be the subset of schema.") @@ -46,39 +46,26 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) - corruptFieldIndex.foreach { corrFieldIndex => - require(schema(corrFieldIndex).dataType == StringType) - require(schema(corrFieldIndex).nullable) - } - - private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) - private val tokenizer = new CsvParser(options.asParserSettings) - private var numMalformedRecords = 0 - private val row = new GenericInternalRow(requiredSchema.length) - // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field - // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method. - private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd + // Retrieve the raw record string. + private def getCurrentInput: UTF8String = { + UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd) + } - // This parser loads an `tokenIndexArr`-th position value in input tokens, - // then put the value in `row(rowIndexArr)`. + // This parser first picks some tokens from the input tokens, according to the required schema, + // then parse these tokens and put the values in a row, with the order specified by the required + // schema. // // For example, let's say there is CSV data as below: // // a,b,c // 1,2,A // - // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true` - // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need - // to map those values below: - // - // required schema - ["c", "b", "_unparsed", "a"] - // CSV data schema - ["a", "b", "c"] - // required CSV data schema - ["c", "b", "a"] + // So the CSV data schema is: ["a", "b", "c"] + // And let's say the required schema is: ["c", "b"] // // with the input tokens, // @@ -86,45 +73,12 @@ class UnivocityParser( // // Each input token is placed in each output row's position by mapping these. In this case, // - // output row - ["A", 2, null, 1] - // - // In more details, - // - `valueConverters`, input tokens - CSV data schema - // `valueConverters` keeps the positions of input token indices (by its index) to each - // value's converter (by its value) in an order of CSV data schema. In this case, - // [string->int, string->int, string->string]. - // - // - `tokenIndexArr`, input tokens - required CSV data schema - // `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered - // fields given the required CSV data schema (by its value). In this case, [2, 1, 0]. - // - // - `rowIndexArr`, input tokens - required schema - // `rowIndexArr` keeps the positions of input token indices (by its index) to reordered - // field indices given the required schema (by its value). In this case, [0, 1, 3]. + // output row - ["A", 2] private val valueConverters: Array[ValueConverter] = - dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means - // the fields that we should try to convert. - private val reorderedFields = if (options.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) - } else { - requiredSchema - } + schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray private val tokenIndexArr: Array[Int] = { - reorderedFields - .filter(_.name != options.columnNameOfCorruptRecord) - .map(f => dataSchema.indexOf(f)).toArray - } - - private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) { - val corrFieldIndex = corruptFieldIndex.get - reorderedFields.indices.filter(_ != corrFieldIndex).toArray - } else { - reorderedFields.indices.toArray + requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -205,7 +159,7 @@ class UnivocityParser( } case _: StringType => (d: String) => - nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_)) + nullSafeDatum(d, name, nullable, options)(UTF8String.fromString) case udt: UserDefinedType[_] => (datum: String) => makeConverter(name, udt.sqlType, nullable, options) @@ -233,81 +187,41 @@ class UnivocityParser( * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). */ - def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input)) - - private def convert(tokens: Array[String]): Option[InternalRow] = { - convertWithParseMode(tokens) { tokens => - var i: Int = 0 - while (i < tokenIndexArr.length) { - // It anyway needs to try to parse since it decides if this row is malformed - // or not after trying to cast in `DROPMALFORMED` mode even if the casted - // value is not stored in the row. - val from = tokenIndexArr(i) - val to = rowIndexArr(i) - val value = valueConverters(from).apply(tokens(from)) - if (i < requiredSchema.length) { - row(to) = value - } - i += 1 - } - row - } - } - - private def convertWithParseMode( - tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = { - if (options.dropMalformed && dataSchema.length != tokens.length) { - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") + def parse(input: String): InternalRow = convert(tokenizer.parseLine(input)) + + private def convert(tokens: Array[String]): InternalRow = { + if (tokens.length != schema.length) { + // If the number of tokens doesn't match the schema, we should treat it as a malformed record. + // However, we still have chance to parse some of the tokens, by adding extra null tokens in + // the tail if the number is smaller, or by dropping extra tokens if the number is larger. + val checkedTokens = if (schema.length > tokens.length) { + tokens ++ new Array[String](schema.length - tokens.length) + } else { + tokens.take(schema.length) } - numMalformedRecords += 1 - None - } else if (options.failFast && dataSchema.length != tokens.length) { - throw new RuntimeException(s"Malformed line in FAILFAST mode: " + - s"${tokens.mkString(options.delimiter.toString)}") - } else { - // If a length of parsed tokens is not equal to expected one, it makes the length the same - // with the expected. If the length is shorter, it adds extra tokens in the tail. - // If longer, it drops extra tokens. - // - // TODO: Revisit this; if a length of tokens does not match an expected length in the schema, - // we probably need to treat it as a malformed record. - // See an URL below for related discussions: - // https://github.com/apache/spark/pull/16928#discussion_r102657214 - val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) { - if (dataSchema.length > tokens.length) { - tokens ++ new Array[String](dataSchema.length - tokens.length) - } else { - tokens.take(dataSchema.length) + def getPartialResult(): Option[InternalRow] = { + try { + Some(convert(checkedTokens)) + } catch { + case _: BadRecordException => None } - } else { - tokens } - + throw BadRecordException( + () => getCurrentInput, + getPartialResult, + new RuntimeException("Malformed CSV record")) + } else { try { - Some(convert(checkedTokens)) + var i = 0 + while (i < requiredSchema.length) { + val from = tokenIndexArr(i) + row(i) = valueConverters(from).apply(tokens(from)) + i += 1 + } + row } catch { - case NonFatal(e) if options.permissive => - val row = new GenericInternalRow(requiredSchema.length) - corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput())) - Some(row) - case NonFatal(e) if options.dropMalformed => - if (numMalformedRecords < options.maxMalformedLogPerPartition) { - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}") - } - if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) { - logWarning( - s"More than ${options.maxMalformedLogPerPartition} malformed records have been " + - "found on this partition. Malformed records from now on will not be logged.") - } - numMalformedRecords += 1 - None + case NonFatal(e) => + throw BadRecordException(() => getCurrentInput, () => None, e) } } } @@ -331,10 +245,16 @@ private[csv] object UnivocityParser { def parseStream( inputStream: InputStream, shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + schema: StructType): Iterator[InternalRow] = { val tokenizer = parser.tokenizer + val safeParser = new FailureSafeParser[Array[String]]( + input => Seq(parser.convert(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) convertStream(inputStream, shouldDropHeader, tokenizer) { tokens => - parser.convert(tokens) + safeParser.parse(tokens) }.flatten } @@ -368,7 +288,8 @@ private[csv] object UnivocityParser { def parseIterator( lines: Iterator[String], shouldDropHeader: Boolean, - parser: UnivocityParser): Iterator[InternalRow] = { + parser: UnivocityParser, + schema: StructType): Iterator[InternalRow] = { val options = parser.options val linesWithoutHeader = if (shouldDropHeader) { @@ -381,6 +302,12 @@ private[csv] object UnivocityParser { val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options) - filteredLines.flatMap(line => parser.parse(line)) + + val safeParser = new FailureSafeParser[String]( + input => Seq(parser.parse(input)), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + filteredLines.flatMap(safeParser.parse) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 84f026620d9070c0aedf9075d57a4fb03a780881..51e952c12202e1e50bffe920e187eac2f3d14a4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.json +import java.io.InputStream + import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration @@ -31,6 +33,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -49,7 +52,8 @@ abstract class JsonDataSource extends Serializable { def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] final def inferSchema( sparkSession: SparkSession, @@ -127,10 +131,16 @@ object TextInputJsonDataSource extends JsonDataSource { override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String)) + val safeParser = new FailureSafeParser[Text]( + input => parser.parse(input, CreateJacksonParser.text, textToUTF8String), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + linesReader.flatMap(safeParser.parse) } private def textToUTF8String(value: Text): UTF8String = { @@ -180,7 +190,8 @@ object WholeFileJsonDataSource extends JsonDataSource { override def readFile( conf: Configuration, file: PartitionedFile, - parser: JacksonParser): Iterator[InternalRow] = { + parser: JacksonParser, + schema: StructType): Iterator[InternalRow] = { def partitionedFileString(ignored: Any): UTF8String = { Utils.tryWithResource { CodecStreams.createInputStreamWithCloseResource(conf, file.filePath) @@ -189,9 +200,13 @@ object WholeFileJsonDataSource extends JsonDataSource { } } - parser.parse( - CodecStreams.createInputStreamWithCloseResource(conf, file.filePath), - CreateJacksonParser.inputStream, - partitionedFileString).toIterator + val safeParser = new FailureSafeParser[InputStream]( + input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString), + parser.options.parseMode, + schema, + parser.options.columnNameOfCorruptRecord) + + safeParser.parse( + CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a9dd91eba6f72c0f4b807c0acaacef7d24276c74..53d62d88b04c68cf177856527f84de44a9d59ae2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -102,6 +102,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val actualSchema = + StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) // Check a field requirement for corrupt records here to throw an exception in a driver side dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex => val f = dataSchema(corruptFieldIndex) @@ -112,11 +114,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { } (file: PartitionedFile) => { - val parser = new JacksonParser(requiredSchema, parsedOptions) + val parser = new JacksonParser(actualSchema, parsedOptions) JsonDataSource(parsedOptions).readFile( broadcastedHadoopConf.value.value, file, - parser) + parser, + requiredSchema) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 95dfdf5b298e6c77e751187b62714890d2d609e1..598babfe0e7ade9942f964e3b8d4bd4183c6b2f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -293,7 +293,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .load(testFile(carsFile)).collect() } - assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt")) + assert(exception.getMessage.contains("Malformed CSV record")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9b0efcbdaf5c37fa6ee508c60583cdad1a95fe10..56fcf773f7dd99946f98d3736a787532b4e6c118 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1043,7 +1043,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionOne.getMessage.contains("JsonParseException")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1052,7 +1052,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(corruptRecords) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode: {")) + assert(exceptionTwo.getMessage.contains("JsonParseException")) } test("Corrupt records: DROPMALFORMED mode") { @@ -1929,7 +1929,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionOne.getMessage.contains("Failed to parse a value")) val exceptionTwo = intercept[SparkException] { spark.read @@ -1939,7 +1939,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .json(path) .collect() } - assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) + assert(exceptionTwo.getMessage.contains("Failed to parse a value")) } }