Skip to content
Snippets Groups Projects
Commit c823ee1e authored by Xinghao's avatar Xinghao
Browse files

Replace map-reduce with dot operator using DoubleMatrix

parent 96e04f4c
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ object LassoGenerator { ...@@ -26,7 +26,7 @@ object LassoGenerator {
val sc = new SparkContext(sparkMaster, "LassoGenerator") val sc = new SparkContext(sparkMaster, "LassoGenerator")
val globalRnd = new Random(94720) val globalRnd = new Random(94720)
val trueWeights = Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() } val trueWeights = new DoubleMatrix(1, nfeatures+1, Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }:_*)
val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx => val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx =>
val rnd = new Random(42 + idx) val rnd = new Random(42 + idx)
...@@ -34,7 +34,7 @@ object LassoGenerator { ...@@ -34,7 +34,7 @@ object LassoGenerator {
val x = Array.fill[Double](nfeatures) { val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0 rnd.nextDouble() * 2.0 - 1.0
} }
val y = ((1.0 +: x) zip trueWeights).map{wx => wx._1 * wx._2}.reduceLeft(_+_) + rnd.nextGaussian() * 0.1 val y = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1
(y, x) (y, x)
} }
......
...@@ -8,6 +8,8 @@ import org.jblas.DoubleMatrix ...@@ -8,6 +8,8 @@ import org.jblas.DoubleMatrix
import spark.{RDD, SparkContext} import spark.{RDD, SparkContext}
import spark.mllib.util.MLUtils import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix
object SVMGenerator { object SVMGenerator {
def main(args: Array[String]) { def main(args: Array[String]) {
...@@ -27,7 +29,7 @@ object SVMGenerator { ...@@ -27,7 +29,7 @@ object SVMGenerator {
val sc = new SparkContext(sparkMaster, "SVMGenerator") val sc = new SparkContext(sparkMaster, "SVMGenerator")
val globalRnd = new Random(94720) val globalRnd = new Random(94720)
val trueWeights = Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() } val trueWeights = new DoubleMatrix(1, nfeatures+1, Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }:_*)
val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx => val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx =>
val rnd = new Random(42 + idx) val rnd = new Random(42 + idx)
...@@ -35,7 +37,7 @@ object SVMGenerator { ...@@ -35,7 +37,7 @@ object SVMGenerator {
val x = Array.fill[Double](nfeatures) { val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0 rnd.nextDouble() * 2.0 - 1.0
} }
val y = signum(((1.0 +: x) zip trueWeights).map{wx => wx._1 * wx._2}.reduceLeft(_+_) + rnd.nextGaussian() * 0.1) val y = signum((new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1)
(y, x) (y, x)
} }
......
...@@ -25,6 +25,8 @@ import org.scalatest.FunSuite ...@@ -25,6 +25,8 @@ import org.scalatest.FunSuite
import spark.SparkContext import spark.SparkContext
import org.jblas.DoubleMatrix
class SVMSuite extends FunSuite with BeforeAndAfterAll { class SVMSuite extends FunSuite with BeforeAndAfterAll {
val sc = new SparkContext("local", "test") val sc = new SparkContext("local", "test")
...@@ -38,16 +40,17 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { ...@@ -38,16 +40,17 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
intercept: Double, intercept: Double,
weights: Array[Double], weights: Array[Double],
nPoints: Int, nPoints: Int,
seed: Int): Seq[(Double, Array[Double])] = { seed: Int): Seq[(Int, Array[Double])] = {
val rnd = new Random(seed) val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian())) val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian()))
val y = x.map(xi => val y = x.map(xi =>
signum((xi zip weights).map(xw => xw._1*xw._2).reduce(_+_) + intercept + 0.1 * rnd.nextGaussian()) signum((new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + 0.1 * rnd.nextGaussian()).toInt
) )
y zip x y zip x
} }
def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) { def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) {
val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
// A prediction is off if the prediction is more than 0.5 away from expected value. // A prediction is off if the prediction is more than 0.5 away from expected value.
math.abs(prediction - expected) > 0.5 math.abs(prediction - expected) > 0.5
......
...@@ -24,6 +24,8 @@ import org.scalatest.FunSuite ...@@ -24,6 +24,8 @@ import org.scalatest.FunSuite
import spark.SparkContext import spark.SparkContext
import org.jblas.DoubleMatrix
class LassoSuite extends FunSuite with BeforeAndAfterAll { class LassoSuite extends FunSuite with BeforeAndAfterAll {
val sc = new SparkContext("local", "test") val sc = new SparkContext("local", "test")
...@@ -40,8 +42,11 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll { ...@@ -40,8 +42,11 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
nPoints: Int, nPoints: Int,
seed: Int): Seq[(Double, Array[Double])] = { seed: Int): Seq[(Double, Array[Double])] = {
val rnd = new Random(seed) val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian())) val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian()))
val y = x.map(xi => (xi zip weights).map(xw => xw._1*xw._2).reduce(_+_) + intercept + 0.1 * rnd.nextGaussian()) val y = x.map(xi =>
(new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + 0.1 * rnd.nextGaussian()
)
y zip x y zip x
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment