diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index e64c7a392b93b9bb09ece9349cf17b79b44c47d2..922f8069fac49592f166c7738413ea8a0b962f04 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -19,7 +19,7 @@ import warnings
 
 from pyspark import since
 from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
 from pyspark.ml.param import TypeConverters
 from pyspark.ml.param.shared import *
 from pyspark.ml.regression import (
@@ -272,7 +272,7 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
         return BinaryLogisticRegressionSummary(java_blr_summary)
 
 
-class LogisticRegressionSummary(JavaCallable):
+class LogisticRegressionSummary(JavaWrapper):
     """
     Abstraction for Logistic Regression Results for a given model.
 
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index c9b95b3bf45d9085386d9846c0f12f9b795bbf35..4b0bade10280241178c8553a51f79bae9ff36767 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -18,7 +18,7 @@
 from abc import abstractmethod, ABCMeta
 
 from pyspark import since
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
 from pyspark.ml.param import Param, Params
 from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
 from pyspark.ml.util import keyword_only
@@ -81,7 +81,7 @@ class Evaluator(Params):
 
 
 @inherit_doc
-class JavaEvaluator(Evaluator, JavaWrapper):
+class JavaEvaluator(JavaParams, Evaluator):
     """
     Base class for :py:class:`Evaluator`s that wrap Java/Scala
     implementations.
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 2b5504bc2966a81253aa0c0cd8bdb8257b188e9f..9d654e8b0f8d063cf9b1c8ba621c5ea587024313 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark import since
 from pyspark.ml import Estimator, Model, Transformer
 from pyspark.ml.param import Param, Params
 from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
 from pyspark.mllib.common import inherit_doc
 
 
@@ -177,7 +177,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
         # Create a new instance of this stage.
         py_stage = cls()
         # Load information from java_stage to the instance.
-        py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+        py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()]
         py_stage.setStages(py_stages)
         py_stage._resetUid(java_stage.uid())
         return py_stage
@@ -195,7 +195,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
         for idx, stage in enumerate(self.getStages()):
             java_stages[idx] = stage._to_java()
 
-        _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
+        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
         _java_obj.setStages(java_stages)
 
         return _java_obj
@@ -275,7 +275,7 @@ class PipelineModel(Model, MLReadable, MLWritable):
         Used for ML persistence.
         """
         # Load information from java_stage to the instance.
-        py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+        py_stages = [JavaParams._from_java(s) for s in java_stage.stages()]
         # Create a new instance of this stage.
         py_stage = cls(py_stages)
         py_stage._resetUid(java_stage.uid())
@@ -295,6 +295,6 @@ class PipelineModel(Model, MLReadable, MLWritable):
             java_stages[idx] = stage._to_java()
 
         _java_obj =\
-            JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
+            JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
 
         return _java_obj
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index bc88f88b7f1e3d635decf121632ba7ca2489043c..316d7e30bcf10267b67872545e4df4c49170de1e 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -20,7 +20,7 @@ import warnings
 from pyspark import since
 from pyspark.ml.param.shared import *
 from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
 from pyspark.mllib.common import inherit_doc
 from pyspark.sql import DataFrame
 
@@ -188,7 +188,7 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
         return LinearRegressionSummary(java_lr_summary)
 
 
-class LinearRegressionSummary(JavaCallable):
+class LinearRegressionSummary(JavaWrapper):
     """
     .. note:: Experimental
 
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 2dcd5eeb52c21266adb5986281d4ee5fc3937efc..bcbeacbe80491ec25d5959323579b6085ab1f214 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -52,7 +52,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
 from pyspark.ml.tuning import *
 from pyspark.ml.util import keyword_only
 from pyspark.ml.util import MLWritable, MLWriter
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
 from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
 from pyspark.sql import DataFrame, SQLContext, Row
 from pyspark.sql.functions import rand
