From 22249afb4a932a82ff1f7a3befea9fda5a60a3f4 Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Thu, 31 Mar 2016 23:49:58 -0700
Subject: [PATCH] [SPARK-14303][ML][SPARKR] Define and use KMeansWrapper for
 SparkR::kmeans

## What changes were proposed in this pull request?
Define and use ```KMeansWrapper``` for ```SparkR::kmeans```. It's only the code refactor for the original ```KMeans``` wrapper.

## How was this patch tested?
Existing tests.

cc mengxr

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12039 from yanboliang/spark-14059.
---
 R/pkg/R/mllib.R                               | 91 +++++++++++++------
 .../org/apache/spark/ml/r/KMeansWrapper.scala | 85 +++++++++++++++++
 .../apache/spark/ml/r/SparkRWrappers.scala    | 52 +----------
 3 files changed, 148 insertions(+), 80 deletions(-)
 create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala

diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 33654d5216..f3152cc232 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -32,6 +32,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
 #' @export
 setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
 
+#' @title S4 class that represents a KMeansModel
+#' @param jobj a Java object reference to the backing Scala KMeansModel
+#' @export
+setClass("KMeansModel", representation(jobj = "jobj"))
+
 #' Fits a generalized linear model
 #'
 #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -154,17 +159,6 @@ setMethod("summary", signature(object = "PipelineModel"),
               colnames(coefficients) <- c("Estimate")
               rownames(coefficients) <- unlist(features)
               return(list(coefficients = coefficients))
-            } else if (modelName == "KMeansModel") {
-              modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
-                                       "getKMeansModelSize", object@model)
-              cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
-                                     "getKMeansCluster", object@model, "classes")
-              k <- unlist(modelSize)[1]
-              size <- unlist(modelSize)[-1]
-              coefficients <- t(matrix(coefficients, ncol = k))
-              colnames(coefficients) <- unlist(features)
-              rownames(coefficients) <- 1:k
-              return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
             } else {
               stop(paste("Unsupported model", modelName, sep = " "))
             }
@@ -213,21 +207,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
 #' @examples
 #' \dontrun{
 #' model <- kmeans(x, centers = 2, algorithm="random")
-#'}
+#' }
 setMethod("kmeans", signature(x = "DataFrame"),
           function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
             columnNames <- as.array(colnames(x))
             algorithm <- match.arg(algorithm)
-            model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
-                                 algorithm, iter.max, centers, columnNames)
-            return(new("PipelineModel", model = model))
+            jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf,
+                                centers, iter.max, algorithm, columnNames)
+            return(new("KMeansModel", jobj = jobj))
          })
 
-#' Get fitted result from a model
+#' Get fitted result from a k-means model
 #'
-#' Get fitted result from a model, similarly to R's fitted().
+#' Get fitted result from a k-means model, similarly to R's fitted().
 #'
-#' @param object A fitted MLlib model
+#' @param object A fitted k-means model
 #' @return DataFrame containing fitted values
 #' @rdname fitted
 #' @export
@@ -237,19 +231,58 @@ setMethod("kmeans", signature(x = "DataFrame"),
 #' fitted.model <- fitted(model)
 #' showDF(fitted.model)
 #'}
-setMethod("fitted", signature(object = "PipelineModel"),
+setMethod("fitted", signature(object = "KMeansModel"),
           function(object, method = c("centers", "classes"), ...) {
-            modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
-                                     "getModelName", object@model)
+            method <- match.arg(method)
+            return(dataFrame(callJMethod(object@jobj, "fitted", method)))
+          })
 
-            if (modelName == "KMeansModel") {
-              method <- match.arg(method)
-              fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
-                                          "getKMeansCluster", object@model, method)
-              return(dataFrame(fittedResult))
-            } else {
-              stop(paste("Unsupported model", modelName, sep = " "))
-            }
+#' Get the summary of a k-means model
+#'
+#' Returns the summary of a k-means model produced by kmeans(),
+#' similarly to R's summary().
+#'
+#' @param object a fitted k-means model
+#' @return the model's coefficients, size and cluster
+#' @rdname summary
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- kmeans(trainingData, 2)
+#' summary(model)
+#' }
+setMethod("summary", signature(object = "KMeansModel"),
+          function(object, ...) {
+            jobj <- object@jobj
+            features <- callJMethod(jobj, "features")
+            coefficients <- callJMethod(jobj, "coefficients")
+            cluster <- callJMethod(jobj, "cluster")
+            k <- callJMethod(jobj, "k")
+            size <- callJMethod(jobj, "size")
+            coefficients <- t(matrix(coefficients, ncol = k))
+            colnames(coefficients) <- unlist(features)
+            rownames(coefficients) <- 1:k
+            return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
+          })
+
+#' Make predictions from a k-means model
+#'
+#' Make predictions from a model produced by kmeans().
+#'
+#' @param object A fitted k-means model
+#' @param newData DataFrame for testing
+#' @return DataFrame containing predicted labels in a column named "prediction"
+#' @rdname predict
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- kmeans(trainingData, 2)
+#' predicted <- predict(model, testData)
+#' showDF(predicted)
+#' }
+setMethod("predict", signature(object = "KMeansModel"),
+          function(object, newData) {
+            return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
           })
 
 #' Fit a Bernoulli naive Bayes model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
