Skip to content
Snippets Groups Projects
Commit ded67ee9 authored by Tor Myklebust's avatar Tor Myklebust
Browse files

Bindings for linear, Lasso, and ridge regression.

parent 2a41c9aa
No related branches found
No related tags found
No related merge requests found
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression._
import org.apache.spark.rdd.RDD
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.DoubleBuffer
......@@ -38,14 +39,45 @@ class PythonMLLibAPI extends Serializable {
return bytes
}
def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]):
java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x))
.map(v => LabeledPoint(v(0), v.slice(1, v.length)))
val model = LinearRegressionWithSGD.train(data, 222)
def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
java.util.LinkedList[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes)
LabeledPoint(x(0), x.slice(1, x.length))
})
val initialWeights = deserializeDoubleVector(initialWeightsBA)
val model = trainFunc(data, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleVector(model.weights))
ret.add(model.intercept: java.lang.Double)
return ret
}
def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]],
numIterations: Int, stepSize: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
return trainRegressionModel((data, initialWeights) =>
LinearRegressionWithSGD.train(data, numIterations, stepSize,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA);
}
def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
stepSize: Double, regParam: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
return trainRegressionModel((data, initialWeights) =>
LassoWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA);
}
def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
stepSize: Double, regParam: Double, miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
return trainRegressionModel((data, initialWeights) =>
RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
miniBatchFraction, initialWeights),
dataBytesJRDD, initialWeightsBA);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment