Skip to content
Snippets Groups Projects
Commit 754f5300 authored by Hossein Falaki's avatar Hossein Falaki
Browse files

Added predictAll python function to MatrixFactorizationModel

parent 04132ea9
No related branches found
No related tags found
No related merge requests found
...@@ -21,8 +21,7 @@ from pyspark.mllib._common import \ ...@@ -21,8 +21,7 @@ from pyspark.mllib._common import \
_serialize_double_matrix, _deserialize_double_matrix, \ _serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \ _serialize_double_vector, _deserialize_double_vector, \
_get_initial_weights, _serialize_rating, _regression_train_wrapper, \ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
_serialize_tuple, _deserialize_rating _serialize_tuple, RatingDeserializer
from pyspark.serializers import BatchedSerializer
from pyspark.rdd import RDD from pyspark.rdd import RDD
class MatrixFactorizationModel(object): class MatrixFactorizationModel(object):
...@@ -36,6 +35,9 @@ class MatrixFactorizationModel(object): ...@@ -36,6 +35,9 @@ class MatrixFactorizationModel(object):
>>> model = ALS.trainImplicit(sc, ratings, 1) >>> model = ALS.trainImplicit(sc, ratings, 1)
>>> model.predict(2,2) is not None >>> model.predict(2,2) is not None
True True
>>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model.predictAll(testset).count == 2
True
""" """
def __init__(self, sc, java_model): def __init__(self, sc, java_model):
...@@ -50,8 +52,8 @@ class MatrixFactorizationModel(object): ...@@ -50,8 +52,8 @@ class MatrixFactorizationModel(object):
def predictAll(self, usersProducts): def predictAll(self, usersProducts):
usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple) usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd), return RDD(self._java_model.predict(usersProductsJRDD._jrdd),
self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize)) self._context, RatingDeserializer())
class ALS(object): class ALS(object):
@classmethod @classmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment