diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala new file mode 100644 index 0000000000000000000000000000000000000000..629f7074c17c50bc27a90ebafaa9d9c5f72ab7fd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.Random + +import scala.reflect.ClassTag + +import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.util.random.RandomSampler + +private[spark] +class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) + extends Partition with Serializable { + override val index: Int = prev.index +} + +/** + * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, + * a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain + * a random sample of the records in the partition. The random seeds assigned to the samplers + * are guaranteed to have different values. + * + * @param prev RDD to be sampled + * @param sampler a random sampler + * @param seed random seed, default to System.nanoTime + * @tparam T input RDD item type + * @tparam U sampled RDD item type + */ +class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( + prev: RDD[T], + sampler: RandomSampler[T, U], + seed: Long = System.nanoTime) + extends RDD[U](prev) { + + override def getPartitions: Array[Partition] = { + val random = new Random(seed) + firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong())) + } + + override def getPreferredLocations(split: Partition): Seq[String] = + firstParent[T].preferredLocations(split.asInstanceOf[PartitionwiseSampledRDDPartition].prev) + + override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = { + val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition] + val thisSampler = sampler.clone + thisSampler.setSeed(split.seed) + thisSampler.sample(firstParent[T].iterator(split.prev, context)) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1472c92b6031d8d9123d6fa0667e246572836acf..033d334079b59edf34763e33db146457a7d8baa8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -45,6 +45,7 @@ import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogL import org.apache.spark.SparkContext._ import org.apache.spark._ +import org.apache.spark.util.random.{PoissonSampler, BernoulliSampler} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -319,8 +320,29 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = - new SampledRDD(this, withReplacement, fraction, seed) + def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { + if (withReplacement) { + new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) + } else { + new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed) + } + } + + /** + * Randomly splits this RDD with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1 + * @param seed random seed, default to System.nanoTime + * + * @return split RDDs in an array + */ + def randomSplit(weights: Array[Double], seed: Long = System.nanoTime): Array[RDD[T]] = { + val sum = weights.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + normalizedCumWeights.sliding(2).map { x => + new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed) + }.toArray + } def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { var fraction = 0.0 diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala index d433670cc2b7fe9862ee9261c82b938e5469439d..08534b6f1db3e79133ce4ed2a35cc45643b474a2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -25,11 +25,13 @@ import cern.jet.random.engine.DRand import org.apache.spark.{Partition, TaskContext} +@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0") private[spark] class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { override val index: Int = prev.index } +@deprecated("Replaced by PartitionwiseSampledRDD", "1.0") class SampledRDD[T: ClassTag]( prev: RDD[T], withReplacement: Boolean, diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index fcdf8486371a40e64ac73110aa77a0edfc8013f9..83fa0bf1e583f579bf66632494c34d509446a292 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import scala.util.Random +import org.apache.spark.util.random.XORShiftRandom class Vector(val elements: Array[Double]) extends Serializable { def length = elements.length diff --git a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala new file mode 100644 index 0000000000000000000000000000000000000000..98569143ee1e31542e1472139d0ffc5ca00c0337 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +/** + * A class with pseudorandom behavior. + */ +trait Pseudorandom { + /** Set random seed. */ + def setSeed(seed: Long) +} diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala new file mode 100644 index 0000000000000000000000000000000000000000..6b66d54751987fe8920f0677c872c3a677bf735d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +import java.util.Random +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +/** + * A pseudorandom sampler. It is possible to change the sampled item type. For example, we might + * want to add weights for stratified sampling or importance sampling. Should only use + * transformations that are tied to the sampler and cannot be applied after sampling. + * + * @tparam T item type + * @tparam U sampled item type + */ +trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable { + + /** take a random sample */ + def sample(items: Iterator[T]): Iterator[U] + + override def clone: RandomSampler[T, U] = + throw new NotImplementedError("clone() is not implemented.") +} + +/** + * A sampler based on Bernoulli trials. + * + * @param lb lower bound of the acceptance range + * @param ub upper bound of the acceptance range + * @param complement whether to use the complement of the range specified, default to false + * @tparam T item type + */ +class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) + (implicit random: Random = new XORShiftRandom) + extends RandomSampler[T, T] { + + def this(ratio: Double)(implicit random: Random = new XORShiftRandom) + = this(0.0d, ratio)(random) + + override def setSeed(seed: Long) = random.setSeed(seed) + + override def sample(items: Iterator[T]): Iterator[T] = { + items.filter { item => + val x = random.nextDouble() + (x >= lb && x < ub) ^ complement + } + } + + override def clone = new BernoulliSampler[T](lb, ub) +} + +/** + * A sampler based on values drawn from Poisson distribution. + * + * @param poisson a Poisson random number generator + * @tparam T item type + */ +class PoissonSampler[T](mean: Double) + (implicit var poisson: Poisson = new Poisson(mean, new DRand)) + extends RandomSampler[T, T] { + + override def setSeed(seed: Long) { + poisson = new Poisson(mean, new DRand(seed.toInt)) + } + + override def sample(items: Iterator[T]): Iterator[T] = { + items.flatMap { item => + val count = poisson.nextInt() + if (count == 0) { + Iterator.empty + } else { + Iterator.fill(count)(item) + } + } + } + + override def clone = new PoissonSampler[T](mean) +} diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala rename to core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 08b31ac64f290561d5c5b21edb032e81315099f1..20d32d01b5e1918ed9473db4daaf551b08f8c59a 100644 --- a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.util.random import java.util.{Random => JavaRandom} import org.apache.spark.util.Utils.timeIt @@ -46,6 +46,10 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { seed = nextSeed (nextSeed & ((1L << bits) -1)).asInstanceOf[Int] } + + override def setSeed(s: Long) { + seed = s + } } /** Contains benchmark method and main method to run benchmark of the RNG */ diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..cfe96fb3f7b953d05fa4950506c73b4468d9fe80 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.scalatest.FunSuite +import org.apache.spark.SharedSparkContext +import org.apache.spark.util.random.RandomSampler + +/** a sampler that outputs its seed */ +class MockSampler extends RandomSampler[Long, Long] { + + private var s: Long = _ + + override def setSeed(seed: Long) { + s = seed + } + + override def sample(items: Iterator[Long]): Iterator[Long] = { + return Iterator(s) + } + + override def clone = new MockSampler +} + +class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext { + + test("seedDistribution") { + val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2) + val sampler = new MockSampler + val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L) + assert(sample.distinct.count == 2, "Seeds must be different.") + } +} + diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 559ea051d35335de1dfc53a8ce8a973e86b46d1e..cd01303bad0a08d9a19d8c972908f727e4f49ba1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -486,6 +486,21 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("randomSplit") { + val n = 600 + val data = sc.parallelize(1 to n, 2) + for(seed <- 1 to 5) { + val splits = data.randomSplit(Array(1.0, 2.0, 3.0), seed) + assert(splits.size == 3, "wrong number of splits") + assert(splits.flatMap(_.collect).sorted.toList == data.collect.toList, + "incomplete or wrong split") + val s = splits.map(_.count) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("runJob on an invalid partition") { intercept[IllegalArgumentException] { sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index 2f7bd370fc4ab1e4e601a044d71bfd90561beca0..e8361199421f117cc11168fa1321fa14ad58ca21 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -98,10 +98,10 @@ class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers assert(sorted.collect() === pairArr.sortBy(_._1)) val partitions = sorted.collectPartitions() logInfo("Partition lengths: " + partitions.map(_.length).mkString(", ")) - partitions(0).length should be > 180 - partitions(1).length should be > 180 - partitions(2).length should be > 180 - partitions(3).length should be > 180 + val lengthArr = partitions.map(_.length) + lengthArr.foreach { len => + assert(len > 100 && len < 400) + } partitions(0).last should be < partitions(1).head partitions(1).last should be < partitions(2).head partitions(2).last should be < partitions(3).head @@ -113,10 +113,10 @@ class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers assert(sorted.collect() === pairArr.sortBy(_._1).reverse) val partitions = sorted.collectPartitions() logInfo("partition lengths: " + partitions.map(_.length).mkString(", ")) - partitions(0).length should be > 180 - partitions(1).length should be > 180 - partitions(2).length should be > 180 - partitions(3).length should be > 180 + val lengthArr = partitions.map(_.length) + lengthArr.foreach { len => + assert(len > 100 && len < 400) + } partitions(0).last should be > partitions(1).head partitions(1).last should be > partitions(2).head partitions(2).last should be > partitions(3).head diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0f4792cd3bdb31fbe6a8d9ba14c26db29a435928 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.random + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.mock.EasyMockSugar + +import java.util.Random +import cern.jet.random.Poisson + +class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { + + val a = List(1, 2, 3, 4, 5, 6, 7, 8, 9) + + var random: Random = _ + var poisson: Poisson = _ + + before { + random = mock[Random] + poisson = mock[Poisson] + } + + test("BernoulliSamplerWithRange") { + expecting { + for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { + random.nextDouble().andReturn(x) + } + } + whenExecuting(random) + { + val sampler = new BernoulliSampler[Int](0.25, 0.55)(random) + assert(sampler.sample(a.iterator).toList == List(3, 4, 5)) + } + } + + test("BernoulliSamplerWithRatio") { + expecting { + for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { + random.nextDouble().andReturn(x) + } + } + whenExecuting(random) + { + val sampler = new BernoulliSampler[Int](0.35)(random) + assert(sampler.sample(a.iterator).toList == List(1, 2, 3)) + } + } + + test("BernoulliSamplerWithComplement") { + expecting { + for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) { + random.nextDouble().andReturn(x) + } + } + whenExecuting(random) + { + val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random) + assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9)) + } + } + + test("BernoulliSamplerSetSeed") { + expecting { + random.setSeed(10L) + } + whenExecuting(random) + { + val sampler = new BernoulliSampler[Int](0.2)(random) + sampler.setSeed(10L) + } + } + + test("PoissonSampler") { + expecting { + for(x <- Seq(0, 1, 2, 0, 1, 1, 0, 0, 0)) { + poisson.nextInt().andReturn(x) + } + } + whenExecuting(poisson) { + val sampler = new PoissonSampler[Int](0.2)(poisson) + assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6)) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala similarity index 97% rename from core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala rename to core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index f1d7b61b31e635ba816fdabbcda780363a03ae49..352aa94219c2f116455effa21d571d720d02b1e3 100644 --- a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -15,10 +15,8 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.util.random -import java.util.Random -import org.scalatest.FlatSpec import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers import org.apache.spark.util.Utils.times diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 0dee9399a86ea2bd5a9d9292f5f3d550fdce76c1..e508b76c3f8c5028948f5c1901c1238b4321098c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -26,8 +26,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.util.XORShiftRandom - +import org.apache.spark.util.random.XORShiftRandom /**