@@ -644,7 +644,7 @@ class PersistenceTest(PySparkTestCase):
         """
         self.assertEqual(m1.uid, m2.uid)
         self.assertEqual(type(m1), type(m2))
-        if isinstance(m1, JavaWrapper):
+        if isinstance(m1, JavaParams):
             self.assertEqual(len(m1.params), len(m2.params))
             for p in m1.params:
                 self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index ea8c61b7efe6cd66625d7e885462f0ca2ae06a2c..456d79d897e003b7a199199b26af2ea78b43aa03 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model
 from pyspark.ml.param import Params, Param, TypeConverters
 from pyspark.ml.param.shared import HasSeed
 from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
 from pyspark.sql.functions import rand
 from pyspark.mllib.common import inherit_doc, _py2java
 
@@ -148,8 +148,8 @@ class ValidatorParams(HasSeed):
         """
 
         # Load information from java_stage to the instance.
-        estimator = JavaWrapper._from_java(java_stage.getEstimator())
-        evaluator = JavaWrapper._from_java(java_stage.getEvaluator())
+        estimator = JavaParams._from_java(java_stage.getEstimator())
+        evaluator = JavaParams._from_java(java_stage.getEvaluator())
         epms = [estimator._transfer_param_map_from_java(epm)
                 for epm in java_stage.getEstimatorParamMaps()]
         return estimator, epms, evaluator
@@ -329,7 +329,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
 
         estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
 
-        _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
+        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
         _java_obj.setEstimatorParamMaps(epms)
         _java_obj.setEvaluator(evaluator)
         _java_obj.setEstimator(estimator)
@@ -393,7 +393,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
         """
 
         # Load information from java_stage to the instance.
-        bestModel = JavaWrapper._from_java(java_stage.bestModel())
+        bestModel = JavaParams._from_java(java_stage.bestModel())
         estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
         # Create a new instance of this stage.
         py_stage = cls(bestModel=bestModel)\
@@ -410,10 +410,10 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
 
         sc = SparkContext._active_spark_context
 
-        _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
-                                              self.uid,
-                                              self.bestModel._to_java(),
-                                              _py2java(sc, []))
+        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+                                             self.uid,
+                                             self.bestModel._to_java(),
+                                             _py2java(sc, []))
         estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
 
         _java_obj.set("evaluator", evaluator)
@@ -574,8 +574,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
 
         estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
 
-        _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
-                                              self.uid)
+        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+                                             self.uid)
         _java_obj.setEstimatorParamMaps(epms)
         _java_obj.setEvaluator(evaluator)
         _java_obj.setEstimator(estimator)
@@ -639,7 +639,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
         """
 
         # Load information from java_stage to the instance.
-        bestModel = JavaWrapper._from_java(java_stage.bestModel())
+        bestModel = JavaParams._from_java(java_stage.bestModel())
         estimator, epms, evaluator = \
             super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
         # Create a new instance of this stage.
@@ -657,7 +657,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
 
         sc = SparkContext._active_spark_context
 
