diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f7a8a2e4de24b016eb0e516b3fcbc00322edde2..a329e14f25aeb9fd0c21df3bd570f163ad8db3ce 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -12,7 +12,8 @@ export("print.jobj") # MLlib integration exportMethods("glm", - "predict") + "predict", + "summary") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 6a8bacaa552c6b8715ea3f0c9c97edfbae7ccc4c..efddcc1d8d71c587d4ddcb041ef5480c5729632d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"), function(object, newData) { return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param model A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(object = "PipelineModel"), + function(object) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", object@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", object@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 3bef69324770af8d4e3dc6d1eb52828b18a7203f..f272de78ad4a6e58c0eccc3ab8beccc50fdf9f78 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", { rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 38259427956456deaeb1a468bbb664b05531f7be..9c60d4084ec464ee83bfaf44ed98ae67cf1f0837 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) @@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => if (nominal.values.isDefined) { - nominal.values.map(_.map(v => inputColName + is + v)) + nominal.values } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) } else { None } case binary: BinaryAttribute => if (binary.values.isDefined) { - binary.values.map(_.map(v => inputColName + is + v)) + binary.values } else { - Some(Array.tabulate(2)(i => inputColName + is + i)) + Some(Array.tabulate(2)(_.toString)) } case _: NumericAttribute => throw new RuntimeException( @@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer override def transform(dataset: DataFrame): DataFrame = { // schema transformation - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) val shouldDropLast = $(dropLast) @@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer math.max(m0, m1) } ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames val outputAttrs: Array[Attribute] = filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 0b428d278d908835dc0dac1605147d38a12c7aef..d1726917e4517ebc16c86d028a4b3424e4d0183a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers @@ -91,11 +92,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() + val takenNames = mutable.Set(dataset.columns: _*) val encodedTerms = resolvedFormula.terms.map { term => dataset.schema(term) match { case column if column.dataType == StringType => val indexCol = term + "_idx_" + uid - val encodedCol = term + "_onehot_" + uid + val encodedCol = { + var tmp = term + while (takenNames.contains(tmp)) { + tmp += "_" + } + tmp + } + takenNames.add(indexCol) + takenNames.add(encodedCol) encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) tempColumns += indexCol diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 9f70592ccad7e398794fe62a5d5c1d5de6ffc7c9..f5a022c31ed90680d76c96d3b2fa63edf8871d51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.api.r +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame @@ -44,4 +45,26 @@ private[r] object SparkRWrappers { val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) } + + def getModelWeights(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => + Array(m.intercept) ++ m.weights.toArray + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No weights available for LogisticRegressionModel") // SPARK-9492 + } + } + + def getModelFeatures(model: PipelineModel): Array[String] = { + model.stages.last match { + case m: LinearRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No features names available for LogisticRegressionModel") // SPARK-9492 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 89718e0f3e15a39b6623f8a870c6f2d1c928a52f..3b85ba001b128b9bb06006cab678b056d812622c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructField import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter @@ -146,9 +147,10 @@ class LinearRegression(override val uid: String) val model = new LinearRegressionModel(uid, weights, intercept) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } @@ -221,9 +223,10 @@ class LinearRegression(override val uid: String) val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + val featuresCol: String, val objectiveHistory: Array[Double]) extends LinearRegressionSummary(predictions, predictionCol, labelCol) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 65846a846b7b4cbbb2ff2eadbdbd2d885a93f648..321eeb843941c8392d66c22e14848633e3de6520 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) } test("input column without ML attribute") { @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 8148c553e905162f52d2cd62c3f53628e0827ccf..6aed3243afce80cced53d551ee212ff96591bdab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) } + + test("attribute generation") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array( + new BinaryAttribute(Some("a__bar"), Some(1)), + new BinaryAttribute(Some("a__foo"), Some(2)), + new NumericAttribute(Some("b"), Some(3)))) + assert(attrs === expectedAttrs) + } }