From 9f8ce4825e378b6a856ce65cb9986a5a0f0b624e Mon Sep 17 00:00:00 2001
From: Xin Ren <iamshrek@126.com>
Date: Sun, 12 Mar 2017 12:15:19 -0700
Subject: [PATCH] [SPARK-19282][ML][SPARKR] RandomForest Wrapper and GBT
 Wrapper return param "maxDepth" to R models

## What changes were proposed in this pull request?

RandomForest R Wrapper and GBT R Wrapper return param `maxDepth` to R models.

Below 4 R wrappers are changed:
* `RandomForestClassificationWrapper`
* `RandomForestRegressionWrapper`
* `GBTClassificationWrapper`
* `GBTRegressionWrapper`

## How was this patch tested?

Test manually on my local machine.

Author: Xin Ren <iamshrek@126.com>

Closes #17207 from keypointt/SPARK-19282.
---
 R/pkg/R/mllib_tree.R                                  | 11 +++++++----
 R/pkg/inst/tests/testthat/test_mllib_tree.R           | 10 ++++++++++
 .../apache/spark/ml/r/GBTClassificationWrapper.scala  |  1 +
 .../org/apache/spark/ml/r/GBTRegressionWrapper.scala  |  1 +
 .../ml/r/RandomForestClassificationWrapper.scala      |  1 +
 .../spark/ml/r/RandomForestRegressionWrapper.scala    |  1 +
 6 files changed, 21 insertions(+), 4 deletions(-)

diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 40a806c41b..82279be6fb 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -52,12 +52,14 @@ summary.treeEnsemble <- function(model) {
   numFeatures <- callJMethod(jobj, "numFeatures")
   features <-  callJMethod(jobj, "features")
   featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
+  maxDepth <- callJMethod(jobj, "maxDepth")
   numTrees <- callJMethod(jobj, "numTrees")
   treeWeights <- callJMethod(jobj, "treeWeights")
   list(formula = formula,
        numFeatures = numFeatures,
        features = features,
        featureImportances = featureImportances,
+       maxDepth = maxDepth,
        numTrees = numTrees,
        treeWeights = treeWeights,
        jobj = jobj)
@@ -70,6 +72,7 @@ print.summary.treeEnsemble <- function(x) {
   cat("\nNumber of features: ", x$numFeatures)
   cat("\nFeatures: ", unlist(x$features))
   cat("\nFeature importances: ", x$featureImportances)
+  cat("\nMax Depth: ", x$maxDepth)
   cat("\nNumber of trees: ", x$numTrees)
   cat("\nTree weights: ", unlist(x$treeWeights))
 
@@ -197,8 +200,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"),
 #' @return \code{summary} returns summary information of the fitted model, which is a list.
 #'         The list of components includes \code{formula} (formula),
 #'         \code{numFeatures} (number of features), \code{features} (list of features),
-#'         \code{featureImportances} (feature importances), \code{numTrees} (number of trees),
-#'         and \code{treeWeights} (tree weights).
+#'         \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees),
+#'         \code{numTrees} (number of trees), and \code{treeWeights} (tree weights).
 #' @rdname spark.gbt
 #' @aliases summary,GBTRegressionModel-method
 #' @export
@@ -403,8 +406,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
 #' @return \code{summary} returns summary information of the fitted model, which is a list.
 #'         The list of components includes \code{formula} (formula),
 #'         \code{numFeatures} (number of features), \code{features} (list of features),
-#'         \code{featureImportances} (feature importances), \code{numTrees} (number of trees),
-#'         and \code{treeWeights} (tree weights).
+#'         \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees),
+#'         \code{numTrees} (number of trees), and \code{treeWeights} (tree weights).
 #' @rdname spark.randomForest
 #' @aliases summary,RandomForestRegressionModel-method
 #' @export
diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R
index e6fda251eb..e0802a9b02 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_tree.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R
@@ -39,6 +39,7 @@ test_that("spark.gbt", {
                tolerance = 1e-4)
   stats <- summary(model)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
   expect_equal(stats$formula, "Employed ~ .")
   expect_equal(stats$numFeatures, 6)
   expect_equal(length(stats$treeWeights), 20)
@@ -53,6 +54,7 @@ test_that("spark.gbt", {
   expect_equal(stats$numFeatures, stats2$numFeatures)
   expect_equal(stats$features, stats2$features)
   expect_equal(stats$featureImportances, stats2$featureImportances)
+  expect_equal(stats$maxDepth, stats2$maxDepth)
   expect_equal(stats$numTrees, stats2$numTrees)
   expect_equal(stats$treeWeights, stats2$treeWeights)
 
@@ -66,6 +68,7 @@ test_that("spark.gbt", {
   stats <- summary(model)
   expect_equal(stats$numFeatures, 2)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
   expect_error(capture.output(stats), NA)
   expect_true(length(capture.output(stats)) > 6)
   predictions <- collect(predict(model, data))$prediction
@@ -93,6 +96,7 @@ test_that("spark.gbt", {
   expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
   expect_equal(s$numFeatures, 5)
   expect_equal(s$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
 
   # spark.gbt classification can work on libsvm data
   data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
@@ -116,6 +120,7 @@ test_that("spark.randomForest", {
 
   stats <- summary(model)
   expect_equal(stats$numTrees, 1)
+  expect_equal(stats$maxDepth, 5)
   expect_error(capture.output(stats), NA)
   expect_true(length(capture.output(stats)) > 6)
 
@@ -129,6 +134,7 @@ test_that("spark.randomForest", {
                tolerance = 1e-4)
   stats <- summary(model)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
 
   modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp")
   write.ml(model, modelPath)
@@ -141,6 +147,7 @@ test_that("spark.randomForest", {
   expect_equal(stats$features, stats2$features)
   expect_equal(stats$featureImportances, stats2$featureImportances)
   expect_equal(stats$numTrees, stats2$numTrees)
+  expect_equal(stats$maxDepth, stats2$maxDepth)
   expect_equal(stats$treeWeights, stats2$treeWeights)
 
   unlink(modelPath)
@@ -153,6 +160,7 @@ test_that("spark.randomForest", {
   stats <- summary(model)
   expect_equal(stats$numFeatures, 2)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
   expect_error(capture.output(stats), NA)
   expect_true(length(capture.output(stats)) > 6)
   # Test string prediction values
@@ -187,6 +195,8 @@ test_that("spark.randomForest", {
   stats <- summary(model)
   expect_equal(stats$numFeatures, 2)
   expect_equal(stats$numTrees, 20)
+  expect_equal(stats$maxDepth, 5)
+
   # Test numeric prediction values
   predictions <- collect(predict(model, data))$prediction
   expect_equal(length(grep("1.0", predictions)), 50)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
index aacb41ee26..c07eadb30a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala
@@ -44,6 +44,7 @@ private[r] class GBTClassifierWrapper private (
   lazy val featureImportances: Vector = gbtcModel.featureImportances
   lazy val numTrees: Int = gbtcModel.getNumTrees
   lazy val treeWeights: Array[Double] = gbtcModel.treeWeights
+  lazy val maxDepth: Int = gbtcModel.getMaxDepth
 
   def summary: String = gbtcModel.toDebugString
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
index 585077588e..b568d78592 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala
@@ -42,6 +42,7 @@ private[r] class GBTRegressorWrapper private (
   lazy val featureImportances: Vector = gbtrModel.featureImportances
   lazy val numTrees: Int = gbtrModel.getNumTrees
   lazy val treeWeights: Array[Double] = gbtrModel.treeWeights
+  lazy val maxDepth: Int = gbtrModel.getMaxDepth
 
   def summary: String = gbtrModel.toDebugString
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 366f375b58..8a83d4e980 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -44,6 +44,7 @@ private[r] class RandomForestClassifierWrapper private (
   lazy val featureImportances: Vector = rfcModel.featureImportances
   lazy val numTrees: Int = rfcModel.getNumTrees
   lazy val treeWeights: Array[Double] = rfcModel.treeWeights
+  lazy val maxDepth: Int = rfcModel.getMaxDepth
 
   def summary: String = rfcModel.toDebugString
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
index 4b9a3a731d..038bd79c70 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala
@@ -42,6 +42,7 @@ private[r] class RandomForestRegressorWrapper private (
   lazy val featureImportances: Vector = rfrModel.featureImportances
   lazy val numTrees: Int = rfrModel.getNumTrees
   lazy val treeWeights: Array[Double] = rfrModel.treeWeights
+  lazy val maxDepth: Int = rfrModel.getMaxDepth
 
   def summary: String = rfrModel.toDebugString
 
-- 
GitLab