From 8d1dec4fa4798bb48b8947446d306ec9ba6bddb5 Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph.kurata.bradley@gmail.com>
Date: Thu, 7 Aug 2014 00:20:38 -0700
Subject: [PATCH] [mllib] DecisionTree Strategy parameter checks

Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters.
CC mengxr

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #1821 from jkbradley/dt-robustness and squashes the following commits:

4dc449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-robustness
7a61f7b [Joseph K. Bradley] Added some checks to Strategy to print out meaningful error messages when given invalid DecisionTree parameters
---
 .../spark/mllib/tree/DecisionTree.scala       | 10 ++++--
 .../mllib/tree/configuration/Strategy.scala   | 31 ++++++++++++++++++-
 2 files changed, 38 insertions(+), 3 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index c8a8656596..bb50f07be5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -44,6 +44,8 @@ import org.apache.spark.util.random.XORShiftRandom
 @Experimental
 class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
 
+  strategy.assertValid()
+
   /**
    * Method to train a decision tree model over an RDD
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
@@ -1465,10 +1467,14 @@ object DecisionTree extends Serializable with Logging {
 
 
     /*
-     * Ensure #bins is always greater than the categories. For multiclass classification,
-     * #bins should be greater than 2^(maxCategories - 1) - 1.
+     * Ensure numBins is always greater than the categories. For multiclass classification,
+     * numBins should be greater than 2^(maxCategories - 1) - 1.
      * It's a limitation of the current implementation but a reasonable trade-off since features
      * with large number of categories get favored over continuous features.
+     *
+     * This needs to be checked here instead of in Strategy since numBins can be determined
+     * by the number of training examples.
+     * TODO: Allow this case, where we simply will know nothing about some categories.
      */
     if (strategy.categoricalFeaturesInfo.size > 0) {
       val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 4ee4bcd0bc..f31a503608 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration
 import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity}
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
 
@@ -90,4 +90,33 @@ class Strategy (
       categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
   }
 
+  private[tree] def assertValid(): Unit = {
+    algo match {
+      case Classification =>
+        require(numClassesForClassification >= 2,
+          s"DecisionTree Strategy for Classification must have numClassesForClassification >= 2," +
+          s" but numClassesForClassification = $numClassesForClassification.")
+        require(Set(Gini, Entropy).contains(impurity),
+          s"DecisionTree Strategy given invalid impurity for Classification: $impurity." +
+          s"  Valid settings: Gini, Entropy")
+      case Regression =>
+        require(impurity == Variance,
+          s"DecisionTree Strategy given invalid impurity for Regression: $impurity." +
+          s"  Valid settings: Variance")
+      case _ =>
+        throw new IllegalArgumentException(
+          s"DecisionTree Strategy given invalid algo parameter: $algo." +
+          s"  Valid settings are: Classification, Regression.")
+    }
+    require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." +
+      s"  Valid values are integers >= 0.")
+    require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." +
+      s"  Valid values are integers >= 2.")
+    categoricalFeaturesInfo.foreach { case (feature, arity) =>
+      require(arity >= 2,
+        s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
+        s" feature $feature has $arity categories.  The number of categories should be >= 2.")
+    }
+  }
+
 }
-- 
GitLab