-        _java_obj = JavaWrapper._new_java_obj(
+        _java_obj = JavaParams._new_java_obj(
             "org.apache.spark.ml.tuning.TrainValidationSplitModel",
             self.uid,
             self.bestModel._to_java(),
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index d4411fdfb9ddea6e943163557a199b912a312ff1..9dfcef0e40d67bdc9d2ffbdce863ffa30fb43c16 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -99,7 +99,7 @@ class MLWriter(object):
 @inherit_doc
 class JavaMLWriter(MLWriter):
     """
-    (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types
+    (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
     """
 
     def __init__(self, instance):
@@ -178,7 +178,7 @@ class MLReader(object):
 @inherit_doc
 class JavaMLReader(MLReader):
     """
-    (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types
+    (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
     """
 
     def __init__(self, clazz):
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index bbeb6cfe6f58aa19ae4a1e612d2ff36c2b7c2e78..cd0e5b80d555943a1614b8f84f99552c14e22a14 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm
 from pyspark.mllib.common import inherit_doc, _java2py, _py2java
 
 
-@inherit_doc
-class JavaWrapper(Params):
+class JavaWrapper(object):
     """
-    Utility class to help create wrapper classes from Java/Scala
-    implementations of pipeline components.
+    Wrapper class for a Java companion object
     """
+    def __init__(self, java_obj=None):
+        super(JavaWrapper, self).__init__()
+        self._java_obj = java_obj
 
-    __metaclass__ = ABCMeta
-
-    def __init__(self):
+    @classmethod
+    def _create_from_java_class(cls, java_class, *args):
         """
-        Initialize the wrapped java object to None
+        Construct this object from given Java classname and arguments
         """
-        super(JavaWrapper, self).__init__()
-        #: The wrapped Java companion object. Subclasses should initialize
-        #: it properly. The param values in the Java object should be
-        #: synced with the Python wrapper in fit/transform/evaluate/copy.
-        self._java_obj = None
+        java_obj = JavaWrapper._new_java_obj(java_class, *args)
+        return cls(java_obj)
+
+    def _call_java(self, name, *args):
+        m = getattr(self._java_obj, name)
+        sc = SparkContext._active_spark_context
+        java_args = [_py2java(sc, arg) for arg in args]
+        return _java2py(sc, m(*java_args))
 
     @staticmethod
     def _new_java_obj(java_class, *args):
         """
-        Construct a new Java object.
+        Returns a new Java object.
         """
         sc = SparkContext._active_spark_context
         java_obj = _jvm()
@@ -56,6 +59,18 @@ class JavaWrapper(Params):
         java_args = [_py2java(sc, arg) for arg in args]
         return java_obj(*java_args)
 
+
+@inherit_doc
+class JavaParams(JavaWrapper, Params):
+    """
+    Utility class to help create wrapper classes from Java/Scala
+    implementations of pipeline components.
+    """
+    #: The param values in the Java object should be
+    #: synced with the Python wrapper in fit/transform/evaluate/copy.
+
+    __metaclass__ = ABCMeta
+
     def _make_java_param_pair(self, param, value):
         """
         Makes a Java parm pair.
@@ -151,7 +166,7 @@ class JavaWrapper(Params):
         stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
         # Generate a default new instance from the stage_name class.
         py_type = __get_class(stage_name)
-        if issubclass(py_type, JavaWrapper):
+        if issubclass(py_type, JavaParams):
             # Load information from java_stage to the instance.
             py_stage = py_type()
             py_stage._java_obj = java_stage
@@ -166,7 +181,7 @@ class JavaWrapper(Params):
 
 
 @inherit_doc
-class JavaEstimator(Estimator, JavaWrapper):
+class JavaEstimator(JavaParams, Estimator):
     """
     Base class for :py:class:`Estimator`s that wrap Java/Scala
     implementations.
@@ -199,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper):
 
 
 @inherit_doc
-class JavaTransformer(Transformer, JavaWrapper):
+class JavaTransformer(JavaParams, Transformer):
     """
     Base class for :py:class:`Transformer`s that wrap Java/Scala
     implementations. Subclasses should ensure they have the transformer Java object
@@ -213,30 +228,8 @@ class JavaTransformer(Transformer, JavaWrapper):
         return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)
 
 
-class JavaCallable(object):
-    """
-    Wrapper for a plain object in JVM to make Java calls, can be used
-    as a mixin to another class that defines a _java_obj wrapper
-    """
-    def __init__(self, java_obj=None, sc=None):
-        super(JavaCallable, self).__init__()
-        self._sc = sc if sc is not None else SparkContext._active_spark_context
-        # if this class is a mixin and _java_obj is already defined then don't initialize
-        if java_obj is not None or not hasattr(self, "_java_obj"):
-            self._java_obj = java_obj
-
-    def __del__(self):
-        if self._java_obj is not None:
-            self._sc._gateway.detach(self._java_obj)
-
-    def _call_java(self, name, *args):
-        m = getattr(self._java_obj, name)
-        java_args = [_py2java(self._sc, arg) for arg in args]
-        return _java2py(self._sc, m(*java_args))
-
-
 @inherit_doc
-class JavaModel(Model, JavaCallable, JavaTransformer):
+class JavaModel(JavaTransformer, Model):
     """
     Base class for :py:class:`Model`s that wrap Java/Scala
     implementations. Subclasses should inherit this class before
@@ -259,9 +252,8 @@ class JavaModel(Model, JavaCallable, JavaTransformer):
         these wrappers depend on pyspark.ml.util (both directly and via
         other ML classes).
         """
-        super(JavaModel, self).__init__()
+        super(JavaModel, self).__init__(java_model)
         if java_model is not None:
-            self._java_obj = java_model
             self.uid = java_model.uid()
 
     def copy(self, extra=None):