diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index a492763344ae666e84c097f8b8296b5fab7f1e8d..29152a11688a274f1b9114a35cafa1f5cc223201 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -35,8 +35,8 @@ test_that("glm and predict", { test_that("predictions match with native glm", { training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ Sepal_Length, data = training) + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f7b46efa10e900ab5930980cdaba6c6d8794b27a..0a95b1ee8de6e2daed50064a7b0fece23f17efa9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -17,17 +17,34 @@ package org.apache.spark.ml.feature +import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.Transformer +import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +/** + * Base trait for [[RFormula]] and [[RFormulaModel]]. + */ +private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { + /** @group getParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group getParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + protected def hasLabelCol(schema: StructType): Boolean = { + schema.map(_.name).contains($(labelCol)) + } +} + /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently @@ -35,8 +52,7 @@ import org.apache.spark.sql.types._ * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental -class RFormula(override val uid: String) - extends Transformer with HasFeaturesCol with HasLabelCol { +class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { def this() = this(Identifiable.randomUID("rFormula")) @@ -62,19 +78,74 @@ class RFormula(override val uid: String) /** @group getParam */ def getFormula: String = $(formula) - /** @group getParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + override def fit(dataset: DataFrame): RFormulaModel = { + require(parsedFormula.isDefined, "Must call setFormula() first.") + // StringType terms and terms representing interactions need to be encoded before assembly. + // TODO(ekl) add support for feature interactions + var encoderStages = ArrayBuffer[PipelineStage]() + var tempColumns = ArrayBuffer[String]() + val encodedTerms = parsedFormula.get.terms.map { term => + dataset.schema(term) match { + case column if column.dataType == StringType => + val indexCol = term + "_idx_" + uid + val encodedCol = term + "_onehot_" + uid + encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) + encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) + tempColumns += indexCol + tempColumns += encodedCol + encodedCol + case _ => + term + } + } + encoderStages += new VectorAssembler(uid) + .setInputCols(encodedTerms.toArray) + .setOutputCol($(featuresCol)) + encoderStages += new ColumnPruner(tempColumns.toSet) + val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) + copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this)) + } - /** @group getParam */ - def setLabelCol(value: String): this.type = set(labelCol, value) + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + if (hasLabelCol(schema)) { + StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) + } else { + StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+ + StructField($(labelCol), DoubleType, true)) + } + } + + override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + + override def toString: String = s"RFormula(${get(formula)})" +} + +/** + * :: Experimental :: + * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. + * @param parsedFormula a pre-parsed R formula. + * @param pipelineModel the fitted feature model, including factor to index mappings. + */ +@Experimental +class RFormulaModel private[feature]( + override val uid: String, + parsedFormula: ParsedRFormula, + pipelineModel: PipelineModel) + extends Model[RFormulaModel] with RFormulaBase { + + override def transform(dataset: DataFrame): DataFrame = { + checkCanTransform(dataset.schema) + transformLabel(pipelineModel.transform(dataset)) + } override def transformSchema(schema: StructType): StructType = { checkCanTransform(schema) - val withFeatures = transformFeatures.transformSchema(schema) + val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else if (schema.exists(_.name == parsedFormula.get.label)) { - val nullable = schema(parsedFormula.get.label).dataType match { + } else if (schema.exists(_.name == parsedFormula.label)) { + val nullable = schema(parsedFormula.label).dataType match { case _: NumericType | BooleanType => false case _ => true } @@ -86,24 +157,19 @@ class RFormula(override val uid: String) } } - override def transform(dataset: DataFrame): DataFrame = { - checkCanTransform(dataset.schema) - transformLabel(transformFeatures.transform(dataset)) - } - - override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + override def copy(extra: ParamMap): RFormulaModel = copyValues( + new RFormulaModel(uid, parsedFormula, pipelineModel)) - override def toString: String = s"RFormula(${get(formula)})" + override def toString: String = s"RFormulaModel(${parsedFormula})" private def transformLabel(dataset: DataFrame): DataFrame = { - val labelName = parsedFormula.get.label + val labelName = parsedFormula.label if (hasLabelCol(dataset.schema)) { dataset } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) - // TODO(ekl) add support for string-type labels case other => throw new IllegalArgumentException("Unsupported type for label: " + other) } @@ -114,25 +180,32 @@ class RFormula(override val uid: String) } } - private def transformFeatures: Transformer = { - // TODO(ekl) add support for non-numeric features and feature interactions - new VectorAssembler(uid) - .setInputCols(parsedFormula.get.terms.toArray) - .setOutputCol($(featuresCol)) - } - private def checkCanTransform(schema: StructType) { - require(parsedFormula.isDefined, "Must call setFormula() first.") val columnNames = schema.map(_.name) require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, "Label column already exists and is not of type DoubleType.") } +} - private def hasLabelCol(schema: StructType): Boolean = { - schema.map(_.name).contains($(labelCol)) +/** + * Utility transformer for removing temporary columns from a DataFrame. + * TODO(ekl) make this a public transformer + */ +private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { + override val uid = Identifiable.randomUID("columnPruner") + + override def transform(dataset: DataFrame): DataFrame = { + val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) + dataset.select(columnsToKeep.map(dataset.col) : _*) } + + override def transformSchema(schema: StructType): StructType = { + StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) + } + + override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) } /** @@ -149,7 +222,7 @@ private[ml] object RFormulaParser extends RegexParsers { def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } def formula: Parser[ParsedRFormula] = - (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) } def parse(value: String): ParsedRFormula = parseAll(formula, value) match { case Success(result, _) => result diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index c8d065f37a60561f055233b0533c7cb31fc0b8fe..c4b45aee06384e42ba89328f4cd2e5902108f57a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -28,6 +28,7 @@ class RFormulaParserSuite extends SparkFunSuite { test("parse simple formulas") { checkParse("y ~ x", "y", Seq("x")) + checkParse("y ~ x + x", "y", Seq("x")) checkParse("y ~ ._foo ", "y", Seq("._foo")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 79c4ccf02d4e0db7b7876309cdbfd7997c36abb2..8148c553e905162f52d2cd62c3f53628e0827ccf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -31,72 +31,78 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { val formula = new RFormula().setFormula("id ~ v1 + v2") val original = sqlContext.createDataFrame( Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") - val result = formula.transform(original) - val resultSchema = formula.transformSchema(original.schema) + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) val expected = sqlContext.createDataFrame( Seq( - (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0), - (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0)) + (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), + (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)) ).toDF("id", "v1", "v2", "features", "label") // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) - assert(result.collect().toSeq == expected.collect().toSeq) + assert(result.collect() === expected.collect()) } test("features column already exists") { val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") intercept[IllegalArgumentException] { - formula.transformSchema(original.schema) + formula.fit(original) } intercept[IllegalArgumentException] { - formula.transform(original) + formula.fit(original) } } test("label column already exists") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") - val resultSchema = formula.transformSchema(original.schema) + val model = formula.fit(original) + val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) - assert(resultSchema.toString == formula.transform(original).schema.toString) + assert(resultSchema.toString == model.transform(original).schema.toString) } test("label column already exists but is not double type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + val model = formula.fit(original) intercept[IllegalArgumentException] { - formula.transformSchema(original.schema) + model.transformSchema(original.schema) } intercept[IllegalArgumentException] { - formula.transform(original) + model.transform(original) } } test("allow missing label column for test datasets") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") - val resultSchema = formula.transformSchema(original.schema) + val model = formula.fit(original) + val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) assert(!resultSchema.exists(_.name == "label")) - assert(resultSchema.toString == formula.transform(original).schema.toString) + assert(resultSchema.toString == model.transform(original).schema.toString) } -// TODO(ekl) enable after we implement string label support -// test("transform string label") { -// val formula = new RFormula().setFormula("name ~ id") -// val original = sqlContext.createDataFrame( -// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name") -// val result = formula.transform(original) -// val resultSchema = formula.transformSchema(original.schema) -// val expected = sqlContext.createDataFrame( -// Seq( -// (1, "foo", Vectors.dense(Array(1.0)), 1.0), -// (2, "bar", Vectors.dense(Array(2.0)), 0.0), -// (3, "bar", Vectors.dense(Array(3.0)), 0.0)) -// ).toDF("id", "name", "features", "label") -// assert(result.schema.toString == resultSchema.toString) -// assert(result.collect().toSeq == expected.collect().toSeq) -// } + test("encodes string terms") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected.collect()) + } }