Skip to content
Snippets Groups Projects
Commit 1cbdd899 authored by Eric Liang's avatar Eric Liang Committed by Shivaram Venkataraman
Browse files

[SPARK-9201] [ML] Initial integration of MLlib + SparkR using RFormula

This exposes the SparkR:::glm() and SparkR:::predict() APIs. It was necessary to change RFormula to silently drop the label column if it was missing from the input dataset, which is kind of a hack but necessary to integrate with the Pipeline API.

The umbrella design doc for MLlib + SparkR integration can be viewed here: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit

mengxr

Author: Eric Liang <ekl@databricks.com>

Closes #7483 from ericl/spark-8774 and squashes the following commits:

3dfac0c [Eric Liang] update
17ef516 [Eric Liang] more comments
1753a0f [Eric Liang] make glm generic
b0f50f8 [Eric Liang] equivalence test
550d56d [Eric Liang] export methods
c015697 [Eric Liang] second pass
117949a [Eric Liang] comments
5afbc67 [Eric Liang] test label columns
6b7f15f [Eric Liang] Fri Jul 17 14:20:22 PDT 2015
3a63ae5 [Eric Liang] Fri Jul 17 13:41:52 PDT 2015
ce61367 [Eric Liang] Fri Jul 17 13:41:17 PDT 2015
0299c59 [Eric Liang] Fri Jul 17 13:40:32 PDT 2015
e37603f [Eric Liang] Fri Jul 17 12:15:03 PDT 2015
d417d0c [Eric Liang] Merge remote-tracking branch 'upstream/master' into spark-8774
29a2ce7 [Eric Liang] Merge branch 'spark-8774-1' into spark-8774
d1959d2 [Eric Liang] clarify comment
2db68aa [Eric Liang] second round of comments
dc3c943 [Eric Liang] address comments
5765ec6 [Eric Liang] fix style checks
1f361b0 [Eric Liang] doc
d33211b [Eric Liang] r support
fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer
parent 2bdf9914
No related branches found
No related tags found
No related merge requests found
......@@ -29,6 +29,7 @@ Collate:
'client.R'
'context.R'
'deserialize.R'
'mllib.R'
'serialize.R'
'sparkR.R'
'utils.R'
......@@ -10,6 +10,10 @@ export("sparkR.init")
export("sparkR.stop")
export("print.jobj")
# MLlib integration
exportMethods("glm",
"predict")
# Job group lifecycle management methods
export("setJobGroup",
"clearJobGroup",
......
......@@ -661,3 +661,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
#' @rdname column
#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })
#' @rdname glm
#' @export
setGeneric("glm")
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mllib.R: Provides methods for MLlib integration
#' @title S4 class that represents a PipelineModel
#' @param model A Java object reference to the backing Scala PipelineModel
#' @export
setClass("PipelineModel", representation(model = "jobj"))
#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~' and '+'.
#' @param data DataFrame for training
#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
#' @param lambda Regularization parameter
#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
#' @return a fitted MLlib model
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' sqlContext <- sparkRSQL.init(sc)
#' data(iris)
#' df <- createDataFrame(sqlContext, iris)
#' model <- glm(Sepal_Length ~ Sepal_Width, df)
#'}
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) {
family <- match.arg(family)
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"fitRModelFormula", deparse(formula), data@sdf, family, lambda,
alpha)
return(new("PipelineModel", model = model))
})
#' Make predictions from a model
#'
#' Makes predictions from a model produced by glm(), similarly to R's predict().
#'
#' @param model A fitted MLlib model
#' @param newData DataFrame for testing
#' @return DataFrame containing predicted values
#' @rdname glm
#' @export
#' @examples
#'\dontrun{
#' model <- glm(y ~ x, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#'}
setMethod("predict", signature(object = "PipelineModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
})
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
library(testthat)
context("MLlib functions")
# Tests for MLlib functions in SparkR
sc <- sparkR.init()
sqlContext <- sparkRSQL.init(sc)
test_that("glm and predict", {
training <- createDataFrame(sqlContext, iris)
test <- select(training, "Sepal_Length")
model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian")
prediction <- predict(model, test)
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
})
test_that("predictions match with native glm", {
training <- createDataFrame(sqlContext, iris)
model <- glm(Sepal_Width ~ Sepal_Length, data = training)
vals <- collect(select(predict(model, training), "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals)
})
......@@ -73,12 +73,16 @@ class RFormula(override val uid: String)
val withFeatures = transformFeatures.transformSchema(schema)
if (hasLabelCol(schema)) {
withFeatures
} else {
} else if (schema.exists(_.name == parsedFormula.get.label)) {
val nullable = schema(parsedFormula.get.label).dataType match {
case _: NumericType | BooleanType => false
case _ => true
}
StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable))
} else {
// Ignore the label field. This is a hack so that this transformer can also work on test
// datasets in a Pipeline.
withFeatures
}
}
......@@ -92,10 +96,10 @@ class RFormula(override val uid: String)
override def toString: String = s"RFormula(${get(formula)})"
private def transformLabel(dataset: DataFrame): DataFrame = {
val labelName = parsedFormula.get.label
if (hasLabelCol(dataset.schema)) {
dataset
} else {
val labelName = parsedFormula.get.label
} else if (dataset.schema.exists(_.name == labelName)) {
dataset.schema(labelName).dataType match {
case _: NumericType | BooleanType =>
dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType))
......@@ -103,6 +107,10 @@ class RFormula(override val uid: String)
case other =>
throw new IllegalArgumentException("Unsupported type for label: " + other)
}
} else {
// Ignore the label field. This is a hack so that this transformer can also work on test
// datasets in a Pipeline.
dataset
}
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.api.r
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.DataFrame
private[r] object SparkRWrappers {
def fitRModelFormula(
value: String,
df: DataFrame,
family: String,
lambda: Double,
alpha: Double): PipelineModel = {
val formula = new RFormula().setFormula(value)
val estimator = family match {
case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha)
case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha)
}
val pipeline = new Pipeline().setStages(Array(formula, estimator))
pipeline.fit(df)
}
}
......@@ -74,6 +74,15 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
test("allow missing label column for test datasets") {
val formula = new RFormula().setFormula("y ~ x").setLabelCol("label")
val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y")
val resultSchema = formula.transformSchema(original.schema)
assert(resultSchema.length == 3)
assert(!resultSchema.exists(_.name == "label"))
assert(resultSchema.toString == formula.transform(original).schema.toString)
}
// TODO(ekl) enable after we implement string label support
// test("transform string label") {
// val formula = new RFormula().setFormula("name ~ id")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment