From 935fe65ff6559a0e3b481e7508fa14337b23020b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" <joseph.kurata.bradley@gmail.com> Date: Thu, 17 Jul 2014 15:05:02 -0700 Subject: [PATCH] SPARK-1215 [MLLIB]: Clustering: Index out of bounds error (2) Added check to LocalKMeans.scala: kMeansPlusPlus initialization to handle case with fewer distinct data points than clusters k. Added two related unit tests to KMeansSuite. (Re-submitting PR after tangling commits in PR 1407 https://github.com/apache/spark/pull/1407 ) Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1468 from jkbradley/kmeans-fix and squashes the following commits: 4e9bd1e [Joseph K. Bradley] Updated PR per comments from mengxr 6c7a2ec [Joseph K. Bradley] Added check to LocalKMeans.scala: kMeansPlusPlus initialization to handle case with fewer distinct data points than clusters k. Added two related unit tests to KMeansSuite. --- .../spark/mllib/clustering/LocalKMeans.scala | 8 +++++- .../spark/mllib/clustering/KMeansSuite.scala | 26 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index 2e3a4ce783..f0722d7c14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -59,7 +59,13 @@ private[mllib] object LocalKMeans extends Logging { cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j)) j += 1 } - centers(i) = points(j-1).toDense + if (j == 0) { + logWarning("kMeansPlusPlus initialization ran out of distinct points for centers." + + s" Using duplicate point for center k = $i.") + centers(i) = points(0).toDense + } else { + centers(i) = points(j - 1).toDense + } } // Run up to maxIterations iterations of Lloyd's algorithm diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 560a4ad71a..76a3bdf9b1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -61,6 +61,32 @@ class KMeansSuite extends FunSuite with LocalSparkContext { assert(model.clusterCenters.head === center) } + test("no distinct points") { + val data = sc.parallelize( + Array( + Vectors.dense(1.0, 2.0, 3.0), + Vectors.dense(1.0, 2.0, 3.0), + Vectors.dense(1.0, 2.0, 3.0)), + 2) + val center = Vectors.dense(1.0, 2.0, 3.0) + + // Make sure code runs. + var model = KMeans.train(data, k=2, maxIterations=1) + assert(model.clusterCenters.size === 2) + } + + test("more clusters than points") { + val data = sc.parallelize( + Array( + Vectors.dense(1.0, 2.0, 3.0), + Vectors.dense(1.0, 3.0, 4.0)), + 2) + + // Make sure code runs. + var model = KMeans.train(data, k=3, maxIterations=1) + assert(model.clusterCenters.size === 3) + } + test("single cluster with big dataset") { val smallData = Array( Vectors.dense(1.0, 2.0, 6.0), -- GitLab