From 5a23213c148bfe362514f9c71f5273ebda0a848a Mon Sep 17 00:00:00 2001 From: Holden Karau <holden@pigscanfly.ca> Date: Tue, 4 Aug 2015 10:12:22 -0700 Subject: [PATCH] [SPARK-8069] [ML] Add multiclass thresholds for ProbabilisticClassifier This PR replaces the old "threshold" with a generalized "thresholds" Param. We keep getThreshold,setThreshold for backwards compatibility for binary classification. Note that the primary author of this PR is holdenk Author: Holden Karau <holden@pigscanfly.ca> Author: Joseph K. Bradley <joseph@databricks.com> Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits: 3952977 [Joseph K. Bradley] fixed pyspark doc test 85febc8 [Joseph K. Bradley] made python unit tests a little more robust 7eb1d86 [Joseph K. Bradley] small cleanups 6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues. 0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests 7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc. 6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests 25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression c02d6c0 [Holden Karau] No default for thresholds 5e43628 [Holden Karau] CR feedback and fixed the renamed test f3fbbd1 [Holden Karau] revert the changes to random forest :( 51f581c [Holden Karau] Add explicit types to public methods, fix long line f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic 398078a [Holden Karau] move the thresholding around a bunch based on the design doc 4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok) 638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test e09919c [Holden Karau] Fix return type, I need more coffee.... 8d92cac [Holden Karau] Use ClassifierParams as the head 3456ed3 [Holden Karau] Add explicit return types even though just test a0f3b0c [Holden Karau] scala style fixes 6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now ffc8dab [Holden Karau] Update the sharedParams 0420290 [Holden Karau] Allow us to override the get methods selectively 978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions 1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there" 1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there efb9084 [Holden Karau] move setThresholds only to where its used 6b34809 [Holden Karau] Add a test with thresholding for the RFCS 74f54c3 [Holden Karau] Fix creation of vote array 1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down. 2f44b18 [Holden Karau] Add a global default of null for thresholds param f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds" 634b06f [Holden Karau] Some progress towards unifying threshold and thresholds 85c9e01 [Holden Karau] Test passes again... little fnur 099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer) 0f46836 [Holden Karau] Start adding a classifiersuite f70eb5e [Holden Karau] Fix test compile issues a7d59c8 [Holden Karau] Move thresholding into Classifier trait 5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test) 1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation 31d6bf2 [Holden Karau] Start threading the threshold info through 0ef228c [Holden Karau] Add hasthresholds --- .../examples/ml/JavaSimpleParamsExample.java | 3 +- .../main/python/ml/simple_params_example.py | 2 +- .../examples/ml/SimpleParamsExample.scala | 2 +- .../spark/ml/classification/Classifier.scala | 3 +- .../classification/LogisticRegression.scala | 47 ++++++++++-- .../ProbabilisticClassifier.scala | 41 +++++++++-- .../ml/param/shared/SharedParamsCodeGen.scala | 19 ++++- .../spark/ml/param/shared/sharedParams.scala | 17 ++++- .../org/apache/spark/ml/tree/treeParams.scala | 3 +- .../JavaLogisticRegressionSuite.java | 9 ++- .../LogisticRegressionSuite.scala | 28 +++++++- .../ml/classification/OneVsRestSuite.scala | 2 +- .../ProbabilisticClassifierSuite.scala | 57 +++++++++++++++ python/pyspark/ml/classification.py | 72 +++++++++++++++---- 14 files changed, 265 insertions(+), 40 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index dac649d1d5..94beeced3d 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -77,7 +77,8 @@ public class JavaSimpleParamsExample { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. + double thresholds[] = {0.45, 0.55}; + paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py index a9f29dab2d..2d6d115d54 100644 --- a/examples/src/main/python/ml/simple_params_example.py +++ b/examples/src/main/python/ml/simple_params_example.py @@ -70,7 +70,7 @@ if __name__ == "__main__": # We may alternatively specify parameters using a parameter map. # paramMap overrides all lr parameters set earlier. - paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"} + paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"} # Now learn a new model using the new parameters. model2 = lr.fit(training, paramMap) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 58d7b67674..f4d1fe5785 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -70,7 +70,7 @@ object SimpleParamsExample { // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. - paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.45, 0.55)) // Specify multiple Params. // One can also combine ParamMaps. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 581d8fa774..45df557a89 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,14 +18,13 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fc9199fb4..c937b9602b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -41,7 +41,39 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasThreshold with HasStandardization + with HasStandardization { + + /** + * Version of setThresholds() for binary classification, available for backwards + * compatibility. + * + * Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`. + * + * Default is effectively 0.5. + * @group setParam + */ + def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value)) + + /** + * Version of [[getThresholds()]] for binary classification, available for backwards + * compatibility. + * + * Param thresholds must have length 2 (or not be specified). + * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}. + * @group getParam + */ + def getThreshold: Double = { + if (isDefined(thresholds)) { + val thresholdValues = $(thresholds) + assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + s" thresholds: ${thresholdValues.mkString(",")}") + 1.0 / (1.0 + thresholdValues(0) / thresholdValues(1)) + } else { + 0.5 + } + } +} /** * :: Experimental :: @@ -110,9 +142,9 @@ class LogisticRegression(override val uid: String) def setStandardization(value: Boolean): this.type = set(standardization, value) setDefault(standardization -> true) - /** @group setParam */ - def setThreshold(value: Double): this.type = set(threshold, value) - setDefault(threshold -> 0.5) + override def setThreshold(value: Double): this.type = super.setThreshold(value) + + override def getThreshold: Double = super.getThreshold override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. @@ -270,8 +302,9 @@ class LogisticRegressionModel private[ml] ( extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams { - /** @group setParam */ - def setThreshold(value: Double): this.type = set(threshold, value) + override def setThreshold(value: Double): this.type = super.setThreshold(value) + + override def getThreshold: Double = super.getThreshold /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { @@ -288,7 +321,7 @@ class LogisticRegressionModel private[ml] ( /** * Predict label for the given feature vector. - * The behavior of this can be adjusted using [[threshold]]. + * The behavior of this can be adjusted using [[thresholds]]. */ override protected def predict(features: Vector): Double = { if (score(features) > getThreshold) 1 else 0 diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index f9c9c2371f..1e50a895a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,17 +20,16 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, DataType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} /** * (private[classification]) Params for probabilistic classification. */ private[classification] trait ProbabilisticClassifierParams - extends ClassifierParams with HasProbabilityCol { - + extends ClassifierParams with HasProbabilityCol with HasThresholds { override protected def validateAndTransformSchema( schema: StructType, fitting: Boolean, @@ -59,6 +58,9 @@ private[spark] abstract class ProbabilisticClassifier[ /** @group setParam */ def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E] + + /** @group setParam */ + def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E] } @@ -80,6 +82,9 @@ private[spark] abstract class ProbabilisticClassificationModel[ /** @group setParam */ def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] + /** @group setParam */ + def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M] + /** * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by * parameters: @@ -92,6 +97,11 @@ private[spark] abstract class ProbabilisticClassificationModel[ */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".transform() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. @@ -155,6 +165,14 @@ private[spark] abstract class ProbabilisticClassificationModel[ raw2probabilityInPlace(probs) } + override protected def raw2prediction(rawPrediction: Vector): Double = { + if (!isDefined(thresholds)) { + rawPrediction.argmax + } else { + probability2prediction(raw2probability(rawPrediction)) + } + } + /** * Predict the probability of each class given the features. * These predictions are also called class conditional probabilities. @@ -170,10 +188,21 @@ private[spark] abstract class ProbabilisticClassificationModel[ /** * Given a vector of class conditional probabilities, select the predicted label. - * This may be overridden to support thresholds which favor particular labels. + * This supports thresholds which favor particular labels. * @return predicted label */ - protected def probability2prediction(probability: Vector): Double = probability.argmax + protected def probability2prediction(probability: Vector): Double = { + if (!isDefined(thresholds)) { + probability.argmax + } else { + val thresholds: Array[Double] = getThresholds + val scaledProbability: Array[Double] = + probability.toArray.zip(thresholds).map { case (p, t) => + if (t == 0.0) Double.PositiveInfinity else p / t + } + Vectors.dense(scaledProbability).argmax + } + } } private[ml] object ProbabilisticClassificationModel { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index f7ae1de522..a97c8059b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -46,7 +46,13 @@ private[shared] object SharedParamsCodeGen { Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", - isValid = "ParamValidators.inRange(0, 1)"), + isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), + ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + + " to adjust the probability of predicting each class." + + " Array must have length equal to the number of classes, with values >= 0." + + " The class with largest value p/t is predicted, where p is the original probability" + + " of that class and t is the class' threshold.", + isValid = "(t: Array[Double]) => t.forall(_ >= 0)"), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), @@ -74,7 +80,8 @@ private[shared] object SharedParamsCodeGen { name: String, doc: String, defaultValueStr: Option[String] = None, - isValid: String = "") { + isValid: String = "", + finalMethods: Boolean = true) { require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") require(doc.nonEmpty) // TODO: more rigorous on doc @@ -88,6 +95,7 @@ private[shared] object SharedParamsCodeGen { case _ if c == classOf[Double] => "DoubleParam" case _ if c == classOf[Boolean] => "BooleanParam" case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam" + case _ if c.isArray && c.getComponentType == classOf[Double] => s"DoubleArrayParam" case _ => s"Param[${getTypeString(c)}]" } } @@ -131,6 +139,11 @@ private[shared] object SharedParamsCodeGen { } else { "" } + val methodStr = if (param.finalMethods) { + "final def" + } else { + "def" + } s""" |/** @@ -145,7 +158,7 @@ private[shared] object SharedParamsCodeGen { | final val $name: $Param = new $Param(this, "$name", "$doc"$isValid) |$setDefault | /** @group getParam */ - | final def get$Name: $T = $$($name) + | $methodStr get$Name: $T = $$($name) |} |""".stripMargin } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 65e48e4ee5..f332630c32 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -150,7 +150,22 @@ private[ml] trait HasThreshold extends Params { final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) /** @group getParam */ - final def getThreshold: Double = $(threshold) + def getThreshold: Double = $(threshold) +} + +/** + * Trait for shared param thresholds. + */ +private[ml] trait HasThresholds extends Params { + + /** + * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.. + * @group param + */ + final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0)) + + /** @group getParam */ + final def getThresholds: Array[Double] = $(thresholds) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index a0c5238d96..e817090f8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.tree +import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index f75e024a71..fb1de51163 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -87,6 +87,8 @@ public class JavaLogisticRegressionSuite implements Serializable { LogisticRegression parent = (LogisticRegression) model.parent(); assert(parent.getMaxIter() == 10); assert(parent.getRegParam() == 1.0); + assert(parent.getThresholds()[0] == 0.4); + assert(parent.getThresholds()[1] == 0.6); assert(parent.getThreshold() == 0.6); assert(model.getThreshold() == 0.6); @@ -98,7 +100,9 @@ public class JavaLogisticRegressionSuite implements Serializable { assert(r.getDouble(0) == 0.0); } // Call transform with params, and check that the params worked. - model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) + double[] thresholds = {1.0, 0.0}; + model.transform( + dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; @@ -108,8 +112,9 @@ public class JavaLogisticRegressionSuite implements Serializable { assert(foundNonZero); // Call fit() with new params, and check as many params as we can. + double[] thresholds2 = {0.6, 0.4}; LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b7dd447538..da13dcb42d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -91,6 +91,28 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.hasParent) } + test("setThreshold, getThreshold") { + val lr = new LogisticRegression + // default + withClue("LogisticRegression should not have thresholds set by default") { + intercept[java.util.NoSuchElementException] { + lr.getThresholds + } + } + // Set via thresholds. + // Intuition: Large threshold or large thresholds(1) makes class 0 more likely. + lr.setThreshold(1.0) + assert(lr.getThresholds === Array(0.0, 1.0)) + lr.setThreshold(0.0) + assert(lr.getThresholds === Array(1.0, 0.0)) + lr.setThreshold(0.5) + assert(lr.getThresholds === Array(0.5, 0.5)) + // Test getThreshold + lr.setThresholds(Array(0.3, 0.7)) + val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7) + assert(lr.getThreshold ~== expectedThreshold relTol 1E-7) + } + test("logistic regression doesn't fit intercept when fitIntercept is off") { val lr = new LogisticRegression lr.setFitIntercept(false) @@ -123,14 +145,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.") // Call transform with params, and check that the params worked. val predNotAllZero = - model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb") + model.transform(dataset, model.thresholds -> Array(1.0, 0.0), + model.probabilityCol -> "myProb") .select("prediction", "myProb") .collect() .map { case Row(pred: Double, prob: Vector) => pred } assert(predNotAllZero.exists(_ !== 0.0)) // Call fit() with new params, and check as many params as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, + val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, + lr.thresholds -> Array(0.6, 0.4), lr.probabilityCol -> "theProb") val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 3775292f6d..bd8e819f69 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -151,7 +151,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, "copy should handle extra classifier params") - val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1)) + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.9, 0.1))) ovrModel.models.foreach { case m: LogisticRegressionModel => require(m.getThreshold === 0.1, "copy should handle extra model params") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala new file mode 100644 index 0000000000..8f50cb924e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +final class TestProbabilisticClassificationModel( + override val uid: String, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] { + + override def copy(extra: org.apache.spark.ml.param.ParamMap): this.type = defaultCopy(extra) + + override protected def predictRaw(input: Vector): Vector = { + input + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction + } + + def friendlyPredict(input: Vector): Double = { + predict(input) + } +} + + +class ProbabilisticClassifierSuite extends SparkFunSuite { + + test("test thresholding") { + val thresholds = Array(0.5, 0.2) + val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) + } + + test("test thresholding not required") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2) + assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) + } +} diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b5814f76de..291320f881 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -69,17 +69,25 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") - threshold = Param(Params._dummy(), "threshold", - "threshold in binary classification prediction, in range [0, 1].") + thresholds = Param(Params._dummy(), "thresholds", + "Thresholds in multi-class classification" + + " to adjust the probability of predicting each class." + + " Array must have length equal to the number of classes, with values >= 0." + + " The class with largest value p/t is predicted, where p is the original" + + " probability of that class and t is the class' threshold.") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=None, thresholds=None, + probabilityCol="probability", rawPredictionCol="rawPrediction"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=None, thresholds=None, \ + probabilityCol="probability", rawPredictionCol="rawPrediction") + Param thresholds overrides Param threshold; threshold is provided + for backwards compatibility and only applies to binary classification. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -93,23 +101,35 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti #: param for whether to fit an intercept term. self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification prediction, in range [0, 1]. - self.threshold = Param(self, "threshold", - "threshold in binary classification prediction, in range [0, 1].") + self.thresholds = \ + Param(self, "thresholds", + "Thresholds in multi-class classification" + + " to adjust the probability of predicting each class." + + " Array must have length equal to the number of classes, with values >= 0." + + " The class with largest value p/t is predicted, where p is the original" + + " probability of that class and t is the class' threshold.") self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True, threshold=0.5) + fitIntercept=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=None, thresholds=None, + probabilityCol="probability", rawPredictionCol="rawPrediction"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=None, thresholds=None, \ + probabilityCol="probability", rawPredictionCol="rawPrediction") Sets params for logistic regression. + Param thresholds overrides Param threshold; threshold is provided + for backwards compatibility and only applies to binary classification. """ + # Under the hood we use thresholds so translate threshold to thresholds if applicable + if thresholds is None and threshold is not None: + kwargs[thresholds] = [1-threshold, threshold] kwargs = self.setParams._input_kwargs return self._set(**kwargs) @@ -144,16 +164,44 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def setThreshold(self, value): """ - Sets the value of :py:attr:`threshold`. + Sets the value of :py:attr:`thresholds` using [1-value, value]. + + >>> lr = LogisticRegression() + >>> lr.getThreshold() + 0.5 + >>> lr.setThreshold(0.6) + LogisticRegression_... + >>> abs(lr.getThreshold() - 0.6) < 1e-5 + True + """ + return self.setThresholds([1-value, value]) + + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. """ - self._paramMap[self.threshold] = value + self._paramMap[self.thresholds] = value return self + def getThresholds(self): + """ + Gets the value of thresholds or its default value. + """ + return self.getOrDefault(self.thresholds) + def getThreshold(self): """ Gets the value of threshold or its default value. """ - return self.getOrDefault(self.threshold) + if self.isDefined(self.thresholds): + thresholds = self.getOrDefault(self.thresholds) + if len(thresholds) != 2: + raise ValueError("Logistic Regression getThreshold only applies to" + + " binary classification, but thresholds has length != 2." + + " thresholds: " + ",".join(thresholds)) + return 1.0/(1.0+thresholds[0]/thresholds[1]) + else: + return 0.5 class LogisticRegressionModel(JavaModel): -- GitLab