diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 75748b2b5440078f2dbd132856a432ebdbb82f82..de8fe2dae38f64f826c845af8bad601d201e65b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -114,7 +114,7 @@ private[sql] object CatalystConverter { } } // All other primitive types use the default converter - case ctype: NativeType => { // note: need the type tag here! + case ctype: PrimitiveType => { // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) } case _ => throw new RuntimeException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 108f8b6815423b9a895f831ceb2a53fa393d178e..f1953a008a49b7ceae7205c05de2d74184906043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -191,6 +191,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { value.asInstanceOf[String].getBytes("utf-8") ) ) + case BinaryType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(value.asInstanceOf[Int]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) case LongType => writer.addLong(value.asInstanceOf[Long]) @@ -299,6 +301,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { record(index).asInstanceOf[String].getBytes("utf-8") ) ) + case BinaryType => writer.addBinary( + Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) case ShortType => writer.addInteger(record.getShort(index)) case LongType => writer.addLong(record.getLong(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 1dc58633a2a68cd910c1bab01c3d5ee1eb4f8709..d4599da71125428fc3b31ba94a18e64f30567ccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -58,7 +58,7 @@ private[sql] object ParquetTestData { """message myrecord { optional boolean myboolean; optional int32 myint; - optional binary mystring; + optional binary mystring (UTF8); optional int64 mylong; optional float myfloat; optional double mydouble; @@ -87,7 +87,7 @@ private[sql] object ParquetTestData { message myrecord { required boolean myboolean; required int32 myint; - required binary mystring; + required binary mystring (UTF8); required int64 mylong; required float myfloat; required double mydouble; @@ -119,14 +119,14 @@ private[sql] object ParquetTestData { // so that array types can be translated correctly. """ message AddressBook { - required binary owner; + required binary owner (UTF8); optional group ownerPhoneNumbers { - repeated binary array; + repeated binary array (UTF8); } optional group contacts { repeated group array { - required binary name; - optional binary phoneNumber; + required binary name (UTF8); + optional binary phoneNumber (UTF8); } } } @@ -181,16 +181,16 @@ private[sql] object ParquetTestData { required int32 x; optional group data1 { repeated group map { - required binary key; + required binary key (UTF8); required int32 value; } } required group data2 { repeated group map { - required binary key; + required binary key (UTF8); required group value { required int64 payload1; - optional binary payload2; + optional binary payload2 (UTF8); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index f9046368e7cedcfd20294aeca4f64a8f3fd6dd0a..7f6ad908f78ede507776e49c147f3e1eea09a83b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -42,20 +42,22 @@ private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass - def toPrimitiveDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { - case ParquetPrimitiveTypeName.BINARY => StringType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - sys.error("Potential loss of precision: cannot convert INT96") - case _ => sys.error( - s"Unsupported parquet datatype $parquetType") - } + def toPrimitiveDataType(parquetType: ParquetPrimitiveType): DataType = + parquetType.getPrimitiveTypeName match { + case ParquetPrimitiveTypeName.BINARY + if parquetType.getOriginalType == ParquetOriginalType.UTF8 => StringType + case ParquetPrimitiveTypeName.BINARY => BinaryType + case ParquetPrimitiveTypeName.BOOLEAN => BooleanType + case ParquetPrimitiveTypeName.DOUBLE => DoubleType + case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 => IntegerType + case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 => + // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? + sys.error("Potential loss of precision: cannot convert INT96") + case _ => sys.error( + s"Unsupported parquet datatype $parquetType") + } /** * Converts a given Parquet `Type` into the corresponding @@ -104,7 +106,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType.getPrimitiveTypeName) + toPrimitiveDataType(parquetType.asPrimitiveType) } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -164,18 +166,17 @@ private[parquet] object ParquetTypesConverter extends Logging { * @return The name of the corresponding Parquet primitive type */ def fromPrimitiveDataType(ctype: DataType): - Option[ParquetPrimitiveTypeName] = ctype match { - case StringType => Some(ParquetPrimitiveTypeName.BINARY) - case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN) - case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE) - case ArrayType(ByteType) => - Some(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) - case FloatType => Some(ParquetPrimitiveTypeName.FLOAT) - case IntegerType => Some(ParquetPrimitiveTypeName.INT32) + Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match { + case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)) + case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None) + case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None) + case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None) + case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None) + case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None) // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetPrimitiveTypeName.INT32) - case ByteType => Some(ParquetPrimitiveTypeName.INT32) - case LongType => Some(ParquetPrimitiveTypeName.INT64) + case ShortType => Some(ParquetPrimitiveTypeName.INT32, None) + case ByteType => Some(ParquetPrimitiveTypeName.INT32, None) + case LongType => Some(ParquetPrimitiveTypeName.INT64, None) case _ => None } @@ -227,9 +228,10 @@ private[parquet] object ParquetTypesConverter extends Logging { if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED } val primitiveType = fromPrimitiveDataType(ctype) - if (primitiveType.isDefined) { - new ParquetPrimitiveType(repetition, primitiveType.get, name) - } else { + primitiveType.map { + case (primitiveType, originalType) => + new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull) + }.getOrElse { ctype match { case ArrayType(elementType) => { val parquetElementType = fromDataType( @@ -237,7 +239,7 @@ private[parquet] object ParquetTypesConverter extends Logging { CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, nullable = false, inArray = true) - ConversionPatterns.listType(repetition, name, parquetElementType) + ConversionPatterns.listType(repetition, name, parquetElementType) } case StructType(structFields) => { val fields = structFields.map { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 8fa143e2deca61f13a8f1eaa455e2f4641e09fa7..3c911e9a4e7b1d938c1b0ac619d620794f069045 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -65,7 +65,8 @@ case class AllDataTypes( doubleField: Double, shortField: Short, byteField: Byte, - booleanField: Boolean) + booleanField: Boolean, + binaryField: Array[Byte]) case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -76,6 +77,7 @@ case class AllDataTypesWithNonPrimitiveType( shortField: Short, byteField: Byte, booleanField: Boolean, + binaryField: Array[Byte], array: Seq[Int], map: Map[Int, String], data: Data) @@ -116,7 +118,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) TestSQLContext.sparkContext.parallelize(range) - .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, + (0 to x).map(_.toByte).toArray)) .saveAsParquetFile(tempDir) val result = parquetFile(tempDir).collect() range.foreach { @@ -129,6 +132,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result(i).getShort(5) === i.toShort) assert(result(i).getByte(6) === i.toByte) assert(result(i).getBoolean(7) === (i % 2 == 0)) + assert(result(i)(8) === (0 to i).map(_.toByte).toArray) } } @@ -138,6 +142,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA TestSQLContext.sparkContext.parallelize(range) .map(x => AllDataTypesWithNonPrimitiveType( s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, + (0 to x).map(_.toByte).toArray, (0 until x), (0 until x).map(i => i -> s"$i").toMap, Data((0 until x), Nested(x, s"$x")))) .saveAsParquetFile(tempDir) val result = parquetFile(tempDir).collect() @@ -151,9 +156,10 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result(i).getShort(5) === i.toShort) assert(result(i).getByte(6) === i.toByte) assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 until i)) - assert(result(i)(9) === (0 until i).map(i => i -> s"$i").toMap) - assert(result(i)(10) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) + assert(result(i)(8) === (0 to i).map(_.toByte).toArray) + assert(result(i)(9) === (0 until i)) + assert(result(i)(10) === (0 until i).map(i => i -> s"$i").toMap) + assert(result(i)(11) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) } }