From 0c8444cf6d0620cd219ddcf5f50b12ff648639e9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang <ybliang8@gmail.com> Date: Thu, 29 Jun 2017 10:32:32 +0800 Subject: [PATCH] [SPARK-14657][SPARKR][ML] RFormula w/o intercept should output reference category when encoding string terms ## What changes were proposed in this pull request? Please see [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657) for detail of this bug. I searched online and test some other cases, found when we fit R glm model(or other models powered by R formula) w/o intercept on a dataset including string/category features, one of the categories in the first category feature is being used as reference category, we will not drop any category for that feature. I think we should keep consistent semantics between Spark RFormula and R formula. ## How was this patch tested? Add standard unit tests. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #12414 from yanboliang/spark-14657. --- .../apache/spark/ml/feature/RFormula.scala | 10 ++- .../spark/ml/feature/RFormulaSuite.scala | 83 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1fad0a6fc9..4b44878784 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -205,12 +205,20 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) }.toMap // Then we handle one-hot encoding and interactions between terms. + var keepReferenceCategory = false val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - encoderStages += new OneHotEncoder() + var encoder = new OneHotEncoder() .setInputCol(indexed(term)) .setOutputCol(encodedCol) + // Formula w/o intercept, one of the categories in the first category feature is + // being used as reference category, we will not drop any category for that feature. + if (!hasIntercept && !keepReferenceCategory) { + encoder = encoder.setDropLast(false) + keepReferenceCategory = true + } + encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 41d0062c2c..23570d6e0b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -213,6 +213,89 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result.collect() === expected.collect()) } + test("formula w/o intercept, we should output reference category when encoding string terms") { + /* + R code: + + df <- data.frame(id = c(1, 2, 3, 4), + a = c("foo", "bar", "bar", "baz"), + b = c("zq", "zz", "zz", "zz"), + c = c(4, 4, 5, 5)) + model.matrix(id ~ a + b + c - 1, df) + + abar abaz afoo bzz c + 1 0 0 1 0 4 + 2 1 0 0 1 4 + 3 1 0 0 1 5 + 4 0 1 0 1 5 + + model.matrix(id ~ a:b + c - 1, df) + + c abar:bzq abaz:bzq afoo:bzq abar:bzz abaz:bzz afoo:bzz + 1 4 0 0 1 0 0 0 + 2 4 0 0 0 1 0 0 + 3 5 0 0 0 1 0 0 + 4 5 0 0 0 0 1 0 + */ + val original = Seq((1, "foo", "zq", 4), (2, "bar", "zz", 4), (3, "bar", "zz", 5), + (4, "baz", "zz", 5)).toDF("id", "a", "b", "c") + + val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model1 = formula1.fit(original) + val result1 = model1.transform(original) + val resultSchema1 = model1.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result1.schema.toString == resultSchema1.toString) + assert(result1.collect() === expected1.collect()) + + val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) + val expectedAttrs1 = new AttributeGroup( + "features", + Array[Attribute]( + new BinaryAttribute(Some("a_foo"), Some(1)), + new BinaryAttribute(Some("a_baz"), Some(2)), + new BinaryAttribute(Some("a_bar"), Some(3)), + new BinaryAttribute(Some("b_zz"), Some(4)), + new NumericAttribute(Some("c"), Some(5)))) + assert(attrs1 === expectedAttrs1) + + // There is no impact for string terms interaction. + val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model2 = formula2.fit(original) + val result2 = model2.transform(original) + val resultSchema2 = model2.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected2 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.sparse(7, Array(4, 6), Array(1.0, 4.0)), 2.0), + (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), + (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result2.schema.toString == resultSchema2.toString) + assert(result2.collect() === expected2.collect()) + + val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) + val expectedAttrs2 = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_foo:b_zz"), Some(1)), + new NumericAttribute(Some("a_foo:b_zq"), Some(2)), + new NumericAttribute(Some("a_baz:b_zz"), Some(3)), + new NumericAttribute(Some("a_baz:b_zq"), Some(4)), + new NumericAttribute(Some("a_bar:b_zz"), Some(5)), + new NumericAttribute(Some("a_bar:b_zq"), Some(6)), + new NumericAttribute(Some("c"), Some(7)))) + assert(attrs2 === expectedAttrs2) + } + test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") val original = -- GitLab