diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 7df61601fb1e9b12bab5566f8e4d556a811d0490..f2c70baf472a75c087511880ad3e23932aca4420 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -1201,6 +1201,20 @@ private[python] class PythonMLLibAPI extends Serializable { val spark = SparkSession.builder().sparkContext(sc).getOrCreate() spark.createDataFrame(blockMatrix.blocks) } + + /** + * Python-friendly version of [[MLUtils.convertVectorColumnsToML()]]. + */ + def convertVectorColumnsToML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { + MLUtils.convertVectorColumnsToML(dataset, cols.asScala: _*) + } + + /** + * Python-friendly version of [[MLUtils.convertVectorColumnsFromML()]] + */ + def convertVectorColumnsFromML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { + MLUtils.convertVectorColumnsFromML(dataset, cols.asScala: _*) + } } /** diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index a316ee1ad45ffa0ff297675f2a19321f9f02717f..a7e6bcc754dc71ff5793a9230a54d71838d5ea9c 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -26,6 +26,7 @@ if sys.version > '3': from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector +from pyspark.sql import DataFrame class MLUtils(object): @@ -200,6 +201,86 @@ class MLUtils(object): """ return callMLlibFunc("loadVectors", sc, path) + @staticmethod + @since("2.0.0") + def convertVectorColumnsToML(dataset, *cols): + """ + Converts vector columns in an input DataFrame from the + :py:class:`pyspark.mllib.linalg.Vector` type to the new + :py:class:`pyspark.ml.linalg.Vector` type under the `spark.ml` + package. + + :param dataset: + input dataset + :param cols: + a list of vector columns to be converted. + New vector columns will be ignored. If unspecified, all old + vector columns will be converted excepted nested ones. + :return: + the input dataset with old vector columns converted to the + new vector type + + >>> import pyspark + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.mllib.util import MLUtils + >>> df = spark.createDataFrame( + ... [(0, Vectors.sparse(2, [1], [1.0]), Vectors.dense(2.0, 3.0))], + ... ["id", "x", "y"]) + >>> r1 = MLUtils.convertVectorColumnsToML(df).first() + >>> isinstance(r1.x, pyspark.ml.linalg.SparseVector) + True + >>> isinstance(r1.y, pyspark.ml.linalg.DenseVector) + True + >>> r2 = MLUtils.convertVectorColumnsToML(df, "x").first() + >>> isinstance(r2.x, pyspark.ml.linalg.SparseVector) + True + >>> isinstance(r2.y, pyspark.mllib.linalg.DenseVector) + True + """ + if not isinstance(dataset, DataFrame): + raise TypeError("Input dataset must be a DataFrame but got {}.".format(type(dataset))) + return callMLlibFunc("convertVectorColumnsToML", dataset, list(cols)) + + @staticmethod + @since("2.0.0") + def convertVectorColumnsFromML(dataset, *cols): + """ + Converts vector columns in an input DataFrame to the + :py:class:`pyspark.mllib.linalg.Vector` type from the new + :py:class:`pyspark.ml.linalg.Vector` type under the `spark.ml` + package. + + :param dataset: + input dataset + :param cols: + a list of vector columns to be converted. + Old vector columns will be ignored. If unspecified, all new + vector columns will be converted except nested ones. + :return: + the input dataset with new vector columns converted to the + old vector type + + >>> import pyspark + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.mllib.util import MLUtils + >>> df = spark.createDataFrame( + ... [(0, Vectors.sparse(2, [1], [1.0]), Vectors.dense(2.0, 3.0))], + ... ["id", "x", "y"]) + >>> r1 = MLUtils.convertVectorColumnsFromML(df).first() + >>> isinstance(r1.x, pyspark.mllib.linalg.SparseVector) + True + >>> isinstance(r1.y, pyspark.mllib.linalg.DenseVector) + True + >>> r2 = MLUtils.convertVectorColumnsFromML(df, "x").first() + >>> isinstance(r2.x, pyspark.mllib.linalg.SparseVector) + True + >>> isinstance(r2.y, pyspark.ml.linalg.DenseVector) + True + """ + if not isinstance(dataset, DataFrame): + raise TypeError("Input dataset must be a DataFrame but got {}.".format(type(dataset))) + return callMLlibFunc("convertVectorColumnsFromML", dataset, list(cols)) + class Saveable(object): """ @@ -355,6 +436,7 @@ def _test(): .master("local[2]")\ .appName("mllib.util tests")\ .getOrCreate() + globs['spark'] = spark globs['sc'] = spark.sparkContext (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop()