diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index e64c7a392b93b9bb09ece9349cf17b79b44c47d2..922f8069fac49592f166c7738413ea8a0b962f04 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -19,7 +19,7 @@ import warnings from pyspark import since from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( @@ -272,7 +272,7 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): return BinaryLogisticRegressionSummary(java_blr_summary) -class LogisticRegressionSummary(JavaCallable): +class LogisticRegressionSummary(JavaWrapper): """ Abstraction for Logistic Regression Results for a given model. diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index c9b95b3bf45d9085386d9846c0f12f9b795bbf35..4b0bade10280241178c8553a51f79bae9ff36767 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -18,7 +18,7 @@ from abc import abstractmethod, ABCMeta from pyspark import since -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only @@ -81,7 +81,7 @@ class Evaluator(Params): @inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): +class JavaEvaluator(JavaParams, Evaluator): """ Base class for :py:class:`Evaluator`s that wrap Java/Scala implementations. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 2b5504bc2966a81253aa0c0cd8bdb8257b188e9f..9d654e8b0f8d063cf9b1c8ba621c5ea587024313 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -25,7 +25,7 @@ from pyspark import since from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.common import inherit_doc @@ -177,7 +177,7 @@ class Pipeline(Estimator, MLReadable, MLWritable): # Create a new instance of this stage. py_stage = cls() # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()] py_stage.setStages(py_stages) py_stage._resetUid(java_stage.uid()) return py_stage @@ -195,7 +195,7 @@ class Pipeline(Estimator, MLReadable, MLWritable): for idx, stage in enumerate(self.getStages()): java_stages[idx] = stage._to_java() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) _java_obj.setStages(java_stages) return _java_obj @@ -275,7 +275,7 @@ class PipelineModel(Model, MLReadable, MLWritable): Used for ML persistence. """ # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.stages()] # Create a new instance of this stage. py_stage = cls(py_stages) py_stage._resetUid(java_stage.uid()) @@ -295,6 +295,6 @@ class PipelineModel(Model, MLReadable, MLWritable): java_stages[idx] = stage._to_java() _java_obj =\ - JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index bc88f88b7f1e3d635decf121632ba7ca2489043c..316d7e30bcf10267b67872545e4df4c49170de1e 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -20,7 +20,7 @@ import warnings from pyspark import since from pyspark.ml.param.shared import * from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.mllib.common import inherit_doc from pyspark.sql import DataFrame @@ -188,7 +188,7 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): return LinearRegressionSummary(java_lr_summary) -class LinearRegressionSummary(JavaCallable): +class LinearRegressionSummary(JavaWrapper): """ .. note:: Experimental diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2dcd5eeb52c21266adb5986281d4ee5fc3937efc..bcbeacbe80491ec25d5959323579b6085ab1f214 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -52,7 +52,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -644,7 +644,7 @@ class PersistenceTest(PySparkTestCase): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaWrapper): + if isinstance(m1, JavaParams): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ea8c61b7efe6cd66625d7e885462f0ca2ae06a2c..456d79d897e003b7a199199b26af2ea78b43aa03 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand from pyspark.mllib.common import inherit_doc, _py2java @@ -148,8 +148,8 @@ class ValidatorParams(HasSeed): """ # Load information from java_stage to the instance. - estimator = JavaWrapper._from_java(java_stage.getEstimator()) - evaluator = JavaWrapper._from_java(java_stage.getEvaluator()) + estimator = JavaParams._from_java(java_stage.getEstimator()) + evaluator = JavaParams._from_java(java_stage.getEvaluator()) epms = [estimator._transfer_param_map_from_java(epm) for epm in java_stage.getEstimatorParamMaps()] return estimator, epms, evaluator @@ -329,7 +329,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -393,7 +393,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. py_stage = cls(bestModel=bestModel)\ @@ -410,10 +410,10 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", - self.uid, - self.bestModel._to_java(), - _py2java(sc, [])) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -574,8 +574,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", - self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -639,7 +639,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = \ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. @@ -657,7 +657,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj( + _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index d4411fdfb9ddea6e943163557a199b912a312ff1..9dfcef0e40d67bdc9d2ffbdce863ffa30fb43c16 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -99,7 +99,7 @@ class MLWriter(object): @inherit_doc class JavaMLWriter(MLWriter): """ - (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types """ def __init__(self, instance): @@ -178,7 +178,7 @@ class MLReader(object): @inherit_doc class JavaMLReader(MLReader): """ - (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types """ def __init__(self, clazz): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index bbeb6cfe6f58aa19ae4a1e612d2ff36c2b7c2e78..cd0e5b80d555943a1614b8f84f99552c14e22a14 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm from pyspark.mllib.common import inherit_doc, _java2py, _py2java -@inherit_doc -class JavaWrapper(Params): +class JavaWrapper(object): """ - Utility class to help create wrapper classes from Java/Scala - implementations of pipeline components. + Wrapper class for a Java companion object """ + def __init__(self, java_obj=None): + super(JavaWrapper, self).__init__() + self._java_obj = java_obj - __metaclass__ = ABCMeta - - def __init__(self): + @classmethod + def _create_from_java_class(cls, java_class, *args): """ - Initialize the wrapped java object to None + Construct this object from given Java classname and arguments """ - super(JavaWrapper, self).__init__() - #: The wrapped Java companion object. Subclasses should initialize - #: it properly. The param values in the Java object should be - #: synced with the Python wrapper in fit/transform/evaluate/copy. - self._java_obj = None + java_obj = JavaWrapper._new_java_obj(java_class, *args) + return cls(java_obj) + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) @staticmethod def _new_java_obj(java_class, *args): """ - Construct a new Java object. + Returns a new Java object. """ sc = SparkContext._active_spark_context java_obj = _jvm() @@ -56,6 +59,18 @@ class JavaWrapper(Params): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + +@inherit_doc +class JavaParams(JavaWrapper, Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + #: The param values in the Java object should be + #: synced with the Python wrapper in fit/transform/evaluate/copy. + + __metaclass__ = ABCMeta + def _make_java_param_pair(self, param, value): """ Makes a Java parm pair. @@ -151,7 +166,7 @@ class JavaWrapper(Params): stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. py_type = __get_class(stage_name) - if issubclass(py_type, JavaWrapper): + if issubclass(py_type, JavaParams): # Load information from java_stage to the instance. py_stage = py_type() py_stage._java_obj = java_stage @@ -166,7 +181,7 @@ class JavaWrapper(Params): @inherit_doc -class JavaEstimator(Estimator, JavaWrapper): +class JavaEstimator(JavaParams, Estimator): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. @@ -199,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper): @inherit_doc -class JavaTransformer(Transformer, JavaWrapper): +class JavaTransformer(JavaParams, Transformer): """ Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. Subclasses should ensure they have the transformer Java object @@ -213,30 +228,8 @@ class JavaTransformer(Transformer, JavaWrapper): return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) -class JavaCallable(object): - """ - Wrapper for a plain object in JVM to make Java calls, can be used - as a mixin to another class that defines a _java_obj wrapper - """ - def __init__(self, java_obj=None, sc=None): - super(JavaCallable, self).__init__() - self._sc = sc if sc is not None else SparkContext._active_spark_context - # if this class is a mixin and _java_obj is already defined then don't initialize - if java_obj is not None or not hasattr(self, "_java_obj"): - self._java_obj = java_obj - - def __del__(self): - if self._java_obj is not None: - self._sc._gateway.detach(self._java_obj) - - def _call_java(self, name, *args): - m = getattr(self._java_obj, name) - java_args = [_py2java(self._sc, arg) for arg in args] - return _java2py(self._sc, m(*java_args)) - - @inherit_doc -class JavaModel(Model, JavaCallable, JavaTransformer): +class JavaModel(JavaTransformer, Model): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before @@ -259,9 +252,8 @@ class JavaModel(Model, JavaCallable, JavaTransformer): these wrappers depend on pyspark.ml.util (both directly and via other ML classes). """ - super(JavaModel, self).__init__() + super(JavaModel, self).__init__(java_model) if java_model is not None: - self._java_obj = java_model self.uid = java_model.uid() def copy(self, extra=None):