diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index f97fc017fceeae4d3c859657b0138dcfe8ddf8de..d47952bec564b9a89be38b608ea431625dccc6ee 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -95,9 +95,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial var fraction = 0.0 var total = 0 var multiplier = 3.0 + var initialCount = count() - if (num > count()) { - total = Math.min(count().toInt, Integer.MAX_VALUE) + if (num > initialCount) { + total = Math.min(initialCount, Integer.MAX_VALUE) + total = total.toInt fraction = 1.0 } else if (num < 0) { @@ -109,12 +111,12 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial } var r = new SampledRDD(this, withReplacement, fraction, seed) + var samples = r.collect() - while (r.count() < total) { + while (samples.length < total) { r = new SampledRDD(this, withReplacement, fraction, seed) } - var samples = r.collect() var arr = new Array[T](total) for (i <- 0 to total - 1) { diff --git a/examples/src/main/scala/spark/examples/SparkLocalKMeans.scala b/examples/src/main/scala/spark/examples/SparkLocalKMeans.scala deleted file mode 100644 index 8d9527b7c14fdf559dbde698a151edb188fa023e..0000000000000000000000000000000000000000 --- a/examples/src/main/scala/spark/examples/SparkLocalKMeans.scala +++ /dev/null @@ -1,73 +0,0 @@ -package spark.examples - -import java.util.Random -import Vector._ -import spark.SparkContext -import spark.SparkContext._ -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -object SparkLocalKMeans { - val R = 1000 // Scaling factor - val rand = new Random(42) - - def parseVector(line: String): Vector = { - return new Vector(line.split(' ').map(_.toDouble)) - } - - def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { - var index = 0 - var bestIndex = 0 - var closest = Double.PositiveInfinity - - for (i <- 1 to centers.size) { - val vCurr = centers.get(i).get - val tempDist = p.squaredDist(vCurr) - if (tempDist < closest) { - closest = tempDist - bestIndex = i - } - } - - return bestIndex - } - - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: SparkLocalKMeans <master> <file> <k> <convergeDist>") - System.exit(1) - } - val sc = new SparkContext(args(0), "SparkLocalKMeans") - val lines = sc.textFile(args(1)) - val data = lines.map(parseVector _).cache() - val K = args(2).toInt - val convergeDist = args(3).toDouble - - var points = data.sample(false, (K+1)/data.count().toDouble, 42).collect - var kPoints = new HashMap[Int, Vector] - var tempDist = 1.0 - - for (i <- 1 to points.size) { - kPoints.put(i, points(i-1)) - } - - while(tempDist > convergeDist) { - var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - - var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1+y2)} - - var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)}.collect() - - tempDist = 0.0 - for (mapping <- newPoints) { - tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2) - } - - for (newP <- newPoints) { - kPoints.put(newP._1, newP._2) - } - } - - println("Final centers: " + kPoints) - } -} diff --git a/kmeans_data.txt b/kmeans_data.txt index 06e5c9b45afa5f598ad0536b8fef1a02befbc7bf..cd8a45eabe2ca0c747198fddb704eced670ee222 100644 Binary files a/kmeans_data.txt and b/kmeans_data.txt differ