Skip to content
Snippets Groups Projects
Commit a2b47dae authored by Reynold Xin's avatar Reynold Xin
Browse files

Merge pull request #499 from jianpingjwang/dev1

Replace commons-math with jblas in SVDPlusPlus
parents a1cd1851 19a01c1b
No related branches found
No related tags found
No related merge requests found
...@@ -38,15 +38,14 @@ ...@@ -38,15 +38,14 @@
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.apache.commons</groupId> <groupId>org.jblas</groupId>
<artifactId>commons-math3</artifactId> <artifactId>jblas</artifactId>
<version>3.2</version> <version>1.2.3</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.eclipse.jetty</groupId> <groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId> <artifactId>jetty-server</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.scalatest</groupId> <groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.binary.version}</artifactId> <artifactId>scalatest_${scala.binary.version}</artifactId>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
package org.apache.spark.graphx.lib package org.apache.spark.graphx.lib
import scala.util.Random import scala.util.Random
import org.apache.commons.math3.linear._ import org.jblas.DoubleMatrix
import org.apache.spark.rdd._ import org.apache.spark.rdd._
import org.apache.spark.graphx._ import org.apache.spark.graphx._
...@@ -52,15 +52,15 @@ object SVDPlusPlus { ...@@ -52,15 +52,15 @@ object SVDPlusPlus {
* @return a graph with vertex attributes containing the trained model * @return a graph with vertex attributes containing the trained model
*/ */
def run(edges: RDD[Edge[Double]], conf: Conf) def run(edges: RDD[Edge[Double]], conf: Conf)
: (Graph[(RealVector, RealVector, Double, Double), Double], Double) = : (Graph[(DoubleMatrix, DoubleMatrix, Double, Double), Double], Double) =
{ {
// Generate default vertex attribute // Generate default vertex attribute
def defaultF(rank: Int): (RealVector, RealVector, Double, Double) = { def defaultF(rank: Int): (DoubleMatrix, DoubleMatrix, Double, Double) = {
val v1 = new ArrayRealVector(rank) val v1 = new DoubleMatrix(rank)
val v2 = new ArrayRealVector(rank) val v2 = new DoubleMatrix(rank)
for (i <- 0 until rank) { for (i <- 0 until rank) {
v1.setEntry(i, Random.nextDouble()) v1.put(i, Random.nextDouble())
v2.setEntry(i, Random.nextDouble()) v2.put(i, Random.nextDouble())
} }
(v1, v2, 0.0, 0.0) (v1, v2, 0.0, 0.0)
} }
...@@ -76,31 +76,32 @@ object SVDPlusPlus { ...@@ -76,31 +76,32 @@ object SVDPlusPlus {
// Calculate initial bias and norm // Calculate initial bias and norm
val t0 = g.mapReduceTriplets( val t0 = g.mapReduceTriplets(
et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))), et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))),
(g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2)) (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2))
g = g.outerJoinVertices(t0) { g = g.outerJoinVertices(t0) {
(vid: VertexId, vd: (RealVector, RealVector, Double, Double), msg: Option[(Long, Double)]) => (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
msg: Option[(Long, Double)]) =>
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
} }
def mapTrainF(conf: Conf, u: Double) def mapTrainF(conf: Conf, u: Double)
(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double])
: Iterator[(VertexId, (RealVector, RealVector, Double))] = { : Iterator[(VertexId, (DoubleMatrix, DoubleMatrix, Double))] = {
val (usr, itm) = (et.srcAttr, et.dstAttr) val (usr, itm) = (et.srcAttr, et.dstAttr)
val (p, q) = (usr._1, itm._1) val (p, q) = (usr._1, itm._1)
var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2) var pred = u + usr._3 + itm._3 + q.dot(usr._2)
pred = math.max(pred, conf.minVal) pred = math.max(pred, conf.minVal)
pred = math.min(pred, conf.maxVal) pred = math.min(pred, conf.maxVal)
val err = et.attr - pred val err = et.attr - pred
val updateP = q.mapMultiply(err) val updateP = q.mul(err)
.subtract(p.mapMultiply(conf.gamma7)) .subColumnVector(p.mul(conf.gamma7))
.mapMultiply(conf.gamma2) .mul(conf.gamma2)
val updateQ = usr._2.mapMultiply(err) val updateQ = usr._2.mul(err)
.subtract(q.mapMultiply(conf.gamma7)) .subColumnVector(q.mul(conf.gamma7))
.mapMultiply(conf.gamma2) .mul(conf.gamma2)
val updateY = q.mapMultiply(err * usr._4) val updateY = q.mul(err * usr._4)
.subtract(itm._2.mapMultiply(conf.gamma7)) .subColumnVector(itm._2.mul(conf.gamma7))
.mapMultiply(conf.gamma2) .mul(conf.gamma2)
Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)), Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)),
(et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))) (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)))
} }
...@@ -110,34 +111,37 @@ object SVDPlusPlus { ...@@ -110,34 +111,37 @@ object SVDPlusPlus {
g.cache() g.cache()
val t1 = g.mapReduceTriplets( val t1 = g.mapReduceTriplets(
et => Iterator((et.srcId, et.dstAttr._2)), et => Iterator((et.srcId, et.dstAttr._2)),
(g1: RealVector, g2: RealVector) => g1.add(g2)) (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2))
g = g.outerJoinVertices(t1) { g = g.outerJoinVertices(t1) {
(vid: VertexId, vd: (RealVector, RealVector, Double, Double), msg: Option[RealVector]) => (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
if (msg.isDefined) (vd._1, vd._1.add(msg.get.mapMultiply(vd._4)), vd._3, vd._4) else vd msg: Option[DoubleMatrix]) =>
if (msg.isDefined) (vd._1, vd._1
.addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd
} }
// Phase 2, update p for user nodes and q, y for item nodes // Phase 2, update p for user nodes and q, y for item nodes
g.cache() g.cache()
val t2 = g.mapReduceTriplets( val t2 = g.mapReduceTriplets(
mapTrainF(conf, u), mapTrainF(conf, u),
(g1: (RealVector, RealVector, Double), g2: (RealVector, RealVector, Double)) => (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) =>
(g1._1.add(g2._1), g1._2.add(g2._2), g1._3 + g2._3)) (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3))
g = g.outerJoinVertices(t2) { g = g.outerJoinVertices(t2) {
(vid: VertexId, (vid: VertexId,
vd: (RealVector, RealVector, Double, Double), vd: (DoubleMatrix, DoubleMatrix, Double, Double),
msg: Option[(RealVector, RealVector, Double)]) => msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) =>
(vd._1.add(msg.get._1), vd._2.add(msg.get._2), vd._3 + msg.get._3, vd._4) (vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2),
vd._3 + msg.get._3, vd._4)
} }
} }
// calculate error on training set // calculate error on training set
def mapTestF(conf: Conf, u: Double) def mapTestF(conf: Conf, u: Double)
(et: EdgeTriplet[(RealVector, RealVector, Double, Double), Double]) (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double])
: Iterator[(VertexId, Double)] = : Iterator[(VertexId, Double)] =
{ {
val (usr, itm) = (et.srcAttr, et.dstAttr) val (usr, itm) = (et.srcAttr, et.dstAttr)
val (p, q) = (usr._1, itm._1) val (p, q) = (usr._1, itm._1)
var pred = u + usr._3 + itm._3 + q.dotProduct(usr._2) var pred = u + usr._3 + itm._3 + q.dot(usr._2)
pred = math.max(pred, conf.minVal) pred = math.max(pred, conf.minVal)
pred = math.min(pred, conf.maxVal) pred = math.min(pred, conf.maxVal)
val err = (et.attr - pred) * (et.attr - pred) val err = (et.attr - pred) * (et.attr - pred)
...@@ -146,7 +150,7 @@ object SVDPlusPlus { ...@@ -146,7 +150,7 @@ object SVDPlusPlus {
g.cache() g.cache()
val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2) val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2)
g = g.outerJoinVertices(t3) { g = g.outerJoinVertices(t3) {
(vid: VertexId, vd: (RealVector, RealVector, Double, Double), msg: Option[Double]) => (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) =>
if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
} }
......
...@@ -318,7 +318,7 @@ object SparkBuild extends Build { ...@@ -318,7 +318,7 @@ object SparkBuild extends Build {
def graphxSettings = sharedSettings ++ Seq( def graphxSettings = sharedSettings ++ Seq(
name := "spark-graphx", name := "spark-graphx",
libraryDependencies ++= Seq( libraryDependencies ++= Seq(
"org.apache.commons" % "commons-math3" % "3.2" "org.jblas" % "jblas" % "1.2.3"
) )
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment