diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 2c3e82830033ecc32edf8eb900ed63bbae4f9a1f..443fc5de5bf045f17a13da5c8adba9cd4db404b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -67,7 +67,14 @@ class MatrixFactorizationModel( } } - def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { + /** + * Predict the rating of many users for many products. + * This is a Java stub for python predictAll() + * + * @param usersProductsJRDD A JavaRDD with serialized tuples (user, product) + * @return JavaRDD of serialized Rating objects. + */ + def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { val pythonAPI = new PythonMLLibAPI() val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes)) predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index c818fc4d97180d805366e9ae6284c2f1c50eef74..769d88dfb9b56a8c77b7e2fd2779015fa2dbfef3 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -18,6 +18,9 @@ from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape from pyspark import SparkContext +from pyspark.serializers import Serializer +import struct + # Double vector format: # # [8-byte 1] [8-byte length] [length*8 bytes of data] @@ -213,9 +216,21 @@ def _serialize_rating(r): intpart[0], intpart[1], doublepart[0] = r return ba -def _deserialize_rating(ba): - ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C') - return ar.copy() +class RatingDeserializer(Serializer): + def loads(self, stream): + length = struct.unpack("!i", stream.read(4))[0] + ba = stream.read(length) + res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4) + return int(res[0]), int(res[1]), res[2] + + def load_stream(self, stream): + while True: + try: + yield self.loads(stream) + except struct.error: + return + except EOFError: + return def _serialize_tuple(t): ba = bytearray(8)