From 04132ea9b20a95cd68482605d4022f692bb556e5 Mon Sep 17 00:00:00 2001 From: Hossein Falaki <falaki@gmail.com> Date: Mon, 6 Jan 2014 12:19:08 -0800 Subject: [PATCH] Added Rating deserializer --- .../MatrixFactorizationModel.scala | 9 +++++++- python/pyspark/mllib/_common.py | 21 ++++++++++++++++--- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 2c3e828300..443fc5de5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -67,7 +67,14 @@ class MatrixFactorizationModel( } } - def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { + /** + * Predict the rating of many users for many products. + * This is a Java stub for python predictAll() + * + * @param usersProductsJRDD A JavaRDD with serialized tuples (user, product) + * @return JavaRDD of serialized Rating objects. + */ + def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { val pythonAPI = new PythonMLLibAPI() val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes)) predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index c818fc4d97..769d88dfb9 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -18,6 +18,9 @@ from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape from pyspark import SparkContext +from pyspark.serializers import Serializer +import struct + # Double vector format: # # [8-byte 1] [8-byte length] [length*8 bytes of data] @@ -213,9 +216,21 @@ def _serialize_rating(r): intpart[0], intpart[1], doublepart[0] = r return ba -def _deserialize_rating(ba): - ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C') - return ar.copy() +class RatingDeserializer(Serializer): + def loads(self, stream): + length = struct.unpack("!i", stream.read(4))[0] + ba = stream.read(length) + res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4) + return int(res[0]), int(res[1]), res[2] + + def load_stream(self, stream): + while True: + try: + yield self.loads(stream) + except struct.error: + return + except EOFError: + return def _serialize_tuple(t): ba = bytearray(8) -- GitLab