Skip to content
Snippets Groups Projects
Commit 5d473f05 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #376 from MLnick/python-als

Python ALS example
parents 922c5ec0 a5ba7a9f
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import cern.jet.math._
import cern.colt.matrix._
import cern.colt.matrix.linalg._
import spark._
import scala.Option
object SparkALS {
// Parameters set through command line arguments
......@@ -42,7 +43,7 @@ object SparkALS {
return sqrt(sumSqs / (M * U))
}
def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
R: DoubleMatrix2D) : DoubleMatrix1D =
{
val U = us.size
......@@ -68,50 +69,30 @@ object SparkALS {
return solved2D.viewColumn(0)
}
def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
R: DoubleMatrix2D) : DoubleMatrix1D =
{
val M = ms.size
val F = ms(0).size
val XtX = factory2D.make(F, F)
val Xty = factory1D.make(F)
// For each movie that the user rated
for (i <- 0 until M) {
val m = ms(i)
// Add m * m^t to XtX
blas.dger(1, m, m, XtX)
// Add m * rating to Xty
blas.daxpy(R.get(i, j), m, Xty)
}
// Add regularization coefs to diagonal terms
for (d <- 0 until F) {
XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
}
// Solve it with Cholesky
val ch = new CholeskyDecomposition(XtX)
val Xty2D = factory2D.make(Xty.toArray, F)
val solved2D = ch.solve(Xty2D)
return solved2D.viewColumn(0)
}
def main(args: Array[String]) {
var host = ""
var slices = 0
args match {
case Array(m, u, f, iters, slices_, host_) => {
M = m.toInt
U = u.toInt
F = f.toInt
ITERATIONS = iters.toInt
slices = slices_.toInt
host = host_
(0 to 5).map(i => {
i match {
case a if a < args.length => Option(args(a))
case _ => Option(null)
}
}).toArray match {
case Array(host_, m, u, f, iters, slices_) => {
host = host_ getOrElse "local"
M = (m getOrElse "100").toInt
U = (u getOrElse "500").toInt
F = (f getOrElse "10").toInt
ITERATIONS = (iters getOrElse "5").toInt
slices = (slices_ getOrElse "2").toInt
}
case _ => {
System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <master>")
System.err.println("Usage: SparkALS [<master> <M> <U> <F> <iters> <slices>]")
System.exit(1)
}
}
printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS);
printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS)
val spark = new SparkContext(host, "SparkALS")
val R = generateR()
......@@ -127,11 +108,11 @@ object SparkALS {
for (iter <- 1 to ITERATIONS) {
println("Iteration " + iter + ":")
ms = spark.parallelize(0 until M, slices)
.map(i => updateMovie(i, msc.value(i), usc.value, Rc.value))
.map(i => update(i, msc.value(i), usc.value, Rc.value))
.toArray
msc = spark.broadcast(ms) // Re-broadcast ms because it was updated
us = spark.parallelize(0 until U, slices)
.map(i => updateUser(i, usc.value(i), msc.value, Rc.value))
.map(i => update(i, usc.value(i), msc.value, algebra.transpose(Rc.value)))
.toArray
usc = spark.broadcast(us) // Re-broadcast us because it was updated
println("RMSE = " + rmse(R, ms, us))
......
"""
This example requires numpy (http://www.numpy.org/)
"""
from os.path import realpath
import sys
import numpy as np
from numpy.random import rand
from numpy import matrix
from pyspark import SparkContext
LAMBDA = 0.01 # regularization
np.random.seed(42)
def rmse(R, ms, us):
diff = R - ms * us.T
return np.sqrt(np.sum(np.power(diff, 2)) / M * U)
def update(i, vec, mat, ratings):
uu = mat.shape[0]
ff = mat.shape[1]
XtX = matrix(np.zeros((ff, ff)))
Xty = np.zeros((ff, 1))
for j in range(uu):
v = mat[j, :]
XtX += v.T * v
Xty += v.T * ratings[i, j]
XtX += np.eye(ff, ff) * LAMBDA * uu
return np.linalg.solve(XtX, Xty)
if __name__ == "__main__":
if len(sys.argv) < 2:
print >> sys.stderr, \
"Usage: PythonALS <master> <M> <U> <F> <iters> <slices>"
exit(-1)
sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)])
M = int(sys.argv[2]) if len(sys.argv) > 2 else 100
U = int(sys.argv[3]) if len(sys.argv) > 3 else 500
F = int(sys.argv[4]) if len(sys.argv) > 4 else 10
ITERATIONS = int(sys.argv[5]) if len(sys.argv) > 5 else 5
slices = int(sys.argv[6]) if len(sys.argv) > 6 else 2
print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \
(M, U, F, ITERATIONS, slices)
R = matrix(rand(M, F)) * matrix(rand(U, F).T)
ms = matrix(rand(M ,F))
us = matrix(rand(U, F))
Rb = sc.broadcast(R)
msb = sc.broadcast(ms)
usb = sc.broadcast(us)
for i in range(ITERATIONS):
ms = sc.parallelize(range(M), slices) \
.map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \
.collect()
ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being
# a 3-d array, we take the first 2 dims for the matrix
msb = sc.broadcast(ms)
us = sc.parallelize(range(U), slices) \
.map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \
.collect()
us = matrix(np.array(us)[:, :, 0])
usb = sc.broadcast(us)
error = rmse(R, ms, us)
print "Iteration %d:" % i
print "\nRMSE: %5.4f\n" % error
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment