Skip to content
Snippets Groups Projects
Commit e71cd96b authored by Holden Karau's avatar Holden Karau Committed by Nick Pentreath
Browse files

[SPARK-15316][PYSPARK][ML] Add linkPredictionCol to GeneralizedLinearRegression

## What changes were proposed in this pull request?

Add linkPredictionCol to GeneralizedLinearRegression and fix the PyDoc to generate the bullet list

## How was this patch tested?

doctests & built docs locally

Author: Holden Karau <holden@us.ibm.com>

Closes #13106 from holdenk/SPARK-15316-add-linkPredictionCol-toGeneralizedLinearRegression.
parent f5065abf
No related branches found
No related tags found
No related merge requests found
......@@ -1245,10 +1245,14 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
predictor (link function) and a description of the error distribution (family). It supports
"gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
is listed below. The first link function of each family is the default one.
- "gaussian" -> "identity", "log", "inverse"
- "binomial" -> "logit", "probit", "cloglog"
- "poisson" -> "log", "identity", "sqrt"
- "gamma" -> "inverse", "identity", "log"
* "gaussian" -> "identity", "log", "inverse"
* "binomial" -> "logit", "probit", "cloglog"
* "poisson" -> "log", "identity", "sqrt"
* "gamma" -> "inverse", "identity", "log"
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
......@@ -1258,9 +1262,12 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
... (1.0, Vectors.dense(1.0, 2.0)),
... (2.0, Vectors.dense(0.0, 0.0)),
... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
>>> model = glr.fit(df)
>>> abs(model.transform(df).head().prediction - 1.5) < 0.001
>>> transformed = model.transform(df)
>>> abs(transformed.head().prediction - 1.5) < 0.001
True
>>> abs(transformed.head().p - 1.5) < 0.001
True
>>> model.coefficients
DenseVector([1.5..., -1.0...])
......@@ -1290,20 +1297,23 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
"relationship between the linear predictor and the mean of the distribution " +
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
"and sqrt.", typeConverter=TypeConverters.toString)
linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
"predictor) column name", typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls"):
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""):
"""
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls")
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="")
"""
super(GeneralizedLinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
linkPredictionCol="")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
......@@ -1311,11 +1321,11 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
@since("2.0.0")
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls"):
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""):
"""
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls")
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="")
Sets params for generalized linear regression.
"""
kwargs = self.setParams._input_kwargs
......@@ -1338,6 +1348,20 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
"""
return self.getOrDefault(self.family)
@since("2.0.0")
def setLinkPredictionCol(self, value):
"""
Sets the value of :py:attr:`linkPredictionCol`.
"""
return self._set(linkPredictionCol=value)
@since("2.0.0")
def getLinkPredictionCol(self):
"""
Gets the value of linkPredictionCol or its default value.
"""
return self.getOrDefault(self.linkPredictionCol)
@since("2.0.0")
def setLink(self, value):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment