From 1cbdd8991898912a8471a7070c472a0edb92487c Mon Sep 17 00:00:00 2001 From: Eric Liang <ekl@databricks.com> Date: Mon, 20 Jul 2015 20:49:38 -0700 Subject: [PATCH] [SPARK-9201] [ML] Initial integration of MLlib + SparkR using RFormula This exposes the SparkR:::glm() and SparkR:::predict() APIs. It was necessary to change RFormula to silently drop the label column if it was missing from the input dataset, which is kind of a hack but necessary to integrate with the Pipeline API. The umbrella design doc for MLlib + SparkR integration can be viewed here: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit mengxr Author: Eric Liang <ekl@databricks.com> Closes #7483 from ericl/spark-8774 and squashes the following commits: 3dfac0c [Eric Liang] update 17ef516 [Eric Liang] more comments 1753a0f [Eric Liang] make glm generic b0f50f8 [Eric Liang] equivalence test 550d56d [Eric Liang] export methods c015697 [Eric Liang] second pass 117949a [Eric Liang] comments 5afbc67 [Eric Liang] test label columns 6b7f15f [Eric Liang] Fri Jul 17 14:20:22 PDT 2015 3a63ae5 [Eric Liang] Fri Jul 17 13:41:52 PDT 2015 ce61367 [Eric Liang] Fri Jul 17 13:41:17 PDT 2015 0299c59 [Eric Liang] Fri Jul 17 13:40:32 PDT 2015 e37603f [Eric Liang] Fri Jul 17 12:15:03 PDT 2015 d417d0c [Eric Liang] Merge remote-tracking branch 'upstream/master' into spark-8774 29a2ce7 [Eric Liang] Merge branch 'spark-8774-1' into spark-8774 d1959d2 [Eric Liang] clarify comment 2db68aa [Eric Liang] second round of comments dc3c943 [Eric Liang] address comments 5765ec6 [Eric Liang] fix style checks 1f361b0 [Eric Liang] doc d33211b [Eric Liang] r support fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 4 + R/pkg/R/generics.R | 4 + R/pkg/R/mllib.R | 73 +++++++++++++++++++ R/pkg/inst/tests/test_mllib.R | 42 +++++++++++ .../apache/spark/ml/feature/RFormula.scala | 14 +++- .../apache/spark/ml/r/SparkRWrappers.scala | 41 +++++++++++ .../spark/ml/feature/RFormulaSuite.scala | 9 +++ 8 files changed, 185 insertions(+), 3 deletions(-) create mode 100644 R/pkg/R/mllib.R create mode 100644 R/pkg/inst/tests/test_mllib.R create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index d028821534..4949d86d20 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -29,6 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'mllib.R' 'serialize.R' 'sparkR.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 331307c207..5834813319 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -10,6 +10,10 @@ export("sparkR.init") export("sparkR.stop") export("print.jobj") +# MLlib integration +exportMethods("glm", + "predict") + # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ebe6fbd97c..39b5586f7c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -661,3 +661,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) + +#' @rdname glm +#' @export +setGeneric("glm") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R new file mode 100644 index 0000000000..258e354081 --- /dev/null +++ b/R/pkg/R/mllib.R @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib.R: Provides methods for MLlib integration + +#' @title S4 class that represents a PipelineModel +#' @param model A Java object reference to the backing Scala PipelineModel +#' @export +setClass("PipelineModel", representation(model = "jobj")) + +#' Fits a generalized linear model +#' +#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~' and '+'. +#' @param data DataFrame for training +#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param lambda Regularization parameter +#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @return a fitted MLlib model +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' data(iris) +#' df <- createDataFrame(sqlContext, iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#'} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + family <- match.arg(family) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + alpha) + return(new("PipelineModel", model = model)) + }) + +#' Make predictions from a model +#' +#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' +#' @param model A fitted MLlib model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted values +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#'} +setMethod("predict", signature(object = "PipelineModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R new file mode 100644 index 0000000000..a492763344 --- /dev/null +++ b/R/pkg/inst/tests/test_mllib.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- createDataFrame(sqlContext, iris) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") +}) + +test_that("predictions match with native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 56169f2a01..f7b46efa10 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -73,12 +73,16 @@ class RFormula(override val uid: String) val withFeatures = transformFeatures.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else { + } else if (schema.exists(_.name == parsedFormula.get.label)) { val nullable = schema(parsedFormula.get.label).dataType match { case _: NumericType | BooleanType => false case _ => true } StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable)) + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + withFeatures } } @@ -92,10 +96,10 @@ class RFormula(override val uid: String) override def toString: String = s"RFormula(${get(formula)})" private def transformLabel(dataset: DataFrame): DataFrame = { + val labelName = parsedFormula.get.label if (hasLabelCol(dataset.schema)) { dataset - } else { - val labelName = parsedFormula.get.label + } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) @@ -103,6 +107,10 @@ class RFormula(override val uid: String) case other => throw new IllegalArgumentException("Unsupported type for label: " + other) } + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + dataset } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala new file mode 100644 index 0000000000..1ee080641e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.api.r + +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.sql.DataFrame + +private[r] object SparkRWrappers { + def fitRModelFormula( + value: String, + df: DataFrame, + family: String, + lambda: Double, + alpha: Double): PipelineModel = { + val formula = new RFormula().setFormula(value) + val estimator = family match { + case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha) + case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha) + } + val pipeline = new Pipeline().setStages(Array(formula, estimator)) + pipeline.fit(df) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index fa8611b243..79c4ccf02d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -74,6 +74,15 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("allow missing label column for test datasets") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") + val resultSchema = formula.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(!resultSchema.exists(_.name == "label")) + assert(resultSchema.toString == formula.transform(original).schema.toString) + } + // TODO(ekl) enable after we implement string label support // test("transform string label") { // val formula = new RFormula().setFormula("name ~ id") -- GitLab