Skip to content
Snippets Groups Projects
Commit c0c902ae authored by Marco Gaido's avatar Marco Gaido Committed by Sean Owen
Browse files

[SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance

## What changes were proposed in this pull request?

In #19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters.

The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance.

## How was this patch tested?

existing/improved UTs

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #20518 from mgaido91/SPARK-22119_followup.
parent 4bbd7443
No related branches found
No related tags found
No related merge requests found
......@@ -310,8 +310,7 @@ class KMeans private (
points.foreach { point =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
costAccum.add(cost)
val sum = sums(bestCenter)
axpy(1.0, point.vector, sum)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
counts(bestCenter) += 1
}
......@@ -319,10 +318,9 @@ class KMeans private (
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
}.mapValues { case (sum, count) =>
scal(1.0 / count, sum)
new VectorWithNorm(sum)
}.collectAsMap()
}.collectAsMap().mapValues { case (sum, count) =>
distanceMeasureInstance.centroid(sum, count)
}
bcCenters.destroy(blocking = false)
......@@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable {
v1: VectorWithNorm,
v2: VectorWithNorm): Double
/**
* Updates the value of `sum` adding the `point` vector.
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(1.0, point.vector, sum)
}
/**
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
*
* @param sum the `sum` for a cluster
* @param count the number of points in the cluster
* @return the centroid of the cluster
*/
def centroid(sum: Vector, count: Long): VectorWithNorm = {
scal(1.0 / count, sum)
new VectorWithNorm(sum)
}
}
@Since("2.4.0")
......@@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
* @return the cosine distance between the two input vectors
*/
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
}
/**
* Updates the value of `sum` adding the `point` vector.
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(1.0 / point.norm, point.vector, sum)
}
/**
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
*
* @param sum the `sum` for a cluster
* @param count the number of points in the cluster
* @return the centroid of the cluster
*/
override def centroid(sum: Vector, count: Long): VectorWithNorm = {
scal(1.0 / count, sum)
val norm = Vectors.norm(sum, 2)
scal(1.0 / norm, sum)
new VectorWithNorm(sum, 1)
}
}
......@@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering
import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
......@@ -179,6 +179,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
predictionsMap(Vectors.dense(-100.0, 90.0)))
model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
}
test("KMeans with cosine distance is not supported for 0-length vectors") {
val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2)
val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
Vectors.dense(0.0, 0.0),
Vectors.dense(10.0, 10.0),
Vectors.dense(1.0, 0.5)
)).map(v => TestRow(v)))
val e = intercept[SparkException](model.fit(df))
assert(e.getCause.isInstanceOf[AssertionError])
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
}
test("read/write") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment