Skip to content
Snippets Groups Projects
Commit c94ecdfc authored by MechCoder's avatar MechCoder Committed by Joseph K. Bradley
Browse files

[SPARK-9906] [ML] User guide for LogisticRegressionSummary

User guide for LogisticRegression summaries

Author: MechCoder <manojkumarsivaraj334@gmail.com>
Author: Manoj Kumar <mks542@nyu.edu>
Author: Feynman Liang <fliang@databricks.com>

Closes #8197 from MechCoder/log_summary_user_guide.
parent 6185cdd2
No related branches found
No related tags found
No related merge requests found
...@@ -23,20 +23,41 @@ displayTitle: <a href="ml-guide.html">ML</a> - Linear Methods ...@@ -23,20 +23,41 @@ displayTitle: <a href="ml-guide.html">ML</a> - Linear Methods
\]` \]`
In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: In MLlib, we implement popular linear methods such as logistic
regression and linear least squares with $L_1$ or $L_2$ regularization.
Refer to [the linear methods in mllib](mllib-linear-methods.html) for
details. In `spark.ml`, we also include Pipelines API for [Elastic
net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid
of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization
and variable selection via the elastic
net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf).
Mathematically, it is defined as a convex combination of the $L_1$ and
the $L_2$ regularization terms:
`\[ `\[
\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. \alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0.
\]` \]`
By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$
regularization as special cases. For example, if a [linear
**Examples** regression](https://en.wikipedia.org/wiki/Linear_regression) model is
trained with the elastic net parameter $\alpha$ set to $1$, it is
equivalent to a
[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model.
On the other hand, if $\alpha$ is set to $0$, the trained model reduces
to a [ridge
regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model.
We implement Pipelines API for both linear regression and logistic
regression with elastic net regularization.
## Example: Logistic Regression
The following example shows how to train a logistic regression model
with elastic net regularization. `elasticNetParam` corresponds to
$\alpha$ and `regParam` corresponds to $\lambda$.
<div class="codetabs"> <div class="codetabs">
<div data-lang="scala" markdown="1"> <div data-lang="scala" markdown="1">
{% highlight scala %} {% highlight scala %}
import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.util.MLUtils
...@@ -53,15 +74,11 @@ val lrModel = lr.fit(training) ...@@ -53,15 +74,11 @@ val lrModel = lr.fit(training)
// Print the weights and intercept for logistic regression // Print the weights and intercept for logistic regression
println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}")
{% endhighlight %} {% endhighlight %}
</div> </div>
<div data-lang="java" markdown="1"> <div data-lang="java" markdown="1">
{% highlight java %} {% highlight java %}
import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LabeledPoint;
...@@ -99,9 +116,7 @@ public class LogisticRegressionWithElasticNetExample { ...@@ -99,9 +116,7 @@ public class LogisticRegressionWithElasticNetExample {
</div> </div>
<div data-lang="python" markdown="1"> <div data-lang="python" markdown="1">
{% highlight python %} {% highlight python %}
from pyspark.ml.classification import LogisticRegression from pyspark.ml.classification import LogisticRegression
from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import MLUtils from pyspark.mllib.util import MLUtils
...@@ -118,12 +133,114 @@ lrModel = lr.fit(training) ...@@ -118,12 +133,114 @@ lrModel = lr.fit(training)
print("Weights: " + str(lrModel.weights)) print("Weights: " + str(lrModel.weights))
print("Intercept: " + str(lrModel.intercept)) print("Intercept: " + str(lrModel.intercept))
{% endhighlight %} {% endhighlight %}
</div>
</div> </div>
The `spark.ml` implementation of logistic regression also supports
extracting a summary of the model over the training set. Note that the
predictions and metrics which are stored as `Dataframe` in
`BinaryLogisticRegressionSummary` are annotated `@transient` and hence
only available on the driver.
<div class="codetabs">
<div data-lang="scala" markdown="1">
[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary)
provides a summary for a
[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel).
Currently, only binary classification is supported and the
summary must be explicitly cast to
[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary).
This will likely change when multiclass classification is supported.
Continuing the earlier example:
{% highlight scala %}
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
val trainingSummary = lrModel.summary
// Obtain the loss per iteration.
val objectiveHistory = trainingSummary.objectiveHistory
objectiveHistory.foreach(loss => println(loss))
// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
// binary classification problem.
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
val roc = binarySummary.roc
roc.show()
roc.select("FPR").show()
println(binarySummary.areaUnderROC)
// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
val fMeasure = binarySummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure).
select("threshold").head().getDouble(0)
logReg.setThreshold(bestThreshold)
logReg.fit(logRegDataFrame)
{% endhighlight %}
</div> </div>
### Optimization <div data-lang="java" markdown="1">
[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html)
provides a summary for a
[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html).
Currently, only binary classification is supported and the
summary must be explicitly cast to
[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html).
This will likely change when multiclass classification is supported.
Continuing the earlier example:
{% highlight java %}
// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example
LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary();
// Obtain the loss per iteration.
double[] objectiveHistory = trainingSummary.objectiveHistory();
for (double lossPerIteration : objectiveHistory) {
System.out.println(lossPerIteration);
}
// Obtain the metrics useful to judge performance on test data.
// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a
// binary classification problem.
BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary;
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
DataFrame roc = binarySummary.roc();
roc.show();
roc.select("FPR").show();
System.out.println(binarySummary.areaUnderROC());
// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with
// this selected threshold.
DataFrame fMeasure = binarySummary.fMeasureByThreshold();
double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0);
double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)).
select("threshold").head().getDouble(0);
logReg.setThreshold(bestThreshold);
logReg.fit(logRegDataFrame);
{% endhighlight %}
</div>
<div data-lang="python" markdown="1">
Logistic regression model summary is not yet supported in Python.
</div>
</div>
# Optimization
The optimization algorithm underlying the implementation is called
[Orthant-Wise Limited-memory
QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf)
(OWL-QN). It is an extension of L-BFGS that can effectively handle L1
regularization and elastic net.
The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf)
(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment