diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 8247c1ebc5d2b32910cc4efc045d32c134296bc4..be2628fac581780a3aa12568961ceb0e81f3dbee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -206,6 +206,24 @@ class PythonMLLibAPI extends Serializable { return new Rating(user, product, rating) } + private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { + val bb = ByteBuffer.wrap(tupleBytes) + bb.order(ByteOrder.nativeOrder()) + val v1 = bb.getInt() + val v2 = bb.getInt() + (v1, v2) + } + + private[spark] def serializeRating(rate: Rating): Array[Byte] = { + val bytes = new Array[Byte](24) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putDouble(rate.user.toDouble) + bb.putDouble(rate.product.toDouble) + bb.putDouble(rate.rating) + bytes + } + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 8caecf0fa1ba85534fe7d7fce470ee6a7b524c7f..2c3e82830033ecc32edf8eb900ed63bbae4f9a1f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -19,9 +19,11 @@ package org.apache.spark.mllib.recommendation import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.api.python.PythonMLLibAPI import org.jblas._ -import java.nio.{ByteOrder, ByteBuffer} +import org.apache.spark.api.java.JavaRDD + /** * Model representing the result of matrix factorization. @@ -65,6 +67,12 @@ class MatrixFactorizationModel( } } + def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { + val pythonAPI = new PythonMLLibAPI() + val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes)) + predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) + } + // TODO: Figure out what other good bulk prediction methods would look like. // Probably want a way to get the top users for a product or vice-versa. } diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index e74ba0fabc09ce8142c4eee10b8ff09aeea4f812..c818fc4d97180d805366e9ae6284c2f1c50eef74 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -213,6 +213,16 @@ def _serialize_rating(r): intpart[0], intpart[1], doublepart[0] = r return ba +def _deserialize_rating(ba): + ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C') + return ar.copy() + +def _serialize_tuple(t): + ba = bytearray(8) + intpart = ndarray(shape=[2], buffer=ba, dtype=int32) + intpart[0], intpart[1] = t + return ba + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 14d06cba2137fc6125bc419c467ed4b679ab3df4..c81b482a87ef7a1f70f58cc9f09d60a1b399bd65 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -20,7 +20,10 @@ from pyspark.mllib._common import \ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ _serialize_double_matrix, _deserialize_double_matrix, \ _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper + _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ + _serialize_tuple, _deserialize_rating +from pyspark.serializers import BatchedSerializer +from pyspark.rdd import RDD class MatrixFactorizationModel(object): """A matrix factorisation model trained by regularized alternating @@ -45,6 +48,11 @@ class MatrixFactorizationModel(object): def predict(self, user, product): return self._java_model.predict(user, product) + def predictAll(self, usersProducts): + usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple) + return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd), + self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize)) + class ALS(object): @classmethod def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):