diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 3403ec4259b86f5b8b947ae5efa2bf6ad321091d..e6ec4e2e36ff07964cd003a409ec03d508e98e4b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -47,7 +47,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
    * Options are:
    * 'skip': filter out rows with invalid data.
    * 'error': throw an error.
-   * 'keep': put invalid data in a special additional bucket, at index numCategories.
+   * 'keep': put invalid data in a special additional bucket, at index of the number of
+   * categories of the feature.
    * Default value: "error"
    * @group param
    */
@@ -55,7 +56,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
   override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
     "How to handle invalid data (unseen labels or NULL values). " +
     "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " +
-    "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
+    "or 'keep' (put invalid data in a special additional bucket, at index of the " +
+    "number of categories of the feature).",
     ParamValidators.inArray(VectorIndexer.supportedHandleInvalids))
 
   setDefault(handleInvalid, VectorIndexer.ERROR_INVALID)
@@ -112,7 +114,6 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
  *  - Preserve metadata in transform; if a feature's metadata is already present, do not recompute.
  *  - Specify certain features to not index, either via a parameter or via existing metadata.
  *  - Add warning if a categorical feature has only 1 category.
- *  - Add option for allowing unknown categories.
  */
 @Since("1.4.0")
 class VectorIndexer @Since("1.4.0") (
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 232ae3ef41166e86dc021df929d524547a6887e2..608f2a57154975a9c5d79a7963ce8fbd5ae0e971 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2490,7 +2490,8 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
 
 
 @inherit_doc
-class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
+class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
+                    JavaMLWritable):
     """
     Class for indexing categorical feature columns in a dataset of `Vector`.
 
@@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
         do not recompute.
       - Specify certain features to not index, either via a parameter or via existing metadata.
       - Add warning if a categorical feature has only 1 category.
-      - Add option for allowing unknown categories.
 
     >>> from pyspark.ml.linalg import Vectors
     >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
@@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
     True
     >>> loadedModel.categoryMaps == model.categoryMaps
     True
+    >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"])
+    >>> indexer.getHandleInvalid()
+    'error'
+    >>> model3 = indexer.setHandleInvalid("skip").fit(df)
+    >>> model3.transform(dfWithInvalid).count()
+    0
+    >>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df)
+    >>> model4.transform(dfWithInvalid).head().indexed
+    DenseVector([2.0, 1.0])
 
     .. versionadded:: 1.4.0
     """
@@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
                           "(>= 2). If a feature is found to have > maxCategories values, then " +
                           "it is declared continuous.", typeConverter=TypeConverters.toInt)
 
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data " +
+                          "(unseen labels or NULL values). Options are 'skip' (filter out " +
+                          "rows with invalid data), 'error' (throw an error), or 'keep' (put " +
+                          "invalid data in a special additional bucket, at index of the number " +
+                          "of categories of the feature).",
+                          typeConverter=TypeConverters.toString)
+
     @keyword_only
-    def __init__(self, maxCategories=20, inputCol=None, outputCol=None):
+    def __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"):
         """
-        __init__(self, maxCategories=20, inputCol=None, outputCol=None)
+        __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")
         """
         super(VectorIndexer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid)
-        self._setDefault(maxCategories=20)
+        self._setDefault(maxCategories=20, handleInvalid="error")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, maxCategories=20, inputCol=None, outputCol=None):
+    def setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"):
         """
-        setParams(self, maxCategories=20, inputCol=None, outputCol=None)
+        setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error")
         Sets params for this VectorIndexer.
         """
         kwargs = self._input_kwargs