From e375456063617cd7000d796024f41e5927f21edd Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph@databricks.com>
Date: Tue, 4 Aug 2015 14:54:26 -0700
Subject: [PATCH] [SPARK-9447] [ML] [PYTHON] Added HasRawPredictionCol,
 HasProbabilityCol to RandomForestClassifier

Added HasRawPredictionCol, HasProbabilityCol to RandomForestClassifier, plus doc tests for those columns.

CC: holdenk yanboliang

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7903 from jkbradley/rf-prob-python and squashes the following commits:

c62a83f [Joseph K. Bradley] made unit test more robust
14eeba2 [Joseph K. Bradley] added HasRawPredictionCol, HasProbabilityCol to RandomForestClassifier in PySpark
---
 python/pyspark/ml/classification.py | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 291320f881..5978d8f4d3 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -347,6 +347,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel):
 
 @inherit_doc
 class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
+                             HasRawPredictionCol, HasProbabilityCol,
                              DecisionTreeParams, HasCheckpointInterval):
     """
     `http://en.wikipedia.org/wiki/Random_forest  Random Forest`
@@ -354,6 +355,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     It supports both binary and multiclass labels, as well as both continuous and categorical
     features.
 
+    >>> import numpy
     >>> from numpy import allclose
     >>> from pyspark.mllib.linalg import Vectors
     >>> from pyspark.ml.feature import StringIndexer
@@ -368,8 +370,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
     True
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
-    >>> model.transform(test0).head().prediction
+    >>> result = model.transform(test0).head()
+    >>> result.prediction
     0.0
+    >>> numpy.argmax(result.probability)
+    0
+    >>> numpy.argmax(result.rawPrediction)
+    0
     >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
     >>> model.transform(test1).head().prediction
     1.0
@@ -390,11 +397,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
 
     @keyword_only
     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+                 probabilityCol="probability", rawPredictionCol="rawPrediction",
                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
                  numTrees=20, featureSubsetStrategy="auto", seed=None):
         """
         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+                 probabilityCol="probability", rawPredictionCol="rawPrediction", \
                  maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
                  maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
                  numTrees=20, featureSubsetStrategy="auto", seed=None)
@@ -427,11 +436,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
 
     @keyword_only
     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+                  probabilityCol="probability", rawPredictionCol="rawPrediction",
                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
                   impurity="gini", numTrees=20, featureSubsetStrategy="auto"):
         """
         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+                 probabilityCol="probability", rawPredictionCol="rawPrediction", \
                   maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
                   maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
                   impurity="gini", numTrees=20, featureSubsetStrategy="auto")
-- 
GitLab