diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index d47952bec564b9a89be38b608ea431625dccc6ee..2d34244a5ede5783697f55ac4e198913285bd90a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -96,32 +96,34 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial var total = 0 var multiplier = 3.0 var initialCount = count() - + var maxSelected = 0 + + if (initialCount > Integer.MAX_VALUE) { + maxSelected = Integer.MAX_VALUE + } + else { + maxSelected = initialCount.toInt + } + if (num > initialCount) { - total = Math.min(initialCount, Integer.MAX_VALUE) - total = total.toInt - fraction = 1.0 + total = maxSelected + fraction = Math.min(multiplier*(maxSelected+1)/initialCount, 1.0) } else if (num < 0) { - throw(new IllegalArgumentException()) + throw(new IllegalArgumentException("Negative number of elements requested")) } else { - fraction = Math.min(multiplier*(num+1)/count(), 1.0) + fraction = Math.min(multiplier*(num+1)/initialCount, 1.0) total = num.toInt } - var r = new SampledRDD(this, withReplacement, fraction, seed) - var samples = r.collect() + var samples = this.sample(withReplacement, fraction, seed).collect() while (samples.length < total) { - r = new SampledRDD(this, withReplacement, fraction, seed) + samples = this.sample(withReplacement, fraction, seed).collect() } - var arr = new Array[T](total) - - for (i <- 0 to total - 1) { - arr(i) = samples(i) - } + val arr = samples.take(total) return arr } diff --git a/kmeans_data.txt b/kmeans_data.txt index cd8a45eabe2ca0c747198fddb704eced670ee222..338664f78de50564de5ee6eb2af2e368904628f4 100644 Binary files a/kmeans_data.txt and b/kmeans_data.txt differ