Skip to content
Snippets Groups Projects
Commit b11887c0 authored by sethah's avatar sethah Committed by Joseph K. Bradley
Browse files

[SPARK-14264][PYSPARK][ML] Add feature importance for GBTs in pyspark

## What changes were proposed in this pull request?

Feature importances are exposed in the python API for GBTs.

Other changes:
* Update the random forest feature importance documentation to not repeat decision tree docstring and instead place a reference to it.

## How was this patch tested?

Python doc tests were updated to validate GBT feature importance.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #12056 from sethah/Pyspark_GBT_feature_importance.
parent e7854028
No related branches found
No related tags found
No related merge requests found
...@@ -396,7 +396,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR ...@@ -396,7 +396,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR
- Normalize importances for tree to sum to 1. - Normalize importances for tree to sum to 1.
Note: Feature importance for single decision trees can have high variance due to Note: Feature importance for single decision trees can have high variance due to
correlated predictor variables. Consider using a :class:`RandomForestClassifier` correlated predictor variables. Consider using a :py:class:`RandomForestClassifier`
to determine feature importance instead. to determine feature importance instead.
""" """
return self._call_java("featureImportances") return self._call_java("featureImportances")
...@@ -500,16 +500,12 @@ class RandomForestClassificationModel(TreeEnsembleModels): ...@@ -500,16 +500,12 @@ class RandomForestClassificationModel(TreeEnsembleModels):
""" """
Estimate of the importance of each feature. Estimate of the importance of each feature.
This generalizes the idea of "Gini" importance to other losses, Each feature's importance is the average of its importance across all trees in the ensemble
following the explanation of Gini importance from "Random Forests" documentation The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
and follows the implementation from scikit-learn.
This feature importance is calculated as follows: .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
- Average over trees:
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
where gain is scaled by the number of instances passing through node
- Normalize importances for tree to sum to 1.
- Normalize feature importance vector to sum to 1.
""" """
return self._call_java("featureImportances") return self._call_java("featureImportances")
...@@ -534,6 +530,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol ...@@ -534,6 +530,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> td = si_model.transform(df) >>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42)
>>> model = gbt.fit(td) >>> model = gbt.fit(td)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
...@@ -613,6 +611,21 @@ class GBTClassificationModel(TreeEnsembleModels): ...@@ -613,6 +611,21 @@ class GBTClassificationModel(TreeEnsembleModels):
.. versionadded:: 1.4.0 .. versionadded:: 1.4.0
""" """
@property
@since("2.0.0")
def featureImportances(self):
"""
Estimate of the importance of each feature.
Each feature's importance is the average of its importance across all trees in the ensemble
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
and follows the implementation from scikit-learn.
.. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
"""
return self._call_java("featureImportances")
@inherit_doc @inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
......
...@@ -533,7 +533,7 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada ...@@ -533,7 +533,7 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada
- Normalize importances for tree to sum to 1. - Normalize importances for tree to sum to 1.
Note: Feature importance for single decision trees can have high variance due to Note: Feature importance for single decision trees can have high variance due to
correlated predictor variables. Consider using a :class:`RandomForestRegressor` correlated predictor variables. Consider using a :py:class:`RandomForestRegressor`
to determine feature importance instead. to determine feature importance instead.
""" """
return self._call_java("featureImportances") return self._call_java("featureImportances")
...@@ -626,16 +626,12 @@ class RandomForestRegressionModel(TreeEnsembleModels): ...@@ -626,16 +626,12 @@ class RandomForestRegressionModel(TreeEnsembleModels):
""" """
Estimate of the importance of each feature. Estimate of the importance of each feature.
This generalizes the idea of "Gini" importance to other losses, Each feature's importance is the average of its importance across all trees in the ensemble
following the explanation of Gini importance from "Random Forests" documentation The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
and follows the implementation from scikit-learn.
This feature importance is calculated as follows: .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
- Average over trees:
- importance(feature j) = sum (over nodes which split on feature j) of the gain,
where gain is scaled by the number of instances passing through node
- Normalize importances for tree to sum to 1.
- Normalize feature importance vector to sum to 1.
""" """
return self._call_java("featureImportances") return self._call_java("featureImportances")
...@@ -655,6 +651,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, ...@@ -655,6 +651,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
>>> model = gbt.fit(df) >>> model = gbt.fit(df)
>>> model.featureImportances
SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
...@@ -734,6 +732,21 @@ class GBTRegressionModel(TreeEnsembleModels): ...@@ -734,6 +732,21 @@ class GBTRegressionModel(TreeEnsembleModels):
.. versionadded:: 1.4.0 .. versionadded:: 1.4.0
""" """
@property
@since("2.0.0")
def featureImportances(self):
"""
Estimate of the importance of each feature.
Each feature's importance is the average of its importance across all trees in the ensemble
The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
(Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
and follows the implementation from scikit-learn.
.. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
"""
return self._call_java("featureImportances")
@inherit_doc @inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment