From 72353311d3a37cb523c5bdd8072ffdff99af9749 Mon Sep 17 00:00:00 2001
From: Holden Karau <holden@us.ibm.com>
Date: Thu, 2 Jun 2016 15:55:14 -0700
Subject: [PATCH] [SPARK-15092][SPARK-15139][PYSPARK][ML] Pyspark TreeEnsemble
 missing methods

## What changes were proposed in this pull request?

Add `toDebugString` and `totalNumNodes` to `TreeEnsembleModels` and add `toDebugString` to `DecisionTreeModel`

## How was this patch tested?

Extended doc tests.

Author: Holden Karau <holden@us.ibm.com>

Closes #12919 from holdenk/SPARK-15139-pyspark-treeEnsemble-missing-methods.
---
 python/pyspark/ml/classification.py | 20 ++++++++++++
 python/pyspark/ml/regression.py     | 48 ++++++++++++++++++++++++++++-
 2 files changed, 67 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index ea660d7808..177cf9d72c 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -512,6 +512,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     1
     >>> model.featureImportances
     SparseVector(1, {0: 1.0})
+    >>> print(model.toDebugString)
+    DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes...
     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> result = model.transform(test0).head()
     >>> result.prediction
@@ -650,6 +652,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
     >>> model.transform(test1).head().prediction
     1.0
+    >>> model.trees
+    [DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...]
     >>> rfc_path = temp_path + "/rfc"
     >>> rf.save(rfc_path)
     >>> rf2 = RandomForestClassifier.load(rfc_path)
@@ -730,6 +734,12 @@ class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaML
         """
         return self._call_java("featureImportances")
 
+    @property
+    @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent Estimators."""
+        return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
+
 
 @inherit_doc
 class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
@@ -772,6 +782,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
     >>> model.transform(test1).head().prediction
     1.0
+    >>> model.totalNumNodes
+    15
+    >>> print(model.toDebugString)
+    GBTClassificationModel (uid=...)...with 5 trees...
     >>> gbtc_path = temp_path + "gbtc"
     >>> gbt.save(gbtc_path)
     >>> gbt2 = GBTClassifier.load(gbtc_path)
@@ -869,6 +883,12 @@ class GBTClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable)
         """
         return self._call_java("featureImportances")
 
+    @property
+    @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent Estimators."""
+        return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+
 
 @inherit_doc
 class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 1b7af7ef59..7c79ab73c7 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -593,7 +593,7 @@ class RandomForestParams(TreeEnsembleParams):
     featureSubsetStrategy = \
         Param(Params._dummy(), "featureSubsetStrategy",
               "The number of features to consider for splits at each tree node. Supported " +
-              "options: " + ", ".join(supportedFeatureSubsetStrategies),
+              "options: " + ", ".join(supportedFeatureSubsetStrategies) + " (0.0-1.0], [1-n].",
               typeConverter=TypeConverters.toString)
 
     def __init__(self):
@@ -744,6 +744,12 @@ class DecisionTreeModel(JavaModel):
         """Return depth of the decision tree."""
         return self._call_java("depth")
 
+    @property
+    @since("2.0.0")
+    def toDebugString(self):
+        """Full description of model."""
+        return self._call_java("toDebugString")
+
     def __repr__(self):
         return self._call_java("toString")
 
@@ -758,12 +764,36 @@ class TreeEnsembleModels(JavaModel):
     .. versionadded:: 1.5.0
     """
 
+    @property
+    @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent Estimators."""
+        return [DecisionTreeModel(m) for m in list(self._call_java("trees"))]
+
+    @property
+    @since("2.0.0")
+    def getNumTrees(self):
+        """Number of trees in ensemble."""
+        return self._call_java("getNumTrees")
+
     @property
     @since("1.5.0")
     def treeWeights(self):
         """Return the weights for each tree"""
         return list(self._call_java("javaTreeWeights"))
 
+    @property
+    @since("2.0.0")
+    def totalNumNodes(self):
+        """Total number of nodes, summed over all trees in the ensemble."""
+        return self._call_java("totalNumNodes")
+
+    @property
+    @since("2.0.0")
+    def toDebugString(self):
+        """Full description of model."""
+        return self._call_java("toDebugString")
+
     def __repr__(self):
         return self._call_java("toString")
 
@@ -825,6 +855,10 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
     >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction
     0.0
+    >>> model.trees
+    [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
+    >>> model.getNumTrees
+    2
     >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
     >>> model.transform(test1).head().prediction
     0.5
@@ -896,6 +930,12 @@ class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLRead
     .. versionadded:: 1.4.0
     """
 
+    @property
+    @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent Estimators."""
+        return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+
     @property
     @since("2.0.0")
     def featureImportances(self):
@@ -1045,6 +1085,12 @@ class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
         """
         return self._call_java("featureImportances")
 
+    @property
+    @since("2.0.0")
+    def trees(self):
+        """Trees in this ensemble. Warning: These have null parent Estimators."""
+        return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+
 
 @inherit_doc
 class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
-- 
GitLab