From 5a23213c148bfe362514f9c71f5273ebda0a848a Mon Sep 17 00:00:00 2001
From: Holden Karau <holden@pigscanfly.ca>
Date: Tue, 4 Aug 2015 10:12:22 -0700
Subject: [PATCH] [SPARK-8069] [ML] Add multiclass thresholds for
 ProbabilisticClassifier

This PR replaces the old "threshold" with a generalized "thresholds" Param.  We keep getThreshold,setThreshold for backwards compatibility for binary classification.

Note that the primary author of this PR is holdenk

Author: Holden Karau <holden@pigscanfly.ca>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7909 from jkbradley/holdenk-SPARK-8069-add-cutoff-aka-threshold-to-random-forest and squashes the following commits:

3952977 [Joseph K. Bradley] fixed pyspark doc test
85febc8 [Joseph K. Bradley] made python unit tests a little more robust
7eb1d86 [Joseph K. Bradley] small cleanups
6cc2ed8 [Joseph K. Bradley] Fixed remaining merge issues.
0255e44 [Joseph K. Bradley] Many cleanups for thresholds, some more tests
7565a60 [Holden Karau] fix pep8 style checks, add a getThreshold method similar to our LogisticRegression.scala one for API compat
be87f26 [Holden Karau] Convert threshold to thresholds in the python code, add specialized support for Array[Double] to shared parems codegen, etc.
6747dad [Holden Karau] Override raw2prediction for ProbabilisticClassifier, fix some tests
25df168 [Holden Karau] Fix handling of thresholds in LogisticRegression
c02d6c0 [Holden Karau] No default for thresholds
5e43628 [Holden Karau] CR feedback and fixed the renamed test
f3fbbd1 [Holden Karau] revert the changes to random forest :(
51f581c [Holden Karau] Add explicit types to public methods, fix long line
f7032eb [Holden Karau] Fix a java test bug, remove some unecessary changes
adf15b4 [Holden Karau] rename the classifier suite test to ProbabilisticClassifierSuite now that we only have it in Probabilistic
398078a [Holden Karau] move the thresholding around a bunch based on the design doc
4893bdc [Holden Karau] Use numtrees of 3 since previous result was tied (one tree for each) and the switch from different max methods picked a different element (since they were equal I think this is ok)
638854c [Holden Karau] Add a scala RandomForestClassifierSuite test based on corresponding python test
e09919c [Holden Karau] Fix return type, I need more coffee....
8d92cac [Holden Karau] Use ClassifierParams as the head
3456ed3 [Holden Karau] Add explicit return types even though just test
a0f3b0c [Holden Karau] scala style fixes
6f14314 [Holden Karau] Since hasthreshold/hasthresholds is in root classifier now
ffc8dab [Holden Karau] Update the sharedParams
0420290 [Holden Karau] Allow us to override the get methods selectively
978e77a [Holden Karau] Move HasThreshold into classifier params and start defining the overloaded getThreshold/getThresholds functions
1433e52 [Holden Karau] Revert "try and hide threshold but chainges the API so no dice there"
1f09a2e [Holden Karau] try and hide threshold but chainges the API so no dice there
efb9084 [Holden Karau] move setThresholds only to where its used
6b34809 [Holden Karau] Add a test with thresholding for the RFCS
74f54c3 [Holden Karau] Fix creation of vote array
1986fa8 [Holden Karau] Setting the thresholds only makes sense if the underlying class hasn't overridden predict, so lets push it down.
2f44b18 [Holden Karau] Add a global default of null for thresholds param
f338cfc [Holden Karau] Wait that wasn't a good idea, Revert "Some progress towards unifying threshold and thresholds"
634b06f [Holden Karau] Some progress towards unifying threshold and thresholds
85c9e01 [Holden Karau] Test passes again... little fnur
099c0f3 [Holden Karau] Move thresholds around some more (set on model not trainer)
0f46836 [Holden Karau] Start adding a classifiersuite
f70eb5e [Holden Karau] Fix test compile issues
a7d59c8 [Holden Karau] Move thresholding into Classifier trait
5d999d2 [Holden Karau] Some more progress, start adding a test (maybe try and see if we can find a better thing to use for the base of the test)
1fed644 [Holden Karau] Use thresholds to scale scores in random forest classifcation
31d6bf2 [Holden Karau] Start threading the threshold info through
0ef228c [Holden Karau] Add hasthresholds
---
 .../examples/ml/JavaSimpleParamsExample.java  |  3 +-
 .../main/python/ml/simple_params_example.py   |  2 +-
 .../examples/ml/SimpleParamsExample.scala     |  2 +-
 .../spark/ml/classification/Classifier.scala  |  3 +-
 .../classification/LogisticRegression.scala   | 47 ++++++++++--
 .../ProbabilisticClassifier.scala             | 41 +++++++++--
 .../ml/param/shared/SharedParamsCodeGen.scala | 19 ++++-
 .../spark/ml/param/shared/sharedParams.scala  | 17 ++++-
 .../org/apache/spark/ml/tree/treeParams.scala |  3 +-
 .../JavaLogisticRegressionSuite.java          |  9 ++-
 .../LogisticRegressionSuite.scala             | 28 +++++++-
 .../ml/classification/OneVsRestSuite.scala    |  2 +-
 .../ProbabilisticClassifierSuite.scala        | 57 +++++++++++++++
 python/pyspark/ml/classification.py           | 72 +++++++++++++++----
 14 files changed, 265 insertions(+), 40 deletions(-)
 create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala

diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index dac649d1d5..94beeced3d 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -77,7 +77,8 @@ public class JavaSimpleParamsExample {
     ParamMap paramMap = new ParamMap();
     paramMap.put(lr.maxIter().w(20)); // Specify 1 Param.
     paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter.
-    paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.
+    double thresholds[] = {0.45, 0.55};
+    paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params.
 
     // One can also combine ParamMaps.
     ParamMap paramMap2 = new ParamMap();
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
index a9f29dab2d..2d6d115d54 100644
--- a/examples/src/main/python/ml/simple_params_example.py
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -70,7 +70,7 @@ if __name__ == "__main__":
 
     # We may alternatively specify parameters using a parameter map.
     # paramMap overrides all lr parameters set earlier.
-    paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"}
+    paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"}
 
     # Now learn a new model using the new parameters.
     model2 = lr.fit(training, paramMap)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index 58d7b67674..f4d1fe5785 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -70,7 +70,7 @@ object SimpleParamsExample {
     // which supports several methods for specifying parameters.
     val paramMap = ParamMap(lr.maxIter -> 20)
     paramMap.put(lr.maxIter, 30) // Specify 1 Param.  This overwrites the original maxIter.
-    paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
+    paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.45, 0.55)) // Specify multiple Params.
 
     // One can also combine ParamMaps.
     val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 581d8fa774..45df557a89 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -18,14 +18,13 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
 import org.apache.spark.ml.param.shared.HasRawPredictionCol
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
 
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8fc9199fb4..c937b9602b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -41,7 +41,39 @@ import org.apache.spark.storage.StorageLevel
  */
 private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
   with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
-  with HasThreshold with HasStandardization
+  with HasStandardization {
+
+  /**
+   * Version of setThresholds() for binary classification, available for backwards
+   * compatibility.
+   *
+   * Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`.
+   *
+   * Default is effectively 0.5.
+   * @group setParam
+   */
+  def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value))
+
+  /**
+   * Version of [[getThresholds()]] for binary classification, available for backwards
+   * compatibility.
+   *
+   * Param thresholds must have length 2 (or not be specified).
+   * This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
+   * @group getParam
+   */
+  def getThreshold: Double = {
+    if (isDefined(thresholds)) {
+      val thresholdValues = $(thresholds)
+      assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" +
+        " binary classification, but thresholds has length != 2." +
+        s"  thresholds: ${thresholdValues.mkString(",")}")
+      1.0 / (1.0 + thresholdValues(0) / thresholdValues(1))
+    } else {
+      0.5
+    }
+  }
+}
 
 /**
  * :: Experimental ::
@@ -110,9 +142,9 @@ class LogisticRegression(override val uid: String)
   def setStandardization(value: Boolean): this.type = set(standardization, value)
   setDefault(standardization -> true)
 
-  /** @group setParam */
-  def setThreshold(value: Double): this.type = set(threshold, value)
-  setDefault(threshold -> 0.5)
+  override def setThreshold(value: Double): this.type = super.setThreshold(value)
+
+  override def getThreshold: Double = super.getThreshold
 
   override protected def train(dataset: DataFrame): LogisticRegressionModel = {
     // Extract columns from data.  If dataset is persisted, do not persist oldDataset.
@@ -270,8 +302,9 @@ class LogisticRegressionModel private[ml] (
   extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
   with LogisticRegressionParams {
 
-  /** @group setParam */
-  def setThreshold(value: Double): this.type = set(threshold, value)
+  override def setThreshold(value: Double): this.type = super.setThreshold(value)
+
+  override def getThreshold: Double = super.getThreshold
 
   /** Margin (rawPrediction) for class label 1.  For binary classification only. */
   private val margin: Vector => Double = (features) => {
@@ -288,7 +321,7 @@ class LogisticRegressionModel private[ml] (
 
   /**
    * Predict label for the given feature vector.
-   * The behavior of this can be adjusted using [[threshold]].
+   * The behavior of this can be adjusted using [[thresholds]].
    */
   override protected def predict(features: Vector): Double = {
     if (score(features) > getThreshold) 1 else 0
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index f9c9c2371f..1e50a895a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -20,17 +20,16 @@ package org.apache.spark.ml.classification
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
 
 /**
  * (private[classification])  Params for probabilistic classification.
  */
 private[classification] trait ProbabilisticClassifierParams
-  extends ClassifierParams with HasProbabilityCol {
-
+  extends ClassifierParams with HasProbabilityCol with HasThresholds {
   override protected def validateAndTransformSchema(
       schema: StructType,
       fitting: Boolean,
@@ -59,6 +58,9 @@ private[spark] abstract class ProbabilisticClassifier[
 
   /** @group setParam */
   def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
+
+  /** @group setParam */
+  def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E]
 }
 
 
@@ -80,6 +82,9 @@ private[spark] abstract class ProbabilisticClassificationModel[
   /** @group setParam */
   def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
 
+  /** @group setParam */
+  def setThresholds(value: Array[Double]): M = set(thresholds, value).asInstanceOf[M]
+
   /**
    * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
    * parameters:
@@ -92,6 +97,11 @@ private[spark] abstract class ProbabilisticClassificationModel[
    */
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
+    if (isDefined(thresholds)) {
+      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+        ".transform() called with non-matching numClasses and thresholds.length." +
+        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
+    }
 
     // Output selected columns only.
     // This is a bit complicated since it tries to avoid repeated computation.
@@ -155,6 +165,14 @@ private[spark] abstract class ProbabilisticClassificationModel[
     raw2probabilityInPlace(probs)
   }
 
+  override protected def raw2prediction(rawPrediction: Vector): Double = {
+    if (!isDefined(thresholds)) {
+      rawPrediction.argmax
+    } else {
+      probability2prediction(raw2probability(rawPrediction))
+    }
+  }
+
   /**
    * Predict the probability of each class given the features.
    * These predictions are also called class conditional probabilities.
@@ -170,10 +188,21 @@ private[spark] abstract class ProbabilisticClassificationModel[
 
   /**
    * Given a vector of class conditional probabilities, select the predicted label.
-   * This may be overridden to support thresholds which favor particular labels.
+   * This supports thresholds which favor particular labels.
    * @return  predicted label
    */
-  protected def probability2prediction(probability: Vector): Double = probability.argmax
+  protected def probability2prediction(probability: Vector): Double = {
+    if (!isDefined(thresholds)) {
+      probability.argmax
+    } else {
+      val thresholds: Array[Double] = getThresholds
+      val scaledProbability: Array[Double] =
+        probability.toArray.zip(thresholds).map { case (p, t) =>
+          if (t == 0.0) Double.PositiveInfinity else p / t
+        }
+      Vectors.dense(scaledProbability).argmax
+    }
+  }
 }
 
 private[ml] object ProbabilisticClassificationModel {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index f7ae1de522..a97c8059b8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -46,7 +46,13 @@ private[shared] object SharedParamsCodeGen {
         Some("\"probability\"")),
       ParamDesc[Double]("threshold",
         "threshold in binary classification prediction, in range [0, 1]",
-        isValid = "ParamValidators.inRange(0, 1)"),
+        isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
+      ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
+        " to adjust the probability of predicting each class." +
+        " Array must have length equal to the number of classes, with values >= 0." +
+        " The class with largest value p/t is predicted, where p is the original probability" +
+        " of that class and t is the class' threshold.",
+        isValid = "(t: Array[Double]) => t.forall(_ >= 0)"),
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
@@ -74,7 +80,8 @@ private[shared] object SharedParamsCodeGen {
       name: String,
       doc: String,
       defaultValueStr: Option[String] = None,
-      isValid: String = "") {
+      isValid: String = "",
+      finalMethods: Boolean = true) {
 
     require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
     require(doc.nonEmpty) // TODO: more rigorous on doc
@@ -88,6 +95,7 @@ private[shared] object SharedParamsCodeGen {
         case _ if c == classOf[Double] => "DoubleParam"
         case _ if c == classOf[Boolean] => "BooleanParam"
         case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
+        case _ if c.isArray && c.getComponentType == classOf[Double] => s"DoubleArrayParam"
         case _ => s"Param[${getTypeString(c)}]"
       }
     }
@@ -131,6 +139,11 @@ private[shared] object SharedParamsCodeGen {
     } else {
       ""
     }
+    val methodStr = if (param.finalMethods) {
+      "final def"
+    } else {
+      "def"
+    }
 
     s"""
       |/**
@@ -145,7 +158,7 @@ private[shared] object SharedParamsCodeGen {
       |  final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
       |$setDefault
       |  /** @group getParam */
-      |  final def get$Name: $T = $$($name)
+      |  $methodStr get$Name: $T = $$($name)
       |}
       |""".stripMargin
   }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 65e48e4ee5..f332630c32 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -150,7 +150,22 @@ private[ml] trait HasThreshold extends Params {
   final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))
 
   /** @group getParam */
-  final def getThreshold: Double = $(threshold)
+  def getThreshold: Double = $(threshold)
+}
+
+/**
+ * Trait for shared param thresholds.
+ */
+private[ml] trait HasThresholds extends Params {
+
+  /**
+   * Param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold..
+   * @group param
+   */
+  final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0))
+
+  /** @group getParam */
+  final def getThresholds: Array[Double] = $(thresholds)
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index a0c5238d96..e817090f8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.ml.tree
 
+import org.apache.spark.ml.classification.ClassifierParams
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
+import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
 import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index f75e024a71..fb1de51163 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -87,6 +87,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
     LogisticRegression parent = (LogisticRegression) model.parent();
     assert(parent.getMaxIter() == 10);
     assert(parent.getRegParam() == 1.0);
+    assert(parent.getThresholds()[0] == 0.4);
+    assert(parent.getThresholds()[1] == 0.6);
     assert(parent.getThreshold() == 0.6);
     assert(model.getThreshold() == 0.6);
 
@@ -98,7 +100,9 @@ public class JavaLogisticRegressionSuite implements Serializable {
       assert(r.getDouble(0) == 0.0);
     }
     // Call transform with params, and check that the params worked.
-    model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
+    double[] thresholds = {1.0, 0.0};
+    model.transform(
+      dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb"))
       .registerTempTable("predNotAllZero");
     DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
     boolean foundNonZero = false;
@@ -108,8 +112,9 @@ public class JavaLogisticRegressionSuite implements Serializable {
     assert(foundNonZero);
 
     // Call fit() with new params, and check as many params as we can.
+    double[] thresholds2 = {0.6, 0.4};
     LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
-        lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
+        lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb"));
     LogisticRegression parent2 = (LogisticRegression) model2.parent();
     assert(parent2.getMaxIter() == 5);
     assert(parent2.getRegParam() == 0.1);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index b7dd447538..da13dcb42d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -91,6 +91,28 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(model.hasParent)
   }
 
+  test("setThreshold, getThreshold") {
+    val lr = new LogisticRegression
+    // default
+    withClue("LogisticRegression should not have thresholds set by default") {
+      intercept[java.util.NoSuchElementException] {
+        lr.getThresholds
+      }
+    }
+    // Set via thresholds.
+    // Intuition: Large threshold or large thresholds(1) makes class 0 more likely.
+    lr.setThreshold(1.0)
+    assert(lr.getThresholds === Array(0.0, 1.0))
+    lr.setThreshold(0.0)
+    assert(lr.getThresholds === Array(1.0, 0.0))
+    lr.setThreshold(0.5)
+    assert(lr.getThresholds === Array(0.5, 0.5))
+    // Test getThreshold
+    lr.setThresholds(Array(0.3, 0.7))
+    val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7)
+    assert(lr.getThreshold ~== expectedThreshold relTol 1E-7)
+  }
+
   test("logistic regression doesn't fit intercept when fitIntercept is off") {
     val lr = new LogisticRegression
     lr.setFitIntercept(false)
@@ -123,14 +145,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
       s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
     // Call transform with params, and check that the params worked.
     val predNotAllZero =
-      model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
+      model.transform(dataset, model.thresholds -> Array(1.0, 0.0),
+        model.probabilityCol -> "myProb")
         .select("prediction", "myProb")
         .collect()
         .map { case Row(pred: Double, prob: Vector) => pred }
     assert(predNotAllZero.exists(_ !== 0.0))
 
     // Call fit() with new params, and check as many params as we can.
-    val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
+    val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1,
+      lr.thresholds -> Array(0.6, 0.4),
       lr.probabilityCol -> "theProb")
     val parent2 = model2.parent.asInstanceOf[LogisticRegression]
     assert(parent2.getMaxIter === 5)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 3775292f6d..bd8e819f69 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -151,7 +151,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
     require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10,
       "copy should handle extra classifier params")
 
-    val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1))
+    val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.thresholds -> Array(0.9, 0.1)))
     ovrModel.models.foreach { case m: LogisticRegressionModel =>
       require(m.getThreshold === 0.1, "copy should handle extra model params")
     }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
