Skip to content
Snippets Groups Projects
Commit 935fe65f authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

SPARK-1215 [MLLIB]: Clustering: Index out of bounds error (2)

Added check to LocalKMeans.scala: kMeansPlusPlus initialization to handle case with fewer distinct data points than clusters k.  Added two related unit tests to KMeansSuite.  (Re-submitting PR after tangling commits in PR 1407 https://github.com/apache/spark/pull/1407 )

Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com>

Closes #1468 from jkbradley/kmeans-fix and squashes the following commits:

4e9bd1e [Joseph K. Bradley] Updated PR per comments from mengxr
6c7a2ec [Joseph K. Bradley] Added check to LocalKMeans.scala: kMeansPlusPlus initialization to handle case with fewer distinct data points than clusters k.  Added two related unit tests to KMeansSuite.
parent 1fcd5dcd
No related branches found
No related tags found
No related merge requests found
......@@ -59,7 +59,13 @@ private[mllib] object LocalKMeans extends Logging {
cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j))
j += 1
}
centers(i) = points(j-1).toDense
if (j == 0) {
logWarning("kMeansPlusPlus initialization ran out of distinct points for centers." +
s" Using duplicate point for center k = $i.")
centers(i) = points(0).toDense
} else {
centers(i) = points(j - 1).toDense
}
}
// Run up to maxIterations iterations of Lloyd's algorithm
......
......@@ -61,6 +61,32 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
assert(model.clusterCenters.head === center)
}
test("no distinct points") {
val data = sc.parallelize(
Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 2.0, 3.0)),
2)
val center = Vectors.dense(1.0, 2.0, 3.0)
// Make sure code runs.
var model = KMeans.train(data, k=2, maxIterations=1)
assert(model.clusterCenters.size === 2)
}
test("more clusters than points") {
val data = sc.parallelize(
Array(
Vectors.dense(1.0, 2.0, 3.0),
Vectors.dense(1.0, 3.0, 4.0)),
2)
// Make sure code runs.
var model = KMeans.train(data, k=3, maxIterations=1)
assert(model.clusterCenters.size === 3)
}
test("single cluster with big dataset") {
val smallData = Array(
Vectors.dense(1.0, 2.0, 6.0),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment