From 8d5bb5283c3cc9180ef34b05be4a715d83073b1e Mon Sep 17 00:00:00 2001 From: Eric Liang <ekl@databricks.com> Date: Tue, 28 Jul 2015 14:16:57 -0700 Subject: [PATCH] [SPARK-9391] [ML] Support minus, dot, and intercept operators in SparkR RFormula Adds '.', '-', and intercept parsing to RFormula. Also splits RFormulaParser into a separate file. Umbrella design doc here: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit?usp=sharing mengxr Author: Eric Liang <ekl@databricks.com> Closes #7707 from ericl/string-features-2 and squashes the following commits: 8588625 [Eric Liang] exclude complex types for . 8106ffe [Eric Liang] comments a9350bb [Eric Liang] s/var/val 9c50d4d [Eric Liang] Merge branch 'string-features' into string-features-2 581afb2 [Eric Liang] Merge branch 'master' into string-features 08ae539 [Eric Liang] Merge branch 'string-features' into string-features-2 f99131a [Eric Liang] comments cecec43 [Eric Liang] Merge branch 'string-features' into string-features-2 0bf3c26 [Eric Liang] update docs 4592df2 [Eric Liang] intercept supports 7412a2e [Eric Liang] Fri Jul 24 14:56:51 PDT 2015 3cf848e [Eric Liang] fix the parser 0556c2b [Eric Liang] Merge branch 'string-features' into string-features-2 c302a2c [Eric Liang] fix tests 9d1ac82 [Eric Liang] Merge remote-tracking branch 'upstream/master' into string-features e713da3 [Eric Liang] comments cd231a9 [Eric Liang] Wed Jul 22 17:18:44 PDT 2015 4d79193 [Eric Liang] revert to seq + distinct 169a085 [Eric Liang] tweak functional test a230a47 [Eric Liang] Merge branch 'master' into string-features 72bd6f3 [Eric Liang] fix merge d841cec [Eric Liang] Merge branch 'master' into string-features 5b2c4a2 [Eric Liang] Mon Jul 20 18:45:33 PDT 2015 b01c7c5 [Eric Liang] add test 8a637db [Eric Liang] encoder wip a1d03f4 [Eric Liang] refactor into estimator --- R/pkg/R/mllib.R | 2 +- R/pkg/inst/tests/test_mllib.R | 8 ++ .../apache/spark/ml/feature/RFormula.scala | 52 +++---- .../spark/ml/feature/RFormulaParser.scala | 129 ++++++++++++++++++ .../apache/spark/ml/r/SparkRWrappers.scala | 10 +- .../ml/feature/RFormulaParserSuite.scala | 55 +++++++- 6 files changed, 215 insertions(+), 41 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 258e354081..6a8bacaa55 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj")) #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~' and '+'. +#' operators are supported, including '~', '+', '-', and '.'. #' @param data DataFrame for training #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 29152a1168..3bef693247 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -40,3 +40,11 @@ test_that("predictions match with native glm", { rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) + +test_that("dot minus and intercept vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 0a95b1ee8d..0b428d278d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -78,13 +78,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** @group getParam */ def getFormula: String = $(formula) + /** Whether the formula specifies fitting an intercept. */ + private[ml] def hasIntercept: Boolean = { + require(parsedFormula.isDefined, "Must call setFormula() first.") + parsedFormula.get.hasIntercept + } + override def fit(dataset: DataFrame): RFormulaModel = { require(parsedFormula.isDefined, "Must call setFormula() first.") + val resolvedFormula = parsedFormula.get.resolve(dataset.schema) // StringType terms and terms representing interactions need to be encoded before assembly. // TODO(ekl) add support for feature interactions - var encoderStages = ArrayBuffer[PipelineStage]() - var tempColumns = ArrayBuffer[String]() - val encodedTerms = parsedFormula.get.terms.map { term => + val encoderStages = ArrayBuffer[PipelineStage]() + val tempColumns = ArrayBuffer[String]() + val encodedTerms = resolvedFormula.terms.map { term => dataset.schema(term) match { case column if column.dataType == StringType => val indexCol = term + "_idx_" + uid @@ -103,7 +110,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R .setOutputCol($(featuresCol)) encoderStages += new ColumnPruner(tempColumns.toSet) val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) - copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this)) + copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) } // optimistic schema; does not contain any ML attributes @@ -124,13 +131,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** * :: Experimental :: * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. - * @param parsedFormula a pre-parsed R formula. + * @param resolvedFormula the fitted R formula. * @param pipelineModel the fitted feature model, including factor to index mappings. */ @Experimental class RFormulaModel private[feature]( override val uid: String, - parsedFormula: ParsedRFormula, + resolvedFormula: ResolvedRFormula, pipelineModel: PipelineModel) extends Model[RFormulaModel] with RFormulaBase { @@ -144,8 +151,8 @@ class RFormulaModel private[feature]( val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else if (schema.exists(_.name == parsedFormula.label)) { - val nullable = schema(parsedFormula.label).dataType match { + } else if (schema.exists(_.name == resolvedFormula.label)) { + val nullable = schema(resolvedFormula.label).dataType match { case _: NumericType | BooleanType => false case _ => true } @@ -158,12 +165,12 @@ class RFormulaModel private[feature]( } override def copy(extra: ParamMap): RFormulaModel = copyValues( - new RFormulaModel(uid, parsedFormula, pipelineModel)) + new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${parsedFormula})" + override def toString: String = s"RFormulaModel(${resolvedFormula})" private def transformLabel(dataset: DataFrame): DataFrame = { - val labelName = parsedFormula.label + val labelName = resolvedFormula.label if (hasLabelCol(dataset.schema)) { dataset } else if (dataset.schema.exists(_.name == labelName)) { @@ -207,26 +214,3 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) } - -/** - * Represents a parsed R formula. - */ -private[ml] case class ParsedRFormula(label: String, terms: Seq[String]) - -/** - * Limited implementation of R formula parsing. Currently supports: '~', '+'. - */ -private[ml] object RFormulaParser extends RegexParsers { - def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r - - def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } - - def formula: Parser[ParsedRFormula] = - (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) } - - def parse(value: String): ParsedRFormula = parseAll(formula, value) match { - case Success(result, _) => result - case failure: NoSuccess => throw new IllegalArgumentException( - "Could not parse formula: " + value) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala new file mode 100644 index 0000000000..1ca3b92a7d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.types._ + +/** + * Represents a parsed R formula. + */ +private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { + /** + * Resolves formula terms into column names. A schema is necessary for inferring the meaning + * of the special '.' term. Duplicate terms will be removed during resolution. + */ + def resolve(schema: StructType): ResolvedRFormula = { + var includedTerms = Seq[String]() + terms.foreach { + case Dot => + includedTerms ++= simpleTypes(schema).filter(_ != label.value) + case ColumnRef(value) => + includedTerms :+= value + case Deletion(term: Term) => + term match { + case ColumnRef(value) => + includedTerms = includedTerms.filter(_ != value) + case Dot => + // e.g. "- .", which removes all first-order terms + val fromSchema = simpleTypes(schema) + includedTerms = includedTerms.filter(fromSchema.contains(_)) + case _: Deletion => + assert(false, "Deletion terms cannot be nested") + case _: Intercept => + } + case _: Intercept => + } + ResolvedRFormula(label.value, includedTerms.distinct) + } + + /** Whether this formula specifies fitting with an intercept term. */ + def hasIntercept: Boolean = { + var intercept = true + terms.foreach { + case Intercept(enabled) => + intercept = enabled + case Deletion(Intercept(enabled)) => + intercept = !enabled + case _ => + } + intercept + } + + // the dot operator excludes complex column types + private def simpleTypes(schema: StructType): Seq[String] = { + schema.fields.filter(_.dataType match { + case _: NumericType | StringType | BooleanType | _: VectorUDT => true + case _ => false + }).map(_.name) + } +} + +/** + * Represents a fully evaluated and simplified R formula. + */ +private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) + +/** + * R formula terms. See the R formula docs here for more information: + * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +private[ml] sealed trait Term + +/* R formula reference to all available columns, e.g. "." in a formula */ +private[ml] case object Dot extends Term + +/* R formula reference to a column, e.g. "+ Species" in a formula */ +private[ml] case class ColumnRef(value: String) extends Term + +/* R formula intercept toggle, e.g. "+ 0" in a formula */ +private[ml] case class Intercept(enabled: Boolean) extends Term + +/* R formula deletion of a variable, e.g. "- Species" in a formula */ +private[ml] case class Deletion(term: Term) extends Term + +/** + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'. + */ +private[ml] object RFormulaParser extends RegexParsers { + def intercept: Parser[Intercept] = + "([01])".r ^^ { case a => Intercept(a == "1") } + + def columnRef: Parser[ColumnRef] = + "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } + + def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot } + + def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { + case op ~ list => list.foldLeft(List(op)) { + case (left, "+" ~ right) => left ++ Seq(right) + case (left, "-" ~ right) => left ++ Seq(Deletion(right)) + } + } + + def formula: Parser[ParsedRFormula] = + (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + + def parse(value: String): ParsedRFormula = parseAll(formula, value) match { + case Success(result, _) => result + case failure: NoSuccess => throw new IllegalArgumentException( + "Could not parse formula: " + value) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 1ee080641e..9f70592cca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -32,8 +32,14 @@ private[r] object SparkRWrappers { alpha: Double): PipelineModel = { val formula = new RFormula().setFormula(value) val estimator = family match { - case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha) - case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha) + case "gaussian" => new LinearRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) + case "binomial" => new LogisticRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) } val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index c4b45aee06..436e66bab0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -18,12 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ class RFormulaParserSuite extends SparkFunSuite { - private def checkParse(formula: String, label: String, terms: Seq[String]) { - val parsed = RFormulaParser.parse(formula) - assert(parsed.label == label) - assert(parsed.terms == terms) + private def checkParse( + formula: String, + label: String, + terms: Seq[String], + schema: StructType = null) { + val resolved = RFormulaParser.parse(formula).resolve(schema) + assert(resolved.label == label) + assert(resolved.terms == terms) } test("parse simple formulas") { @@ -32,4 +37,46 @@ class RFormulaParserSuite extends SparkFunSuite { checkParse("y ~ ._foo ", "y", Seq("._foo")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } + + test("parse dot") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ .", "a", Seq("b", "c"), schema) + } + + test("parse deletion") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ c - b", "a", Seq("c"), schema) + } + + test("parse additions and deletions in order") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ . - b + . - c", "a", Seq("b"), schema) + } + + test("dot ignores complex column types") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "tinyint", false) + .add("c", "map<string, string>", true) + checkParse("a ~ .", "a", Seq("b"), schema) + } + + test("parse intercept") { + assert(RFormulaParser.parse("a ~ b").hasIntercept) + assert(RFormulaParser.parse("a ~ b + 1").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 0").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 1 + 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 0").hasIntercept) + assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) + } } -- GitLab