diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1c1a0cad496254675e4761579664d4cc31e2bb48..54756edd9345dd13b2990446950e98b291eaef4a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3018,8 +3018,8 @@ class ArrowTests(ReusedPySparkTestCase): self.assertTrue(df_without.equals(df_with_arrow), msg=msg) def test_unsupported_datatype(self): - schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) - df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) with QuietTest(self.sc): self.assertRaises(Exception, lambda: df.toPandas()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 71ab0ddf2d6f4b611118b38adeb2906e5b9f7be2..9007367f5aa8fc5a61a64c714806b56029f3b3ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils +import org.apache.spark.TaskContext import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ @@ -3090,7 +3091,8 @@ class Dataset[T] private[sql]( val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch queryExecution.toRdd.mapPartitionsInternal { iter => - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) + val context = TaskContext.get() + ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index c913efe52a41c50afe5fa555dd13f1a60513e6e7..240f38f5bfeb4942a1692a850944cd8f29ef608f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -20,18 +20,13 @@ package org.apache.spark.sql.execution.arrow import java.io.ByteArrayOutputStream import java.nio.channels.Channels -import scala.collection.JavaConverters._ - -import io.netty.buffer.ArrowBuf -import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.file._ -import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -55,19 +50,6 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se def asPythonSerializable: Array[Byte] = payload } -private[sql] object ArrowPayload { - - /** - * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. - */ - def apply( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): ArrowPayload = { - new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) - } -} - private[sql] object ArrowConverters { /** @@ -77,95 +59,55 @@ private[sql] object ArrowConverters { private[sql] def toPayloadIterator( rowIter: Iterator[InternalRow], schema: StructType, - maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { - new Iterator[ArrowPayload] { - private val _allocator = new RootAllocator(Long.MaxValue) - private var _nextPayload = if (rowIter.nonEmpty) convert() else null + maxRecordsPerBatch: Int, + context: TaskContext): Iterator[ArrowPayload] = { - override def hasNext: Boolean = _nextPayload != null - - override def next(): ArrowPayload = { - val obj = _nextPayload - if (hasNext) { - if (rowIter.hasNext) { - _nextPayload = convert() - } else { - _allocator.close() - _nextPayload = null - } - } - obj - } - - private def convert(): ArrowPayload = { - val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) - ArrowPayload(batch, schema, _allocator) - } - } - } + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val allocator = + ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) - /** - * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed - * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, - * then rowIter will be fully consumed. - */ - private def internalRowIterToArrowBatch( - rowIter: Iterator[InternalRow], - schema: StructType, - allocator: BufferAllocator, - maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) - val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => - ColumnWriter(field.dataType, ordinal, allocator).init() - } + var closed = false - val writerLength = columnWriters.length - var recordsInBatch = 0 - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { - val row = rowIter.next() - var i = 0 - while (i < writerLength) { - columnWriters(i).write(row) - i += 1 + context.addTaskCompletionListener { _ => + if (!closed) { + root.close() + allocator.close() } - recordsInBatch += 1 } - val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip - val buffers = bufferArrays.flatten - - val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 - val recordBatch = new ArrowRecordBatch(rowLength, - fieldNodes.toList.asJava, buffers.toList.asJava) + new Iterator[ArrowPayload] { - buffers.foreach(_.release()) - recordBatch - } + override def hasNext: Boolean = rowIter.hasNext || { + root.close() + allocator.close() + closed = true + false + } - /** - * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, - * the batch can no longer be used. - */ - private[arrow] def batchToByteArray( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): Array[Byte] = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + override def next(): ArrowPayload = { + val out = new ByteArrayOutputStream() + val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + + Utils.tryWithSafeFinally { + var rowCount = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + writer.writeBatch() + } { + arrowWriter.reset() + writer.close() + } - // Write a batch to byte stream, ensure the batch, allocator and writer are closed - Utils.tryWithSafeFinally { - val loader = new VectorLoader(root) - loader.load(batch) - writer.writeBatch() // writeBatch can throw IOException - } { - batch.close() - root.close() - writer.close() + new ArrowPayload(out.toByteArray) + } } - out.toByteArray } /** @@ -188,214 +130,3 @@ private[sql] object ArrowConverters { } } } - -/** - * Interface for writing InternalRows to Arrow Buffers. - */ -private[arrow] trait ColumnWriter { - def init(): this.type - def write(row: InternalRow): Unit - - /** - * Clear the column writer and return the ArrowFieldNode and ArrowBuf. - * This should be called only once after all the data is written. - */ - def finish(): (ArrowFieldNode, Array[ArrowBuf]) -} - -/** - * Base class for flat arrow column writer, i.e., column without children. - */ -private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) - extends ColumnWriter { - - def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) - - def valueVector: BaseDataValueVector - def valueMutator: BaseMutator - - def setNull(): Unit - def setValue(row: InternalRow): Unit - - protected var count = 0 - protected var nullCount = 0 - - override def init(): this.type = { - valueVector.allocateNew() - this - } - - override def write(row: InternalRow): Unit = { - if (row.isNullAt(ordinal)) { - setNull() - nullCount += 1 - } else { - setValue(row) - } - count += 1 - } - - override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { - valueMutator.setValueCount(count) - val fieldNode = new ArrowFieldNode(count, nullCount) - val valueBuffers = valueVector.getBuffers(true) - (fieldNode, valueBuffers) - } -} - -private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBitVector - = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) -} - -private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableSmallIntVector - = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) - override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getShort(ordinal)) -} - -private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableIntVector - = new NullableIntVector("IntValue", getFieldType(dtype), allocator) - override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getInt(ordinal)) -} - -private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBigIntVector - = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getLong(ordinal)) -} - -private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat4Vector - = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getFloat(ordinal)) -} - -private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat8Vector - = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getDouble(ordinal)) -} - -private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableUInt1Vector - = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) - override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getByte(ordinal)) -} - -private[arrow] class UTF8StringColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarCharVector - = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val str = row.getUTF8String(ordinal) - valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) - } -} - -private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val bytes = row.getBinary(ordinal) - valueMutator.setSafe(count, bytes, 0, bytes.length) - } -} - -private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableDateDayVector - = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) - override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getInt(ordinal)) - } -} - -private[arrow] class TimeStampColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableTimeStampMicroVector - = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) - override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getLong(ordinal)) - } -} - -private[arrow] object ColumnWriter { - - /** - * Create an Arrow ColumnWriter given the type and ordinal of row. - */ - def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { - val dtype = ArrowUtils.toArrowType(dataType) - dataType match { - case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) - case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) - case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) - case LongType => new LongColumnWriter(dtype, ordinal, allocator) - case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) - case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) - case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) - case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) - case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) - case DateType => new DateColumnWriter(dtype, ordinal, allocator) - case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala new file mode 100644 index 0000000000000000000000000000000000000000..11ba04d2ce9a7603167b71a0cb00ca2d089f6f65 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex._ +import org.apache.arrow.vector.util.DecimalUtility + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.types._ + +object ArrowWriter { + + def create(schema: StructType): ArrowWriter = { + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + create(root) + } + + def create(root: VectorSchemaRoot): ArrowWriter = { + val children = root.getFieldVectors().asScala.map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + new ArrowWriter(root, children.toArray) + } + + private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { + val field = vector.getField() + (ArrowUtils.fromArrowField(field), vector) match { + case (BooleanType, vector: NullableBitVector) => new BooleanWriter(vector) + case (ByteType, vector: NullableTinyIntVector) => new ByteWriter(vector) + case (ShortType, vector: NullableSmallIntVector) => new ShortWriter(vector) + case (IntegerType, vector: NullableIntVector) => new IntegerWriter(vector) + case (LongType, vector: NullableBigIntVector) => new LongWriter(vector) + case (FloatType, vector: NullableFloat4Vector) => new FloatWriter(vector) + case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) + case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) + case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) + case (ArrayType(_, _), vector: ListVector) => + val elementVector = createFieldWriter(vector.getDataVector()) + new ArrayWriter(vector, elementVector) + case (StructType(_), vector: NullableMapVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new StructWriter(vector, children.toArray) + case (dt, _) => + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + } + } +} + +class ArrowWriter( + val root: VectorSchemaRoot, + fields: Array[ArrowFieldWriter]) { + + def schema: StructType = StructType(fields.map { f => + StructField(f.name, f.dataType, f.nullable) + }) + + private var count: Int = 0 + + def write(row: InternalRow): Unit = { + var i = 0 + while (i < fields.size) { + fields(i).write(row, i) + i += 1 + } + count += 1 + } + + def finish(): Unit = { + root.setRowCount(count) + fields.foreach(_.finish()) + } + + def reset(): Unit = { + root.setRowCount(0) + count = 0 + fields.foreach(_.reset()) + } +} + +private[arrow] abstract class ArrowFieldWriter { + + def valueVector: ValueVector + def valueMutator: ValueVector.Mutator + + def name: String = valueVector.getField().getName() + def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField()) + def nullable: Boolean = valueVector.getField().isNullable() + + def setNull(): Unit + def setValue(input: SpecializedGetters, ordinal: Int): Unit + + private[arrow] var count: Int = 0 + + def write(input: SpecializedGetters, ordinal: Int): Unit = { + if (input.isNullAt(ordinal)) { + setNull() + } else { + setValue(input, ordinal) + } + count += 1 + } + + def finish(): Unit = { + valueMutator.setValueCount(count) + } + + def reset(): Unit = { + valueMutator.reset() + count = 0 + } +} + +private[arrow] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter { + + override def valueMutator: NullableBitVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) + } +} + +private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableTinyIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getByte(ordinal)) + } +} + +private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getShort(ordinal)) + } +} + +private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getInt(ordinal)) + } +} + +private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getLong(ordinal)) + } +} + +private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter { + + override def valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getFloat(ordinal)) + } +} + +private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter { + + override def valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getDouble(ordinal)) + } +} + +private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { + + override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val utf8 = input.getUTF8String(ordinal) + // todo: for off-heap UTF8String, how to pass in to arrow without copy? + valueMutator.setSafe(count, utf8.getByteBuffer, 0, utf8.numBytes()) + } +} + +private[arrow] class BinaryWriter( + val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter { + + override def valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val bytes = input.getBinary(ordinal) + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +} + +private[arrow] class ArrayWriter( + val valueVector: ListVector, + val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { + + override def valueMutator: ListVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val array = input.getArray(ordinal) + var i = 0 + valueMutator.startNewValue(count) + while (i < array.numElements()) { + elementWriter.write(array, i) + i += 1 + } + valueMutator.endValue(count, array.numElements()) + } + + override def finish(): Unit = { + super.finish() + elementWriter.finish() + } + + override def reset(): Unit = { + super.reset() + elementWriter.reset() + } +} + +private[arrow] class StructWriter( + val valueVector: NullableMapVector, + children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { + + override def valueMutator: NullableMapVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + var i = 0 + while (i < children.length) { + children(i).setNull() + children(i).count += 1 + i += 1 + } + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val struct = input.getStruct(ordinal, children.length) + var i = 0 + while (i < struct.numFields) { + children(i).write(struct, i) + i += 1 + } + valueMutator.setIndexDefined(count) + } + + override def finish(): Unit = { + super.finish() + children.foreach(_.finish()) + } + + override def reset(): Unit = { + super.reset() + children.foreach(_.reset()) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 55b465578a42d28092b3ab33db87b04479ec4b40..4893b52f240ec783fc23b768b829dee20359367f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -857,6 +857,449 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "nanData-floating_point.json") } + test("array type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 4, 4, 5 ], + | "children" : [ { + | "name" : "element", + | "count" : 5, + | "VALIDITY" : [ 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5 ] + | } ] + | }, { + | "name" : "b_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 0, 1, 0 ], + | "OFFSET" : [ 0, 2, 2, 2, 2 ], + | "children" : [ { + | "name" : "element", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1, 2 ] + | } ] + | }, { + | "name" : "c_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 4, 4, 5 ], + | "children" : [ { + | "name" : "element", + | "count" : 5, + | "VALIDITY" : [ 1, 1, 1, 0, 1 ], + | "DATA" : [ 1, 2, 3, 0, 5 ] + | } ] + | }, { + | "name" : "d_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 3, 3, 4 ], + | "children" : [ { + | "name" : "element", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 3, 3, 4 ], + | "children" : [ { + | "name" : "element", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 5 ] + | } ] + | } ] + | } ] + | } ] + |} + """.stripMargin + + val aArr = Seq(Seq(1, 2), Seq(3, 4), Seq(), Seq(5)) + val bArr = Seq(Some(Seq(1, 2)), None, Some(Seq()), None) + val cArr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5))) + val dArr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5))) + + val df = aArr.zip(bArr).zip(cArr).zip(dArr).map { + case (((a, b), c), d) => (a, b, c, d) + }.toDF("a_arr", "b_arr", "c_arr", "d_arr") + + collectAndValidate(df, json, "arrayData.json") + } + + test("struct type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_struct", + | "nullable" : false, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "b_struct", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "c_struct", + | "nullable" : false, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "d_struct", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "nested", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "b_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "c_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "d_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "children" : [ { + | "name" : "nested", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 0 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 0 ], + | "DATA" : [ 1, 2, 0 ] + | } ] + | } ] + | } ] + | } ] + |} + """.stripMargin + + val aStruct = Seq(Row(1), Row(2), Row(3)) + val bStruct = Seq(Row(1), null, Row(3)) + val cStruct = Seq(Row(1), Row(null), Row(3)) + val dStruct = Seq(Row(Row(1)), null, Row(null)) + val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map { + case (((a, b), c), d) => Row(a, b, c, d) + } + + val rdd = sparkContext.parallelize(data) + val schema = new StructType() + .add("a_struct", new StructType().add("i", IntegerType, nullable = false), nullable = false) + .add("b_struct", new StructType().add("i", IntegerType, nullable = false), nullable = true) + .add("c_struct", new StructType().add("i", IntegerType, nullable = true), nullable = false) + .add("d_struct", new StructType().add("nested", new StructType().add("i", IntegerType))) + val df = spark.createDataFrame(rdd, schema) + + collectAndValidate(df, json, "structData.json") + } + test("partitioned DataFrame") { val json1 = s""" @@ -1015,6 +1458,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") val arrowPayloads = df.toArrowPayload.collect() + assert(arrowPayloads.length >= 4) val allocator = new RootAllocator(Long.MaxValue) val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) var recordCount = 0 @@ -1039,7 +1483,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } runUnsupported { decimalData.toArrowPayload.collect() } - runUnsupported { arrayData.toDF().toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..e9a629315f5f41520a96cac310b447123a980c6e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution.vectorized.ArrowColumnVector +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ArrowWriterSuite extends SparkFunSuite { + + test("simple") { + def check(dt: DataType, data: Seq[Any]): Unit = { + val schema = new StructType().add("value", dt, nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + data.zipWithIndex.foreach { + case (null, rowId) => assert(reader.isNullAt(rowId)) + case (datum, rowId) => + val value = dt match { + case BooleanType => reader.getBoolean(rowId) + case ByteType => reader.getByte(rowId) + case ShortType => reader.getShort(rowId) + case IntegerType => reader.getInt(rowId) + case LongType => reader.getLong(rowId) + case FloatType => reader.getFloat(rowId) + case DoubleType => reader.getDouble(rowId) + case StringType => reader.getUTF8String(rowId) + case BinaryType => reader.getBinary(rowId) + } + assert(value === datum) + } + + writer.root.close() + } + check(BooleanType, Seq(true, null, false)) + check(ByteType, Seq(1.toByte, 2.toByte, null, 4.toByte)) + check(ShortType, Seq(1.toShort, 2.toShort, null, 4.toShort)) + check(IntegerType, Seq(1, 2, null, 4)) + check(LongType, Seq(1L, 2L, null, 4L)) + check(FloatType, Seq(1.0f, 2.0f, null, 4.0f)) + check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) + check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) + check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) + } + + test("get multiple") { + def check(dt: DataType, data: Seq[Any]): Unit = { + val schema = new StructType().add("value", dt, nullable = false) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + val values = dt match { + case BooleanType => reader.getBooleans(0, data.size) + case ByteType => reader.getBytes(0, data.size) + case ShortType => reader.getShorts(0, data.size) + case IntegerType => reader.getInts(0, data.size) + case LongType => reader.getLongs(0, data.size) + case FloatType => reader.getFloats(0, data.size) + case DoubleType => reader.getDoubles(0, data.size) + } + assert(values === data) + + writer.root.close() + } + check(BooleanType, Seq(true, false)) + check(ByteType, (0 until 10).map(_.toByte)) + check(ShortType, (0 until 10).map(_.toShort)) + check(IntegerType, (0 until 10)) + check(LongType, (0 until 10).map(_.toLong)) + check(FloatType, (0 until 10).map(_.toFloat)) + check(DoubleType, (0 until 10).map(_.toDouble)) + } + + test("array") { + val schema = new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) + writer.write(InternalRow(ArrayData.toArrayData(Array(4, 5)))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty[Int]))) + writer.write(InternalRow(ArrayData.toArrayData(Array(6, null, 8)))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 3) + assert(array0.getInt(0) === 1) + assert(array0.getInt(1) === 2) + assert(array0.getInt(2) === 3) + + val array1 = reader.getArray(1) + assert(array1.numElements() === 2) + assert(array1.getInt(0) === 4) + assert(array1.getInt(1) === 5) + + assert(reader.isNullAt(2)) + + val array3 = reader.getArray(3) + assert(array3.numElements() === 0) + + val array4 = reader.getArray(4) + assert(array4.numElements() === 3) + assert(array4.getInt(0) === 6) + assert(array4.isNullAt(1)) + assert(array4.getInt(2) === 8) + + writer.root.close() + } + + test("nested array") { + val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array( + ArrayData.toArrayData(Array(1, 2, 3)), + ArrayData.toArrayData(Array(4, 5)), + null, + ArrayData.toArrayData(Array.empty[Int]), + ArrayData.toArrayData(Array(6, null, 8)))))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 5) + + val array00 = array0.getArray(0) + assert(array00.numElements() === 3) + assert(array00.getInt(0) === 1) + assert(array00.getInt(1) === 2) + assert(array00.getInt(2) === 3) + + val array01 = array0.getArray(1) + assert(array01.numElements() === 2) + assert(array01.getInt(0) === 4) + assert(array01.getInt(1) === 5) + + assert(array0.isNullAt(2)) + + val array03 = array0.getArray(3) + assert(array03.numElements() === 0) + + val array04 = array0.getArray(4) + assert(array04.numElements() === 3) + assert(array04.getInt(0) === 6) + assert(array04.isNullAt(1)) + assert(array04.getInt(2) === 8) + + assert(reader.isNullAt(1)) + + val array2 = reader.getArray(2) + assert(array2.numElements() === 0) + + writer.root.close() + } + + test("struct") { + val schema = new StructType() + .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) + writer.write(InternalRow(InternalRow(null, null))) + writer.write(InternalRow(null)) + writer.write(InternalRow(InternalRow(4, null))) + writer.write(InternalRow(InternalRow(null, UTF8String.fromString("str5")))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct0 = reader.getStruct(0, 2) + assert(struct0.getInt(0) === 1) + assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct1 = reader.getStruct(1, 2) + assert(struct1.isNullAt(0)) + assert(struct1.isNullAt(1)) + + assert(reader.isNullAt(2)) + + val struct3 = reader.getStruct(3, 2) + assert(struct3.getInt(0) === 4) + assert(struct3.isNullAt(1)) + + val struct4 = reader.getStruct(4, 2) + assert(struct4.isNullAt(0)) + assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) + + writer.root.close() + } + + test("nested struct") { + val schema = new StructType().add("struct", + new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) + writer.write(InternalRow(InternalRow(InternalRow(null, null)))) + writer.write(InternalRow(InternalRow(null))) + writer.write(InternalRow(null)) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + assert(struct00.getInt(0) === 1) + assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + assert(struct10.isNullAt(0)) + assert(struct10.isNullAt(1)) + + val struct2 = reader.getStruct(2, 1) + assert(struct2.isNullAt(0)) + + assert(reader.isNullAt(3)) + + writer.root.close() + } +}