Skip to content
Snippets Groups Projects
Commit 16a2be1a authored by Xiangrui Meng's avatar Xiangrui Meng Committed by DB Tsai
Browse files

[SPARK-10231] [MLLIB] update @Since annotation for mllib.classification

Update `Since` annotation in `mllib.classification`:

1. add version to classes, objects, constructors, and public variables declared in constructors
2. correct some versions
3. remove `Since` on `toString`

MechCoder dbtsai

Author: Xiangrui Meng <meng@databricks.com>

Closes #8421 from mengxr/SPARK-10231 and squashes the following commits:

b2dce80 [Xiangrui Meng] update @Since annotation for mllib.classification
parent 881208a8
No related branches found
No related tags found
No related merge requests found
...@@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD ...@@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD
* belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc. * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc.
*/ */
@Experimental @Experimental
@Since("0.8.0")
trait ClassificationModel extends Serializable { trait ClassificationModel extends Serializable {
/** /**
* Predict values for the given data set using the model trained. * Predict values for the given data set using the model trained.
...@@ -37,7 +38,7 @@ trait ClassificationModel extends Serializable { ...@@ -37,7 +38,7 @@ trait ClassificationModel extends Serializable {
* @param testData RDD representing data points to be predicted * @param testData RDD representing data points to be predicted
* @return an RDD[Double] where each entry contains the corresponding prediction * @return an RDD[Double] where each entry contains the corresponding prediction
*/ */
@Since("0.8.0") @Since("1.0.0")
def predict(testData: RDD[Vector]): RDD[Double] def predict(testData: RDD[Vector]): RDD[Double]
/** /**
...@@ -46,7 +47,7 @@ trait ClassificationModel extends Serializable { ...@@ -46,7 +47,7 @@ trait ClassificationModel extends Serializable {
* @param testData array representing a single data point * @param testData array representing a single data point
* @return predicted category from the trained model * @return predicted category from the trained model
*/ */
@Since("0.8.0") @Since("1.0.0")
def predict(testData: Vector): Double def predict(testData: Vector): Double
/** /**
...@@ -54,7 +55,7 @@ trait ClassificationModel extends Serializable { ...@@ -54,7 +55,7 @@ trait ClassificationModel extends Serializable {
* @param testData JavaRDD representing data points to be predicted * @param testData JavaRDD representing data points to be predicted
* @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction
*/ */
@Since("0.8.0") @Since("1.0.0")
def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] =
predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
} }
......
...@@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD ...@@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
* Multinomial Logistic Regression. By default, it is binary logistic regression * Multinomial Logistic Regression. By default, it is binary logistic regression
* so numClasses will be set to 2. * so numClasses will be set to 2.
*/ */
class LogisticRegressionModel ( @Since("0.8.0")
override val weights: Vector, class LogisticRegressionModel @Since("1.3.0") (
override val intercept: Double, @Since("1.0.0") override val weights: Vector,
val numFeatures: Int, @Since("1.0.0") override val intercept: Double,
val numClasses: Int) @Since("1.3.0") val numFeatures: Int,
@Since("1.3.0") val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable with PMMLExportable { with Saveable with PMMLExportable {
...@@ -75,6 +76,7 @@ class LogisticRegressionModel ( ...@@ -75,6 +76,7 @@ class LogisticRegressionModel (
/** /**
* Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification.
*/ */
@Since("1.0.0")
def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)
private var threshold: Option[Double] = Some(0.5) private var threshold: Option[Double] = Some(0.5)
...@@ -166,12 +168,12 @@ class LogisticRegressionModel ( ...@@ -166,12 +168,12 @@ class LogisticRegressionModel (
override protected def formatVersion: String = "1.0" override protected def formatVersion: String = "1.0"
@Since("1.4.0")
override def toString: String = { override def toString: String = {
s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}"
} }
} }
@Since("1.3.0")
object LogisticRegressionModel extends Loader[LogisticRegressionModel] { object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
@Since("1.3.0") @Since("1.3.0")
...@@ -207,6 +209,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { ...@@ -207,6 +209,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
* for k classes multi-label classification problem. * for k classes multi-label classification problem.
* Using [[LogisticRegressionWithLBFGS]] is recommended over this. * Using [[LogisticRegressionWithLBFGS]] is recommended over this.
*/ */
@Since("0.8.0")
class LogisticRegressionWithSGD private[mllib] ( class LogisticRegressionWithSGD private[mllib] (
private var stepSize: Double, private var stepSize: Double,
private var numIterations: Int, private var numIterations: Int,
...@@ -216,6 +219,7 @@ class LogisticRegressionWithSGD private[mllib] ( ...@@ -216,6 +219,7 @@ class LogisticRegressionWithSGD private[mllib] (
private val gradient = new LogisticGradient() private val gradient = new LogisticGradient()
private val updater = new SquaredL2Updater() private val updater = new SquaredL2Updater()
@Since("0.8.0")
override val optimizer = new GradientDescent(gradient, updater) override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize) .setStepSize(stepSize)
.setNumIterations(numIterations) .setNumIterations(numIterations)
...@@ -227,6 +231,7 @@ class LogisticRegressionWithSGD private[mllib] ( ...@@ -227,6 +231,7 @@ class LogisticRegressionWithSGD private[mllib] (
* Construct a LogisticRegression object with default parameters: {stepSize: 1.0, * Construct a LogisticRegression object with default parameters: {stepSize: 1.0,
* numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}.
*/ */
@Since("0.8.0")
def this() = this(1.0, 100, 0.01, 1.0) def this() = this(1.0, 100, 0.01, 1.0)
override protected[mllib] def createModel(weights: Vector, intercept: Double) = { override protected[mllib] def createModel(weights: Vector, intercept: Double) = {
...@@ -238,6 +243,7 @@ class LogisticRegressionWithSGD private[mllib] ( ...@@ -238,6 +243,7 @@ class LogisticRegressionWithSGD private[mllib] (
* Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent.
* NOTE: Labels used in Logistic Regression should be {0, 1} * NOTE: Labels used in Logistic Regression should be {0, 1}
*/ */
@Since("0.8.0")
object LogisticRegressionWithSGD { object LogisticRegressionWithSGD {
// NOTE(shivaram): We use multiple train methods instead of default arguments to support // NOTE(shivaram): We use multiple train methods instead of default arguments to support
// Java programs. // Java programs.
...@@ -333,11 +339,13 @@ object LogisticRegressionWithSGD { ...@@ -333,11 +339,13 @@ object LogisticRegressionWithSGD {
* NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1}
* for k classes multi-label classification problem. * for k classes multi-label classification problem.
*/ */
@Since("1.1.0")
class LogisticRegressionWithLBFGS class LogisticRegressionWithLBFGS
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
this.setFeatureScaling(true) this.setFeatureScaling(true)
@Since("1.1.0")
override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater)
override protected val validators = List(multiLabelValidator) override protected val validators = List(multiLabelValidator)
......
...@@ -41,11 +41,12 @@ import org.apache.spark.sql.{DataFrame, SQLContext} ...@@ -41,11 +41,12 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
* where D is number of features * where D is number of features
* @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli"
*/ */
@Since("0.9.0")
class NaiveBayesModel private[spark] ( class NaiveBayesModel private[spark] (
val labels: Array[Double], @Since("1.0.0") val labels: Array[Double],
val pi: Array[Double], @Since("0.9.0") val pi: Array[Double],
val theta: Array[Array[Double]], @Since("0.9.0") val theta: Array[Array[Double]],
val modelType: String) @Since("1.4.0") val modelType: String)
extends ClassificationModel with Serializable with Saveable { extends ClassificationModel with Serializable with Saveable {
import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes} import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes}
...@@ -83,6 +84,7 @@ class NaiveBayesModel private[spark] ( ...@@ -83,6 +84,7 @@ class NaiveBayesModel private[spark] (
throw new UnknownError(s"Invalid modelType: $modelType.") throw new UnknownError(s"Invalid modelType: $modelType.")
} }
@Since("1.0.0")
override def predict(testData: RDD[Vector]): RDD[Double] = { override def predict(testData: RDD[Vector]): RDD[Double] = {
val bcModel = testData.context.broadcast(this) val bcModel = testData.context.broadcast(this)
testData.mapPartitions { iter => testData.mapPartitions { iter =>
...@@ -91,6 +93,7 @@ class NaiveBayesModel private[spark] ( ...@@ -91,6 +93,7 @@ class NaiveBayesModel private[spark] (
} }
} }
@Since("1.0.0")
override def predict(testData: Vector): Double = { override def predict(testData: Vector): Double = {
modelType match { modelType match {
case Multinomial => case Multinomial =>
...@@ -107,6 +110,7 @@ class NaiveBayesModel private[spark] ( ...@@ -107,6 +110,7 @@ class NaiveBayesModel private[spark] (
* @return an RDD[Vector] where each entry contains the predicted posterior class probabilities, * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities,
* in the same order as class labels * in the same order as class labels
*/ */
@Since("1.5.0")
def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = { def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = {
val bcModel = testData.context.broadcast(this) val bcModel = testData.context.broadcast(this)
testData.mapPartitions { iter => testData.mapPartitions { iter =>
...@@ -122,6 +126,7 @@ class NaiveBayesModel private[spark] ( ...@@ -122,6 +126,7 @@ class NaiveBayesModel private[spark] (
* @return predicted posterior class probabilities from the trained model, * @return predicted posterior class probabilities from the trained model,
* in the same order as class labels * in the same order as class labels
*/ */
@Since("1.5.0")
def predictProbabilities(testData: Vector): Vector = { def predictProbabilities(testData: Vector): Vector = {
modelType match { modelType match {
case Multinomial => case Multinomial =>
...@@ -158,6 +163,7 @@ class NaiveBayesModel private[spark] ( ...@@ -158,6 +163,7 @@ class NaiveBayesModel private[spark] (
new DenseVector(scaledProbs.map(_ / probSum)) new DenseVector(scaledProbs.map(_ / probSum))
} }
@Since("1.3.0")
override def save(sc: SparkContext, path: String): Unit = { override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
...@@ -166,6 +172,7 @@ class NaiveBayesModel private[spark] ( ...@@ -166,6 +172,7 @@ class NaiveBayesModel private[spark] (
override protected def formatVersion: String = "2.0" override protected def formatVersion: String = "2.0"
} }
@Since("1.3.0")
object NaiveBayesModel extends Loader[NaiveBayesModel] { object NaiveBayesModel extends Loader[NaiveBayesModel] {
import org.apache.spark.mllib.util.Loader._ import org.apache.spark.mllib.util.Loader._
...@@ -199,6 +206,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { ...@@ -199,6 +206,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
dataRDD.write.parquet(dataPath(path)) dataRDD.write.parquet(dataPath(path))
} }
@Since("1.3.0")
def load(sc: SparkContext, path: String): NaiveBayesModel = { def load(sc: SparkContext, path: String): NaiveBayesModel = {
val sqlContext = new SQLContext(sc) val sqlContext = new SQLContext(sc)
// Load Parquet data. // Load Parquet data.
...@@ -301,30 +309,35 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { ...@@ -301,30 +309,35 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
* document classification. By making every vector a 0-1 vector, it can also be used as * document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
*/ */
@Since("0.9.0")
class NaiveBayes private ( class NaiveBayes private (
private var lambda: Double, private var lambda: Double,
private var modelType: String) extends Serializable with Logging { private var modelType: String) extends Serializable with Logging {
import NaiveBayes.{Bernoulli, Multinomial} import NaiveBayes.{Bernoulli, Multinomial}
@Since("1.4.0")
def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
@Since("0.9.0")
def this() = this(1.0, NaiveBayes.Multinomial) def this() = this(1.0, NaiveBayes.Multinomial)
/** Set the smoothing parameter. Default: 1.0. */ /** Set the smoothing parameter. Default: 1.0. */
@Since("0.9.0")
def setLambda(lambda: Double): NaiveBayes = { def setLambda(lambda: Double): NaiveBayes = {
this.lambda = lambda this.lambda = lambda
this this
} }
/** Get the smoothing parameter. */ /** Get the smoothing parameter. */
@Since("1.4.0")
def getLambda: Double = lambda def getLambda: Double = lambda
/** /**
* Set the model type using a string (case-sensitive). * Set the model type using a string (case-sensitive).
* Supported options: "multinomial" (default) and "bernoulli". * Supported options: "multinomial" (default) and "bernoulli".
*/ */
@Since("1.4.0")
def setModelType(modelType: String): NaiveBayes = { def setModelType(modelType: String): NaiveBayes = {
require(NaiveBayes.supportedModelTypes.contains(modelType), require(NaiveBayes.supportedModelTypes.contains(modelType),
s"NaiveBayes was created with an unknown modelType: $modelType.") s"NaiveBayes was created with an unknown modelType: $modelType.")
...@@ -333,6 +346,7 @@ class NaiveBayes private ( ...@@ -333,6 +346,7 @@ class NaiveBayes private (
} }
/** Get the model type. */ /** Get the model type. */
@Since("1.4.0")
def getModelType: String = this.modelType def getModelType: String = this.modelType
/** /**
...@@ -340,6 +354,7 @@ class NaiveBayes private ( ...@@ -340,6 +354,7 @@ class NaiveBayes private (
* *
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/ */
@Since("0.9.0")
def run(data: RDD[LabeledPoint]): NaiveBayesModel = { def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match { val values = v match {
...@@ -423,6 +438,7 @@ class NaiveBayes private ( ...@@ -423,6 +438,7 @@ class NaiveBayes private (
/** /**
* Top-level methods for calling naive Bayes. * Top-level methods for calling naive Bayes.
*/ */
@Since("0.9.0")
object NaiveBayes { object NaiveBayes {
/** String name for multinomial model type. */ /** String name for multinomial model type. */
...@@ -485,7 +501,7 @@ object NaiveBayes { ...@@ -485,7 +501,7 @@ object NaiveBayes {
* @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be
* multinomial or bernoulli * multinomial or bernoulli
*/ */
@Since("0.9.0") @Since("1.4.0")
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = {
require(supportedModelTypes.contains(modelType), require(supportedModelTypes.contains(modelType),
s"NaiveBayes was created with an unknown modelType: $modelType.") s"NaiveBayes was created with an unknown modelType: $modelType.")
......
...@@ -33,9 +33,10 @@ import org.apache.spark.rdd.RDD ...@@ -33,9 +33,10 @@ import org.apache.spark.rdd.RDD
* @param weights Weights computed for every feature. * @param weights Weights computed for every feature.
* @param intercept Intercept computed for this model. * @param intercept Intercept computed for this model.
*/ */
class SVMModel ( @Since("0.8.0")
override val weights: Vector, class SVMModel @Since("1.1.0") (
override val intercept: Double) @Since("1.0.0") override val weights: Vector,
@Since("0.8.0") override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable
with Saveable with PMMLExportable { with Saveable with PMMLExportable {
...@@ -47,7 +48,7 @@ class SVMModel ( ...@@ -47,7 +48,7 @@ class SVMModel (
* with prediction score greater than or equal to this threshold is identified as an positive, * with prediction score greater than or equal to this threshold is identified as an positive,
* and negative otherwise. The default value is 0.0. * and negative otherwise. The default value is 0.0.
*/ */
@Since("1.3.0") @Since("1.0.0")
@Experimental @Experimental
def setThreshold(threshold: Double): this.type = { def setThreshold(threshold: Double): this.type = {
this.threshold = Some(threshold) this.threshold = Some(threshold)
...@@ -92,12 +93,12 @@ class SVMModel ( ...@@ -92,12 +93,12 @@ class SVMModel (
override protected def formatVersion: String = "1.0" override protected def formatVersion: String = "1.0"
@Since("1.4.0")
override def toString: String = { override def toString: String = {
s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}"
} }
} }
@Since("1.3.0")
object SVMModel extends Loader[SVMModel] { object SVMModel extends Loader[SVMModel] {
@Since("1.3.0") @Since("1.3.0")
...@@ -132,6 +133,7 @@ object SVMModel extends Loader[SVMModel] { ...@@ -132,6 +133,7 @@ object SVMModel extends Loader[SVMModel] {
* regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. * regularization is used, which can be changed via [[SVMWithSGD.optimizer]].
* NOTE: Labels used in SVM should be {0, 1}. * NOTE: Labels used in SVM should be {0, 1}.
*/ */
@Since("0.8.0")
class SVMWithSGD private ( class SVMWithSGD private (
private var stepSize: Double, private var stepSize: Double,
private var numIterations: Int, private var numIterations: Int,
...@@ -141,6 +143,7 @@ class SVMWithSGD private ( ...@@ -141,6 +143,7 @@ class SVMWithSGD private (
private val gradient = new HingeGradient() private val gradient = new HingeGradient()
private val updater = new SquaredL2Updater() private val updater = new SquaredL2Updater()
@Since("0.8.0")
override val optimizer = new GradientDescent(gradient, updater) override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize) .setStepSize(stepSize)
.setNumIterations(numIterations) .setNumIterations(numIterations)
...@@ -152,6 +155,7 @@ class SVMWithSGD private ( ...@@ -152,6 +155,7 @@ class SVMWithSGD private (
* Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100,
* regParm: 0.01, miniBatchFraction: 1.0}. * regParm: 0.01, miniBatchFraction: 1.0}.
*/ */
@Since("0.8.0")
def this() = this(1.0, 100, 0.01, 1.0) def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = { override protected def createModel(weights: Vector, intercept: Double) = {
...@@ -162,6 +166,7 @@ class SVMWithSGD private ( ...@@ -162,6 +166,7 @@ class SVMWithSGD private (
/** /**
* Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}. * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}.
*/ */
@Since("0.8.0")
object SVMWithSGD { object SVMWithSGD {
/** /**
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
package org.apache.spark.mllib.classification package org.apache.spark.mllib.classification
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.StreamingLinearAlgorithm import org.apache.spark.mllib.regression.StreamingLinearAlgorithm
...@@ -44,6 +44,7 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm ...@@ -44,6 +44,7 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm
* }}} * }}}
*/ */
@Experimental @Experimental
@Since("1.3.0")
class StreamingLogisticRegressionWithSGD private[mllib] ( class StreamingLogisticRegressionWithSGD private[mllib] (
private var stepSize: Double, private var stepSize: Double,
private var numIterations: Int, private var numIterations: Int,
...@@ -58,6 +59,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( ...@@ -58,6 +59,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
* Initial weights must be set before using trainOn or predictOn * Initial weights must be set before using trainOn or predictOn
* (see `StreamingLinearAlgorithm`) * (see `StreamingLinearAlgorithm`)
*/ */
@Since("1.3.0")
def this() = this(0.1, 50, 1.0, 0.0) def this() = this(0.1, 50, 1.0, 0.0)
protected val algorithm = new LogisticRegressionWithSGD( protected val algorithm = new LogisticRegressionWithSGD(
...@@ -66,30 +68,35 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( ...@@ -66,30 +68,35 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
protected var model: Option[LogisticRegressionModel] = None protected var model: Option[LogisticRegressionModel] = None
/** Set the step size for gradient descent. Default: 0.1. */ /** Set the step size for gradient descent. Default: 0.1. */
@Since("1.3.0")
def setStepSize(stepSize: Double): this.type = { def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize) this.algorithm.optimizer.setStepSize(stepSize)
this this
} }
/** Set the number of iterations of gradient descent to run per update. Default: 50. */ /** Set the number of iterations of gradient descent to run per update. Default: 50. */
@Since("1.3.0")
def setNumIterations(numIterations: Int): this.type = { def setNumIterations(numIterations: Int): this.type = {
this.algorithm.optimizer.setNumIterations(numIterations) this.algorithm.optimizer.setNumIterations(numIterations)
this this
} }
/** Set the fraction of each batch to use for updates. Default: 1.0. */ /** Set the fraction of each batch to use for updates. Default: 1.0. */
@Since("1.3.0")
def setMiniBatchFraction(miniBatchFraction: Double): this.type = { def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction)
this this
} }
/** Set the regularization parameter. Default: 0.0. */ /** Set the regularization parameter. Default: 0.0. */
@Since("1.3.0")
def setRegParam(regParam: Double): this.type = { def setRegParam(regParam: Double): this.type = {
this.algorithm.optimizer.setRegParam(regParam) this.algorithm.optimizer.setRegParam(regParam)
this this
} }
/** Set the initial weights. Default: [0.0, 0.0]. */ /** Set the initial weights. Default: [0.0, 0.0]. */
@Since("1.3.0")
def setInitialWeights(initialWeights: Vector): this.type = { def setInitialWeights(initialWeights: Vector): this.type = {
this.model = Some(algorithm.createModel(initialWeights, 0.0)) this.model = Some(algorithm.createModel(initialWeights, 0.0))
this this
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment