diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 93aa41e49961e890b03beff4d922fb792e969dd6..43d219a49cf4ebfdbd1b9cb4c57b6a194912c5e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -22,6 +22,7 @@ import java.lang.{Integer => JavaInteger}
 
 import scala.collection.mutable
 
+import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 import org.apache.hadoop.fs.Path
 import org.json4s._
@@ -79,6 +80,30 @@ class MatrixFactorizationModel(
     blas.ddot(rank, userVector, 1, productVector, 1)
   }
 
+  /**
+   * Return approximate numbers of users and products in the given usersProducts tuples.
+   * This method is based on `countApproxDistinct` in class `RDD`.
+   *
+   * @param usersProducts  RDD of (user, product) pairs.
+   * @return approximate numbers of users and products.
+   */
+  private[this] def countApproxDistinctUserProduct(usersProducts: RDD[(Int, Int)]): (Long, Long) = {
+    val zeroCounterUser = new HyperLogLogPlus(4, 0)
+    val zeroCounterProduct = new HyperLogLogPlus(4, 0)
+    val aggregated = usersProducts.aggregate((zeroCounterUser, zeroCounterProduct))(
+      (hllTuple: (HyperLogLogPlus, HyperLogLogPlus), v: (Int, Int)) => {
+        hllTuple._1.offer(v._1)
+        hllTuple._2.offer(v._2)
+        hllTuple
+      },
+      (h1: (HyperLogLogPlus, HyperLogLogPlus), h2: (HyperLogLogPlus, HyperLogLogPlus)) => {
+        h1._1.addAll(h2._1)
+        h1._2.addAll(h2._2)
+        h1
+      })
+    (aggregated._1.cardinality(), aggregated._2.cardinality())
+  }
+
   /**
    * Predict the rating of many users for many products.
    * The output RDD has an element per each element in the input RDD (including all duplicates)
@@ -88,12 +113,30 @@ class MatrixFactorizationModel(
    * @return RDD of Ratings.
    */
   def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = {
-    val users = userFeatures.join(usersProducts).map {
-      case (user, (uFeatures, product)) => (product, (user, uFeatures))
-    }
-    users.join(productFeatures).map {
-      case (product, ((user, uFeatures), pFeatures)) =>
-        Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+    // Previously the partitions of ratings are only based on the given products.
+    // So if the usersProducts given for prediction contains only few products or
+    // even one product, the generated ratings will be pushed into few or single partition
+    // and can't use high parallelism.
+    // Here we calculate approximate numbers of users and products. Then we decide the
+    // partitions should be based on users or products.
+    val (usersCount, productsCount) = countApproxDistinctUserProduct(usersProducts)
+
+    if (usersCount < productsCount) {
+      val users = userFeatures.join(usersProducts).map {
+        case (user, (uFeatures, product)) => (product, (user, uFeatures))
+      }
+      users.join(productFeatures).map {
+        case (product, ((user, uFeatures), pFeatures)) =>
+          Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+      }
+    } else {
+      val products = productFeatures.join(usersProducts.map(_.swap)).map {
+        case (product, (pFeatures, user)) => (user, (product, pFeatures))
+      }
+      products.join(userFeatures).map {
+        case (user, ((product, pFeatures), uFeatures)) =>
+          Rating(user, product, blas.ddot(uFeatures.length, uFeatures, 1, pFeatures, 1))
+      }
     }
   }