diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 3403ec4259b86f5b8b947ae5efa2bf6ad321091d..e6ec4e2e36ff07964cd003a409ec03d508e98e4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -47,7 +47,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * Options are: * 'skip': filter out rows with invalid data. * 'error': throw an error. - * 'keep': put invalid data in a special additional bucket, at index numCategories. + * 'keep': put invalid data in a special additional bucket, at index of the number of + * categories of the feature. * Default value: "error" * @group param */ @@ -55,7 +56,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data (unseen labels or NULL values). " + "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " + - "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + "or 'keep' (put invalid data in a special additional bucket, at index of the " + + "number of categories of the feature).", ParamValidators.inArray(VectorIndexer.supportedHandleInvalids)) setDefault(handleInvalid, VectorIndexer.ERROR_INVALID) @@ -112,7 +114,6 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. * - Specify certain features to not index, either via a parameter or via existing metadata. * - Add warning if a categorical feature has only 1 category. - * - Add option for allowing unknown categories. */ @Since("1.4.0") class VectorIndexer @Since("1.4.0") ( diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 232ae3ef41166e86dc021df929d524547a6887e2..608f2a57154975a9c5d79a7963ce8fbd5ae0e971 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2490,7 +2490,8 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl @inherit_doc -class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ Class for indexing categorical feature columns in a dataset of `Vector`. @@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja do not recompute. - Specify certain features to not index, either via a parameter or via existing metadata. - Add warning if a categorical feature has only 1 category. - - Add option for allowing unknown categories. >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),), @@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja True >>> loadedModel.categoryMaps == model.categoryMaps True + >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"]) + >>> indexer.getHandleInvalid() + 'error' + >>> model3 = indexer.setHandleInvalid("skip").fit(df) + >>> model3.transform(dfWithInvalid).count() + 0 + >>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df) + >>> model4.transform(dfWithInvalid).head().indexed + DenseVector([2.0, 1.0]) .. versionadded:: 1.4.0 """ @@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja "(>= 2). If a feature is found to have > maxCategories values, then " + "it is declared continuous.", typeConverter=TypeConverters.toInt) + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data " + + "(unseen labels or NULL values). Options are 'skip' (filter out " + + "rows with invalid data), 'error' (throw an error), or 'keep' (put " + + "invalid data in a special additional bucket, at index of the number " + + "of categories of the feature).", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, maxCategories=20, inputCol=None, outputCol=None): + def __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, maxCategories=20, inputCol=None, outputCol=None) + __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error") """ super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) - self._setDefault(maxCategories=20) + self._setDefault(maxCategories=20, handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, maxCategories=20, inputCol=None, outputCol=None): + def setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, maxCategories=20, inputCol=None, outputCol=None) + setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this VectorIndexer. """ kwargs = self._input_kwargs