diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R index 66a0693a59a529765b6d156f9179bb937631a8f9..e31a65f8dfedb9abcd6c77da13cb731c599ea1b9 100644 --- a/R/pkg/tests/fulltests/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -225,7 +225,7 @@ test_that("spark.randomForest", { expect_error(collect(predictions)) model <- spark.randomForest(traindf, clicked ~ ., type = "classification", maxDepth = 10, maxBins = 10, numTrees = 10, - handleInvalid = "skip") + handleInvalid = "keep") predictions <- predict(model, testdf) expect_equal(class(collect(predictions)$clicked[1]), "character") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index bb7acaf118ea5f235ef69ffd92d7d10d6d5f3b9e..c22445467dbc39b0da02006432c89ec6d0eff5d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -134,16 +134,16 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) def getFormula: String = $(formula) /** - * Param for how to handle invalid data (unseen labels or NULL values). - * Options are 'skip' (filter out rows with invalid data), + * Param for how to handle invalid data (unseen or NULL values) in features and label column + * of string type. Options are 'skip' (filter out rows with invalid data), * 'error' (throw an error), or 'keep' (put invalid data in a special additional * bucket, at index numLabels). * Default: "error" * @group param */ @Since("2.3.0") - override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "How to handle invalid data (unseen labels or NULL values). " + + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to " + + "handle invalid data (unseen or NULL values) in features and label column of string type. " + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) @@ -265,6 +265,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) encoderStages += new StringIndexer() .setInputCol(resolvedFormula.label) .setOutputCol($(labelCol)) + .setHandleInvalid($(handleInvalid)) } val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 23570d6e0b4cb4135eb270570b201435fbf5c1ee..5d09c90ec6dfa31d54ae3ca250a080b85b9c2416 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite @@ -501,4 +501,51 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept) } } + + test("handle unseen features or labels") { + val df1 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b") + val df2 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zy")).toDF("id", "a", "b") + + // Handle unseen features. + val formula1 = new RFormula().setFormula("id ~ a + b") + intercept[SparkException] { + formula1.fit(df1).transform(df2).collect() + } + val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2) + val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2) + + val expected1 = Seq( + (1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0), + (2, "bar", "zq", Vectors.dense(1.0, 1.0), 2.0) + ).toDF("id", "a", "b", "features", "label") + val expected2 = Seq( + (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0, 0.0), 1.0), + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 1.0, 0.0), 2.0), + (3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0) + ).toDF("id", "a", "b", "features", "label") + + assert(result1.collect() === expected1.collect()) + assert(result2.collect() === expected2.collect()) + + // Handle unseen labels. + val formula2 = new RFormula().setFormula("b ~ a + id") + intercept[SparkException] { + formula2.fit(df1).transform(df2).collect() + } + val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2) + val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2) + + val expected3 = Seq( + (1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0), + (2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0) + ).toDF("id", "a", "b", "features", "label") + val expected4 = Seq( + (1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0), + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0), + (3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0) + ).toDF("id", "a", "b", "features", "label") + + assert(result3.collect() === expected3.collect()) + assert(result4.collect() === expected4.collect()) + } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 7eb1b9fac2f5a53ef724ff2d373b07e1b944a8c2..54b4026f78bec36bce316f74620ab988208ec11f 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2107,8 +2107,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, typeConverter=TypeConverters.toString) handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + - "labels or NULL values). Options are 'skip' (filter out rows with " + - "invalid data), error (throw an error), or 'keep' (put invalid data " + + "or NULL values) in features and label column of string type. " + + "Options are 'skip' (filter out rows with invalid data), " + + "error (throw an error), or 'keep' (put invalid data " + "in a special additional bucket, at index numLabels).", typeConverter=TypeConverters.toString)