diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index d028821534b1afee2be785ccfa8e0819687c42e2..4949d86d20c913ce17dd44e56c1748cf359bec28 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -29,6 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'mllib.R' 'serialize.R' 'sparkR.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 331307c2077a5c6f337c22a00c82e39823c70582..5834813319bfdac96628b50cb84d797364e313ae 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -10,6 +10,10 @@ export("sparkR.init") export("sparkR.stop") export("print.jobj") +# MLlib integration +exportMethods("glm", + "predict") + # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ebe6fbd97ce86a37b94ac7fb474ff47c3dd81680..39b5586f7c90ed52378cae885d3aebd0cd67e20a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -661,3 +661,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) + +#' @rdname glm +#' @export +setGeneric("glm") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R new file mode 100644 index 0000000000000000000000000000000000000000..258e354081fc12a5870940a925a8188cb0934830 --- /dev/null +++ b/R/pkg/R/mllib.R @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib.R: Provides methods for MLlib integration + +#' @title S4 class that represents a PipelineModel +#' @param model A Java object reference to the backing Scala PipelineModel +#' @export +setClass("PipelineModel", representation(model = "jobj")) + +#' Fits a generalized linear model +#' +#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~' and '+'. +#' @param data DataFrame for training +#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param lambda Regularization parameter +#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @return a fitted MLlib model +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' data(iris) +#' df <- createDataFrame(sqlContext, iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#'} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + family <- match.arg(family) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + alpha) + return(new("PipelineModel", model = model)) + }) + +#' Make predictions from a model +#' +#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' +#' @param model A fitted MLlib model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted values +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#'} +setMethod("predict", signature(object = "PipelineModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R new file mode 100644 index 0000000000000000000000000000000000000000..a492763344ae666e84c097f8b8296b5fab7f1e8d --- /dev/null +++ b/R/pkg/inst/tests/test_mllib.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- createDataFrame(sqlContext, iris) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") +}) + +test_that("predictions match with native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 56169f2a01fc99fd04e92a60caec373268a815dc..f7b46efa10e900ab5930980cdaba6c6d8794b27a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -73,12 +73,16 @@ class RFormula(override val uid: String) val withFeatures = transformFeatures.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else { + } else if (schema.exists(_.name == parsedFormula.get.label)) { val nullable = schema(parsedFormula.get.label).dataType match { case _: NumericType | BooleanType => false case _ => true } StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable)) + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + withFeatures } } @@ -92,10 +96,10 @@ class RFormula(override val uid: String) override def toString: String = s"RFormula(${get(formula)})" private def transformLabel(dataset: DataFrame): DataFrame = { + val labelName = parsedFormula.get.label if (hasLabelCol(dataset.schema)) { dataset - } else { - val labelName = parsedFormula.get.label + } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) @@ -103,6 +107,10 @@ class RFormula(override val uid: String) case other => throw new IllegalArgumentException("Unsupported type for label: " + other) } + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + dataset } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala new file mode 100644 index 0000000000000000000000000000000000000000..1ee080641e3e324b049c8bcefbbaf47b1944b6e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.api.r + +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.sql.DataFrame + +private[r] object SparkRWrappers { + def fitRModelFormula( + value: String, + df: DataFrame, + family: String, + lambda: Double, + alpha: Double): PipelineModel = { + val formula = new RFormula().setFormula(value) + val estimator = family match { + case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha) + case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha) + } + val pipeline = new Pipeline().setStages(Array(formula, estimator)) + pipeline.fit(df) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index fa8611b243a9ff33d3e0cd69516e5619a25ab7a4..79c4ccf02d4e0db7b7876309cdbfd7997c36abb2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -74,6 +74,15 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("allow missing label column for test datasets") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") + val resultSchema = formula.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(!resultSchema.exists(_.name == "label")) + assert(resultSchema.toString == formula.transform(original).schema.toString) + } + // TODO(ekl) enable after we implement string label support // test("transform string label") { // val formula = new RFormula().setFormula("name ~ id")