diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 90fa4fbbc604db8ed81547861a1f7b4f2b588427..076cca6016ecb660fb6ea54aa302017fc098ab87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -27,8 +27,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.annotation.Since import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -194,9 +193,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) - row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) - row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) + row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs)) + row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices)) + row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values)) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -205,7 +204,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) + row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values)) row.setBoolean(6, dm.isTransposed) } row diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6e3da6b701cb07491f1d7eec2a48d1f08d443035..132e54a8c3de25d57a8c3fadc908db7b7e1af1dc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -33,8 +33,7 @@ import org.apache.spark.annotation.{AlphaComponent, Since} import org.apache.spark.ml.{linalg => newlinalg} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -216,15 +215,15 @@ class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) - row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row.update(2, UnsafeArrayData.fromPrimitiveArray(indices)) + row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala new file mode 100644 index 0000000000000000000000000000000000000000..be7110ad6bbf0773fc45928a341ff3f436ae0efa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.util.Benchmark + +/** + * Serialization benchmark for VectorUDT. + */ +object UDTSerializationBenchmark { + + def main(args: Array[String]): Unit = { + val iters = 1e2.toInt + val numRows = 1e3.toInt + + val encoder = ExpressionEncoder[Vector].defaultBinding + + val vectors = (1 to numRows).map { i => + Vectors.dense(Array.fill(1e5.toInt)(1.0 * i)) + }.toArray + val rows = vectors.map(encoder.toRow) + + val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters) + + benchmark.addCase("serialize") { _ => + var sum = 0 + var i = 0 + while (i < numRows) { + sum += encoder.toRow(vectors(i)).numFields + i += 1 + } + } + + benchmark.addCase("deserialize") { _ => + var sum = 0 + var i = 0 + while (i < numRows) { + sum += encoder.fromRow(rows(i)).numActives + i += 1 + } + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + serialize 380 / 392 0.0 379730.0 1.0X + deserialize 138 / 142 0.0 137816.6 2.8X + */ + benchmark.run() + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 648625b2cc5d2303fc36e22718a7231870453e0a..02a863b2bb498f9ae3cf6193d9d68303acf9088c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -47,7 +47,7 @@ import org.apache.spark.unsafe.types.UTF8String; * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. */ // todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. -public class UnsafeArrayData extends ArrayData { +public final class UnsafeArrayData extends ArrayData { private Object baseObject; private long baseOffset; @@ -81,7 +81,7 @@ public class UnsafeArrayData extends ArrayData { } public Object[] array() { - throw new UnsupportedOperationException("Only supported on GenericArrayData."); + throw new UnsupportedOperationException("Not supported on UnsafeArrayData."); } /** @@ -336,4 +336,64 @@ public class UnsafeArrayData extends ArrayData { arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return arrayCopy; } + + public static UnsafeArrayData fromPrimitiveArray(int[] arr) { + if (arr.length > (Integer.MAX_VALUE - 4) / 8) { + throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + + "it's too big."); + } + + final int offsetRegionSize = 4 * arr.length; + final int valueRegionSize = 4 * arr.length; + final int totalSize = 4 + offsetRegionSize + valueRegionSize; + final byte[] data = new byte[totalSize]; + + Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); + + int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4; + int valueOffset = 4 + offsetRegionSize; + for (int i = 0; i < arr.length; i++) { + Platform.putInt(data, offsetPosition, valueOffset); + offsetPosition += 4; + valueOffset += 4; + } + + Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data, + Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize); + + UnsafeArrayData result = new UnsafeArrayData(); + result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); + return result; + } + + public static UnsafeArrayData fromPrimitiveArray(double[] arr) { + if (arr.length > (Integer.MAX_VALUE - 4) / 12) { + throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " + + "it's too big."); + } + + final int offsetRegionSize = 4 * arr.length; + final int valueRegionSize = 8 * arr.length; + final int totalSize = 4 + offsetRegionSize + valueRegionSize; + final byte[] data = new byte[totalSize]; + + Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length); + + int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4; + int valueOffset = 4 + offsetRegionSize; + for (int i = 0; i < arr.length; i++) { + Platform.putInt(data, offsetPosition, valueOffset); + offsetPosition += 4; + valueOffset += 8; + } + + Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data, + Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize); + + UnsafeArrayData result = new UnsafeArrayData(); + result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize); + return result; + } + + // TODO: add more specialized methods. } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 651eb1ff0c561d31253780e4405a4a64d59a4246..0700148becabadcb343c4c4b5f9ef0f3a4773497 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -30,7 +30,7 @@ import org.apache.spark.unsafe.Platform; * [unsafe key array numBytes] [unsafe key array] [unsafe value array] */ // TODO: Use a more efficient format which doesn't depend on unsafe array. -public class UnsafeMapData extends MapData { +public final class UnsafeMapData extends MapData { private Object baseObject; private long baseOffset; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..1685276ff1201be48db9ba7d1498541f3b22f033 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData + +class UnsafeArraySuite extends SparkFunSuite { + + test("from primitive int array") { + val array = Array(1, 10, 100) + val unsafe = UnsafeArrayData.fromPrimitiveArray(array) + assert(unsafe.numElements == 3) + assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3) + assert(unsafe.getInt(0) == 1) + assert(unsafe.getInt(1) == 10) + assert(unsafe.getInt(2) == 100) + } + + test("from primitive double array") { + val array = Array(1.1, 2.2, 3.3) + val unsafe = UnsafeArrayData.fromPrimitiveArray(array) + assert(unsafe.numElements == 3) + assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3) + assert(unsafe.getDouble(0) == 1.1) + assert(unsafe.getDouble(1) == 2.2) + assert(unsafe.getDouble(2) == 3.3) + } +}