Skip to content
Snippets Groups Projects
Commit f00e949f authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Added Java unit test, data, and main method for Naive Bayes

Also fixes mains of a few other algorithms to print the final model
parent 4c28a2ba
No related branches found
No related tags found
No related merge requests found
0, 1 0 0
0, 2 0 0
1, 0 1 0
1, 0 2 0
2, 0 0 1
2, 0 0 2
...@@ -97,7 +97,7 @@ object LogisticRegressionWithSGD { ...@@ -97,7 +97,7 @@ object LogisticRegressionWithSGD {
* @param numIterations Number of iterations of gradient descent to run. * @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent. * @param stepSize Step size to be used for each iteration of gradient descent.
* @param miniBatchFraction Fraction of data to be used per iteration. * @param miniBatchFraction Fraction of data to be used per iteration.
* @param initialWeights Initial set of weights to be used. Array should be equal in size to * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data. * the number of features in the data.
*/ */
def train( def train(
...@@ -183,6 +183,8 @@ object LogisticRegressionWithSGD { ...@@ -183,6 +183,8 @@ object LogisticRegressionWithSGD {
val sc = new SparkContext(args(0), "LogisticRegression") val sc = new SparkContext(args(0), "LogisticRegression")
val data = MLUtils.loadLabeledData(sc, args(1)) val data = MLUtils.loadLabeledData(sc, args(1))
val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
println("Weights: " + model.weights.mkString("[", ", ", "]"))
println("Intercept: " + model.intercept)
sc.stop() sc.stop()
} }
......
...@@ -21,9 +21,10 @@ import scala.collection.mutable ...@@ -21,9 +21,10 @@ import scala.collection.mutable
import org.jblas.DoubleMatrix import org.jblas.DoubleMatrix
import org.apache.spark.Logging import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.util.MLUtils
/** /**
* Model for Naive Bayes Classifiers. * Model for Naive Bayes Classifiers.
...@@ -144,4 +145,22 @@ object NaiveBayes { ...@@ -144,4 +145,22 @@ object NaiveBayes {
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
new NaiveBayes(lambda).run(input) new NaiveBayes(lambda).run(input)
} }
def main(args: Array[String]) {
if (args.length != 2 && args.length != 3) {
println("Usage: NaiveBayes <master> <input_dir> [<lambda>]")
System.exit(1)
}
val sc = new SparkContext(args(0), "NaiveBayes")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = if (args.length == 2) {
NaiveBayes.train(data)
} else {
NaiveBayes.train(data, args(2).toDouble)
}
println("Pi: " + model.pi.mkString("[", ", ", "]"))
println("Theta:\n" + model.theta.map(_.mkString("[", ", ", "]")).mkString("[", "\n ", "]"))
sc.stop()
}
} }
...@@ -183,6 +183,8 @@ object SVMWithSGD { ...@@ -183,6 +183,8 @@ object SVMWithSGD {
val sc = new SparkContext(args(0), "SVM") val sc = new SparkContext(args(0), "SVM")
val data = MLUtils.loadLabeledData(sc, args(1)) val data = MLUtils.loadLabeledData(sc, args(1))
val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
println("Weights: " + model.weights.mkString("[", ", ", "]"))
println("Intercept: " + model.intercept)
sc.stop() sc.stop()
} }
......
...@@ -121,7 +121,7 @@ object LassoWithSGD { ...@@ -121,7 +121,7 @@ object LassoWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent. * @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter. * @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration. * @param miniBatchFraction Fraction of data to be used per iteration.
* @param initialWeights Initial set of weights to be used. Array should be equal in size to * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data. * the number of features in the data.
*/ */
def train( def train(
...@@ -205,6 +205,8 @@ object LassoWithSGD { ...@@ -205,6 +205,8 @@ object LassoWithSGD {
val sc = new SparkContext(args(0), "Lasso") val sc = new SparkContext(args(0), "Lasso")
val data = MLUtils.loadLabeledData(sc, args(1)) val data = MLUtils.loadLabeledData(sc, args(1))
val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
println("Weights: " + model.weights.mkString("[", ", ", "]"))
println("Intercept: " + model.intercept)
sc.stop() sc.stop()
} }
......
...@@ -162,6 +162,8 @@ object LinearRegressionWithSGD { ...@@ -162,6 +162,8 @@ object LinearRegressionWithSGD {
val sc = new SparkContext(args(0), "LinearRegression") val sc = new SparkContext(args(0), "LinearRegression")
val data = MLUtils.loadLabeledData(sc, args(1)) val data = MLUtils.loadLabeledData(sc, args(1))
val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
println("Weights: " + model.weights.mkString("[", ", ", "]"))
println("Intercept: " + model.intercept)
sc.stop() sc.stop()
} }
......
...@@ -122,7 +122,7 @@ object RidgeRegressionWithSGD { ...@@ -122,7 +122,7 @@ object RidgeRegressionWithSGD {
* @param stepSize Step size to be used for each iteration of gradient descent. * @param stepSize Step size to be used for each iteration of gradient descent.
* @param regParam Regularization parameter. * @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration. * @param miniBatchFraction Fraction of data to be used per iteration.
* @param initialWeights Initial set of weights to be used. Array should be equal in size to * @param initialWeights Initial set of weights to be used. Array should be equal in size to
* the number of features in the data. * the number of features in the data.
*/ */
def train( def train(
...@@ -208,6 +208,8 @@ object RidgeRegressionWithSGD { ...@@ -208,6 +208,8 @@ object RidgeRegressionWithSGD {
val data = MLUtils.loadLabeledData(sc, args(1)) val data = MLUtils.loadLabeledData(sc, args(1))
val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
args(3).toDouble) args(3).toDouble)
println("Weights: " + model.weights.mkString("[", ", ", "]"))
println("Intercept: " + model.intercept)
sc.stop() sc.stop()
} }
......
package org.apache.spark.mllib.classification;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
public class JavaNaiveBayesSuite implements Serializable {
private transient JavaSparkContext sc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
}
@After
public void tearDown() {
sc.stop();
sc = null;
System.clearProperty("spark.driver.port");
}
private static final List<LabeledPoint> POINTS = Arrays.asList(
new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}),
new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}),
new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}),
new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}),
new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}),
new LabeledPoint(2, new double[] {0.0, 0.0, 2.0})
);
private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
int correct = 0;
for (LabeledPoint p: points) {
if (model.predict(p.features()) == p.label()) {
correct += 1;
}
}
return correct;
}
@Test
public void runUsingConstructor() {
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
NaiveBayes nb = new NaiveBayes().setLambda(1.0);
NaiveBayesModel model = nb.run(testRDD.rdd());
int numAccurate = validatePrediction(POINTS, model);
Assert.assertEquals(POINTS.size(), numAccurate);
}
@Test
public void runUsingStaticMethods() {
JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
int numAccurate1 = validatePrediction(POINTS, model1);
Assert.assertEquals(POINTS.size(), numAccurate1);
NaiveBayesModel model2 = NaiveBayes.train(testRDD.rdd(), 0.5);
int numAccurate2 = validatePrediction(POINTS, model2);
Assert.assertEquals(POINTS.size(), numAccurate2);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment