From 22249afb4a932a82ff1f7a3befea9fda5a60a3f4 Mon Sep 17 00:00:00 2001 From: Yanbo Liang <ybliang8@gmail.com> Date: Thu, 31 Mar 2016 23:49:58 -0700 Subject: [PATCH] [SPARK-14303][ML][SPARKR] Define and use KMeansWrapper for SparkR::kmeans ## What changes were proposed in this pull request? Define and use ```KMeansWrapper``` for ```SparkR::kmeans```. It's only the code refactor for the original ```KMeans``` wrapper. ## How was this patch tested? Existing tests. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #12039 from yanboliang/spark-14059. --- R/pkg/R/mllib.R | 91 +++++++++++++------ .../org/apache/spark/ml/r/KMeansWrapper.scala | 85 +++++++++++++++++ .../apache/spark/ml/r/SparkRWrappers.scala | 52 +---------- 3 files changed, 148 insertions(+), 80 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 33654d5216..f3152cc232 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -32,6 +32,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @export setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) +#' @title S4 class that represents a KMeansModel +#' @param jobj a Java object reference to the backing Scala KMeansModel +#' @export +setClass("KMeansModel", representation(jobj = "jobj")) + #' Fits a generalized linear model #' #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. @@ -154,17 +159,6 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) - } else if (modelName == "KMeansModel") { - modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getKMeansModelSize", object@model) - cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getKMeansCluster", object@model, "classes") - k <- unlist(modelSize)[1] - size <- unlist(modelSize)[-1] - coefficients <- t(matrix(coefficients, ncol = k)) - colnames(coefficients) <- unlist(features) - rownames(coefficients) <- 1:k - return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) } else { stop(paste("Unsupported model", modelName, sep = " ")) } @@ -213,21 +207,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"), #' @examples #' \dontrun{ #' model <- kmeans(x, centers = 2, algorithm="random") -#'} +#' } setMethod("kmeans", signature(x = "DataFrame"), function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) { columnNames <- as.array(colnames(x)) algorithm <- match.arg(algorithm) - model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf, - algorithm, iter.max, centers, columnNames) - return(new("PipelineModel", model = model)) + jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf, + centers, iter.max, algorithm, columnNames) + return(new("KMeansModel", jobj = jobj)) }) -#' Get fitted result from a model +#' Get fitted result from a k-means model #' -#' Get fitted result from a model, similarly to R's fitted(). +#' Get fitted result from a k-means model, similarly to R's fitted(). #' -#' @param object A fitted MLlib model +#' @param object A fitted k-means model #' @return DataFrame containing fitted values #' @rdname fitted #' @export @@ -237,19 +231,58 @@ setMethod("kmeans", signature(x = "DataFrame"), #' fitted.model <- fitted(model) #' showDF(fitted.model) #'} -setMethod("fitted", signature(object = "PipelineModel"), +setMethod("fitted", signature(object = "KMeansModel"), function(object, method = c("centers", "classes"), ...) { - modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", object@model) + method <- match.arg(method) + return(dataFrame(callJMethod(object@jobj, "fitted", method))) + }) - if (modelName == "KMeansModel") { - method <- match.arg(method) - fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getKMeansCluster", object@model, method) - return(dataFrame(fittedResult)) - } else { - stop(paste("Unsupported model", modelName, sep = " ")) - } +#' Get the summary of a k-means model +#' +#' Returns the summary of a k-means model produced by kmeans(), +#' similarly to R's summary(). +#' +#' @param object a fitted k-means model +#' @return the model's coefficients, size and cluster +#' @rdname summary +#' @export +#' @examples +#' \dontrun{ +#' model <- kmeans(trainingData, 2) +#' summary(model) +#' } +setMethod("summary", signature(object = "KMeansModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + coefficients <- callJMethod(jobj, "coefficients") + cluster <- callJMethod(jobj, "cluster") + k <- callJMethod(jobj, "k") + size <- callJMethod(jobj, "size") + coefficients <- t(matrix(coefficients, ncol = k)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:k + return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + }) + +#' Make predictions from a k-means model +#' +#' Make predictions from a model produced by kmeans(). +#' +#' @param object A fitted k-means model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted labels in a column named "prediction" +#' @rdname predict +#' @export +#' @examples +#' \dontrun{ +#' model <- kmeans(trainingData, 2) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#' } +setMethod("predict", signature(object = "KMeansModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) #' Fit a Bernoulli naive Bayes model diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala new file mode 100644 index 0000000000..d3a0df4063 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.clustering.{KMeans, KMeansModel} +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.sql.DataFrame + +private[r] class KMeansWrapper private ( + pipeline: PipelineModel) { + + private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel] + + lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray) + + private lazy val attrs = AttributeGroup.fromStructField( + kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol)) + + lazy val features: Array[String] = attrs.attributes.get.map(_.name.get) + + lazy val k: Int = kMeansModel.getK + + lazy val size: Array[Int] = kMeansModel.summary.size + + lazy val cluster: DataFrame = kMeansModel.summary.cluster + + def fitted(method: String): DataFrame = { + if (method == "centers") { + kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol) + } else if (method == "classes") { + kMeansModel.summary.cluster + } else { + throw new UnsupportedOperationException( + s"Method (centers or classes) required but $method found.") + } + } + + def transform(dataset: DataFrame): DataFrame = { + pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol) + } + +} + +private[r] object KMeansWrapper { + + def fit( + data: DataFrame, + k: Double, + maxIter: Double, + initMode: String, + columns: Array[String]): KMeansWrapper = { + + val assembler = new VectorAssembler() + .setInputCols(columns) + .setOutputCol("features") + + val kMeans = new KMeans() + .setK(k.toInt) + .setMaxIter(maxIter.toInt) + .setInitMode(initMode) + + val pipeline = new Pipeline() + .setStages(Array(assembler, kMeans)) + .fit(data) + + new KMeansWrapper(pipeline) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index d23e4fc9d1..551e75dc0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -20,8 +20,7 @@ package org.apache.spark.ml.api.r import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.clustering.{KMeans, KMeansModel} -import org.apache.spark.ml.feature.{RFormula, VectorAssembler} +import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.sql.DataFrame @@ -52,22 +51,6 @@ private[r] object SparkRWrappers { pipeline.fit(df) } - def fitKMeans( - df: DataFrame, - initMode: String, - maxIter: Double, - k: Double, - columns: Array[String]): PipelineModel = { - val assembler = new VectorAssembler().setInputCols(columns) - val kMeans = new KMeans() - .setInitMode(initMode) - .setMaxIter(maxIter.toInt) - .setK(k.toInt) - .setFeaturesCol(assembler.getOutputCol) - val pipeline = new Pipeline().setStages(Array(assembler, kMeans)) - pipeline.fit(df) - } - def getModelCoefficients(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => { @@ -89,8 +72,6 @@ private[r] object SparkRWrappers { m.coefficients.toArray } } - case m: KMeansModel => - m.clusterCenters.flatMap(_.toArray) } } @@ -104,31 +85,6 @@ private[r] object SparkRWrappers { } } - def getKMeansModelSize(model: PipelineModel): Array[Int] = { - model.stages.last match { - case m: KMeansModel => Array(m.getK) ++ m.summary.size - case other => throw new UnsupportedOperationException( - s"KMeansModel required but ${other.getClass.getSimpleName} found.") - } - } - - def getKMeansCluster(model: PipelineModel, method: String): DataFrame = { - model.stages.last match { - case m: KMeansModel => - if (method == "centers") { - // Drop the assembled vector for easy-print to R side. - m.summary.predictions.drop(m.summary.featuresCol) - } else if (method == "classes") { - m.summary.cluster - } else { - throw new UnsupportedOperationException( - s"Method (centers or classes) required but $method found.") - } - case other => throw new UnsupportedOperationException( - s"KMeansModel required but ${other.getClass.getSimpleName} found.") - } - } - def getModelFeatures(model: PipelineModel): Array[String] = { model.stages.last match { case m: LinearRegressionModel => @@ -147,10 +103,6 @@ private[r] object SparkRWrappers { } else { attrs.attributes.get.map(_.name.get) } - case m: KMeansModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - attrs.attributes.get.map(_.name.get) } } @@ -160,8 +112,6 @@ private[r] object SparkRWrappers { "LinearRegressionModel" case m: LogisticRegressionModel => "LogisticRegressionModel" - case m: KMeansModel => - "KMeansModel" } } } -- GitLab