Skip to content
Snippets Groups Projects
Commit c5c38d19 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Some optimizations to loading phase of ALS

parent b91a218c
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ package spark.mllib.recommendation ...@@ -2,6 +2,7 @@ package spark.mllib.recommendation
import scala.collection.mutable.{ArrayBuffer, BitSet} import scala.collection.mutable.{ArrayBuffer, BitSet}
import scala.util.Random import scala.util.Random
import scala.util.Sorting
import spark.{HashPartitioner, Partitioner, SparkContext, RDD} import spark.{HashPartitioner, Partitioner, SparkContext, RDD}
import spark.storage.StorageLevel import spark.storage.StorageLevel
...@@ -33,6 +34,12 @@ private[recommendation] case class InLinkBlock( ...@@ -33,6 +34,12 @@ private[recommendation] case class InLinkBlock(
elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]]) elementIds: Array[Int], ratingsForBlock: Array[Array[(Array[Int], Array[Double])]])
/**
* A more compact class to represent a rating than Tuple3[Int, Int, Double].
*/
private[recommendation] case class Rating(user: Int, product: Int, rating: Double)
/** /**
* Alternating Least Squares matrix factorization. * Alternating Least Squares matrix factorization.
* *
...@@ -126,13 +133,13 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l ...@@ -126,13 +133,13 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* Make the out-links table for a block of the users (or products) dataset given the list of * Make the out-links table for a block of the users (or products) dataset given the list of
* (user, product, rating) values for the users in that block (or the opposite for products). * (user, product, rating) values for the users in that block (or the opposite for products).
*/ */
private def makeOutLinkBlock(numBlocks: Int, ratings: Array[(Int, Int, Double)]): OutLinkBlock = { private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating]): OutLinkBlock = {
val userIds = ratings.map(_._1).distinct.sorted val userIds = ratings.map(_.user).distinct.sorted
val numUsers = userIds.length val numUsers = userIds.length
val userIdToPos = userIds.zipWithIndex.toMap val userIdToPos = userIds.zipWithIndex.toMap
val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks)) val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks))
for ((u, p, r) <- ratings) { for (r <- ratings) {
shouldSend(userIdToPos(u))(p % numBlocks) = true shouldSend(userIdToPos(r.user))(r.product % numBlocks) = true
} }
OutLinkBlock(userIds, shouldSend) OutLinkBlock(userIds, shouldSend)
} }
...@@ -141,18 +148,28 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l ...@@ -141,18 +148,28 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
* Make the in-links table for a block of the users (or products) dataset given a list of * Make the in-links table for a block of the users (or products) dataset given a list of
* (user, product, rating) values for the users in that block (or the opposite for products). * (user, product, rating) values for the users in that block (or the opposite for products).
*/ */
private def makeInLinkBlock(numBlocks: Int, ratings: Array[(Int, Int, Double)]): InLinkBlock = { private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating]): InLinkBlock = {
val userIds = ratings.map(_._1).distinct.sorted val userIds = ratings.map(_.user).distinct.sorted
val numUsers = userIds.length val numUsers = userIds.length
val userIdToPos = userIds.zipWithIndex.toMap val userIdToPos = userIds.zipWithIndex.toMap
// Split out our ratings by product block
val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating])
for (r <- ratings) {
blockRatings(r.product % numBlocks) += r
}
val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks) val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks)
for (productBlock <- 0 until numBlocks) { for (productBlock <- 0 until numBlocks) {
val ratingsInBlock = ratings.filter(t => t._2 % numBlocks == productBlock) // Create an array of (product, Seq(Rating)) ratings
val ratingsByProduct = ratingsInBlock.groupBy(_._2) // (p, Seq[(u, p, r)]) val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray
.toArray // Sort them by user ID
.sortBy(_._1) val ordering = new Ordering[(Int, ArrayBuffer[Rating])] {
.map{case (p, rs) => (rs.map(t => userIdToPos(t._1)), rs.map(_._3))} def compare(a: (Int, ArrayBuffer[Rating]), b: (Int, ArrayBuffer[Rating])): Int = a._1 - b._1
ratingsForBlock(productBlock) = ratingsByProduct }
Sorting.quickSort(groupedRatings)(ordering)
// Translate the user IDs to indices based on userIdToPos
ratingsForBlock(productBlock) = groupedRatings.map { case (p, rs) =>
(rs.view.map(r => userIdToPos(r.user)).toArray, rs.view.map(_.rating).toArray)
}
} }
InLinkBlock(userIds, ratingsForBlock) InLinkBlock(userIds, ratingsForBlock)
} }
...@@ -167,7 +184,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l ...@@ -167,7 +184,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
{ {
val grouped = ratings.partitionBy(new HashPartitioner(numBlocks)) val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
val links = grouped.mapPartitionsWithIndex((blockId, elements) => { val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
val ratings = elements.map(_._2).toArray val ratings = elements.map{case (k, t) => Rating(t._1, t._2, t._3)}.toArray
val inLinkBlock = makeInLinkBlock(numBlocks, ratings) val inLinkBlock = makeInLinkBlock(numBlocks, ratings)
val outLinkBlock = makeOutLinkBlock(numBlocks, ratings) val outLinkBlock = makeOutLinkBlock(numBlocks, ratings)
Iterator.single((blockId, (inLinkBlock, outLinkBlock))) Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
...@@ -373,6 +390,8 @@ object ALS { ...@@ -373,6 +390,8 @@ object ALS {
} }
val (master, ratingsFile, rank, iters, outputDir) = val (master, ratingsFile, rank, iters, outputDir) =
(args(0), args(1), args(2).toInt, args(3).toInt, args(4)) (args(0), args(1), args(2).toInt, args(3).toInt, args(4))
System.setProperty("spark.serializer", "spark.KryoSerializer")
System.setProperty("spark.locality.wait", "10000")
val sc = new SparkContext(master, "ALS") val sc = new SparkContext(master, "ALS")
val ratings = sc.textFile(ratingsFile).map { line => val ratings = sc.textFile(ratingsFile).map { line =>
val fields = line.split(',') val fields = line.split(',')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment