diff --git a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
index 2abaf2f2dd8a32cdcb6fb4a1334c7632ef09999a..4c18cbdc6bc32494e5fb8f82de7f67be13863f32 100644
--- a/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/spark/mllib/recommendation/ALS.scala
@@ -6,8 +6,10 @@ import scala.util.Sorting
 
 import spark.{HashPartitioner, Partitioner, SparkContext, RDD}
 import spark.storage.StorageLevel
+import spark.KryoRegistrator
 import spark.SparkContext._
 
+import com.esotericsoftware.kryo.Kryo
 import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
 
 
@@ -98,8 +100,8 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
 
     val partitioner = new HashPartitioner(numBlocks)
 
-    val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, (u, p, r)) }
-    val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, (p, u, r)) }
+    val ratingsByUserBlock = ratings.map{ case (u, p, r) => (u % numBlocks, Rating(u, p, r)) }
+    val ratingsByProductBlock = ratings.map{ case (u, p, r) => (p % numBlocks, Rating(p, u, r)) }
 
     val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock)
     val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
@@ -179,12 +181,12 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
    * the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid
    * having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it.
    */
-  private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, (Int, Int, Double))])
+  private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)])
     : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) =
   {
     val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
     val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
-      val ratings = elements.map{case (k, t) => Rating(t._1, t._2, t._3)}.toArray
+      val ratings = elements.map{_._2}.toArray
       val inLinkBlock = makeInLinkBlock(numBlocks, ratings)
       val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
       Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
@@ -383,6 +385,12 @@ object ALS {
     train(ratings, rank, iterations, 0.01, -1)
   }
 
+  private class ALSRegistrator extends KryoRegistrator {
+    override def registerClasses(kryo: Kryo) {
+      kryo.register(classOf[Rating])
+    }
+  }
+
   def main(args: Array[String]) {
     if (args.length != 5 && args.length != 6) {
       println("Usage: ALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]")
@@ -392,6 +400,8 @@ object ALS {
       (args(0), args(1), args(2).toInt, args(3).toInt, args(4))
     val blocks = if (args.length == 6) args(5).toInt else -1
     System.setProperty("spark.serializer", "spark.KryoSerializer")
+    System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName)
+    System.setProperty("spark.kryo.referenceTracking", "false")
     System.setProperty("spark.locality.wait", "10000")
     val sc = new SparkContext(master, "ALS")
     val ratings = sc.textFile(ratingsFile).map { line =>