new file mode 100644
index 0000000000..8f50cb924e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+
+final class TestProbabilisticClassificationModel(
+    override val uid: String,
+    override val numClasses: Int)
+  extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel] {
+
+  override def copy(extra: org.apache.spark.ml.param.ParamMap): this.type = defaultCopy(extra)
+
+  override protected def predictRaw(input: Vector): Vector = {
+    input
+  }
+
+  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+    rawPrediction
+  }
+
+  def friendlyPredict(input: Vector): Double = {
+    predict(input)
+  }
+}
+
+
+class ProbabilisticClassifierSuite extends SparkFunSuite {
+
+  test("test thresholding") {
+    val thresholds = Array(0.5, 0.2)
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds)
+    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
+    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
+  }
+
+  test("test thresholding not required") {
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2)
+    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
+  }
+}
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index b5814f76de..291320f881 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -69,17 +69,25 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
               "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
               "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.")
     fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.")
-    threshold = Param(Params._dummy(), "threshold",
-                      "threshold in binary classification prediction, in range [0, 1].")
+    thresholds = Param(Params._dummy(), "thresholds",
+                       "Thresholds in multi-class classification" +
+                       " to adjust the probability of predicting each class." +
+                       " Array must have length equal to the number of classes, with values >= 0." +
+                       " The class with largest value p/t is predicted, where p is the original" +
+                       " probability of that class and t is the class' threshold.")
 
     @keyword_only
     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                  maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