new file mode 100644
index 0000000000..d3a0df4063
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
+import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.sql.DataFrame
+
+private[r] class KMeansWrapper private (
+    pipeline: PipelineModel) {
+
+  private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
+
+  lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)
+
+  private lazy val attrs = AttributeGroup.fromStructField(
+    kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
+
+  lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)
+
+  lazy val k: Int = kMeansModel.getK
+
+  lazy val size: Array[Int] = kMeansModel.summary.size
+
+  lazy val cluster: DataFrame = kMeansModel.summary.cluster
+
+  def fitted(method: String): DataFrame = {
+    if (method == "centers") {
+      kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol)
+    } else if (method == "classes") {
+      kMeansModel.summary.cluster
+    } else {
+      throw new UnsupportedOperationException(
+        s"Method (centers or classes) required but $method found.")
+    }
+  }
+
+  def transform(dataset: DataFrame): DataFrame = {
+    pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
+  }
+
+}
+
+private[r] object KMeansWrapper {
+
+  def fit(
+      data: DataFrame,
+      k: Double,
+      maxIter: Double,
+      initMode: String,
+      columns: Array[String]): KMeansWrapper = {
+
+    val assembler = new VectorAssembler()
+      .setInputCols(columns)
+      .setOutputCol("features")
+
+    val kMeans = new KMeans()
+      .setK(k.toInt)
+      .setMaxIter(maxIter.toInt)
+      .setInitMode(initMode)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(assembler, kMeans))
+      .fit(data)
+
+    new KMeansWrapper(pipeline)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index d23e4fc9d1..551e75dc0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -20,8 +20,7 @@ package org.apache.spark.ml.api.r
 import org.apache.spark.ml.{Pipeline, PipelineModel}
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
-import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
-import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
+import org.apache.spark.ml.feature.RFormula
 import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
 import org.apache.spark.sql.DataFrame
 
@@ -52,22 +51,6 @@ private[r] object SparkRWrappers {
     pipeline.fit(df)
   }
 
-  def fitKMeans(
-      df: DataFrame,
-      initMode: String,
-      maxIter: Double,
-      k: Double,
-      columns: Array[String]): PipelineModel = {
-    val assembler = new VectorAssembler().setInputCols(columns)
-    val kMeans = new KMeans()
-      .setInitMode(initMode)
-      .setMaxIter(maxIter.toInt)
-      .setK(k.toInt)
-      .setFeaturesCol(assembler.getOutputCol)
-    val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
-    pipeline.fit(df)
-  }
-
   def getModelCoefficients(model: PipelineModel): Array[Double] = {
     model.stages.last match {
       case m: LinearRegressionModel => {
@@ -89,8 +72,6 @@ private[r] object SparkRWrappers {
           m.coefficients.toArray
         }
       }
-      case m: KMeansModel =>
-        m.clusterCenters.flatMap(_.toArray)
     }
   }
 
@@ -104,31 +85,6 @@ private[r] object SparkRWrappers {
     }
   }
 
-  def getKMeansModelSize(model: PipelineModel): Array[Int] = {
-    model.stages.last match {
-      case m: KMeansModel => Array(m.getK) ++ m.summary.size
-      case other => throw new UnsupportedOperationException(
-        s"KMeansModel required but ${other.getClass.getSimpleName} found.")
-    }
-  }
-
-  def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
-    model.stages.last match {
-      case m: KMeansModel =>
-        if (method == "centers") {
-          // Drop the assembled vector for easy-print to R side.
-          m.summary.predictions.drop(m.summary.featuresCol)
-        } else if (method == "classes") {
-          m.summary.cluster
-        } else {
-          throw new UnsupportedOperationException(
-            s"Method (centers or classes) required but $method found.")
-        }
-      case other => throw new UnsupportedOperationException(
-        s"KMeansModel required but ${other.getClass.getSimpleName} found.")
-    }
-  }
-
   def getModelFeatures(model: PipelineModel): Array[String] = {
     model.stages.last match {
       case m: LinearRegressionModel =>
@@ -147,10 +103,6 @@ private[r] object SparkRWrappers {
         } else {
           attrs.attributes.get.map(_.name.get)
         }
-      case m: KMeansModel =>
-        val attrs = AttributeGroup.fromStructField(
-          m.summary.predictions.schema(m.summary.featuresCol))
-        attrs.attributes.get.map(_.name.get)
     }
   }
 
@@ -160,8 +112,6 @@ private[r] object SparkRWrappers {
         "LinearRegressionModel"
       case m: LogisticRegressionModel =>
         "LogisticRegressionModel"
-      case m: KMeansModel =>
-        "KMeansModel"
     }
   }
 }
-- 
GitLab