From acf827232458e87773a71a38f88cb7ba9a6ab77e Mon Sep 17 00:00:00 2001
From: root <root@ip-10-6-154-245.ec2.internal>
Date: Sun, 11 Nov 2012 07:05:22 +0000
Subject: [PATCH] Fix K-means example a little

---
 core/src/main/scala/spark/util/Vector.scala   |  3 ++-
 .../scala/spark/examples/SparkKMeans.scala    | 27 ++++++++-----------
 2 files changed, 13 insertions(+), 17 deletions(-)

diff --git a/core/src/main/scala/spark/util/Vector.scala b/core/src/main/scala/spark/util/Vector.scala
index 4e95ac2ac6..03559751bc 100644
--- a/core/src/main/scala/spark/util/Vector.scala
+++ b/core/src/main/scala/spark/util/Vector.scala
@@ -49,7 +49,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
     return ans
   }
 
-  def +=(other: Vector) {
+  def += (other: Vector): Vector = {
     if (length != other.length)
       throw new IllegalArgumentException("Vectors of different length")
     var ans = 0.0
@@ -58,6 +58,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
       elements(i) += other(i)
       i += 1
     }
+    this
   }
 
   def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
diff --git a/examples/src/main/scala/spark/examples/SparkKMeans.scala b/examples/src/main/scala/spark/examples/SparkKMeans.scala
index adce551322..6375961390 100644
--- a/examples/src/main/scala/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/spark/examples/SparkKMeans.scala
@@ -15,14 +15,13 @@ object SparkKMeans {
       return new Vector(line.split(' ').map(_.toDouble))
   }
   
-  def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = {
+  def closestPoint(p: Vector, centers: Array[Vector]): Int = {
     var index = 0
     var bestIndex = 0
     var closest = Double.PositiveInfinity
   
-    for (i <- 1 to centers.size) {
-      val vCurr = centers.get(i).get
-      val tempDist = p.squaredDist(vCurr)
+    for (i <- 0 until centers.length) {
+      val tempDist = p.squaredDist(centers(i))
       if (tempDist < closest) {
         closest = tempDist
         bestIndex = i
@@ -43,32 +42,28 @@ object SparkKMeans {
     val K = args(2).toInt
     val convergeDist = args(3).toDouble
   
-    var points = data.takeSample(false, K, 42)
-    var kPoints = new HashMap[Int, Vector]
+    var kPoints = data.takeSample(false, K, 42).toArray
     var tempDist = 1.0
-    
-    for (i <- 1 to points.size) {
-      kPoints.put(i, points(i-1))
-    }
 
     while(tempDist > convergeDist) {
       var closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
       
-      var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)}
+      var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)}
       
-      var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collect()
+      var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap()
       
       tempDist = 0.0
-      for (pair <- newPoints) {
-        tempDist += kPoints.get(pair._1).get.squaredDist(pair._2)
+      for (i <- 0 until K) {
+        tempDist += kPoints(i).squaredDist(newPoints(i))
       }
       
       for (newP <- newPoints) {
-        kPoints.put(newP._1, newP._2)
+        kPoints(newP._1) = newP._2
       }
     }
 
-    println("Final centers: " + kPoints)
+    println("Final centers:")
+    kPoints.foreach(println)
     System.exit(0)
   }
 }
-- 
GitLab