diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff40618cc8469f693e092439f88e7913e8b2..4ffec433a823460866a9200b81dac93821e8595d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,17 +1,17 @@ package spark import java.io.EOFException -import java.net.URL import java.io.ObjectInputStream -import java.util.concurrent.atomic.AtomicLong +import java.net.URL import java.util.Random import java.util.Date import java.util.{HashMap => JHashMap} +import java.util.concurrent.atomic.AtomicLong -import scala.collection.mutable.ArrayBuffer import scala.collection.Map -import scala.collection.mutable.HashMap import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable @@ -47,7 +47,7 @@ import spark.storage.StorageLevel import SparkContext._ /** - * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, + * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, * partitioned collection of elements that can be operated on in parallel. This class contains the * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, * [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such @@ -86,28 +86,28 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial @transient val dependencies: List[Dependency[_]] // Methods available on all RDDs: - + /** Record user function generating this RDD. */ private[spark] val origin = Utils.getSparkCallSite - + /** Optionally overridden by subclasses to specify how they are partitioned. */ val partitioner: Option[Partitioner] = None /** Optionally overridden by subclasses to specify placement preferences. */ def preferredLocations(split: Split): Seq[String] = Nil - + /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - + /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() - + // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - - /** + + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -123,32 +123,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) - + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): RDD[T] = persist() /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - + private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") - } - + } + // This is a hack. Ideally this should re-use the code used by the CacheTracker // to generate the key. def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index) - + persist(level) sc.runJob(this, (iter: Iterator[T]) => {} ) - + val p = this.partitioner - + new BlockRDD[T](sc, splits.map(getSplitKey).toArray) { - override val partitioner = p + override val partitioner = p } } - + /** * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. * This should ''not'' be called by users directly, but is available for implementors of custom @@ -161,9 +161,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial compute(split) } } - + // Transformations (return a new RDD) - + /** * Return a new RDD by applying a function to all elements of this RDD. */ @@ -199,13 +199,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial var multiplier = 3.0 var initialCount = count() var maxSelected = 0 - + if (initialCount > Integer.MAX_VALUE - 1) { maxSelected = Integer.MAX_VALUE - 1 } else { maxSelected = initialCount.toInt } - + if (num > initialCount) { total = maxSelected fraction = math.min(multiplier * (maxSelected + 1) / initialCount, 1.0) @@ -215,14 +215,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial fraction = math.min(multiplier * (num + 1) / initialCount, 1.0) total = num } - + val rand = new Random(seed) var samples = this.sample(withReplacement, fraction, rand.nextInt).collect() - + while (samples.length < total) { samples = this.sample(withReplacement, fraction, rand.nextInt).collect() } - + Utils.randomizeInPlace(samples, rand).take(total) } @@ -290,8 +290,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ - def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = - new MapPartitionsWithSplitRDD(this, sc.clean(f)) + def mapPartitionsWithSplit[U: ClassManifest]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning) // Actions (launch a job to return a value to the user program) @@ -342,7 +344,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to + * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to * modify t1 and return it as its result value to avoid object allocation; however, it should not * modify t2. */ @@ -443,7 +445,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial val evaluator = new GroupedCountEvaluator[T](splits.size, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } - + /** * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so * it will be slow if a lot of partitions are required. In that case, use collect() to get the diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index adc541694e73ba1ef887d61439fd6dc35b2ef350..14e390c43b9cf3922b3678bcc67ab645cb9d3798 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -12,9 +12,11 @@ import spark.Split private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], - f: (Int, Iterator[T]) => Iterator[U]) + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean) extends RDD[U](prev.context) { + override val partitioner = if (preservesPartitioning) prev.partitioner else None override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) override def compute(split: Split) = f(split.index, prev.iterator(split))