From 7d432af8f3c47973550ea253dae0c23cd2961bde Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=A2=9C=E5=8F=91=E6=89=8D=EF=BC=88Yan=20Facai=EF=BC=89?=
 <facai.yan@gmail.com>
Date: Tue, 28 Mar 2017 16:14:01 -0700
Subject: [PATCH] [SPARK-20043][ML] DecisionTreeModel: ImpurityCalculator
 builder fails for uppercase impurity type Gini
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Fix bug: DecisionTreeModel can't recongnize Impurity "Gini" when loading

TODO:
+ [x] add unit test
+ [x] fix the bug

Author: 颜发才(Yan Facai) <facai.yan@gmail.com>

Closes #17407 from facaiy/BUG/decision_tree_loader_failer_with_Gini_impurity.
---
 .../spark/mllib/tree/impurity/Impurity.scala       |  2 +-
 .../DecisionTreeClassifierSuite.scala              | 14 ++++++++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index a5bdc2c6d2..98a3021461 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -184,7 +184,7 @@ private[spark] object ImpurityCalculator {
    * the given stats.
    */
   def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
-    impurity match {
+    impurity.toLowerCase match {
       case "gini" => new GiniCalculator(stats)
       case "entropy" => new EntropyCalculator(stats)
       case "variance" => new VarianceCalculator(stats)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 10de50306a..964fcfbdd8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -385,6 +385,20 @@ class DecisionTreeClassifierSuite
     testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
       allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
   }
+
+  test("SPARK-20043: " +
+       "ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") {
+    val rdd = TreeTests.getTreeReadWriteData(sc)
+    val data: DataFrame =
+      TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+
+    val dt = new DecisionTreeClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(2)
+    val model = dt.fit(data)
+
+    testDefaultReadWrite(model)
+  }
 }
 
 private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
-- 
GitLab