-                 threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"):
+                 threshold=None, thresholds=None,
+                 probabilityCol="probability", rawPredictionCol="rawPrediction"):
         """
         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
-                 threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction")
+                 threshold=None, thresholds=None, \
+                 probabilityCol="probability", rawPredictionCol="rawPrediction")
+        Param thresholds overrides Param threshold; threshold is provided
+        for backwards compatibility and only applies to binary classification.
         """
         super(LogisticRegression, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -93,23 +101,35 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
         #: param for whether to fit an intercept term.
         self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.")
         #: param for threshold in binary classification prediction, in range [0, 1].
-        self.threshold = Param(self, "threshold",
-                               "threshold in binary classification prediction, in range [0, 1].")
+        self.thresholds = \
+            Param(self, "thresholds",
+                  "Thresholds in multi-class classification" +
+                  " to adjust the probability of predicting each class." +
+                  " Array must have length equal to the number of classes, with values >= 0." +
+                  " The class with largest value p/t is predicted, where p is the original" +
+                  " probability of that class and t is the class' threshold.")
         self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6,
-                         fitIntercept=True, threshold=0.5)
+                         fitIntercept=True)
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                   maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
-                  threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction"):
+                  threshold=None, thresholds=None,
+                  probabilityCol="probability", rawPredictionCol="rawPrediction"):
         """
         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
-                 threshold=0.5, probabilityCol="probability", rawPredictionCol="rawPrediction")
+                  threshold=None, thresholds=None, \
+                  probabilityCol="probability", rawPredictionCol="rawPrediction")
         Sets params for logistic regression.
+        Param thresholds overrides Param threshold; threshold is provided
+        for backwards compatibility and only applies to binary classification.
         """
