diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 8b5bb79a18dbaeee47ed4780dbf2578ffbc93ca8..3d73d95909270e5e7ebe68a64c8444bfff6a49c8 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -43,9 +43,10 @@ from pyspark.rdd import RDD
 from pyspark.files import SparkFiles
 from pyspark.storagelevel import StorageLevel
 from pyspark.mllib import LinearRegressionModel, LassoModel, \
-    RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel
+    RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel, \
+    ALSModel
 
 
 __all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel",
     "LinearRegressionModel", "LassoModel", "RidgeRegressionModel",
-    "LogisticRegressionModel", "SVMModel", "KMeansModel"];
+    "LogisticRegressionModel", "SVMModel", "KMeansModel", "ALSModel"];
diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py
index 8848284a5e927c44989d1aaac47840ec78a05540..22187eb4dd67a451ba830b113662a14ae5e8ab19 100644
--- a/python/pyspark/mllib.py
+++ b/python/pyspark/mllib.py
@@ -164,14 +164,17 @@ class LinearRegressionModelBase(LinearModel):
         _linear_predictor_typecheck(x, self._coeff)
         return dot(self._coeff, x) + self._intercept
 
-# Map a pickled Python RDD of numpy double vectors to a Java RDD of
-# _serialized_double_vectors
-def _get_unmangled_double_vector_rdd(data):
-    dataBytes = data.map(_serialize_double_vector)
+def _get_unmangled_rdd(data, serializer):
+    dataBytes = data.map(serializer)
     dataBytes._bypass_serializer = True
     dataBytes.cache()
     return dataBytes
 
+# Map a pickled Python RDD of numpy double vectors to a Java RDD of
+# _serialized_double_vectors
+def _get_unmangled_double_vector_rdd(data):
+    return _get_unmangled_rdd(data, _serialize_double_vector)
+
 # If we weren't given initial weights, take a zero vector of the appropriate
 # length.
 def _get_initial_weights(initial_weights, data):
@@ -317,7 +320,7 @@ class KMeansModel(object):
         return best
 
     @classmethod
-    def train(cls, sc, data, k, maxIterations = 100, runs = 1,
+    def train(cls, sc, data, k, maxIterations=100, runs=1,
             initialization_mode="k-means||"):
         """Train a k-means clustering model."""
         dataBytes = _get_unmangled_double_vector_rdd(data)
@@ -330,12 +333,56 @@ class KMeansModel(object):
                     + type(ans[0]) + " which is not bytearray")
         return KMeansModel(_deserialize_double_matrix(ans[0]))
 
+def _serialize_rating(r):
+    ba = bytearray(16)
+    intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
+    doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8)
+    intpart[0], intpart[1], doublepart[0] = r
+    return ba
+
+class ALSModel(object):
+    """A matrix factorisation model trained by regularized alternating
+    least-squares.
+
+    >>> r1 = (1, 1, 1.0)
+    >>> r2 = (1, 2, 2.0)
+    >>> r3 = (2, 1, 2.0)
+    >>> ratings = sc.parallelize([r1, r2, r3])
+    >>> model = ALSModel.trainImplicit(sc, ratings, 1)
+    >>> model.predict(2,2) is not None
+    True
+    """
+
+    def __init__(self, sc, java_model):
+        self._context = sc
+        self._java_model = java_model
+
+    #def __del__(self):
+        #self._gateway.detach(self._java_model)
+
+    def predict(self, user, product):
+        return self._java_model.predict(user, product)
+
+    @classmethod
+    def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
+        ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+        mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd,
+                rank, iterations, lambda_, blocks)
+        return ALSModel(sc, mod)
+
+    @classmethod
+    def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
+        ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+        mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd,
+                rank, iterations, lambda_, blocks, alpha)
+        return ALSModel(sc, mod)
+
 def _test():
     import doctest
     globs = globals().copy()
     globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
     (failure_count, test_count) = doctest.testmod(globs=globs,
-        optionflags=doctest.ELLIPSIS)
+            optionflags=doctest.ELLIPSIS)
     globs['sc'].stop()
     print failure_count,"failures among",test_count,"tests"
     if failure_count: