diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index ea843a10137f260977e80cabc56821f28e40dbfd..a26a8084b63bfff4d2b03abb3f9e58c7fe7b06c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.text.NumberFormat +import java.text.{NumberFormat, SimpleDateFormat} import java.util.Locale import scala.util.control.Exception._ @@ -41,11 +41,10 @@ private[csv] object CSVInferSchema { def infer( tokenRdd: RDD[Array[String]], header: Array[String], - nullValue: String = ""): StructType = { - + options: CSVOptions): StructType = { val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) val rootTypes: Array[DataType] = - tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) + tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes) val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => val dType = rootType match { @@ -58,11 +57,11 @@ private[csv] object CSVInferSchema { StructType(structFields) } - private def inferRowType(nullValue: String) + private def inferRowType(options: CSVOptions) (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. - rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue) + rowSoFar(i) = inferField(rowSoFar(i), next(i), options) i+=1 } rowSoFar @@ -78,17 +77,17 @@ private[csv] object CSVInferSchema { * Infer type of string field. Given known type Double, and a string "1", there is no * point checking if it is an Int, as the final type must be Double or higher. */ - def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { - if (field == null || field.isEmpty || field == nullValue) { + def inferField(typeSoFar: DataType, field: String, options: CSVOptions): DataType = { + if (field == null || field.isEmpty || field == options.nullValue) { typeSoFar } else { typeSoFar match { - case NullType => tryParseInteger(field) - case IntegerType => tryParseInteger(field) - case LongType => tryParseLong(field) - case DoubleType => tryParseDouble(field) - case TimestampType => tryParseTimestamp(field) - case BooleanType => tryParseBoolean(field) + case NullType => tryParseInteger(field, options) + case IntegerType => tryParseInteger(field, options) + case LongType => tryParseLong(field, options) + case DoubleType => tryParseDouble(field, options) + case TimestampType => tryParseTimestamp(field, options) + case BooleanType => tryParseBoolean(field, options) case StringType => StringType case other: DataType => throw new UnsupportedOperationException(s"Unexpected data type $other") @@ -96,35 +95,49 @@ private[csv] object CSVInferSchema { } } - private def tryParseInteger(field: String): DataType = if ((allCatch opt field.toInt).isDefined) { - IntegerType - } else { - tryParseLong(field) + private def tryParseInteger(field: String, options: CSVOptions): DataType = { + if ((allCatch opt field.toInt).isDefined) { + IntegerType + } else { + tryParseLong(field, options) + } } - private def tryParseLong(field: String): DataType = if ((allCatch opt field.toLong).isDefined) { - LongType - } else { - tryParseDouble(field) + private def tryParseLong(field: String, options: CSVOptions): DataType = { + if ((allCatch opt field.toLong).isDefined) { + LongType + } else { + tryParseDouble(field, options) + } } - private def tryParseDouble(field: String): DataType = { + private def tryParseDouble(field: String, options: CSVOptions): DataType = { if ((allCatch opt field.toDouble).isDefined) { DoubleType } else { - tryParseTimestamp(field) + tryParseTimestamp(field, options) } } - def tryParseTimestamp(field: String): DataType = { - if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { - TimestampType + private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { + if (options.dateFormat != null) { + // This case infers a custom `dataFormat` is set. + if ((allCatch opt options.dateFormat.parse(field)).isDefined) { + TimestampType + } else { + tryParseBoolean(field, options) + } } else { - tryParseBoolean(field) + // We keep this for backwords competibility. + if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { + TimestampType + } else { + tryParseBoolean(field, options) + } } } - def tryParseBoolean(field: String): DataType = { + private def tryParseBoolean(field: String, options: CSVOptions): DataType = { if ((allCatch opt field.toBoolean).isDefined) { BooleanType } else { @@ -177,7 +190,8 @@ private[csv] object CSVTypeCast { datum: String, castType: DataType, nullable: Boolean = true, - nullValue: String = ""): Any = { + nullValue: String = "", + dateFormat: SimpleDateFormat = null): Any = { if (datum == nullValue && nullable && (!castType.isInstanceOf[StringType])) { null @@ -195,12 +209,16 @@ private[csv] object CSVTypeCast { case dt: DecimalType => val value = new BigDecimal(datum.replaceAll(",", "")) Decimal(value, dt.precision, dt.scale) - // TODO(hossein): would be good to support other common timestamp formats + case _: TimestampType if dateFormat != null => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + dateFormat.parse(datum).getTime * 1000L case _: TimestampType => // This one will lose microseconds parts. // See https://issues.apache.org/jira/browse/SPARK-10681. DateTimeUtils.stringToTime(datum).getTime * 1000L - // TODO(hossein): would be good to support other common date formats + case _: DateType if dateFormat != null => + DateTimeUtils.millisToDays(dateFormat.parse(datum).getTime) case _: DateType => DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) case _: StringType => UTF8String.fromString(datum) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 80a0ad785629ffaaa614579aa279ca705da4517f..b87d19f7cf657b9638710a6322b44465e7707795 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets +import java.text.SimpleDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} @@ -94,6 +95,12 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str name.map(CompressionCodecs.getCodecClassName) } + // Share date format object as it is expensive to parse date pattern. + val dateFormat: SimpleDateFormat = { + val dateFormat = parameters.get("dateFormat") + dateFormat.map(new SimpleDateFormat(_)).orNull + } + val maxColumns = getInt("maxColumns", 20480) val maxCharsPerColumn = getInt("maxCharsPerColumn", 1000000) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index ed40cd0c812ab8bf43098712bf45ac053d3747a4..9a723630de7dbbfa4a892ca2d790df16cec0b3e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -99,7 +99,8 @@ object CSVRelation extends Logging { indexSafeTokens(index), field.dataType, field.nullable, - params.nullValue) + params.nullValue, + params.dateFormat) if (subIndex < requiredSize) { row(subIndex) = value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 8ca105d92375f8bc945f9454535e09d380bdded4..75143e609aaf70f0197a5ef97e6df645e5123394 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -68,7 +68,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths) val schema = if (csvOptions.inferSchemaFlag) { - CSVInferSchema.infer(parsedRdd, header, csvOptions.nullValue) + CSVInferSchema.infer(parsedRdd, header, csvOptions) } else { // By default fields are assumed to be StringType val schemaFields = header.map { fieldName => diff --git a/sql/core/src/test/resources/dates.csv b/sql/core/src/test/resources/dates.csv new file mode 100644 index 0000000000000000000000000000000000000000..9ee99c31b334a06043b8e3b295e67e15d021c740 --- /dev/null +++ b/sql/core/src/test/resources/dates.csv @@ -0,0 +1,4 @@ +date +26/08/2015 18:00 +27/10/2014 18:30 +28/01/2016 20:00 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 23d422635b0a93bc579f23939c590d14006380a8..daf85be56f3d25e4b437385ecae58430de2f4d97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -17,45 +17,58 @@ package org.apache.spark.sql.execution.datasources.csv +import java.text.SimpleDateFormat + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { test("String fields types are inferred correctly from null types") { - assert(CSVInferSchema.inferField(NullType, "") == NullType) - assert(CSVInferSchema.inferField(NullType, null) == NullType) - assert(CSVInferSchema.inferField(NullType, "100000000000") == LongType) - assert(CSVInferSchema.inferField(NullType, "60") == IntegerType) - assert(CSVInferSchema.inferField(NullType, "3.5") == DoubleType) - assert(CSVInferSchema.inferField(NullType, "test") == StringType) - assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType) - assert(CSVInferSchema.inferField(NullType, "True") == BooleanType) - assert(CSVInferSchema.inferField(NullType, "FAlSE") == BooleanType) + val options = new CSVOptions(Map.empty[String, String]) + assert(CSVInferSchema.inferField(NullType, "", options) == NullType) + assert(CSVInferSchema.inferField(NullType, null, options) == NullType) + assert(CSVInferSchema.inferField(NullType, "100000000000", options) == LongType) + assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType) + assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType) + assert(CSVInferSchema.inferField(NullType, "test", options) == StringType) + assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) + assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) + assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) } test("String fields types are inferred correctly from other types") { - assert(CSVInferSchema.inferField(LongType, "1.0") == DoubleType) - assert(CSVInferSchema.inferField(LongType, "test") == StringType) - assert(CSVInferSchema.inferField(IntegerType, "1.0") == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, null) == DoubleType) - assert(CSVInferSchema.inferField(DoubleType, "test") == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00") == TimestampType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == TimestampType) - assert(CSVInferSchema.inferField(LongType, "True") == BooleanType) - assert(CSVInferSchema.inferField(IntegerType, "FALSE") == BooleanType) - assert(CSVInferSchema.inferField(TimestampType, "FALSE") == BooleanType) + val options = new CSVOptions(Map.empty[String, String]) + assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType) + assert(CSVInferSchema.inferField(LongType, "test", options) == StringType) + assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType) + assert(CSVInferSchema.inferField(DoubleType, "test", options) == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) == TimestampType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", options) == TimestampType) + assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) + assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) + assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) + } + + test("Timestamp field types are inferred correctly via custom data format") { + var options = new CSVOptions(Map("dateFormat" -> "yyyy-mm")) + assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) + options = new CSVOptions(Map("dateFormat" -> "yyyy")) + assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } test("Timestamp field types are inferred correctly from other types") { - assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14") == StringType) - assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10") == StringType) - assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00") == StringType) + val options = new CSVOptions(Map.empty[String, String]) + assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == StringType) + assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) == StringType) + assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == StringType) } test("Boolean fields types are inferred correctly from other types") { - assert(CSVInferSchema.inferField(LongType, "Fale") == StringType) - assert(CSVInferSchema.inferField(DoubleType, "TRUEe") == StringType) + val options = new CSVOptions(Map.empty[String, String]) + assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType) + assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == StringType) } test("Type arrays are merged to highest common type") { @@ -71,13 +84,16 @@ class CSVInferSchemaSuite extends SparkFunSuite { } test("Null fields are handled properly when a nullValue is specified") { - assert(CSVInferSchema.inferField(NullType, "null", "null") == NullType) - assert(CSVInferSchema.inferField(StringType, "null", "null") == StringType) - assert(CSVInferSchema.inferField(LongType, "null", "null") == LongType) - assert(CSVInferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType) - assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) - assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) - assert(CSVInferSchema.inferField(BooleanType, "\\N", "\\N") == BooleanType) + var options = new CSVOptions(Map("nullValue" -> "null")) + assert(CSVInferSchema.inferField(NullType, "null", options) == NullType) + assert(CSVInferSchema.inferField(StringType, "null", options) == StringType) + assert(CSVInferSchema.inferField(LongType, "null", options) == LongType) + + options = new CSVOptions(Map("nullValue" -> "\\N")) + assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == IntegerType) + assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) + assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) + assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) } test("Merging Nulltypes should yield Nulltype.") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index ceda920ddc86895d00e5ab8d14d8d1d09a27c4a0..8847c7632fcf8857fe1431af0ce95b6822889120 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.File import java.nio.charset.UnsupportedCharsetException -import java.sql.Timestamp +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat import scala.collection.JavaConverters._ @@ -45,6 +46,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val disableCommentsFile = "disable_comments.csv" private val boolFile = "bool.csv" private val simpleSparseFile = "simple_sparse.csv" + private val datesFile = "dates.csv" private val unescapedQuotesFile = "unescaped-quotes.csv" private def testFile(fileName: String): String = { @@ -367,6 +369,54 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(results.toSeq.map(_.toSeq) === expected) } + test("inferring timestamp types via custom date format") { + val options = Map( + "header" -> "true", + "inferSchema" -> "true", + "dateFormat" -> "dd/MM/yyyy hh:mm") + val results = sqlContext.read + .format("csv") + .options(options) + .load(testFile(datesFile)) + .select("date") + .collect() + + val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val expected = + Seq(Seq(new Timestamp(dateFormat.parse("26/08/2015 18:00").getTime)), + Seq(new Timestamp(dateFormat.parse("27/10/2014 18:30").getTime)), + Seq(new Timestamp(dateFormat.parse("28/01/2016 20:00").getTime))) + assert(results.toSeq.map(_.toSeq) === expected) + } + + test("load date types via custom date format") { + val customSchema = new StructType(Array(StructField("date", DateType, true))) + val options = Map( + "header" -> "true", + "inferSchema" -> "false", + "dateFormat" -> "dd/MM/yyyy hh:mm") + val results = sqlContext.read + .format("csv") + .options(options) + .schema(customSchema) + .load(testFile(datesFile)) + .select("date") + .collect() + + val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val expected = Seq( + new Date(dateFormat.parse("26/08/2015 18:00").getTime), + new Date(dateFormat.parse("27/10/2014 18:30").getTime), + new Date(dateFormat.parse("28/01/2016 20:00").getTime)) + val dates = results.toSeq.map(_.toSeq.head) + expected.zip(dates).foreach { + case (expectedDate, date) => + // As it truncates the hours, minutes and etc., we only check + // if the dates (days, months and years) are the same via `toString()`. + assert(expectedDate.toString === date.toString) + } + } + test("setting comment to null disables comment support") { val results = sqlContext.read .format("csv") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 5702a1b4ea1f756bf4133205dce95c8f483572d2..8b59bc148fcd505a4e8db36477c9fefde3a9d8c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat import java.util.Locale import org.apache.spark.SparkFunSuite @@ -87,6 +89,15 @@ class CSVTypeCastSuite extends SparkFunSuite { assert(CSVTypeCast.castTo("1.00", FloatType) == 1.0) assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0) assert(CSVTypeCast.castTo("true", BooleanType) == true) + + val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val customTimestamp = "31/01/2015 00:00" + val expectedTime = dateFormat.parse("31/01/2015 00:00").getTime + assert(CSVTypeCast.castTo(customTimestamp, TimestampType, dateFormat = dateFormat) + == expectedTime * 1000L) + assert(CSVTypeCast.castTo(customTimestamp, DateType, dateFormat = dateFormat) == + DateTimeUtils.millisToDays(expectedTime)) + val timestamp = "2015-01-01 00:00:00" assert(CSVTypeCast.castTo(timestamp, TimestampType) == DateTimeUtils.stringToTime(timestamp).getTime * 1000L)