+        # Under the hood we use thresholds so translate threshold to thresholds if applicable
+        if thresholds is None and threshold is not None:
+            kwargs[thresholds] = [1-threshold, threshold]
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)
 
@@ -144,16 +164,44 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
 
     def setThreshold(self, value):
         """
-        Sets the value of :py:attr:`threshold`.
+        Sets the value of :py:attr:`thresholds` using [1-value, value].
+
+        >>> lr = LogisticRegression()
+        >>> lr.getThreshold()
+        0.5
+        >>> lr.setThreshold(0.6)
+        LogisticRegression_...
+        >>> abs(lr.getThreshold() - 0.6) < 1e-5
+        True
+        """
+        return self.setThresholds([1-value, value])
+
+    def setThresholds(self, value):
+        """
+        Sets the value of :py:attr:`thresholds`.
         """
-        self._paramMap[self.threshold] = value
+        self._paramMap[self.thresholds] = value
         return self
 
+    def getThresholds(self):
+        """
+        Gets the value of thresholds or its default value.
+        """
+        return self.getOrDefault(self.thresholds)
+
     def getThreshold(self):
         """
         Gets the value of threshold or its default value.
         """
-        return self.getOrDefault(self.threshold)
+        if self.isDefined(self.thresholds):
+            thresholds = self.getOrDefault(self.thresholds)
+            if len(thresholds) != 2:
+                raise ValueError("Logistic Regression getThreshold only applies to" +
+                                 " binary classification, but thresholds has length != 2." +
+                                 "  thresholds: " + ",".join(thresholds))
+            return 1.0/(1.0+thresholds[0]/thresholds[1])
+        else:
+            return 0.5
 
 
 class LogisticRegressionModel(JavaModel):
-- 
GitLab