diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 1c18df3b27ab98b28ffc47e2886cf99f91981dfc..bc88f88b7f1e3d635decf121632ba7ca2489043c 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -28,6 +28,7 @@ from pyspark.sql import DataFrame __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', 'GBTRegressionModel', + 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel' 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', @@ -1197,6 +1198,150 @@ class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): return self._call_java("predict", features) +@inherit_doc +class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol, + HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol, + HasSolver, JavaMLWritable, JavaMLReadable): + """ + Generalized Linear Regression. + + Fit a Generalized Linear Model specified by giving a symbolic description of the linear + predictor (link function) and a description of the error distribution (family). It supports + "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family + is listed below. The first link function of each family is the default one. + - "gaussian" -> "identity", "log", "inverse" + - "binomial" -> "logit", "probit", "cloglog" + - "poisson" -> "log", "identity", "sqrt" + - "gamma" -> "inverse", "identity", "log" + + .. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_ + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(0.0, 0.0)), + ... (1.0, Vectors.dense(1.0, 2.0)), + ... (2.0, Vectors.dense(0.0, 0.0)), + ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"]) + >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity") + >>> model = glr.fit(df) + >>> abs(model.transform(df).head().prediction - 1.5) < 0.001 + True + >>> model.coefficients + DenseVector([1.5..., -1.0...]) + >>> abs(model.intercept - 1.5) < 0.001 + True + >>> glr_path = temp_path + "/glr" + >>> glr.save(glr_path) + >>> glr2 = GeneralizedLinearRegression.load(glr_path) + >>> glr.getFamily() == glr2.getFamily() + True + >>> model_path = temp_path + "/glr_model" + >>> model.save(model_path) + >>> model2 = GeneralizedLinearRegressionModel.load(model_path) + >>> model.intercept == model2.intercept + True + >>> model.coefficients[0] == model2.coefficients[0] + True + + .. versionadded:: 2.0.0 + """ + + family = Param(Params._dummy(), "family", "The name of family which is a description of " + + "the error distribution to be used in the model. Supported options: " + + "gaussian(default), binomial, poisson and gamma.") + link = Param(Params._dummy(), "link", "The name of link function which provides the " + + "relationship between the linear predictor and the mean of the distribution " + + "function. Supported options: identity, log, inverse, logit, probit, cloglog " + + "and sqrt.") + + @keyword_only + def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, + regParam=0.0, weightCol=None, solver="irls"): + """ + __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ + regParam=0.0, weightCol=None, solver="irls") + """ + super(GeneralizedLinearRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, + regParam=0.0, weightCol=None, solver="irls"): + """ + setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ + regParam=0.0, weightCol=None, solver="irls") + Sets params for generalized linear regression. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return GeneralizedLinearRegressionModel(java_model) + + @since("2.0.0") + def setFamily(self, value): + """ + Sets the value of :py:attr:`family`. + """ + self._paramMap[self.family] = value + return self + + @since("2.0.0") + def getFamily(self): + """ + Gets the value of family or its default value. + """ + return self.getOrDefault(self.family) + + @since("2.0.0") + def setLink(self, value): + """ + Sets the value of :py:attr:`link`. + """ + self._paramMap[self.link] = value + return self + + @since("2.0.0") + def getLink(self): + """ + Gets the value of link or its default value. + """ + return self.getOrDefault(self.link) + + +class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + Model fitted by GeneralizedLinearRegression. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + + @property + @since("2.0.0") + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + + if __name__ == "__main__": import doctest import pyspark.ml.regression