From fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler <cutlerb@gmail.com> Date: Wed, 13 Apr 2016 14:08:57 -0700 Subject: [PATCH] [SPARK-14472][PYSPARK][ML] Cleanup ML JavaWrapper and related class hierarchy Currently, JavaWrapper is only a wrapper class for pipeline classes that have Params and JavaCallable is a separate mixin that provides methods to make Java calls. This change simplifies the class structure and to define the Java wrapper in a plain base class along with methods to make Java calls. Also, renames Java wrapper classes to better reflect their purpose. Ran existing Python ml tests and generated documentation to test this change. Author: Bryan Cutler <cutlerb@gmail.com> Closes #12304 from BryanCutler/pyspark-cleanup-JavaWrapper-SPARK-14472. --- python/pyspark/ml/classification.py | 4 +- python/pyspark/ml/evaluation.py | 4 +- python/pyspark/ml/pipeline.py | 10 ++-- python/pyspark/ml/regression.py | 4 +- python/pyspark/ml/tests.py | 4 +- python/pyspark/ml/tuning.py | 26 +++++----- python/pyspark/ml/util.py | 4 +- python/pyspark/ml/wrapper.py | 76 +++++++++++++---------------- 8 files changed, 62 insertions(+), 70 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index e64c7a392b..922f8069fa 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -19,7 +19,7 @@ import warnings from pyspark import since from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( @@ -272,7 +272,7 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): return BinaryLogisticRegressionSummary(java_blr_summary) -class LogisticRegressionSummary(JavaCallable): +class LogisticRegressionSummary(JavaWrapper): """ Abstraction for Logistic Regression Results for a given model. diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index c9b95b3bf4..4b0bade102 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -18,7 +18,7 @@ from abc import abstractmethod, ABCMeta from pyspark import since -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only @@ -81,7 +81,7 @@ class Evaluator(Params): @inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): +class JavaEvaluator(JavaParams, Evaluator): """ Base class for :py:class:`Evaluator`s that wrap Java/Scala implementations. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 2b5504bc29..9d654e8b0f 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -25,7 +25,7 @@ from pyspark import since from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.common import inherit_doc @@ -177,7 +177,7 @@ class Pipeline(Estimator, MLReadable, MLWritable): # Create a new instance of this stage. py_stage = cls() # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()] py_stage.setStages(py_stages) py_stage._resetUid(java_stage.uid()) return py_stage @@ -195,7 +195,7 @@ class Pipeline(Estimator, MLReadable, MLWritable): for idx, stage in enumerate(self.getStages()): java_stages[idx] = stage._to_java() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) _java_obj.setStages(java_stages) return _java_obj @@ -275,7 +275,7 @@ class PipelineModel(Model, MLReadable, MLWritable): Used for ML persistence. """ # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.stages()] # Create a new instance of this stage. py_stage = cls(py_stages) py_stage._resetUid(java_stage.uid()) @@ -295,6 +295,6 @@ class PipelineModel(Model, MLReadable, MLWritable): java_stages[idx] = stage._to_java() _java_obj =\ - JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index bc88f88b7f..316d7e30bc 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -20,7 +20,7 @@ import warnings from pyspark import since from pyspark.ml.param.shared import * from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.mllib.common import inherit_doc from pyspark.sql import DataFrame @@ -188,7 +188,7 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): return LinearRegressionSummary(java_lr_summary) -class LinearRegressionSummary(JavaCallable): +class LinearRegressionSummary(JavaWrapper): """ .. note:: Experimental diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2dcd5eeb52..bcbeacbe80 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -52,7 +52,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -644,7 +644,7 @@ class PersistenceTest(PySparkTestCase): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaWrapper): + if isinstance(m1, JavaParams): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ea8c61b7ef..456d79d897 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand from pyspark.mllib.common import inherit_doc, _py2java @@ -148,8 +148,8 @@ class ValidatorParams(HasSeed): """ # Load information from java_stage to the instance. - estimator = JavaWrapper._from_java(java_stage.getEstimator()) - evaluator = JavaWrapper._from_java(java_stage.getEvaluator()) + estimator = JavaParams._from_java(java_stage.getEstimator()) + evaluator = JavaParams._from_java(java_stage.getEvaluator()) epms = [estimator._transfer_param_map_from_java(epm) for epm in java_stage.getEstimatorParamMaps()] return estimator, epms, evaluator @@ -329,7 +329,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -393,7 +393,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. py_stage = cls(bestModel=bestModel)\ @@ -410,10 +410,10 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", - self.uid, - self.bestModel._to_java(), - _py2java(sc, [])) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -574,8 +574,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", - self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -639,7 +639,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = \ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. @@ -657,7 +657,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj( + _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index d4411fdfb9..9dfcef0e40 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -99,7 +99,7 @@ class MLWriter(object): @inherit_doc class JavaMLWriter(MLWriter): """ - (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types """ def __init__(self, instance): @@ -178,7 +178,7 @@ class MLReader(object): @inherit_doc class JavaMLReader(MLReader): """ - (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types """ def __init__(self, clazz): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index bbeb6cfe6f..cd0e5b80d5 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm from pyspark.mllib.common import inherit_doc, _java2py, _py2java -@inherit_doc -class JavaWrapper(Params): +class JavaWrapper(object): """ - Utility class to help create wrapper classes from Java/Scala - implementations of pipeline components. + Wrapper class for a Java companion object """ + def __init__(self, java_obj=None): + super(JavaWrapper, self).__init__() + self._java_obj = java_obj - __metaclass__ = ABCMeta - - def __init__(self): + @classmethod + def _create_from_java_class(cls, java_class, *args): """ - Initialize the wrapped java object to None + Construct this object from given Java classname and arguments """ - super(JavaWrapper, self).__init__() - #: The wrapped Java companion object. Subclasses should initialize - #: it properly. The param values in the Java object should be - #: synced with the Python wrapper in fit/transform/evaluate/copy. - self._java_obj = None + java_obj = JavaWrapper._new_java_obj(java_class, *args) + return cls(java_obj) + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) @staticmethod def _new_java_obj(java_class, *args): """ - Construct a new Java object. + Returns a new Java object. """ sc = SparkContext._active_spark_context java_obj = _jvm() @@ -56,6 +59,18 @@ class JavaWrapper(Params): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + +@inherit_doc +class JavaParams(JavaWrapper, Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + #: The param values in the Java object should be + #: synced with the Python wrapper in fit/transform/evaluate/copy. + + __metaclass__ = ABCMeta + def _make_java_param_pair(self, param, value): """ Makes a Java parm pair. @@ -151,7 +166,7 @@ class JavaWrapper(Params): stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. py_type = __get_class(stage_name) - if issubclass(py_type, JavaWrapper): + if issubclass(py_type, JavaParams): # Load information from java_stage to the instance. py_stage = py_type() py_stage._java_obj = java_stage @@ -166,7 +181,7 @@ class JavaWrapper(Params): @inherit_doc -class JavaEstimator(Estimator, JavaWrapper): +class JavaEstimator(JavaParams, Estimator): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. @@ -199,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper): @inherit_doc -class JavaTransformer(Transformer, JavaWrapper): +class JavaTransformer(JavaParams, Transformer): """ Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. Subclasses should ensure they have the transformer Java object @@ -213,30 +228,8 @@ class JavaTransformer(Transformer, JavaWrapper): return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) -class JavaCallable(object): - """ - Wrapper for a plain object in JVM to make Java calls, can be used - as a mixin to another class that defines a _java_obj wrapper - """ - def __init__(self, java_obj=None, sc=None): - super(JavaCallable, self).__init__() - self._sc = sc if sc is not None else SparkContext._active_spark_context - # if this class is a mixin and _java_obj is already defined then don't initialize - if java_obj is not None or not hasattr(self, "_java_obj"): - self._java_obj = java_obj - - def __del__(self): - if self._java_obj is not None: - self._sc._gateway.detach(self._java_obj) - - def _call_java(self, name, *args): - m = getattr(self._java_obj, name) - java_args = [_py2java(self._sc, arg) for arg in args] - return _java2py(self._sc, m(*java_args)) - - @inherit_doc -class JavaModel(Model, JavaCallable, JavaTransformer): +class JavaModel(JavaTransformer, Model): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before @@ -259,9 +252,8 @@ class JavaModel(Model, JavaCallable, JavaTransformer): these wrappers depend on pyspark.ml.util (both directly and via other ML classes). """ - super(JavaModel, self).__init__() + super(JavaModel, self).__init__(java_model) if java_model is not None: - self._java_obj = java_model self.uid = java_model.uid() def copy(self, extra=None): -- GitLab