Skip to content
Snippets Groups Projects
Commit 8d0c2f73 authored by Hossein Falaki's avatar Hossein Falaki
Browse files

Added python binding for bulk recommendation

parent dfe57fa8
No related branches found
No related tags found
No related merge requests found
...@@ -206,6 +206,24 @@ class PythonMLLibAPI extends Serializable { ...@@ -206,6 +206,24 @@ class PythonMLLibAPI extends Serializable {
return new Rating(user, product, rating) return new Rating(user, product, rating)
} }
private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
val bb = ByteBuffer.wrap(tupleBytes)
bb.order(ByteOrder.nativeOrder())
val v1 = bb.getInt()
val v2 = bb.getInt()
(v1, v2)
}
private[spark] def serializeRating(rate: Rating): Array[Byte] = {
val bytes = new Array[Byte](24)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
bb.putDouble(rate.user.toDouble)
bb.putDouble(rate.product.toDouble)
bb.putDouble(rate.rating)
bytes
}
/** /**
* Java stub for Python mllib ALS.train(). This stub returns a handle * Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care * to the Java object instead of the content of the Java object. Extra care
......
...@@ -19,9 +19,11 @@ package org.apache.spark.mllib.recommendation ...@@ -19,9 +19,11 @@ package org.apache.spark.mllib.recommendation
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._ import org.apache.spark.SparkContext._
import org.apache.spark.mllib.api.python.PythonMLLibAPI
import org.jblas._ import org.jblas._
import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.api.java.JavaRDD
/** /**
* Model representing the result of matrix factorization. * Model representing the result of matrix factorization.
...@@ -65,6 +67,12 @@ class MatrixFactorizationModel( ...@@ -65,6 +67,12 @@ class MatrixFactorizationModel(
} }
} }
def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val pythonAPI = new PythonMLLibAPI()
val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))
}
// TODO: Figure out what other good bulk prediction methods would look like. // TODO: Figure out what other good bulk prediction methods would look like.
// Probably want a way to get the top users for a product or vice-versa. // Probably want a way to get the top users for a product or vice-versa.
} }
...@@ -213,6 +213,16 @@ def _serialize_rating(r): ...@@ -213,6 +213,16 @@ def _serialize_rating(r):
intpart[0], intpart[1], doublepart[0] = r intpart[0], intpart[1], doublepart[0] = r
return ba return ba
def _deserialize_rating(ba):
ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C')
return ar.copy()
def _serialize_tuple(t):
ba = bytearray(8)
intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
intpart[0], intpart[1] = t
return ba
def _test(): def _test():
import doctest import doctest
globs = globals().copy() globs = globals().copy()
......
...@@ -20,7 +20,10 @@ from pyspark.mllib._common import \ ...@@ -20,7 +20,10 @@ from pyspark.mllib._common import \
_get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
_serialize_double_matrix, _deserialize_double_matrix, \ _serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \ _serialize_double_vector, _deserialize_double_vector, \
_get_initial_weights, _serialize_rating, _regression_train_wrapper _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
_serialize_tuple, _deserialize_rating
from pyspark.serializers import BatchedSerializer
from pyspark.rdd import RDD
class MatrixFactorizationModel(object): class MatrixFactorizationModel(object):
"""A matrix factorisation model trained by regularized alternating """A matrix factorisation model trained by regularized alternating
...@@ -45,6 +48,11 @@ class MatrixFactorizationModel(object): ...@@ -45,6 +48,11 @@ class MatrixFactorizationModel(object):
def predict(self, user, product): def predict(self, user, product):
return self._java_model.predict(user, product) return self._java_model.predict(user, product)
def predictAll(self, usersProducts):
usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd),
self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize))
class ALS(object): class ALS(object):
@classmethod @classmethod
def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment