From 2998e38a942351974da36cb619e863c6f0316e7a Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph.kurata.bradley@gmail.com>
Date: Sun, 3 Aug 2014 10:36:52 -0700
Subject: [PATCH] [SPARK-2197] [mllib] Java DecisionTree bug fix and
 easy-of-use

Bug fix: Before, when an RDD was created in Java and passed to DecisionTree.train(), the fake class tag caused problems.
* Fix: DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java.

Other improvements to Decision Trees for easy-of-use with Java:
* impurity classes: Added instance() methods to help with Java interface.
* Strategy: Added Java-friendly constructor
--> Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently.  I suspect we will redo the API before the other options are included.

CC: mengxr

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #1740 from jkbradley/dt-java-new and squashes the following commits:

0805dc6 [Joseph K. Bradley] Changed Strategy to use JavaConverters instead of JavaConversions
519b1b7 [Joseph K. Bradley] * Organized imports in JavaDecisionTreeSuite.java * Using JavaConverters instead of JavaConversions in DecisionTreeSuite.scala
f7b5ca1 [Joseph K. Bradley] Improvements to make it easier to run DecisionTree from Java. * DecisionTree: Used new RDD.retag() method to allow passing RDDs from Java. * impurity classes: Added instance() methods to help with Java interface. * Strategy: Added Java-friendly constructor ** Note: I removed quantileCalculationStrategy from the Java-friendly constructor since (a) it is a special class and (b) there is only 1 option currently.  I suspect we will redo the API before the other options are included.
d78ada6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java
320853f [Joseph K. Bradley] Added JavaDecisionTreeSuite, partly written
13a585e [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-java
f1a8283 [Joseph K. Bradley] Added old JavaDecisionTreeSuite, to be updated later
225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature.
---
 .../spark/mllib/tree/DecisionTree.scala       |   8 +-
 .../mllib/tree/configuration/Strategy.scala   |  29 +++++
 .../spark/mllib/tree/impurity/Entropy.scala   |   7 ++
 .../spark/mllib/tree/impurity/Gini.scala      |   7 ++
 .../spark/mllib/tree/impurity/Variance.scala  |   7 ++
 .../mllib/tree/JavaDecisionTreeSuite.java     | 102 ++++++++++++++++++
 .../spark/mllib/tree/DecisionTreeSuite.scala  |   6 ++
 7 files changed, 162 insertions(+), 4 deletions(-)
 create mode 100644 mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 382e76a9b7..1d03e6e3b3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -48,12 +48,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
   def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
 
     // Cache input RDD for speedup during multiple passes.
-    input.cache()
+    val retaggedInput = input.retag(classOf[LabeledPoint]).cache()
     logDebug("algo = " + strategy.algo)
 
     // Find the splits and the corresponding bins (interval between the splits) using a sample
     // of the input data.
-    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
     val numBins = bins(0).length
     logDebug("numBins = " + numBins)
 
@@ -70,7 +70,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     // dummy value for top node (updated during first split calculation)
     val nodes = new Array[Node](maxNumNodes)
     // num features
-    val numFeatures = input.take(1)(0).features.size
+    val numFeatures = retaggedInput.take(1)(0).features.size
 
     // Calculate level for single group construction
 
@@ -107,7 +107,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       logDebug("#####################################")
 
       // Find best split for all nodes at a level.
-      val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities,
+      val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities,
         strategy, level, filters, splits, bins, maxLevelForSingleGroup)
 
       for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index fdad4f029a..4ee4bcd0bc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.mllib.tree.configuration
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.mllib.tree.impurity.Impurity
 import org.apache.spark.mllib.tree.configuration.Algo._
@@ -61,4 +63,31 @@ class Strategy (
   val isMulticlassWithCategoricalFeatures
     = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
 
+  /**
+   * Java-friendly constructor.
+   *
+   * @param algo classification or regression
+   * @param impurity criterion used for information gain calculation
+   * @param maxDepth Maximum depth of the tree.
+   *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+   * @param numClassesForClassification number of classes for classification. Default value is 2
+   *                                    leads to binary classification
+   * @param maxBins maximum number of bins used for splitting features
+   * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+   *                                the number of discrete values they take. For example, an entry
+   *                                (n -> k) implies the feature n is categorical with k categories
+   *                                0, 1, 2, ... , k-1. It's important to note that features are
+   *                                zero-indexed.
+   */
+  def this(
+      algo: Algo,
+      impurity: Impurity,
+      maxDepth: Int,
+      numClassesForClassification: Int,
+      maxBins: Int,
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) {
+    this(algo, impurity, maxDepth, numClassesForClassification, maxBins, Sort,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+  }
+
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 9297c20596..96d2471e1f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -66,4 +66,11 @@ object Entropy extends Impurity {
   @DeveloperApi
   override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
     throw new UnsupportedOperationException("Entropy.calculate")
+
+  /**
+   * Get this impurity instance.
+   * This is useful for passing impurity parameters to a Strategy in Java.
+   */
+  def instance = this
+
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 2874bcf496..d586f44904 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -62,4 +62,11 @@ object Gini extends Impurity {
   @DeveloperApi
   override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
     throw new UnsupportedOperationException("Gini.calculate")
+
+  /**
+   * Get this impurity instance.
+   * This is useful for passing impurity parameters to a Strategy in Java.
+   */
+  def instance = this
+
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 698a1a2a8e..f7d99a40eb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -53,4 +53,11 @@ object Variance extends Impurity {
     val squaredLoss = sumSquares - (sum * sum) / count
     squaredLoss / count
   }
+
+  /**
+   * Get this impurity instance.
+   * This is useful for passing impurity parameters to a Strategy in Java.
+   */
+  def instance = this
+
 }
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
new file mode 100644
index 0000000000..2c281a1ee7
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.configuration.Algo;
+import org.apache.spark.mllib.tree.configuration.Strategy;
+import org.apache.spark.mllib.tree.impurity.Gini;
+import org.apache.spark.mllib.tree.model.DecisionTreeModel;
+
+
+public class JavaDecisionTreeSuite implements Serializable {
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaDecisionTreeSuite");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
+    int numCorrect = 0;
+    for (LabeledPoint point: validationData) {
+      Double prediction = model.predict(point.features());
+      if (prediction == point.label()) {
+        numCorrect++;
+      }
+    }
+    return numCorrect;
+  }
+
+  @Test
+  public void runDTUsingConstructor() {
+    List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
+    JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+    HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
+    categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
+
+    int maxDepth = 4;
+    int numClasses = 2;
+    int maxBins = 100;
+    Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
+        maxBins, categoricalFeaturesInfo);
+
+    DecisionTree learner = new DecisionTree(strategy);
+    DecisionTreeModel model = learner.train(rdd.rdd());
+
+    int numCorrect = validatePrediction(arr, model);
+    Assert.assertTrue(numCorrect == rdd.count());
+  }
+
+  @Test
+  public void runDTUsingStaticMethods() {
+    List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
+    JavaRDD<LabeledPoint> rdd = sc.parallelize(arr);
+    HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
+    categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories
+
+    int maxDepth = 4;
+    int numClasses = 2;
+    int maxBins = 100;
+    Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses,
+        maxBins, categoricalFeaturesInfo);
+
+    DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
+
+    int numCorrect = validatePrediction(arr, model);
+    Assert.assertTrue(numCorrect == rdd.count());
+  }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 8665a00f3b..70ca7c8a26 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.mllib.tree
 
+import scala.collection.JavaConverters._
+
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
@@ -815,6 +817,10 @@ object DecisionTreeSuite {
     arr
   }
 
+  def generateCategoricalDataPointsAsJavaList(): java.util.List[LabeledPoint] = {
+    generateCategoricalDataPoints().toList.asJava
+  }
+
   def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = {
     val arr = new Array[LabeledPoint](3000)
     for (i <- 0 until 3000) {
-- 
GitLab